From 63c7bbe6ac163659e9631badce296ff1c658d45e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 28 Aug 2023 18:16:10 +0200 Subject: [PATCH] Add argument `average` to `MeanAveragePrecision` (#2018) --- CHANGELOG.md | 3 ++ src/torchmetrics/detection/mean_ap.py | 69 ++++++++++++++++++++------- tests/unittests/detection/test_map.py | 65 ++++++++++++++++++------- 3 files changed, 102 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 648cde46870..05d848e91cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `average` argument to `MeanAveragePrecision` ([#2018](https://github.com/Lightning-AI/torchmetrics/pull/2018) + + - Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 0d47049a608..6cbe4f2b16a 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -187,6 +187,9 @@ class MeanAveragePrecision(Metric): IoU thresholds, ``K`` is the number of classes, ``A`` is the number of areas and ``M`` is the number of max detections per image. + average: + Method for averaging scores over labels. Choose between "``macro``"" and "``micro``". Default is "macro" + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -329,6 +332,7 @@ def __init__( max_detection_thresholds: Optional[List[int]] = None, class_metrics: bool = False, extended_summary: bool = False, + average: Literal["macro", "micro"] = "macro", **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -379,6 +383,10 @@ def __init__( raise ValueError("Expected argument `extended_summary` to be a boolean") self.extended_summary = extended_summary + if average not in ("macro", "micro"): + raise ValueError(f"Expected argument `average` to be one of ('macro', 'micro') but got {average}") + self.average = average + self.add_state("detection_box", default=[], dist_reduce_fx=None) self.add_state("detection_mask", default=[], dist_reduce_fx=None) self.add_state("detection_scores", default=[], dist_reduce_fx=None) @@ -434,27 +442,10 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] def compute(self) -> dict: """Computes the metric.""" - coco_target, coco_preds = COCO(), COCO() - - coco_target.dataset = self._get_coco_format( - labels=self.groundtruth_labels, - boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None, - masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None, - crowds=self.groundtruth_crowds, - area=self.groundtruth_area, - ) - coco_preds.dataset = self._get_coco_format( - labels=self.detection_labels, - boxes=self.detection_box if len(self.detection_box) > 0 else None, - masks=self.detection_mask if len(self.detection_mask) > 0 else None, - scores=self.detection_scores, - ) + coco_preds, coco_target = self._get_coco_datasets(average=self.average) result_dict = {} with contextlib.redirect_stdout(io.StringIO()): - coco_target.createIndex() - coco_preds.createIndex() - for i_type in self.iou_type: prefix = "" if len(self.iou_type) == 1 else f"{i_type}_" if len(self.iou_type) > 1: @@ -487,6 +478,15 @@ def compute(self) -> dict: # if class mode is enabled, evaluate metrics per class if self.class_metrics: + if self.average == "micro": + # since micro averaging have all the data in one class, we need to reinitialize the coco_eval + # object in macro mode to get the per class stats + coco_preds, coco_target = self._get_coco_datasets(average="macro") + coco_eval = COCOeval(coco_target, coco_preds, iouType=i_type) + coco_eval.params.iouThrs = np.array(self.iou_thresholds, dtype=np.float64) + coco_eval.params.recThrs = np.array(self.rec_thresholds, dtype=np.float64) + coco_eval.params.maxDets = self.max_detection_thresholds + map_per_class_list = [] mar_100_per_class_list = [] for class_id in self._get_classes(): @@ -516,8 +516,41 @@ def compute(self) -> dict: return result_dict + def _get_coco_datasets(self, average: Literal["macro", "micro"]) -> Tuple[COCO, COCO]: + """Returns the coco datasets for the target and the predictions.""" + if average == "micro": + # for micro averaging we set everything to be the same class + groundtruth_labels = apply_to_collection(self.groundtruth_labels, Tensor, lambda x: torch.zeros_like(x)) + detection_labels = apply_to_collection(self.detection_labels, Tensor, lambda x: torch.zeros_like(x)) + else: + groundtruth_labels = self.groundtruth_labels + detection_labels = self.detection_labels + + coco_target, coco_preds = COCO(), COCO() + + coco_target.dataset = self._get_coco_format( + labels=groundtruth_labels, + boxes=self.groundtruth_box if len(self.groundtruth_box) > 0 else None, + masks=self.groundtruth_mask if len(self.groundtruth_mask) > 0 else None, + crowds=self.groundtruth_crowds, + area=self.groundtruth_area, + ) + coco_preds.dataset = self._get_coco_format( + labels=detection_labels, + boxes=self.detection_box if len(self.detection_box) > 0 else None, + masks=self.detection_mask if len(self.detection_mask) > 0 else None, + scores=self.detection_scores, + ) + + with contextlib.redirect_stdout(io.StringIO()): + coco_target.createIndex() + coco_preds.createIndex() + + return coco_preds, coco_target + @staticmethod def _coco_stats_to_tensor_dict(stats: List[float], prefix: str) -> Dict[str, Tensor]: + """Converts the output of COCOeval.stats to a dict of tensors.""" return { f"{prefix}map": torch.tensor([stats[0]], dtype=torch.float32), f"{prefix}map_50": torch.tensor([stats[1]], dtype=torch.float32), diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index de3b5805254..438e5dea952 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch +from lightning_utilities import apply_to_collection from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from torch import IntTensor, Tensor @@ -474,37 +475,32 @@ def test_empty_preds_cxcywh(): metric.compute() -_gpu_test_condition = not torch.cuda.is_available() - - -def _move_to_gpu(inputs): - for x in inputs: - for key in x: - if torch.is_tensor(x[key]): - x[key] = x[key].to("cuda") - return inputs - - @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability") @pytest.mark.parametrize("inputs", [_inputs, _inputs2, _inputs3]) def test_map_gpu(inputs): """Test predictions on single gpu.""" metric = MeanAveragePrecision() metric = metric.to("cuda") - for preds, targets in zip(inputs.preds, inputs.target): - metric.update(_move_to_gpu(preds), _move_to_gpu(targets)) + for preds, targets in zip(deepcopy(inputs.preds), deepcopy(inputs.target)): + metric.update( + apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), + apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), + ) metric.compute() @pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA availability") def test_map_with_custom_thresholds(): """Test that map works with custom iou thresholds.""" metric = MeanAveragePrecision(iou_thresholds=[0.1, 0.2]) metric = metric.to("cuda") - for preds, targets in zip(_inputs.preds, _inputs.target): - metric.update(_move_to_gpu(preds), _move_to_gpu(targets)) + for preds, targets in zip(deepcopy(_inputs.preds), deepcopy(_inputs.target)): + metric.update( + apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), + apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), + ) res = metric.compute() assert res["map_50"].item() == -1 assert res["map_75"].item() == -1 @@ -794,3 +790,38 @@ def test_for_extended_stats(preds, target, expected_iou_len, iou_keys, precision recall = result["recall"] assert isinstance(recall, Tensor) assert recall.shape == recall_shape + + +@pytest.mark.parametrize("class_metrics", [False, True]) +def test_average_argument(class_metrics): + """Test that average argument works. + + Calculating macro on inputs that only have one label should be the same as micro. Calculating class metrics should + be the same regardless of average argument. + + """ + if class_metrics: + _preds = _inputs.preds + _target = _inputs.target + else: + _preds = apply_to_collection(deepcopy(_inputs.preds), IntTensor, lambda x: torch.ones_like(x)) + _target = apply_to_collection(deepcopy(_inputs.target), IntTensor, lambda x: torch.ones_like(x)) + + metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics) + metric_macro.update(_preds[0], _target[0]) + metric_macro.update(_preds[1], _target[1]) + result_macro = metric_macro.compute() + + metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics) + metric_micro.update(_inputs.preds[0], _inputs.target[0]) + metric_micro.update(_inputs.preds[1], _inputs.target[1]) + result_micro = metric_micro.compute() + + if class_metrics: + assert torch.allclose(result_macro["map_per_class"], result_micro["map_per_class"]) + assert torch.allclose(result_macro["mar_100_per_class"], result_micro["mar_100_per_class"]) + else: + for key in result_macro: + if key == "classes": + continue + assert torch.allclose(result_macro[key], result_micro[key])