Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Aug 31, 2024
1 parent 1836c48 commit d542406
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
9 changes: 5 additions & 4 deletions code/datasets/simplicial_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def processed_file_names(self):
else:
f_names = [
self._data_filename(task_type, mode)
for task_type in [TaskType.BETTI_NUMBERS, TaskType.ORIENTABILITY]
for task_type in [
TaskType.BETTI_NUMBERS,
TaskType.ORIENTABILITY,
]
for mode in ["train", "test", "val"]
]
return f_names
Expand All @@ -115,9 +118,7 @@ def process(self):
for task_type in TaskType:

# no name classification on 3 manifolds
if self.manifold == "3" and (
task_type == TaskType.NAME
):
if self.manifold == "3" and (task_type == TaskType.NAME):
continue

# apply class transform
Expand Down
21 changes: 15 additions & 6 deletions code/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from enum import Enum
from datasets.dataset_types import DatasetType
from typing import List, Callable

NAME_TO_CLASS = {"Klein bottle": 0, "RP^2": 1, "T^2": 2, "S^2": 3, "": 4}


Expand All @@ -21,7 +22,7 @@ def __call__(self, data):

class OrientableToClassTransform:
def __call__(self, data):
data.orientable = torch.tensor(data.betti_numbers)[...,-1]
data.orientable = torch.tensor(data.betti_numbers)[..., -1]
data.y = data.orientable.long()
return data

Expand Down Expand Up @@ -52,6 +53,7 @@ def __call__(self, data):
data.triangulation = None
return data


class TriangulationToFaceTransform:
"""
Transforms tetrahedra to faces.
Expand All @@ -71,7 +73,9 @@ def __call__(self, data):
idx = [[0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3]]
if hasattr(data, "triangulation"):
assert data.triangulation is not None
face = torch.cat([torch.tensor(data.triangulation)[i] for i in idx], dim=1)
face = torch.cat(
[torch.tensor(data.triangulation)[i] for i in idx], dim=1
)

# Remove duplicate triangles in
data.face = torch.unique(face, dim=1)
Expand All @@ -81,6 +85,7 @@ def __call__(self, data):

return data


class SimplicialComplexTransform:
def __call__(self, data):
data.sc = SimplicialComplex(data.triangulation)
Expand Down Expand Up @@ -242,8 +247,9 @@ class TransformType(Enum):
random_simplices_features = "random_simplices_features"



def transforms_lookup(tr_type: TransformType, ds_type: DatasetType) -> List[Callable]:
def transforms_lookup(
tr_type: TransformType, ds_type: DatasetType
) -> List[Callable]:
_transforms_lookup = {
TransformType.degree_transform: degree_transform,
TransformType.degree_transform_onehot: degree_transform_onehot,
Expand All @@ -253,7 +259,10 @@ def transforms_lookup(tr_type: TransformType, ds_type: DatasetType) -> List[Call
}

tr = _transforms_lookup[tr_type]
if tr_type != TransformType.degree_transform_sc and tr_type != TransformType.random_simplices_features:
if (
tr_type != TransformType.degree_transform_sc
and tr_type != TransformType.random_simplices_features
):
tr[0] = TriangulationToFaceTransform()

return tr
16 changes: 7 additions & 9 deletions code/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
betti_0: List[NamedMetric],
betti_1: List[NamedMetric],
betti_2: List[NamedMetric],
betti_3: Optional[List[NamedMetric]] = None
betti_3: Optional[List[NamedMetric]] = None,
) -> None:
self.betti_0 = betti_0
self.betti_1 = betti_1
Expand All @@ -53,13 +53,9 @@ def __init__(
def as_list(self):
if self.betti_3 is None:
return [self.betti_0, self.betti_1, self.betti_2]
else:
return [
self.betti_0,
self.betti_1,
self.betti_2,
self.betti_3
]
else:
return [self.betti_0, self.betti_1, self.betti_2, self.betti_3]


class MetricTrainValTest:
"""
Expand Down Expand Up @@ -154,7 +150,9 @@ def get_betti_numbers_metrics(ds_type: DatasetType):
betti_2_metrics = [
NamedMetric(GeneralAccuracy(), "Accuracy"),
NamedMetric(MatthewsCorrCoeff(), "MCC"),
NamedMetric(torchmetrics.classification.BinaryAUROC(), "BinaryAUROC"),
NamedMetric(
torchmetrics.classification.BinaryAUROC(), "BinaryAUROC"
),
NamedMetric(
BettiNumbersMultiClassAccuracy(num_classes=2),
"BalancedAccuracy",
Expand Down
2 changes: 1 addition & 1 deletion code/metrics/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ def get_task_lookup(

class_transforms_lookup_3manifold: Dict[TaskType, List[Callable]] = {
TaskType.BETTI_NUMBERS: betti_numbers_transforms_3manifold,
TaskType.ORIENTABILITY: orientability_transforms
TaskType.ORIENTABILITY: orientability_transforms,
}

0 comments on commit d542406

Please sign in to comment.