Skip to content

Commit

Permalink
added MonaiDiceScore and switched class-wise metrics to monai
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Nov 29, 2024
1 parent fc0c09d commit 0766787
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 11 deletions.
39 changes: 30 additions & 9 deletions src/eva/core/metrics/structs/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import List

from torch import nn

from eva.core.metrics.structs import collection, schemas
Expand Down Expand Up @@ -46,6 +48,7 @@ def from_metrics(
test: MetricModuleType | None,
*,
separator: str = "/",
compute_groups: bool | List[List[str]] = True,
) -> MetricModule:
"""Initializes a metric module from a list of metrics.
Expand All @@ -55,32 +58,42 @@ def from_metrics(
test: Metrics for the test stage.
separator: The separator between the group name of the metric
and the metric itself.
compute_groups: All metrics in a compute group share the same metric state
and are therefore only different in their compute step. To disable this
behavior, set to `False`.
"""
return cls(
train=_create_collection_from_metrics(train, prefix="train" + separator),
val=_create_collection_from_metrics(val, prefix="val" + separator),
test=_create_collection_from_metrics(test, prefix="test" + separator),
train=_create_collection_from_metrics(
train, prefix="train" + separator, compute_groups=compute_groups
),
val=_create_collection_from_metrics(
val, prefix="val" + separator, compute_groups=compute_groups
),
test=_create_collection_from_metrics(
test, prefix="test" + separator, compute_groups=compute_groups
),
)

@classmethod
def from_schema(
cls,
schema: schemas.MetricsSchema,
*,
separator: str = "/",
cls, schema: schemas.MetricsSchema, *, separator: str = "/", compute_groups: bool = True
) -> MetricModule:
"""Initializes a metric module from the metrics schema.
Args:
schema: The dataclass metric schema.
separator: The separator between the group name of the metric
and the metric itself.
compute_groups: All metrics in a compute group share the same metric state
and are therefore only different in their compute step. To disable this
behavior, set to `False`.
"""
return cls.from_metrics(
train=schema.training_metrics,
val=schema.evaluation_metrics,
test=schema.evaluation_metrics,
separator=separator,
compute_groups=compute_groups,
)

@property
Expand All @@ -100,16 +113,24 @@ def test_metrics(self) -> collection.MetricCollection:


def _create_collection_from_metrics(
metrics: MetricModuleType | None, *, prefix: str | None = None
metrics: MetricModuleType | None,
*,
prefix: str | None = None,
compute_groups: bool | List[List[str]] = True,
) -> collection.MetricCollection | None:
"""Create a unique collection from metrics.
Args:
metrics: The desired metrics.
prefix: A prefix to added to the collection.
compute_groups: All metrics in a compute group share the same metric state
and are therefore only different in their compute step. To disable this
behavior, set to `False`.
Returns:
The resulted metrics collection.
"""
metrics_collection = collection.MetricCollection(metrics or [], prefix=prefix) # type: ignore
metrics_collection = collection.MetricCollection(
metrics=metrics or [], prefix=prefix, compute_groups=compute_groups # type: ignore
)
return metrics_collection.clone()
12 changes: 12 additions & 0 deletions src/eva/vision/metrics/defaults/segmentation/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ def __init__(
"""
super().__init__(
metrics={
"MonaiDiceScore": segmentation.MonaiDiceScore(
num_classes=num_classes,
include_background=include_background,
ignore_index=ignore_index,
ignore_empty=True,
),
"MonaiDiceScore (ignore_empty=False)": segmentation.MonaiDiceScore(
num_classes=num_classes,
include_background=include_background,
ignore_index=ignore_index,
ignore_empty=False,
),
"DiceScore (micro)": segmentation.DiceScore(
num_classes=num_classes,
include_background=include_background,
Expand Down
2 changes: 2 additions & 0 deletions src/eva/vision/metrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from eva.vision.metrics.segmentation.dice import DiceScore
from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore
from eva.vision.metrics.segmentation.mean_iou import MeanIoU
from eva.vision.metrics.segmentation.monai_dice import MonaiDiceScore

__all__ = [
"DiceScore",
"MonaiDiceScore",
"GeneralizedDiceScore",
"MeanIoU",
]
59 changes: 59 additions & 0 deletions src/eva/vision/metrics/segmentation/monai_dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Wrapper for dice score metric from MONAI."""

from monai.metrics.meandice import DiceMetric
from typing_extensions import override

from eva.vision.metrics import wrappers
from eva.vision.metrics.segmentation import _utils


class MonaiDiceScore(wrappers.MonaiMetricWrapper):
"""Wrapper to make MONAI's `DiceMetric` compatible with `torchmetrics`."""

def __init__(
self,
num_classes: int,
include_background: bool = True,
reduction: str = "mean",
ignore_index: int | None = None,
**kwargs,
):
"""Initializes metric.
Args:
num_classes: The number of classes in the dataset.
include_background: Whether to include the background class in the computation.
reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`.
ignore_index: Integer specifying a target class to ignore. If given, this class
index does not contribute to the returned score.
kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class.
"""
super().__init__(
DiceMetric(
include_background=include_background,
reduction=reduction,
num_classes=num_classes,
**kwargs,
)
)

self.reduction = reduction
self.num_classes = num_classes
self.ignore_index = ignore_index

@override
def update(self, preds, target):
preds = _utils.index_to_one_hot(preds, num_classes=self.num_classes)
target = _utils.index_to_one_hot(target, num_classes=self.num_classes)
if self.ignore_index is not None:
preds, target = _utils.apply_ignore_index(
preds, target, self.ignore_index, self.num_classes
)
return super().update(preds, target)

@override
def compute(self):
result = super().compute()
if self.reduction == "none" and len(result) > 1:
result = result.nanmean(dim=0)
return result
5 changes: 5 additions & 0 deletions src/eva/vision/metrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Metrics wrappers API."""

from eva.vision.metrics.wrappers.monai import MonaiMetricWrapper

__all__ = ["MonaiMetricWrapper"]
32 changes: 32 additions & 0 deletions src/eva/vision/metrics/wrappers/monai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Monai metrics wrappers."""

import torch
import torchmetrics
from monai.metrics.metric import CumulativeIterationMetric
from typing_extensions import override


class MonaiMetricWrapper(torchmetrics.Metric):
"""Wrapper class to make MONAI metrics compatible with `torchmetrics`."""

def __init__(self, monai_metric: CumulativeIterationMetric):
"""Initializes the monai metric wrapper.
Args:
monai_metric: The MONAI metric to wrap.
"""
super().__init__()
self._monai_metric = monai_metric

@override
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
self._monai_metric(preds, target)

@override
def compute(self) -> torch.Tensor:
return self._monai_metric.aggregate()

@override
def reset(self) -> None:
super().reset()
self._monai_metric.reset()
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
PREDS_ONE = torch.randint(0, NUM_CLASSES_ONE, (NUM_BATCHES, BATCH_SIZE, 32, 32))
TARGET_ONE = torch.randint(0, NUM_CLASSES_ONE, (NUM_BATCHES, BATCH_SIZE, 32, 32))
EXPECTED_ONE = {
"MonaiDiceScore": torch.tensor(0.34805023670196533),
"MonaiDiceScore (ignore_empty=False)": torch.tensor(0.34805023670196533),
"DiceScore (micro)": torch.tensor(0.3482658863067627),
"DiceScore (macro)": torch.tensor(0.34805023670196533),
"DiceScore (weighted)": torch.tensor(0.3484232723712921),
"MeanIoU": torch.tensor(0.2109210342168808),
}
"""Test features."""
assert EXPECTED_ONE["MonaiDiceScore (ignore_empty=False)"] == EXPECTED_ONE["DiceScore (macro)"]


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/eva/vision/test_vision_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"configs/vision/pathology/online/segmentation/bcss.yaml",
"configs/vision/pathology/online/segmentation/consep.yaml",
"configs/vision/pathology/online/segmentation/monusac.yaml",
"configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml",
"configs/vision/radiology/online/segmentation/total_segmentator_2d.yaml",
"configs/vision/radiology/online/segmentation/lits.yaml",
# | offline
# classification
Expand All @@ -38,7 +38,7 @@
"configs/vision/pathology/offline/segmentation/bcss.yaml",
"configs/vision/pathology/offline/segmentation/consep.yaml",
"configs/vision/pathology/offline/segmentation/monusac.yaml",
"configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml",
"configs/vision/radiology/offline/segmentation/total_segmentator_2d.yaml",
"configs/vision/radiology/offline/segmentation/lits.yaml",
],
)
Expand Down

0 comments on commit 0766787

Please sign in to comment.