From ae3ef4e8e8e3b7ae45407ba9cea86e236c435b86 Mon Sep 17 00:00:00 2001 From: binbash Date: Sat, 31 Aug 2024 10:11:40 +0200 Subject: [PATCH] dataset type integration contd. --- code/datasets/transforms.py | 1 - code/experiments/utils/run_experiment.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/code/datasets/transforms.py b/code/datasets/transforms.py index bce2ce1..ddf03ee 100644 --- a/code/datasets/transforms.py +++ b/code/datasets/transforms.py @@ -48,7 +48,6 @@ def __call__(self, data): data.face = torch.tensor(data.triangulation).T - 1 data.num_nodes = data.face.max() + 1 data.triangulation = None - print(data, "m") return data diff --git a/code/experiments/utils/run_experiment.py b/code/experiments/utils/run_experiment.py index 6177f93..19631d4 100644 --- a/code/experiments/utils/run_experiment.py +++ b/code/experiments/utils/run_experiment.py @@ -31,7 +31,9 @@ def get_setup( ) -> Tuple[SimplicialDataModule, BaseModel, L.Trainer, WandbLogger]: run_id = str(uuid.uuid4()) transforms = transforms_lookup[config.transforms] - task_lookup: Dict[TaskType, Task] = get_task_lookup(transforms) + task_lookup: Dict[TaskType, Task] = get_task_lookup( + transforms, ds_type=config.ds_type + ) dm = SimplicialDataModule( ds_type=config.ds_type,