From 76b8afa59cf33a9b577556874221e840b343063a Mon Sep 17 00:00:00 2001 From: Daniel Bin Schmid Date: Mon, 24 Jun 2024 21:13:42 +0200 Subject: [PATCH] fixed NamedMetrics is not a Module bug --- metrics/metrics.py | 53 ++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/metrics/metrics.py b/metrics/metrics.py index 5a37d08..1964e91 100644 --- a/metrics/metrics.py +++ b/metrics/metrics.py @@ -17,9 +17,9 @@ def __init__(self, metric: Metric, name: str) -> None: class BettiNumbersMetricCollection: - betti_0: ModuleList - betti_1: ModuleList - betti_2: ModuleList + betti_0: List[NamedMetric] + betti_1: List[NamedMetric] + betti_2: List[NamedMetric] def __init__( self, @@ -127,33 +127,30 @@ def get_name_metrics(num_classes=5): def get_betti_numbers_metrics(): - betti_0_metrics = ModuleList([GeneralAccuracy()]) - betti_1_metrics = ModuleList( - [ - NamedMetric(GeneralAccuracy(), "Accuracy"), - NamedMetric( - torchmetrics.classification.MulticlassAccuracy( - num_classes=7, - average="macro", - ), - "BalancedAccuracy", + betti_0_metrics = [NamedMetric(GeneralAccuracy(), "Accuracy")] + betti_1_metrics = [ + NamedMetric(GeneralAccuracy(), "Accuracy"), + NamedMetric( + torchmetrics.classification.MulticlassAccuracy( + num_classes=7, + average="macro", ), - ] - ) - betti_2_metrics = ModuleList( - [ - NamedMetric(GeneralAccuracy(), "Accuracy"), - NamedMetric(MatthewsCorrCoeff(), "MCC"), - NamedMetric(torchmetrics.classification.BinaryF1Score(), "F1"), - NamedMetric( - torchmetrics.classification.MulticlassAccuracy( - num_classes=2, - average="macro", - ), - "BalancedAccuracy", + "BalancedAccuracy", + ), + ] + + betti_2_metrics = [ + NamedMetric(GeneralAccuracy(), "Accuracy"), + NamedMetric(MatthewsCorrCoeff(), "MCC"), + NamedMetric(torchmetrics.classification.BinaryF1Score(), "F1"), + NamedMetric( + torchmetrics.classification.MulticlassAccuracy( + num_classes=2, + average="macro", ), - ] - ) + "BalancedAccuracy", + ), + ] collection = BettiNumbersMetricCollection( betti_0=betti_0_metrics,