Skip to content

Commit

Permalink
fixed NamedMetrics is not a Module bug
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Jun 24, 2024
1 parent f68de12 commit 76b8afa
Showing 1 changed file with 25 additions and 28 deletions.
53 changes: 25 additions & 28 deletions metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(self, metric: Metric, name: str) -> None:


class BettiNumbersMetricCollection:
betti_0: ModuleList
betti_1: ModuleList
betti_2: ModuleList
betti_0: List[NamedMetric]
betti_1: List[NamedMetric]
betti_2: List[NamedMetric]

def __init__(
self,
Expand Down Expand Up @@ -127,33 +127,30 @@ def get_name_metrics(num_classes=5):

def get_betti_numbers_metrics():

betti_0_metrics = ModuleList([GeneralAccuracy()])
betti_1_metrics = ModuleList(
[
NamedMetric(GeneralAccuracy(), "Accuracy"),
NamedMetric(
torchmetrics.classification.MulticlassAccuracy(
num_classes=7,
average="macro",
),
"BalancedAccuracy",
betti_0_metrics = [NamedMetric(GeneralAccuracy(), "Accuracy")]
betti_1_metrics = [
NamedMetric(GeneralAccuracy(), "Accuracy"),
NamedMetric(
torchmetrics.classification.MulticlassAccuracy(
num_classes=7,
average="macro",
),
]
)
betti_2_metrics = ModuleList(
[
NamedMetric(GeneralAccuracy(), "Accuracy"),
NamedMetric(MatthewsCorrCoeff(), "MCC"),
NamedMetric(torchmetrics.classification.BinaryF1Score(), "F1"),
NamedMetric(
torchmetrics.classification.MulticlassAccuracy(
num_classes=2,
average="macro",
),
"BalancedAccuracy",
"BalancedAccuracy",
),
]

betti_2_metrics = [
NamedMetric(GeneralAccuracy(), "Accuracy"),
NamedMetric(MatthewsCorrCoeff(), "MCC"),
NamedMetric(torchmetrics.classification.BinaryF1Score(), "F1"),
NamedMetric(
torchmetrics.classification.MulticlassAccuracy(
num_classes=2,
average="macro",
),
]
)
"BalancedAccuracy",
),
]

collection = BettiNumbersMetricCollection(
betti_0=betti_0_metrics,
Expand Down

0 comments on commit 76b8afa

Please sign in to comment.