-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added MonaiDiceScore and switched class-wise metrics to monai
- Loading branch information
Showing
8 changed files
with
145 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Wrapper for dice score metric from MONAI.""" | ||
|
||
from monai.metrics.meandice import DiceMetric | ||
from typing_extensions import override | ||
|
||
from eva.vision.metrics import wrappers | ||
from eva.vision.metrics.segmentation import _utils | ||
|
||
|
||
class MonaiDiceScore(wrappers.MonaiMetricWrapper): | ||
"""Wrapper to make MONAI's `DiceMetric` compatible with `torchmetrics`.""" | ||
|
||
def __init__( | ||
self, | ||
num_classes: int, | ||
include_background: bool = True, | ||
reduction: str = "mean", | ||
ignore_index: int | None = None, | ||
**kwargs, | ||
): | ||
"""Initializes metric. | ||
Args: | ||
num_classes: The number of classes in the dataset. | ||
include_background: Whether to include the background class in the computation. | ||
reduction: The method to reduce the dice score. Options are `"mean"`, `"sum"`, `"none"`. | ||
ignore_index: Integer specifying a target class to ignore. If given, this class | ||
index does not contribute to the returned score. | ||
kwargs: Additional keyword arguments for instantiating monai's `DiceMetric` class. | ||
""" | ||
super().__init__( | ||
DiceMetric( | ||
include_background=include_background, | ||
reduction=reduction, | ||
num_classes=num_classes, | ||
**kwargs, | ||
) | ||
) | ||
|
||
self.reduction = reduction | ||
self.num_classes = num_classes | ||
self.ignore_index = ignore_index | ||
|
||
@override | ||
def update(self, preds, target): | ||
preds = _utils.index_to_one_hot(preds, num_classes=self.num_classes) | ||
target = _utils.index_to_one_hot(target, num_classes=self.num_classes) | ||
if self.ignore_index is not None: | ||
preds, target = _utils.apply_ignore_index( | ||
preds, target, self.ignore_index, self.num_classes | ||
) | ||
return super().update(preds, target) | ||
|
||
@override | ||
def compute(self): | ||
result = super().compute() | ||
if self.reduction == "none" and len(result) > 1: | ||
result = result.nanmean(dim=0) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""Metrics wrappers API.""" | ||
|
||
from eva.vision.metrics.wrappers.monai import MonaiMetricWrapper | ||
|
||
__all__ = ["MonaiMetricWrapper"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Monai metrics wrappers.""" | ||
|
||
import torch | ||
import torchmetrics | ||
from monai.metrics.metric import CumulativeIterationMetric | ||
from typing_extensions import override | ||
|
||
|
||
class MonaiMetricWrapper(torchmetrics.Metric): | ||
"""Wrapper class to make MONAI metrics compatible with `torchmetrics`.""" | ||
|
||
def __init__(self, monai_metric: CumulativeIterationMetric): | ||
"""Initializes the monai metric wrapper. | ||
Args: | ||
monai_metric: The MONAI metric to wrap. | ||
""" | ||
super().__init__() | ||
self._monai_metric = monai_metric | ||
|
||
@override | ||
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: | ||
self._monai_metric(preds, target) | ||
|
||
@override | ||
def compute(self) -> torch.Tensor: | ||
return self._monai_metric.aggregate() | ||
|
||
@override | ||
def reset(self) -> None: | ||
super().reset() | ||
self._monai_metric.reset() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters