diff --git a/code/datasets/simplicial_ds.py b/code/datasets/simplicial_ds.py index 78f5989..04de232 100644 --- a/code/datasets/simplicial_ds.py +++ b/code/datasets/simplicial_ds.py @@ -8,7 +8,7 @@ from sklearn.model_selection import train_test_split from typing import Dict, List, Literal, Tuple, Optional from torch_geometric.transforms import Compose -from mantra.simplicial import SimplicialDataset +from mantra.datasets import ManifoldTriangulations from torch_geometric.data import InMemoryDataset import os @@ -56,7 +56,7 @@ def __init__( self.task_type = task_type self.split = mode self.split_config = SplitConfig(split, seed, use_stratified) - self.raw_simplicial_ds = SimplicialDataset( + self.raw_simplicial_ds = ManifoldTriangulations( os.path.join(root, "raw_simplicial"), manifold, version, diff --git a/code/run.py b/code/run.py index 741a7a9..b06f044 100644 --- a/code/run.py +++ b/code/run.py @@ -30,7 +30,7 @@ # model model_config = CellMPConfig( - num_input_features=1, + num_input_features=8, num_classes=3, ) transform_type = TransformType.degree_transform_sc @@ -49,7 +49,7 @@ config = ConfigExperimentRun( seed=10, ds_type=DatasetType.FULL_2D, - transforms=TransformType.degree_transform_sc, + transforms=TransformType.random_simplices_features, use_stratified=True, task_type=TaskType.BETTI_NUMBERS, trainer_config=trainer_config, @@ -57,7 +57,7 @@ ) # data and logging -data_dir = "/data" +data_dir = "./data" use_logger = False devices = [0] run_id = str(uuid.uuid4()) diff --git a/pyproject.toml b/pyproject.toml index d3ebd7e..d221a3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ omegaconf = "2.3.0" pydantic-settings = "2.2.1" tabulate = "0.9.0" ogb = "1.3.6" +mantra-dataset = "0.0.6" [tool.poetry.dev-dependencies] black = {version = "^21.12b0", allow-prereleases = true}