diff --git a/code/datasets/simplicial_ds.py b/code/datasets/simplicial_ds.py index 80ce31e..d701489 100644 --- a/code/datasets/simplicial_ds.py +++ b/code/datasets/simplicial_ds.py @@ -49,17 +49,14 @@ def __init__( pre_transform=None, pre_filter=None, ): - + self.manifold = manifold self.task_type = task_type self.split = mode self.split_config = SplitConfig(split, seed, use_stratified) - is_full = pre_filter is None self.raw_simplicial_ds = SimplicialDataset( os.path.join( root, - "raw_simplicial", - f"manifold_{manifold}", - f"is_full_{is_full}", + "raw_simplicial" ), manifold, version, @@ -108,12 +105,18 @@ def process(self): indices = range(self.raw_simplicial_ds.len()) for task_type in TaskType: + + # no name classification on 3 manifolds + if self.manifold == "3" and task_type == TaskType.NAME: + continue + # apply class transform class_transform = Compose(class_transforms_lookup[task_type]) data_list_processed = [ class_transform(self.raw_simplicial_ds.get(idx)) for idx in indices ] + # train test split stratified = torch.vstack([data.y for data in data_list_processed]) train_val_indices, test_indices = train_test_split( diff --git a/code/datasets/transforms.py b/code/datasets/transforms.py index 9ccfd94..3ab726e 100644 --- a/code/datasets/transforms.py +++ b/code/datasets/transforms.py @@ -20,6 +20,7 @@ def __call__(self, data): class OrientableToClassTransform: def __call__(self, data): + print(data) data.y = data.orientable.long() return data