Skip to content

Commit

Permalink
Adding workable example and mantra dataset to the pyproject dependenc…
Browse files Browse the repository at this point in the history
…ies.
  • Loading branch information
rballeba committed Nov 17, 2024
1 parent f5e9c79 commit 9f042b1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions code/datasets/simplicial_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.model_selection import train_test_split
from typing import Dict, List, Literal, Tuple, Optional
from torch_geometric.transforms import Compose
from mantra.simplicial import SimplicialDataset
from mantra.datasets import ManifoldTriangulations
from torch_geometric.data import InMemoryDataset
import os

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
self.task_type = task_type
self.split = mode
self.split_config = SplitConfig(split, seed, use_stratified)
self.raw_simplicial_ds = SimplicialDataset(
self.raw_simplicial_ds = ManifoldTriangulations(
os.path.join(root, "raw_simplicial"),
manifold,
version,
Expand Down
6 changes: 3 additions & 3 deletions code/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

# model
model_config = CellMPConfig(
num_input_features=1,
num_input_features=8,
num_classes=3,
)
transform_type = TransformType.degree_transform_sc
Expand All @@ -49,15 +49,15 @@
config = ConfigExperimentRun(
seed=10,
ds_type=DatasetType.FULL_2D,
transforms=TransformType.degree_transform_sc,
transforms=TransformType.random_simplices_features,
use_stratified=True,
task_type=TaskType.BETTI_NUMBERS,
trainer_config=trainer_config,
conf_model=model_config,
)

# data and logging
data_dir = "/data"
data_dir = "./data"
use_logger = False
devices = [0]
run_id = str(uuid.uuid4())
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ omegaconf = "2.3.0"
pydantic-settings = "2.2.1"
tabulate = "0.9.0"
ogb = "1.3.6"
mantra-dataset = "0.0.6"

[tool.poetry.dev-dependencies]
black = {version = "^21.12b0", allow-prereleases = true}
Expand Down

0 comments on commit 9f042b1

Please sign in to comment.