Skip to content

Commit

Permalink
removed redundant folder bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Aug 29, 2024
1 parent 53fb522 commit aefe40b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
13 changes: 8 additions & 5 deletions code/datasets/simplicial_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,14 @@ def __init__(
pre_transform=None,
pre_filter=None,
):

self.manifold = manifold
self.task_type = task_type
self.split = mode
self.split_config = SplitConfig(split, seed, use_stratified)
is_full = pre_filter is None
self.raw_simplicial_ds = SimplicialDataset(
os.path.join(
root,
"raw_simplicial",
f"manifold_{manifold}",
f"is_full_{is_full}",
"raw_simplicial"
),
manifold,
version,
Expand Down Expand Up @@ -108,12 +105,18 @@ def process(self):
indices = range(self.raw_simplicial_ds.len())

for task_type in TaskType:

# no name classification on 3 manifolds
if self.manifold == "3" and task_type == TaskType.NAME:
continue

# apply class transform
class_transform = Compose(class_transforms_lookup[task_type])
data_list_processed = [
class_transform(self.raw_simplicial_ds.get(idx))
for idx in indices
]

# train test split
stratified = torch.vstack([data.y for data in data_list_processed])
train_val_indices, test_indices = train_test_split(
Expand Down
1 change: 1 addition & 0 deletions code/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __call__(self, data):

class OrientableToClassTransform:
def __call__(self, data):
print(data)
data.y = data.orientable.long()
return data

Expand Down

0 comments on commit aefe40b

Please sign in to comment.