From d54240613775f5edd7509cb89fda0c403d901401 Mon Sep 17 00:00:00 2001 From: binbash Date: Sat, 31 Aug 2024 12:22:30 +0200 Subject: [PATCH] formatting --- code/datasets/simplicial_ds.py | 9 +++++---- code/datasets/transforms.py | 21 +++++++++++++++------ code/metrics/metrics.py | 16 +++++++--------- code/metrics/tasks.py | 2 +- 4 files changed, 28 insertions(+), 20 deletions(-) diff --git a/code/datasets/simplicial_ds.py b/code/datasets/simplicial_ds.py index d2d57f6..4ab9a0e 100644 --- a/code/datasets/simplicial_ds.py +++ b/code/datasets/simplicial_ds.py @@ -103,7 +103,10 @@ def processed_file_names(self): else: f_names = [ self._data_filename(task_type, mode) - for task_type in [TaskType.BETTI_NUMBERS, TaskType.ORIENTABILITY] + for task_type in [ + TaskType.BETTI_NUMBERS, + TaskType.ORIENTABILITY, + ] for mode in ["train", "test", "val"] ] return f_names @@ -115,9 +118,7 @@ def process(self): for task_type in TaskType: # no name classification on 3 manifolds - if self.manifold == "3" and ( - task_type == TaskType.NAME - ): + if self.manifold == "3" and (task_type == TaskType.NAME): continue # apply class transform diff --git a/code/datasets/transforms.py b/code/datasets/transforms.py index 3f074e3..bdebc84 100644 --- a/code/datasets/transforms.py +++ b/code/datasets/transforms.py @@ -10,6 +10,7 @@ from enum import Enum from datasets.dataset_types import DatasetType from typing import List, Callable + NAME_TO_CLASS = {"Klein bottle": 0, "RP^2": 1, "T^2": 2, "S^2": 3, "": 4} @@ -21,7 +22,7 @@ def __call__(self, data): class OrientableToClassTransform: def __call__(self, data): - data.orientable = torch.tensor(data.betti_numbers)[...,-1] + data.orientable = torch.tensor(data.betti_numbers)[..., -1] data.y = data.orientable.long() return data @@ -52,6 +53,7 @@ def __call__(self, data): data.triangulation = None return data + class TriangulationToFaceTransform: """ Transforms tetrahedra to faces. @@ -71,7 +73,9 @@ def __call__(self, data): idx = [[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]] if hasattr(data, "triangulation"): assert data.triangulation is not None - face = torch.cat([torch.tensor(data.triangulation)[i] for i in idx], dim=1) + face = torch.cat( + [torch.tensor(data.triangulation)[i] for i in idx], dim=1 + ) # Remove duplicate triangles in data.face = torch.unique(face, dim=1) @@ -81,6 +85,7 @@ def __call__(self, data): return data + class SimplicialComplexTransform: def __call__(self, data): data.sc = SimplicialComplex(data.triangulation) @@ -242,8 +247,9 @@ class TransformType(Enum): random_simplices_features = "random_simplices_features" - -def transforms_lookup(tr_type: TransformType, ds_type: DatasetType) -> List[Callable]: +def transforms_lookup( + tr_type: TransformType, ds_type: DatasetType +) -> List[Callable]: _transforms_lookup = { TransformType.degree_transform: degree_transform, TransformType.degree_transform_onehot: degree_transform_onehot, @@ -253,7 +259,10 @@ def transforms_lookup(tr_type: TransformType, ds_type: DatasetType) -> List[Call } tr = _transforms_lookup[tr_type] - if tr_type != TransformType.degree_transform_sc and tr_type != TransformType.random_simplices_features: + if ( + tr_type != TransformType.degree_transform_sc + and tr_type != TransformType.random_simplices_features + ): tr[0] = TriangulationToFaceTransform() - + return tr diff --git a/code/metrics/metrics.py b/code/metrics/metrics.py index 8c1675c..ced7aec 100644 --- a/code/metrics/metrics.py +++ b/code/metrics/metrics.py @@ -43,7 +43,7 @@ def __init__( betti_0: List[NamedMetric], betti_1: List[NamedMetric], betti_2: List[NamedMetric], - betti_3: Optional[List[NamedMetric]] = None + betti_3: Optional[List[NamedMetric]] = None, ) -> None: self.betti_0 = betti_0 self.betti_1 = betti_1 @@ -53,13 +53,9 @@ def __init__( def as_list(self): if self.betti_3 is None: return [self.betti_0, self.betti_1, self.betti_2] - else: - return [ - self.betti_0, - self.betti_1, - self.betti_2, - self.betti_3 - ] + else: + return [self.betti_0, self.betti_1, self.betti_2, self.betti_3] + class MetricTrainValTest: """ @@ -154,7 +150,9 @@ def get_betti_numbers_metrics(ds_type: DatasetType): betti_2_metrics = [ NamedMetric(GeneralAccuracy(), "Accuracy"), NamedMetric(MatthewsCorrCoeff(), "MCC"), - NamedMetric(torchmetrics.classification.BinaryAUROC(), "BinaryAUROC"), + NamedMetric( + torchmetrics.classification.BinaryAUROC(), "BinaryAUROC" + ), NamedMetric( BettiNumbersMultiClassAccuracy(num_classes=2), "BalancedAccuracy", diff --git a/code/metrics/tasks.py b/code/metrics/tasks.py index 2ace393..130870c 100644 --- a/code/metrics/tasks.py +++ b/code/metrics/tasks.py @@ -106,5 +106,5 @@ def get_task_lookup( class_transforms_lookup_3manifold: Dict[TaskType, List[Callable]] = { TaskType.BETTI_NUMBERS: betti_numbers_transforms_3manifold, - TaskType.ORIENTABILITY: orientability_transforms + TaskType.ORIENTABILITY: orientability_transforms, }