diff --git a/configs/vision/pathology/offline/segmentation/bcss.yaml b/configs/vision/pathology/offline/segmentation/bcss.yaml index 34acad17..e2de4b40 100644 --- a/configs/vision/pathology/offline/segmentation/bcss.yaml +++ b/configs/vision/pathology/offline/segmentation/bcss.yaml @@ -80,13 +80,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/offline/segmentation/consep.yaml b/configs/vision/pathology/offline/segmentation/consep.yaml index 1b95c49a..0e397a12 100644 --- a/configs/vision/pathology/offline/segmentation/consep.yaml +++ b/configs/vision/pathology/offline/segmentation/consep.yaml @@ -80,13 +80,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/offline/segmentation/monusac.yaml b/configs/vision/pathology/offline/segmentation/monusac.yaml index f1a15473..b44e688c 100644 --- a/configs/vision/pathology/offline/segmentation/monusac.yaml +++ b/configs/vision/pathology/offline/segmentation/monusac.yaml @@ -82,14 +82,14 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES ignore_index: *IGNORE_INDEX - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml index f0c02b09..763432b4 100644 --- a/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/offline/segmentation/total_segmentator_2d.yaml @@ -79,13 +79,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/online/segmentation/bcss.yaml b/configs/vision/pathology/online/segmentation/bcss.yaml index 7b4cede6..edc9901a 100644 --- a/configs/vision/pathology/online/segmentation/bcss.yaml +++ b/configs/vision/pathology/online/segmentation/bcss.yaml @@ -73,13 +73,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/online/segmentation/consep.yaml b/configs/vision/pathology/online/segmentation/consep.yaml index c626c53b..a36f2625 100644 --- a/configs/vision/pathology/online/segmentation/consep.yaml +++ b/configs/vision/pathology/online/segmentation/consep.yaml @@ -73,13 +73,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/online/segmentation/monusac.yaml b/configs/vision/pathology/online/segmentation/monusac.yaml index 64cbef6a..234b30e4 100644 --- a/configs/vision/pathology/online/segmentation/monusac.yaml +++ b/configs/vision/pathology/online/segmentation/monusac.yaml @@ -74,14 +74,14 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES ignore_index: *IGNORE_INDEX - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml index 39cf92e3..4cd23269 100644 --- a/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml +++ b/configs/vision/pathology/online/segmentation/total_segmentator_2d.yaml @@ -72,13 +72,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: eva.core.metrics.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/radiology/offline/segmentation/lits.yaml b/configs/vision/radiology/offline/segmentation/lits.yaml index a4604a33..c3b8d4d0 100644 --- a/configs/vision/radiology/offline/segmentation/lits.yaml +++ b/configs/vision/radiology/offline/segmentation/lits.yaml @@ -79,13 +79,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: torchmetrics.segmentation.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml index 28e76c59..a34b7a3b 100644 --- a/configs/vision/radiology/offline/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/offline/segmentation/lits_balanced.yaml @@ -79,13 +79,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: torchmetrics.segmentation.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/radiology/online/segmentation/lits.yaml b/configs/vision/radiology/online/segmentation/lits.yaml index 6cc8c87b..1ce7731e 100644 --- a/configs/vision/radiology/online/segmentation/lits.yaml +++ b/configs/vision/radiology/online/segmentation/lits.yaml @@ -72,13 +72,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: torchmetrics.segmentation.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/configs/vision/radiology/online/segmentation/lits_balanced.yaml b/configs/vision/radiology/online/segmentation/lits_balanced.yaml index e767c097..a0a7fc5d 100644 --- a/configs/vision/radiology/online/segmentation/lits_balanced.yaml +++ b/configs/vision/radiology/online/segmentation/lits_balanced.yaml @@ -72,13 +72,13 @@ model: common: - class_path: eva.metrics.AverageLoss evaluation: - - class_path: eva.core.metrics.defaults.MulticlassSegmentationMetrics + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics init_args: num_classes: *NUM_CLASSES - class_path: torchmetrics.ClasswiseWrapper init_args: metric: - class_path: torchmetrics.segmentation.GeneralizedDiceScore + class_path: eva.vision.metrics.GeneralizedDiceScore init_args: num_classes: *NUM_CLASSES weight_type: linear diff --git a/src/eva/core/metrics/__init__.py b/src/eva/core/metrics/__init__.py index b58d938c..aed8c33e 100644 --- a/src/eva/core/metrics/__init__.py +++ b/src/eva/core/metrics/__init__.py @@ -3,8 +3,6 @@ from eva.core.metrics.average_loss import AverageLoss from eva.core.metrics.binary_balanced_accuracy import BinaryBalancedAccuracy from eva.core.metrics.defaults import BinaryClassificationMetrics, MulticlassClassificationMetrics -from eva.core.metrics.generalized_dice import GeneralizedDiceScore -from eva.core.metrics.mean_iou import MeanIoU from eva.core.metrics.structs import Metric, MetricCollection, MetricModule, MetricsSchema __all__ = [ @@ -12,8 +10,6 @@ "BinaryBalancedAccuracy", "BinaryClassificationMetrics", "MulticlassClassificationMetrics", - "GeneralizedDiceScore", - "MeanIoU", "Metric", "MetricCollection", "MetricModule", diff --git a/src/eva/core/metrics/defaults/__init__.py b/src/eva/core/metrics/defaults/__init__.py index 84120acf..be65d757 100644 --- a/src/eva/core/metrics/defaults/__init__.py +++ b/src/eva/core/metrics/defaults/__init__.py @@ -4,10 +4,8 @@ BinaryClassificationMetrics, MulticlassClassificationMetrics, ) -from eva.core.metrics.defaults.segmentation import MulticlassSegmentationMetrics __all__ = [ "MulticlassClassificationMetrics", "BinaryClassificationMetrics", - "MulticlassSegmentationMetrics", ] diff --git a/src/eva/core/metrics/mean_iou.py b/src/eva/core/metrics/mean_iou.py deleted file mode 100644 index 58ca6b54..00000000 --- a/src/eva/core/metrics/mean_iou.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Mean Intersection over Union (mIoU) metric for semantic segmentation.""" - -from typing import Any, Literal, Tuple - -import torch -import torchmetrics - - -class MeanIoU(torchmetrics.Metric): - """Computes Mean Intersection over Union (mIoU) for semantic segmentation. - - Fixes the torchmetrics implementation - (issue https://github.com/Lightning-AI/torchmetrics/issues/2558) - """ - - def __init__( - self, - num_classes: int, - include_background: bool = True, - ignore_index: int | None = None, - per_class: bool = False, - **kwargs: Any, - ) -> None: - """Initializes the metric. - - Args: - num_classes: The number of classes in the segmentation problem. - include_background: Whether to include the background class in the computation - ignore_index: Integer specifying a target class to ignore. If given, this class - index does not contribute to the returned score, regardless of reduction method. - per_class: Whether to compute the IoU for each class separately. If set to ``False``, - the metric will compute the mean IoU over all classes. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - """ - super().__init__(**kwargs) - - self.num_classes = num_classes - self.include_background = include_background - self.ignore_index = ignore_index - self.per_class = per_class - - self.add_state("intersection", default=torch.zeros(num_classes), dist_reduce_fx="sum") - self.add_state("union", default=torch.zeros(num_classes), dist_reduce_fx="sum") - - def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: - """Update the state with the new data.""" - intersection, union = _compute_intersection_and_union( - preds, - target, - num_classes=self.num_classes, - include_background=self.include_background, - ignore_index=self.ignore_index, - ) - self.intersection += intersection.sum(0) - self.union += union.sum(0) - - def compute(self) -> torch.Tensor: - """Compute the final mean IoU score.""" - iou_valid = torch.gt(self.union, 0) - iou = torch.where( - iou_valid, - torch.divide(self.intersection, self.union), - torch.nan, - ) - if not self.per_class: - iou = torch.mean(iou[iou_valid]) - return iou - - -def _compute_intersection_and_union( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - include_background: bool = False, - input_format: Literal["one-hot", "index"] = "index", - ignore_index: int | None = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute the intersection and union for semantic segmentation tasks. - - Args: - preds: Predicted tensor with shape (N, ...) where N is the batch size. - The shape can be (N, H, W) for 2D data or (N, D, H, W) for 3D data. - target: Ground truth tensor with the same shape as preds. - num_classes: Number of classes in the segmentation task. - include_background: Whether to include the background class in the computation. - input_format: Format of the input tensors. - ignore_index: Integer specifying a target class to ignore. If given, this class - index does not contribute to the returned score, regardless of reduction method. - - Returns: - Two tensors representing the intersection and union for each class. - Shape of each tensor is (N, num_classes). - - Note: - - If input_format is "index", the tensors are converted to one-hot encoding. - - If include_background is `False`, the background class - (assumed to be the first channel) is ignored in the computation. - """ - if ignore_index is not None: - mask = target != ignore_index - mask = mask.all(dim=-1, keepdim=True) - preds = preds * mask - target = target * mask - - if input_format == "index": - preds = torch.nn.functional.one_hot(preds, num_classes=num_classes) - target = torch.nn.functional.one_hot(target, num_classes=num_classes) - - if not include_background: - preds[..., 0] = 0 - target[..., 0] = 0 - - reduce_axis = list(range(1, preds.ndim - 1)) - - intersection = torch.sum(torch.logical_and(preds, target), dim=reduce_axis) - target_sum = torch.sum(target, dim=reduce_axis) - pred_sum = torch.sum(preds, dim=reduce_axis) - union = target_sum + pred_sum - intersection - - return intersection, union diff --git a/src/eva/vision/metrics/__init__.py b/src/eva/vision/metrics/__init__.py new file mode 100644 index 00000000..c879d08f --- /dev/null +++ b/src/eva/vision/metrics/__init__.py @@ -0,0 +1,11 @@ +"""Default metric collections API.""" + +from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics +from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore +from eva.vision.metrics.segmentation.mean_iou import MeanIoU + +__all__ = [ + "MulticlassSegmentationMetrics", + "GeneralizedDiceScore", + "MeanIoU", +] diff --git a/src/eva/vision/metrics/defaults/__init__.py b/src/eva/vision/metrics/defaults/__init__.py new file mode 100644 index 00000000..14bcecda --- /dev/null +++ b/src/eva/vision/metrics/defaults/__init__.py @@ -0,0 +1,7 @@ +"""Default metric collections API.""" + +from eva.vision.metrics.defaults.segmentation import MulticlassSegmentationMetrics + +__all__ = [ + "MulticlassSegmentationMetrics", +] diff --git a/src/eva/core/metrics/defaults/segmentation/__init__.py b/src/eva/vision/metrics/defaults/segmentation/__init__.py similarity index 50% rename from src/eva/core/metrics/defaults/segmentation/__init__.py rename to src/eva/vision/metrics/defaults/segmentation/__init__.py index 31c397dd..34d11a38 100644 --- a/src/eva/core/metrics/defaults/segmentation/__init__.py +++ b/src/eva/vision/metrics/defaults/segmentation/__init__.py @@ -1,5 +1,5 @@ """Default segmentation metric collections API.""" -from eva.core.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics +from eva.vision.metrics.defaults.segmentation.multiclass import MulticlassSegmentationMetrics __all__ = ["MulticlassSegmentationMetrics"] diff --git a/src/eva/core/metrics/defaults/segmentation/multiclass.py b/src/eva/vision/metrics/defaults/segmentation/multiclass.py similarity index 93% rename from src/eva/core/metrics/defaults/segmentation/multiclass.py rename to src/eva/vision/metrics/defaults/segmentation/multiclass.py index f4d4017a..031ab732 100644 --- a/src/eva/core/metrics/defaults/segmentation/multiclass.py +++ b/src/eva/vision/metrics/defaults/segmentation/multiclass.py @@ -1,6 +1,7 @@ """Default metric collection for multiclass semantic segmentation tasks.""" -from eva.core.metrics import generalized_dice, mean_iou, structs +from eva.core.metrics import structs +from eva.vision.metrics.segmentation import generalized_dice, mean_iou class MulticlassSegmentationMetrics(structs.MetricCollection): diff --git a/src/eva/vision/metrics/segmentation/BUILD b/src/eva/vision/metrics/segmentation/BUILD new file mode 100644 index 00000000..db46e8d6 --- /dev/null +++ b/src/eva/vision/metrics/segmentation/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/src/eva/vision/metrics/segmentation/__init__.py b/src/eva/vision/metrics/segmentation/__init__.py new file mode 100644 index 00000000..2fc51387 --- /dev/null +++ b/src/eva/vision/metrics/segmentation/__init__.py @@ -0,0 +1,9 @@ +"""Segmentation metrics API.""" + +from eva.vision.metrics.segmentation.generalized_dice import GeneralizedDiceScore +from eva.vision.metrics.segmentation.mean_iou import MeanIoU + +__all__ = [ + "GeneralizedDiceScore", + "MeanIoU", +] diff --git a/src/eva/vision/metrics/segmentation/_utils.py b/src/eva/vision/metrics/segmentation/_utils.py new file mode 100644 index 00000000..b1dba5d4 --- /dev/null +++ b/src/eva/vision/metrics/segmentation/_utils.py @@ -0,0 +1,69 @@ +"""Utils for segmentation metric collections.""" + +from typing import Tuple + +import torch + + +def apply_ignore_index( + preds: torch.Tensor, target: torch.Tensor, ignore_index: int, num_classes: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Applies the ignore index to the predictions and target tensors. + + 1. Masks the values in the target tensor that correspond to the ignored index. + 2. Remove the channel corresponding to the ignored index from both tensors. + + Args: + preds: The predictions tensor. Expected to be of shape `(N,C,...)`. + target: The target tensor. Expected to be of shape `(N,C,...)`. + ignore_index: The index to ignore. + num_classes: The number of classes. + + Returns: + The modified predictions and target tensors of shape `(N,C-1,...)`. + """ + if ignore_index < 0: + raise ValueError("ignore_index must be a non-negative integer") + + ignore_mask = preds[:, ignore_index] == 1 + target = target * (~ignore_mask.unsqueeze(1)) + + preds = _ignore_tensor_channel(preds, ignore_index) + target = _ignore_tensor_channel(target, ignore_index) + + return preds, target + + +def index_to_one_hot(tensor: torch.Tensor, num_classes: int) -> torch.Tensor: + """Converts an index tensor to a one-hot tensor. + + Args: + tensor: The index tensor to convert. Expected to be of shape `(N,...)`. + num_classes: The number of classes to one-hot encode. + + Returns: + A one-hot tensor of shape `(N,C,...)`. + """ + if not _is_one_hot(tensor): + tensor = torch.nn.functional.one_hot(tensor.long(), num_classes=num_classes).movedim(-1, 1) + return tensor + + +def _ignore_tensor_channel(tensor: torch.Tensor, ignore_index: int) -> torch.Tensor: + """Removes the channel corresponding to the specified ignore index. + + Args: + tensor: The tensor to remove the channel from. Expected to be of shape `(N,C,...)`. + ignore_index: The index of the channel dimension (C) to remove. + + Returns: + A tensor without the specified channel `(N,C-1,...)`. + """ + if ignore_index < 0: + raise ValueError("ignore_index must be a non-negative integer") + return torch.cat([tensor[:, :ignore_index], tensor[:, ignore_index + 1 :]], dim=1) + + +def _is_one_hot(tensor: torch.Tensor, expected_dim: int = 4) -> bool: + """Checks if the tensor is a one-hot tensor.""" + return bool((tensor.bool() == tensor).all()) and tensor.ndim == expected_dim diff --git a/src/eva/core/metrics/generalized_dice.py b/src/eva/vision/metrics/segmentation/generalized_dice.py similarity index 75% rename from src/eva/core/metrics/generalized_dice.py rename to src/eva/vision/metrics/segmentation/generalized_dice.py index 03994506..0fdfead3 100644 --- a/src/eva/core/metrics/generalized_dice.py +++ b/src/eva/vision/metrics/segmentation/generalized_dice.py @@ -6,6 +6,8 @@ from torchmetrics import segmentation from typing_extensions import override +from eva.vision.metrics.segmentation import _utils + class GeneralizedDiceScore(segmentation.GeneralizedDiceScore): """Defines the Generalized Dice Score. @@ -30,8 +32,6 @@ def __init__( include_background: Whether to include the background class in the computation weight_type: The type of weight to apply to each class. Can be one of `"square"`, `"simple"`, or `"linear"`. - input_format: What kind of input the function receives. Choose between ``"one-hot"`` - for one-hot encoded tensors or ``"index"`` for index tensors. ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. per_class: Whether to compute the IoU for each class separately. If set to ``False``, @@ -39,21 +39,23 @@ def __init__( kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. """ super().__init__( - num_classes=num_classes, + num_classes=num_classes + - (ignore_index is not None) + + (ignore_index == 0 and not include_background), include_background=include_background, weight_type=weight_type, per_class=per_class, **kwargs, ) - + self.orig_num_classes = num_classes self.ignore_index = ignore_index @override def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes) + target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes) if self.ignore_index is not None: - mask = target != self.ignore_index - mask = mask.all(dim=-1, keepdim=True) - preds = preds * mask - target = target * mask - - super().update(preds=preds, target=target) + preds, target = _utils.apply_ignore_index( + preds, target, self.ignore_index, self.num_classes + ) + super().update(preds=preds.long(), target=target.long()) diff --git a/src/eva/vision/metrics/segmentation/mean_iou.py b/src/eva/vision/metrics/segmentation/mean_iou.py new file mode 100644 index 00000000..4fbd09ac --- /dev/null +++ b/src/eva/vision/metrics/segmentation/mean_iou.py @@ -0,0 +1,57 @@ +"""MeanIoU metric for semantic segmentation.""" + +from typing import Any + +import torch +from torchmetrics import segmentation +from typing_extensions import override + +from eva.vision.metrics.segmentation import _utils + + +class MeanIoU(segmentation.MeanIoU): + """MeanIoU (mIOU) metric for semantic segmentation. + + It expands the `torchmetrics` class by including an `ignore_index` + functionality. + """ + + def __init__( + self, + num_classes: int, + include_background: bool = True, + ignore_index: int | None = None, + per_class: bool = False, + **kwargs: Any, + ) -> None: + """Initializes the metric. + + Args: + num_classes: The number of classes in the segmentation problem. + include_background: Whether to include the background class in the computation + ignore_index: Integer specifying a target class to ignore. If given, this class + index does not contribute to the returned score, regardless of reduction method. + per_class: Whether to compute the IoU for each class separately. If set to ``False``, + the metric will compute the mean IoU over all classes. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + """ + super().__init__( + num_classes=num_classes + - (ignore_index is not None) + + (ignore_index == 0 and not include_background), + include_background=include_background, + per_class=per_class, + **kwargs, + ) + self.orig_num_classes = num_classes + self.ignore_index = ignore_index + + @override + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + preds = _utils.index_to_one_hot(preds, num_classes=self.orig_num_classes) + target = _utils.index_to_one_hot(target, num_classes=self.orig_num_classes) + if self.ignore_index is not None: + preds, target = _utils.apply_ignore_index( + preds, target, self.ignore_index, self.num_classes + ) + super().update(preds=preds.long(), target=target.long()) diff --git a/tests/eva/vision/metrics/defaults/__init__.py b/tests/eva/vision/metrics/defaults/__init__.py new file mode 100644 index 00000000..4ed9e825 --- /dev/null +++ b/tests/eva/vision/metrics/defaults/__init__.py @@ -0,0 +1 @@ +"""Tests default metric groups.""" diff --git a/tests/eva/core/metrics/defaults/segmentation/__init__.py b/tests/eva/vision/metrics/defaults/segmentation/__init__.py similarity index 100% rename from tests/eva/core/metrics/defaults/segmentation/__init__.py rename to tests/eva/vision/metrics/defaults/segmentation/__init__.py diff --git a/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py b/tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py similarity index 97% rename from tests/eva/core/metrics/defaults/segmentation/test_multiclass.py rename to tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py index bbc6ea5a..a86ec9ec 100644 --- a/tests/eva/core/metrics/defaults/segmentation/test_multiclass.py +++ b/tests/eva/vision/metrics/defaults/segmentation/test_multiclass.py @@ -3,7 +3,7 @@ import pytest import torch -from eva.core.metrics import defaults +from eva.vision.metrics import defaults NUM_BATCHES = 2 BATCH_SIZE = 4 diff --git a/tests/eva/vision/metrics/segmentation/__init__.py b/tests/eva/vision/metrics/segmentation/__init__.py new file mode 100644 index 00000000..8ccc3632 --- /dev/null +++ b/tests/eva/vision/metrics/segmentation/__init__.py @@ -0,0 +1 @@ +"""Tests for the vision segmentation metric collections.""" diff --git a/tests/eva/vision/metrics/segmentation/_utils.py b/tests/eva/vision/metrics/segmentation/_utils.py new file mode 100644 index 00000000..0ae46def --- /dev/null +++ b/tests/eva/vision/metrics/segmentation/_utils.py @@ -0,0 +1,32 @@ +from typing import Callable, Tuple + +import pytest +import torch +import torchmetrics + + +def _test_ignore_index( + metric_cls: Callable[..., torchmetrics.Metric], + batch_size: int, + num_classes: int, + image_size: Tuple[int, int], + ignore_index: int, +) -> None: + """Test ignore index functionality of a torchmetric.""" + generator = torch.Generator() + generator.manual_seed(42) + + metric = metric_cls(num_classes=num_classes) + preds = torch.randint(0, num_classes, (batch_size,) + image_size, generator=generator) + target = preds.clone() + result_one = metric(preds=preds, target=target) + + random_mask = torch.randint(0, 2, (batch_size,) + image_size).bool() + preds[random_mask] = ignore_index # simulate wrong predictions + result_two = metric(preds=preds, target=target) + + metric_with_ignore = metric_cls(num_classes=num_classes, ignore_index=ignore_index) + result_three = metric_with_ignore(preds=preds, target=target) + + assert result_one != pytest.approx(result_two, abs=1e-6) + assert result_one == pytest.approx(result_three, abs=1e-6) diff --git a/tests/eva/vision/metrics/segmentation/test_generalized_dice.py b/tests/eva/vision/metrics/segmentation/test_generalized_dice.py new file mode 100644 index 00000000..c2400497 --- /dev/null +++ b/tests/eva/vision/metrics/segmentation/test_generalized_dice.py @@ -0,0 +1,24 @@ +"""GeneralizedDiceScore metric tests.""" + +from typing import Tuple + +import pytest + +from eva.vision.metrics import segmentation +from tests.eva.vision.metrics.segmentation import _utils + + +@pytest.mark.parametrize( + "batch_size, num_classes, image_size, ignore_index", + [ + (4, 3, (16, 16), 0), + (16, 5, (20, 20), 1), + ], +) +def test_ignore_index( + batch_size: int, num_classes: int, image_size: Tuple[int, int], ignore_index: int +) -> None: + """Tests the `ignore_index` functionality.""" + _utils._test_ignore_index( + segmentation.GeneralizedDiceScore, batch_size, num_classes, image_size, ignore_index + ) diff --git a/tests/eva/vision/metrics/segmentation/test_mean_iou.py b/tests/eva/vision/metrics/segmentation/test_mean_iou.py new file mode 100644 index 00000000..7c9ded6b --- /dev/null +++ b/tests/eva/vision/metrics/segmentation/test_mean_iou.py @@ -0,0 +1,24 @@ +"""MeanIoU metric tests.""" + +from typing import Tuple + +import pytest + +from eva.vision.metrics import segmentation +from tests.eva.vision.metrics.segmentation import _utils + + +@pytest.mark.parametrize( + "batch_size, num_classes, image_size, ignore_index", + [ + (4, 3, (16, 16), 0), + (16, 5, (20, 20), 1), + ], +) +def test_ignore_index( + batch_size: int, num_classes: int, image_size: Tuple[int, int], ignore_index: int +) -> None: + """Tests the `ignore_index` functionality.""" + _utils._test_ignore_index( + segmentation.MeanIoU, batch_size, num_classes, image_size, ignore_index + )