diff --git a/CHANGELOG.md b/CHANGELOG.md index bef7846751f..003648f8e08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,8 +35,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) +- Fixed bug in detection intersection metrics when `class_metrics=True` resulting in wrong values ([#1924](https://github.com/Lightning-AI/torchmetrics/pull/1924)) + + - Fixed missing attributes `higher_is_better`, `is_differentiable` for some metrics ([#2028](https://github.com/Lightning-AI/torchmetrics/pull/2028) + ## [1.1.0] - 2023-08-22 ### Added diff --git a/Makefile b/Makefile index 653d837ed87..004f5f5a625 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ export FREEZE_REQUIREMENTS=1 # assume you have installed need packages export SPHINX_MOCK_REQUIREMENTS=1 +export SPHINX_FETCH_ASSETS=0 clean: # clean all temp runs diff --git a/src/torchmetrics/detection/ciou.py b/src/torchmetrics/detection/ciou.py index 5b62679a396..d45e73880d6 100644 --- a/src/torchmetrics/detection/ciou.py +++ b/src/torchmetrics/detection/ciou.py @@ -37,8 +37,6 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion): - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection boxes of the format specified in the constructor. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - - ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores - for the boxes. - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection classes for the boxes. @@ -48,14 +46,14 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion): - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` ground truth boxes of the format specified in the constructor. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed ground truth + - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection classes for the boxes. As output of ``forward`` and ``compute`` the metric returns the following output: - ``ciou_dict``: A dictionary containing the following key-values: - - ciou: (:class:`~torch.Tensor`) + - ciou: (:class:`~torch.Tensor`) with overall ciou value over all classes and samples. - ciou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True`` Args: @@ -65,6 +63,9 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion): Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. class_metrics: Option to enable per-class metrics for IoU. Has a performance impact. + respect_labels: + Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou + between all pairs of boxes. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -86,7 +87,7 @@ class CompleteIntersectionOverUnion(IntersectionOverUnion): ... ] >>> metric = CompleteIntersectionOverUnion() >>> metric(preds, target) - {'ciou': tensor(-0.5694)} + {'ciou': tensor(0.8611)} Raises: ModuleNotFoundError: @@ -105,6 +106,7 @@ def __init__( box_format: str = "xyxy", iou_threshold: Optional[float] = None, class_metrics: bool = False, + respect_labels: bool = True, **kwargs: Any, ) -> None: if not _TORCHVISION_GREATER_EQUAL_0_13: @@ -112,7 +114,7 @@ def __init__( f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed." " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." ) - super().__init__(box_format, iou_threshold, class_metrics, **kwargs) + super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs) @staticmethod def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: diff --git a/src/torchmetrics/detection/diou.py b/src/torchmetrics/detection/diou.py index 6778979b1c0..063b67a7f61 100644 --- a/src/torchmetrics/detection/diou.py +++ b/src/torchmetrics/detection/diou.py @@ -37,8 +37,6 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion): - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection boxes of the format specified in the constructor. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - - ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores - for the boxes. - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection classes for the boxes. @@ -55,7 +53,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion): - ``diou_dict``: A dictionary containing the following key-values: - - diou: (:class:`~torch.Tensor`) + - diou: (:class:`~torch.Tensor`) with overall diou value over all classes and samples. - diou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class_metrics=True`` Args: @@ -65,6 +63,9 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion): Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. class_metrics: Option to enable per-class metrics for IoU. Has a performance impact. + respect_labels: + Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou + between all pairs of boxes. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -86,7 +87,7 @@ class DistanceIntersectionOverUnion(IntersectionOverUnion): ... ] >>> metric = DistanceIntersectionOverUnion() >>> metric(preds, target) - {'diou': tensor(-0.0694)} + {'diou': tensor(0.8611)} Raises: ModuleNotFoundError: @@ -105,6 +106,7 @@ def __init__( box_format: str = "xyxy", iou_threshold: Optional[float] = None, class_metrics: bool = False, + respect_labels: bool = True, **kwargs: Any, ) -> None: if not _TORCHVISION_GREATER_EQUAL_0_13: @@ -112,7 +114,7 @@ def __init__( f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed." " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." ) - super().__init__(box_format, iou_threshold, class_metrics, **kwargs) + super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs) @staticmethod def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: diff --git a/src/torchmetrics/detection/giou.py b/src/torchmetrics/detection/giou.py index e4ec9aee65c..43edd76c0c5 100644 --- a/src/torchmetrics/detection/giou.py +++ b/src/torchmetrics/detection/giou.py @@ -37,8 +37,6 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion): - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection boxes of the format specified in the constructor. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - - ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores - for the boxes. - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection classes for the boxes. @@ -55,7 +53,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion): - ``giou_dict``: A dictionary containing the following key-values: - - giou: (:class:`~torch.Tensor`) + - giou: (:class:`~torch.Tensor`) with overall giou value over all classes and samples. - giou/cl_{cl}: (:class:`~torch.Tensor`), if argument ``class metrics=True`` Args: @@ -65,6 +63,9 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion): Optional IoU thresholds for evaluation. If set to `None` the threshold is ignored. class_metrics: Option to enable per-class metrics for IoU. Has a performance impact. + respect_labels: + Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou + between all pairs of boxes. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -86,7 +87,7 @@ class GeneralizedIntersectionOverUnion(IntersectionOverUnion): ... ] >>> metric = GeneralizedIntersectionOverUnion() >>> metric(preds, target) - {'giou': tensor(-0.0694)} + {'giou': tensor(0.8613)} Raises: ModuleNotFoundError: @@ -105,9 +106,10 @@ def __init__( box_format: str = "xyxy", iou_threshold: Optional[float] = None, class_metrics: bool = False, + respect_labels: bool = True, **kwargs: Any, ) -> None: - super().__init__(box_format, iou_threshold, class_metrics, **kwargs) + super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs) @staticmethod def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: diff --git a/src/torchmetrics/detection/helpers.py b/src/torchmetrics/detection/helpers.py index f3681545b96..1550493e55a 100644 --- a/src/torchmetrics/detection/helpers.py +++ b/src/torchmetrics/detection/helpers.py @@ -20,6 +20,7 @@ def _input_validator( preds: Sequence[Dict[str, Tensor]], targets: Sequence[Dict[str, Tensor]], iou_type: Union[Literal["bbox", "segm"], Tuple[Literal["bbox", "segm"]]] = "bbox", + ignore_score: bool = False, ) -> None: """Ensure the correct input format of `preds` and `targets`.""" if isinstance(iou_type, str): @@ -39,7 +40,7 @@ def _input_validator( f"Expected argument `preds` and `target` to have the same length, but got {len(preds)} and {len(targets)}" ) - for k in [*item_val_name, "scores", "labels"]: + for k in [*item_val_name, "labels"] + (["scores"] if not ignore_score else []): if any(k not in p for p in preds): raise ValueError(f"Expected all dicts in `preds` to contain the `{k}` key") @@ -50,7 +51,7 @@ def _input_validator( for ivn in item_val_name: if any(type(pred[ivn]) is not Tensor for pred in preds): raise ValueError(f"Expected all {ivn} in `preds` to be of type Tensor") - if any(type(pred["scores"]) is not Tensor for pred in preds): + if not ignore_score and any(type(pred["scores"]) is not Tensor for pred in preds): raise ValueError("Expected all scores in `preds` to be of type Tensor") if any(type(pred["labels"]) is not Tensor for pred in preds): raise ValueError("Expected all labels in `preds` to be of type Tensor") @@ -67,6 +68,8 @@ def _input_validator( f"Input '{ivn}' and labels of sample {i} in targets have a" f" different length (expected {item[ivn].size(0)} labels, got {item['labels'].size(0)})" ) + if ignore_score: + return for i, item in enumerate(preds): for ivn in item_val_name: if not (item[ivn].size(0) == item["labels"].size(0) == item["scores"].size(0)): diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index be8734dbc61..1b884fbfb9f 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict from typing import Any, Dict, List, Optional, Sequence, Union import torch @@ -46,10 +45,8 @@ class IntersectionOverUnion(Metric): - ``boxes`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes, 4)`` containing ``num_boxes`` detection boxes of the format specified in the constructor. By default, this method expects ``(xmin, ymin, xmax, ymax)`` in absolute image coordinates. - - ``scores`` (:class:`~torch.Tensor`): float tensor of shape ``(num_boxes)`` containing detection scores - for the boxes. - - ``labels`` (:class:`~torch.Tensor`): integer tensor of shape ``(num_boxes)`` containing 0-indexed detection - classes for the boxes. + - labels: ``IntTensor`` of shape ``(num_boxes)`` containing 0-indexed detection classes for + the boxes. - ``target`` (:class:`~List`): A list consisting of dictionaries each containing the key-values (each dictionary corresponds to a single image). Parameters that should be provided per dict: @@ -75,17 +72,20 @@ class IntersectionOverUnion(Metric): class_metrics: Option to enable per-class metrics for IoU. Has a performance impact. respect_labels: - Replace IoU values with the `invalid_val` if the labels do not match. + Ignore values from boxes that do not have the same label as the ground truth box. Else will compute Iou + between all pairs of boxes. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Example: + Example:: + >>> import torch >>> from torchmetrics.detection import IntersectionOverUnion >>> preds = [ ... { - ... "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), - ... "scores": torch.tensor([0.236, 0.56]), + ... "boxes": torch.tensor([ + ... [296.55, 93.96, 314.97, 152.79], + ... [298.55, 98.96, 314.97, 151.79]]), ... "labels": torch.tensor([4, 5]), ... } ... ] @@ -97,7 +97,34 @@ class IntersectionOverUnion(Metric): ... ] >>> metric = IntersectionOverUnion() >>> metric(preds, target) - {'iou': tensor(0.4307)} + {'iou': tensor(0.8614)} + + Example:: + + The metric can also return the score per class: + + >>> import torch + >>> from torchmetrics.detection import IntersectionOverUnion + >>> preds = [ + ... { + ... "boxes": torch.tensor([ + ... [296.55, 93.96, 314.97, 152.79], + ... [298.55, 98.96, 314.97, 151.79]]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> target = [ + ... { + ... "boxes": torch.tensor([ + ... [300.00, 100.00, 315.00, 150.00], + ... [300.00, 100.00, 315.00, 150.00] + ... ]), + ... "labels": torch.tensor([4, 5]), + ... } + ... ] + >>> metric = IntersectionOverUnion(class_metrics=True) + >>> metric(preds, target) + {'iou': tensor(0.7756), 'iou/cl_4': tensor(0.6898), 'iou/cl_5': tensor(0.8614)} Raises: ModuleNotFoundError: @@ -108,15 +135,10 @@ class IntersectionOverUnion(Metric): higher_is_better: Optional[bool] = True full_state_update: bool = True - detections: List[Tensor] - detection_scores: List[Tensor] - detection_labels: List[Tensor] - groundtruths: List[Tensor] groundtruth_labels: List[Tensor] - results: List[Tensor] - labels_eq: List[Tensor] + iou_matrix: List[Tensor] _iou_type: str = "iou" - _invalid_val: float = 0.0 + _invalid_val: float = -1.0 def __init__( self, @@ -149,13 +171,8 @@ def __init__( raise ValueError("Expected argument `respect_labels` to be a boolean") self.respect_labels = respect_labels - self.add_state("detections", default=[], dist_reduce_fx=None) - self.add_state("detection_scores", default=[], dist_reduce_fx=None) - self.add_state("detection_labels", default=[], dist_reduce_fx=None) - self.add_state("groundtruths", default=[], dist_reduce_fx=None) self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None) - self.add_state("results", default=[], dist_reduce_fx=None) - self.add_state("labels_eq", default=[], dist_reduce_fx=None) + self.add_state("iou_matrix", default=[], dist_reduce_fx=None) @staticmethod def _iou_update_fn(*args: Any, **kwargs: Any) -> Tensor: @@ -166,50 +183,19 @@ def _iou_compute_fn(*args: Any, **kwargs: Any) -> Tensor: return _iou_compute(*args, **kwargs) def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: - """Update state with predictions and targets. - - Raises: - ValueError: - If ``preds`` is not of type List[Dict[str, Tensor]] - ValueError: - If ``target`` is not of type List[Dict[str, Tensor]] - ValueError: - If ``preds`` and ``target`` are not of the same length - ValueError: - If any of ``preds.boxes``, ``preds.scores`` - and ``preds.labels`` are not of the same length - ValueError: - If any of ``target.boxes`` and ``target.labels`` are not of the same length - ValueError: - If any box is not type float and of length 4 - ValueError: - If any class is not type int and of length 1 - ValueError: - If any score is not type float and of length 1 - - """ - _input_validator(preds, target) + """Update state with predictions and targets.""" + _input_validator(preds, target, ignore_score=True) for p, t in zip(preds, target): det_boxes = self._get_safe_item_values(p["boxes"]) - self.detections.append(det_boxes) - self.detection_labels.append(p["labels"]) - self.detection_scores.append(p["scores"]) - gt_boxes = self._get_safe_item_values(t["boxes"]) - self.groundtruths.append(gt_boxes) self.groundtruth_labels.append(t["labels"]) - label_eq = torch.equal(p["labels"], t["labels"]) - # Workaround to persist state, which only works with tensors - self.labels_eq.append(torch.tensor([label_eq], dtype=torch.int, device=self.device)) - - ious = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) - if self.respect_labels and not label_eq: - label_diff = p["labels"].unsqueeze(0).T - t["labels"].unsqueeze(0) - labels_not_eq = label_diff != 0.0 - ious[labels_not_eq] = self._invalid_val - self.results.append(ious.to(dtype=torch.float, device=self.device)) + iou_matrix = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) # N x M + if self.respect_labels: + label_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) # N x M + iou_matrix[~label_eq] = self._invalid_val + self.iou_matrix.append(iou_matrix) def _get_safe_item_values(self, boxes: Tensor) -> Tensor: boxes = _fix_empty_tensors(boxes) @@ -225,22 +211,19 @@ def _get_gt_classes(self) -> List: def compute(self) -> dict: """Computes IoU based on inputs passed in to ``update`` previously.""" - aggregated_iou = dim_zero_cat( - [self._iou_compute_fn(iou, bool(lbl_eq)) for iou, lbl_eq in zip(self.results, self.labels_eq)] - ) - results: Dict[str, Tensor] = {f"{self._iou_type}": aggregated_iou.mean()} + score = torch.cat([mat[mat != self._invalid_val] for mat in self.iou_matrix], 0).mean() + results: Dict[str, Tensor] = {f"{self._iou_type}": score} if self.class_metrics: - class_results: Dict[int, List[Tensor]] = defaultdict(list) - for iou, label in zip(self.results, self.groundtruth_labels): - for cl in self._get_gt_classes(): - masked_iou = iou[:, label == cl] - if masked_iou.numel() > 0: - class_results[cl].append(self._iou_compute_fn(masked_iou, False)) - - results.update( - {f"{self._iou_type}/cl_{cl}": dim_zero_cat(class_results[cl]).mean() for cl in class_results} - ) + gt_labels = dim_zero_cat(self.groundtruth_labels) + classes = gt_labels.unique().tolist() if len(gt_labels) > 0 else [] + for cl in classes: + masked_iou, observed = torch.zeros_like(score), torch.zeros_like(score) + for mat, gt_lab in zip(self.iou_matrix, self.groundtruth_labels): + scores = mat[:, gt_lab == cl] + masked_iou += scores[scores != self._invalid_val].sum() + observed += scores[scores != self._invalid_val].numel() + results.update({f"{self._iou_type}/cl_{cl}": masked_iou / observed}) return results def plot( diff --git a/src/torchmetrics/functional/detection/ciou.py b/src/torchmetrics/functional/detection/ciou.py index d0f5533b802..3edffc6d193 100644 --- a/src/torchmetrics/functional/detection/ciou.py +++ b/src/torchmetrics/functional/detection/ciou.py @@ -35,10 +35,10 @@ def _ciou_update( return iou -def _ciou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: - if labels_eq: - return iou.diag().mean() - return iou.mean() +def _ciou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor: + if not aggregate: + return iou + return iou.diag().mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device) def complete_intersection_over_union( @@ -62,15 +62,53 @@ def complete_intersection_over_union( replacement_val: Value to replace values under the threshold with. aggregate: - Return the average value instead of the complete IoU matrix. + Return the average value instead of the full matrix of values + + Example:: + By default iou is aggregated across all box pairs e.g. mean along the diagonal of the IoU matrix: - Example: >>> import torch >>> from torchmetrics.functional.detection import complete_intersection_over_union - >>> preds = torch.Tensor([[100, 100, 200, 200]]) - >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) >>> complete_intersection_over_union(preds, target) - tensor(0.6724) + tensor(0.5790) + + Example:: + By setting `aggregate=False` the IoU score per prediction and target boxes is returned: + + >>> import torch + >>> from torchmetrics.functional.detection import complete_intersection_over_union + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) + >>> complete_intersection_over_union(preds, target, aggregate=False) + tensor([[ 0.6883, -0.2072, -0.3352], + [-0.2217, 0.4881, -0.1913], + [-0.3971, -0.1543, 0.5606]]) """ if not _TORCHVISION_GREATER_EQUAL_0_13: @@ -80,4 +118,4 @@ def complete_intersection_over_union( " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." ) iou = _ciou_update(preds, target, iou_threshold, replacement_val) - return _ciou_compute(iou) if aggregate else iou + return _ciou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/diou.py b/src/torchmetrics/functional/detection/diou.py index 42554df0f20..cef9ed0e1c6 100644 --- a/src/torchmetrics/functional/detection/diou.py +++ b/src/torchmetrics/functional/detection/diou.py @@ -35,10 +35,10 @@ def _diou_update( return iou -def _diou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: - if labels_eq: - return iou.diag().mean() - return iou.mean() +def _diou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor: + if not aggregate: + return iou + return iou.diag().mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device) def distance_intersection_over_union( @@ -62,15 +62,53 @@ def distance_intersection_over_union( replacement_val: Value to replace values under the threshold with. aggregate: - Return the average value instead of the complete IoU matrix. + Return the average value instead of the full matrix of values + + Example:: + By default diou is aggregated across all box pairs e.g. mean along the diagonal of the dIoU matrix: - Example: >>> import torch >>> from torchmetrics.functional.detection import distance_intersection_over_union - >>> preds = torch.Tensor([[100, 100, 200, 200]]) - >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) >>> distance_intersection_over_union(preds, target) - tensor(0.6724) + tensor(0.5793) + + Example:: + By setting `aggregate=False` the IoU score per prediction and target boxes is returned: + + >>> import torch + >>> from torchmetrics.functional.detection import distance_intersection_over_union + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) + >>> distance_intersection_over_union(preds, target, aggregate=False) + tensor([[ 0.6883, -0.2043, -0.3351], + [-0.2214, 0.4886, -0.1913], + [-0.3971, -0.1510, 0.5609]]) """ if not _TORCHVISION_GREATER_EQUAL_0_13: @@ -80,4 +118,4 @@ def distance_intersection_over_union( " Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`." ) iou = _diou_update(preds, target, iou_threshold, replacement_val) - return _diou_compute(iou) if aggregate else iou + return _diou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/giou.py b/src/torchmetrics/functional/detection/giou.py index 980784f7e4f..946b5bf0726 100644 --- a/src/torchmetrics/functional/detection/giou.py +++ b/src/torchmetrics/functional/detection/giou.py @@ -35,10 +35,10 @@ def _giou_update( return iou -def _giou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: - if labels_eq: - return iou.diag().mean() - return iou.mean() +def _giou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor: + if not aggregate: + return iou + return iou.diag().mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device) def generalized_intersection_over_union( @@ -62,15 +62,53 @@ def generalized_intersection_over_union( replacement_val: Value to replace values under the threshold with. aggregate: - Return the average value instead of the complete IoU matrix. + Return the average value instead of the full matrix of values + + Example:: + By default giou is aggregated across all box pairs e.g. mean along the diagonal of the gIoU matrix: - Example: >>> import torch >>> from torchmetrics.functional.detection import generalized_intersection_over_union - >>> preds = torch.Tensor([[100, 100, 200, 200]]) - >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) >>> generalized_intersection_over_union(preds, target) - tensor(0.6641) + tensor(0.5638) + + Example:: + By setting `aggregate=False` the full IoU matrix is returned: + + >>> import torch + >>> from torchmetrics.functional.detection import generalized_intersection_over_union + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) + >>> generalized_intersection_over_union(preds, target, aggregate=False) + tensor([[ 0.6895, -0.4964, -0.4944], + [-0.5105, 0.4673, -0.3434], + [-0.6024, -0.4021, 0.5345]]) """ if not _TORCHVISION_GREATER_EQUAL_0_8: @@ -80,4 +118,4 @@ def generalized_intersection_over_union( " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." ) iou = _giou_update(preds, target, iou_threshold, replacement_val) - return _giou_compute(iou) if aggregate else iou + return _giou_compute(iou, aggregate) diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py index 1f37ef01a27..82c4742bfcd 100644 --- a/src/torchmetrics/functional/detection/iou.py +++ b/src/torchmetrics/functional/detection/iou.py @@ -35,10 +35,10 @@ def _iou_update( return iou -def _iou_compute(iou: torch.Tensor, labels_eq: bool = True) -> torch.Tensor: - if labels_eq: - return iou.diag().mean() - return iou.mean() if iou.numel() > 0 else torch.tensor(0.0).to(iou.device) +def _iou_compute(iou: torch.Tensor, aggregate: bool = True) -> torch.Tensor: + if not aggregate: + return iou + return iou.diag().mean() if iou.numel() > 0 else torch.tensor(0.0, device=iou.device) def intersection_over_union( @@ -62,15 +62,53 @@ def intersection_over_union( replacement_val: Value to replace values under the threshold with. aggregate: - Return the average value instead of the complete IoU matrix. + Return the average value instead of the full matrix of values + + Example:: + By default iou is aggregated across all box pairs e.g. mean along the diagonal of the IoU matrix: - Example: >>> import torch >>> from torchmetrics.functional.detection import intersection_over_union - >>> preds = torch.Tensor([[100, 100, 200, 200]]) - >>> target = torch.Tensor([[110, 110, 210, 210]]) + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) >>> intersection_over_union(preds, target) - tensor(0.6807) + tensor(0.5879) + + Example:: + By setting `aggregate=False` the full IoU matrix is returned: + + >>> import torch + >>> from torchmetrics.functional.detection import intersection_over_union + >>> preds = torch.tensor( + ... [ + ... [296.55, 93.96, 314.97, 152.79], + ... [328.94, 97.05, 342.49, 122.98], + ... [356.62, 95.47, 372.33, 147.55], + ... ] + ... ) + >>> target = torch.tensor( + ... [ + ... [300.00, 100.00, 315.00, 150.00], + ... [330.00, 100.00, 350.00, 125.00], + ... [350.00, 100.00, 375.00, 150.00], + ... ] + ... ) + >>> intersection_over_union(preds, target, aggregate=False) + tensor([[0.6898, 0.0000, 0.0000], + [0.0000, 0.5086, 0.0000], + [0.0000, 0.0000, 0.5654]]) """ if not _TORCHVISION_GREATER_EQUAL_0_8: @@ -79,4 +117,4 @@ def intersection_over_union( " Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`." ) iou = _iou_update(preds, target, iou_threshold, replacement_val) - return _iou_compute(iou) if aggregate else iou + return _iou_compute(iou, aggregate) diff --git a/tests/unittests/detection/base_iou_test.py b/tests/unittests/detection/base_iou_test.py deleted file mode 100644 index 15132811297..00000000000 --- a/tests/unittests/detection/base_iou_test.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from abc import ABC -from collections import namedtuple -from dataclasses import dataclass -from functools import partial -from typing import Any, Callable, ClassVar, Dict - -import pytest -import torch -from torch import IntTensor, Tensor -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 - -Input = namedtuple("Input", ["preds", "target"]) - - -@dataclass -class TestCaseData: - """Test data sample.""" - - data: Input - result: Any - - -_preds = torch.Tensor( - [ - [296.55, 93.96, 314.97, 152.79], - [328.94, 97.05, 342.49, 122.98], - [356.62, 95.47, 372.33, 147.55], - ] -) -_target = torch.Tensor( - [ - [300.00, 100.00, 315.00, 150.00], - [330.00, 100.00, 350.00, 125.00], - [350.00, 100.00, 375.00, 150.00], - ] -) - -_inputs = Input( - preds=[ - [ - { - "boxes": Tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), - "scores": Tensor([0.236, 0.56]), - "labels": IntTensor([4, 5]), - } - ], - [ - { - "boxes": Tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), - "scores": Tensor([0.236, 0.56]), - "labels": IntTensor([4, 5]), - } - ], - [ - { - "boxes": Tensor([[328.94, 97.05, 342.49, 122.98]]), - "scores": Tensor([0.456]), - "labels": IntTensor([4]), - }, - { - "boxes": Tensor([[356.62, 95.47, 372.33, 147.55]]), - "scores": Tensor([0.791]), - "labels": IntTensor([4]), - }, - ], - [ - { - "boxes": Tensor([[328.94, 97.05, 342.49, 122.98]]), - "scores": Tensor([0.456]), - "labels": IntTensor([5]), - }, - { - "boxes": Tensor([[356.62, 95.47, 372.33, 147.55]]), - "scores": Tensor([0.791]), - "labels": IntTensor([5]), - }, - ], - ], - target=[ - [ - { - "boxes": Tensor([[300.00, 100.00, 315.00, 150.00]]), - "labels": IntTensor([5]), - } - ], - [ - { - "boxes": Tensor([[300.00, 100.00, 315.00, 150.00]]), - "labels": IntTensor([5]), - } - ], - [ - { - "boxes": Tensor([[330.00, 100.00, 350.00, 125.00]]), - "labels": IntTensor([4]), - }, - { - "boxes": Tensor([[350.00, 100.00, 375.00, 150.00]]), - "labels": IntTensor([4]), - }, - ], - [ - { - "boxes": Tensor([[330.00, 100.00, 350.00, 125.00]]), - "labels": IntTensor([5]), - }, - { - "boxes": Tensor([[350.00, 100.00, 375.00, 150.00]]), - "labels": IntTensor([4]), - }, - ], - ], -) - -_box_inputs = Input(preds=_preds, target=_target) - -_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) - - -def compare_fn(preds: Any, target: Any, result: Any): - """Mock compare function by returning additional parameter results directly.""" - return result - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -@pytest.mark.parametrize("compute_on_cpu", [True, False]) -@pytest.mark.parametrize("ddp", [False, True]) -class BaseTestIntersectionOverUnion(ABC): - """Base Test the Intersection over Union metric for object detection predictions.""" - - data: ClassVar[Dict[str, TestCaseData]] = { - "iou_variant": TestCaseData(data=_inputs, result={"iou": torch.Tensor([0])}), - "fn_iou_variant": TestCaseData(data=_box_inputs, result=None), - } - metric_class: ClassVar - metric_fn: Callable[[Tensor, Tensor, bool, float], Tensor] - - def test_iou_variant(self, compute_on_cpu: bool, ddp: bool): - """Test modular implementation for correctness.""" - key = "iou_variant" - - self.run_class_metric_test( # type: ignore - ddp=ddp, - preds=self.data[key].data.preds, - target=self.data[key].data.target, - metric_class=self.metric_class, - reference_metric=partial(compare_fn, result=self.data[key].result), - dist_sync_on_step=False, - check_batch=False, - metric_args={"compute_on_cpu": compute_on_cpu}, - ) - - def test_iou_variant_dont_respect_labels(self, compute_on_cpu: bool, ddp: bool): - """Test modular implementation for correctness while ignoring labels.""" - key = "iou_variant_respect" - - self.run_class_metric_test( # type: ignore - ddp=ddp, - preds=self.data[key].data.preds, - target=self.data[key].data.target, - metric_class=self.metric_class, - reference_metric=partial(compare_fn, result=self.data[key].result), - dist_sync_on_step=False, - check_batch=False, - metric_args={"compute_on_cpu": compute_on_cpu, "respect_labels": False}, - ) - - def test_fn(self, compute_on_cpu: bool, ddp: bool): - """Test functional implementation for correctness.""" - key = "fn_iou_variant" - self.run_functional_metric_test( - self.data[key].data.preds[0].unsqueeze(0), # pass as batch, otherwise it attempts to pass element wise - self.data[key].data.target[0].unsqueeze(0), - self.metric_fn.__func__, - partial(compare_fn, result=self.data[key].result), - ) - - def test_error_on_wrong_input(self, compute_on_cpu: bool, ddp: bool): - """Test class input validation.""" - metric = self.metric_class() - - metric.update([], []) # no error - - with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"): - metric.update(Tensor(), []) # type: ignore - - with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"): - metric.update([], Tensor()) # type: ignore - - with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): - metric.update([{}], [{}, {}]) - - with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): - metric.update( - [{"scores": Tensor(), "labels": IntTensor}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `scores` key"): - metric.update( - [{"boxes": Tensor(), "labels": IntTensor}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"): - metric.update( - [{"boxes": Tensor(), "scores": IntTensor}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"): - metric.update( - [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], - [{"labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"): - metric.update( - [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], - [{"boxes": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"): - metric.update( - [{"boxes": [], "scores": Tensor(), "labels": IntTensor()}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all scores in `preds` to be of type Tensor"): - metric.update( - [{"boxes": Tensor(), "scores": [], "labels": IntTensor()}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"): - metric.update( - [{"boxes": Tensor(), "scores": Tensor(), "labels": []}], - [{"boxes": Tensor(), "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"): - metric.update( - [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], - [{"boxes": [], "labels": IntTensor()}], - ) - - with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"): - metric.update( - [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], - [{"boxes": Tensor(), "labels": []}], - ) diff --git a/tests/unittests/detection/test_ciou.py b/tests/unittests/detection/test_ciou.py deleted file mode 100644 index b622e7e34ce..00000000000 --- a/tests/unittests/detection/test_ciou.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, ClassVar, Dict - -import pytest -import torch -from torch import Tensor -from torchmetrics import Metric -from torchmetrics.detection.ciou import CompleteIntersectionOverUnion -from torchmetrics.functional.detection.ciou import complete_intersection_over_union -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 - -from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs -from unittests.helpers.testers import MetricTester - -ciou = torch.Tensor( - [ - [-0.2669985], - ] -) -ciou_dontrespect = torch.Tensor( - [ - [0.6078202], - ] -) -box_ciou = torch.Tensor( - [ - [0.6883, -0.2072, -0.3352], - [-0.2217, 0.4881, -0.1913], - [-0.3971, -0.1543, 0.5606], - ] -) - - -_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13) - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.13.0 is installed") -class TestCompleteIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): - """Test the Complete Intersection over Union metric for object detection predictions.""" - - data: ClassVar[Dict[str, TestCaseData]] = { - "iou_variant": TestCaseData(data=_inputs, result={CompleteIntersectionOverUnion._iou_type: ciou}), - "iou_variant_respect": TestCaseData( - data=_inputs, result={CompleteIntersectionOverUnion._iou_type: ciou_dontrespect} - ), - "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_ciou), - } - metric_class: ClassVar[Metric] = CompleteIntersectionOverUnion - metric_fn: ClassVar[Callable[[Tensor, Tensor, bool, float], Tensor]] = complete_intersection_over_union diff --git a/tests/unittests/detection/test_diou.py b/tests/unittests/detection/test_diou.py deleted file mode 100644 index 1bc6323045c..00000000000 --- a/tests/unittests/detection/test_diou.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, ClassVar, Dict - -import pytest -import torch -from torch import Tensor -from torchmetrics.detection.diou import DistanceIntersectionOverUnion -from torchmetrics.functional.detection.diou import distance_intersection_over_union -from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 - -from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs -from unittests.helpers.testers import MetricTester - -diou = torch.Tensor( - [ - [0.06653749], - ] -) -diou_dontrespect = torch.Tensor( - [ - [0.6080749], - ] -) -box_diou = torch.Tensor( - [ - [0.6883, -0.2043, -0.3351], - [-0.2214, 0.4886, -0.1913], - [-0.3971, -0.1510, 0.5609], - ] -) - -_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13) - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.13.0 is installed") -class TestDistanceIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): - """Test the Distance Intersection over Union metric for object detection predictions.""" - - data: ClassVar[Dict[str, TestCaseData]] = { - "iou_variant": TestCaseData(data=_inputs, result={DistanceIntersectionOverUnion._iou_type: diou}), - "iou_variant_respect": TestCaseData( - data=_inputs, result={DistanceIntersectionOverUnion._iou_type: diou_dontrespect} - ), - "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_diou), - } - metric_class: ClassVar[Metric] = DistanceIntersectionOverUnion - metric_fn: ClassVar[Callable[[Tensor, Tensor, bool, float], Tensor]] = distance_intersection_over_union diff --git a/tests/unittests/detection/test_giou.py b/tests/unittests/detection/test_giou.py deleted file mode 100644 index 8c8ffc70341..00000000000 --- a/tests/unittests/detection/test_giou.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, ClassVar, Dict - -import pytest -import torch -from torch import Tensor -from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion -from torchmetrics.functional.detection.giou import generalized_intersection_over_union -from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 - -from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs -from unittests.helpers.testers import MetricTester - -giou = torch.Tensor( - [ - [0.05507809], - ] -) -giou_dontrespect = torch.Tensor( - [ - [0.59242314], - ] -) -box_giou = torch.Tensor( - [ - [0.6895, -0.4964, -0.4944], - [-0.5105, 0.4673, -0.3434], - [-0.6024, -0.4021, 0.5345], - ] -) - -_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -class TestGeneralizedIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): - """Test the Generalized Intersection over Union metric for object detection predictions.""" - - data: ClassVar[Dict[str, TestCaseData]] = { - "iou_variant": TestCaseData(data=_inputs, result={GeneralizedIntersectionOverUnion._iou_type: giou}), - "iou_variant_respect": TestCaseData( - data=_inputs, result={GeneralizedIntersectionOverUnion._iou_type: giou_dontrespect} - ), - "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_giou), - } - metric_class: ClassVar[Metric] = GeneralizedIntersectionOverUnion - metric_fn: ClassVar[Callable[[Tensor, Tensor, bool, float], Tensor]] = generalized_intersection_over_union diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py new file mode 100644 index 00000000000..50ed46cf32e --- /dev/null +++ b/tests/unittests/detection/test_intersection.py @@ -0,0 +1,342 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pytest +import torch +from torch import IntTensor, Tensor +from torchmetrics.detection.ciou import CompleteIntersectionOverUnion +from torchmetrics.detection.diou import DistanceIntersectionOverUnion +from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion +from torchmetrics.detection.iou import IntersectionOverUnion +from torchmetrics.functional.detection.ciou import complete_intersection_over_union +from torchmetrics.functional.detection.diou import distance_intersection_over_union +from torchmetrics.functional.detection.giou import generalized_intersection_over_union +from torchmetrics.functional.detection.iou import intersection_over_union +from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_13 + +# todo: check if some older versions have these functions too? +if _TORCHVISION_GREATER_EQUAL_0_13: + from torchvision.ops import box_iou as tv_iou + from torchvision.ops import complete_box_iou as tv_ciou + from torchvision.ops import distance_box_iou as tv_diou + from torchvision.ops import generalized_box_iou as tv_giou +else: + tv_iou, tv_ciou, tv_diou, tv_giou = ..., ..., ..., ... + +from unittests.helpers.testers import MetricTester + + +def _tv_wrapper(preds, target, base_fn, aggregate=True, iou_threshold=None): + out = base_fn(preds, target) + if iou_threshold is not None: + out[out < iou_threshold] = 0 + if aggregate: + return out.diag().mean() + return out + + +def _tv_wrapper_class(preds, target, base_fn, respect_labels, iou_threshold, class_metrics): + iou = [] + classes = [] + for p, t in zip(preds, target): + out = base_fn(p["boxes"], t["boxes"]) + if iou_threshold is not None: + out[out < iou_threshold] = -1 + if respect_labels: + labels_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) + out[~labels_eq] = -1 + iou.append(out) + classes.append(t["labels"]) + score = torch.cat([i[i != -1] for i in iou]).mean() + base_name = {tv_ciou: "ciou", tv_diou: "diou", tv_giou: "giou", tv_iou: "iou"}[base_fn] + + result = {f"{base_name}": score.cpu()} + if class_metrics: + for cl in torch.cat(classes).unique().tolist(): + class_score, numel = 0, 0 + for s, c in zip(iou, classes): + masked_s = s[:, c == cl] + class_score += masked_s[masked_s != -1].sum() + numel += masked_s[masked_s != -1].numel() + result.update({f"{base_name}/cl_{cl}": class_score.cpu() / numel}) + + return result + + +_preds_fn = ( + torch.tensor( + [ + [296.55, 93.96, 314.97, 152.79], + [328.94, 97.05, 342.49, 122.98], + [356.62, 95.47, 372.33, 147.55], + ] + ) + .unsqueeze(0) + .repeat(4, 1, 1) +) +_target_fn = ( + torch.tensor( + [ + [300.00, 100.00, 315.00, 150.00], + [330.00, 100.00, 350.00, 125.00], + [350.00, 100.00, 375.00, 150.00], + ] + ) + .unsqueeze(0) + .repeat(4, 1, 1) +) + +_preds_class = [ + [ + { + "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + "labels": torch.tensor([4, 5]), + } + ], + [ + { + "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), + "labels": torch.tensor([4, 5]), + } + ], + [ + { + "boxes": torch.tensor([[328.94, 97.05, 342.49, 122.98]]), + "labels": torch.tensor([4]), + }, + { + "boxes": torch.tensor([[356.62, 95.47, 372.33, 147.55]]), + "labels": torch.tensor([4]), + }, + ], + [ + { + "boxes": torch.tensor([[328.94, 97.05, 342.49, 122.98]]), + "labels": torch.tensor([5]), + }, + { + "boxes": torch.tensor([[356.62, 95.47, 372.33, 147.55]]), + "labels": torch.tensor([5]), + }, + ], +] +_target_class = [ + [ + { + "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + "labels": torch.tensor([5]), + } + ], + [ + { + "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), + "labels": torch.tensor([5]), + } + ], + [ + { + "boxes": torch.tensor([[330.00, 100.00, 350.00, 125.00]]), + "labels": torch.tensor([4]), + }, + { + "boxes": torch.tensor([[350.00, 100.00, 375.00, 150.00]]), + "labels": torch.tensor([4]), + }, + ], + [ + { + "boxes": torch.tensor([[330.00, 100.00, 350.00, 125.00]]), + "labels": torch.tensor([5]), + }, + { + "boxes": torch.tensor([[350.00, 100.00, 375.00, 150.00]]), + "labels": torch.tensor([4]), + }, + ], +] + + +def _add_noise(x, scale=10): + """Add noise to boxes and labels to make testing non-deterministic.""" + if isinstance(x, torch.Tensor): + return x + scale * torch.rand_like(x) + for batch in x: + for sample in batch: + sample["boxes"] = _add_noise(sample["boxes"], scale) + sample["labels"] += abs(torch.randint_like(sample["labels"], 0, 10)) + return x + + +@pytest.mark.parametrize( + "class_metric, functional_metric, reference_metric", + [ + (IntersectionOverUnion, intersection_over_union, tv_iou), + (CompleteIntersectionOverUnion, complete_intersection_over_union, tv_ciou), + (DistanceIntersectionOverUnion, distance_intersection_over_union, tv_diou), + (GeneralizedIntersectionOverUnion, generalized_intersection_over_union, tv_giou), + ], +) +@pytest.mark.skipif(not _TORCHVISION_GREATER_EQUAL_0_13, reason="test requires torchvision >= 0.13") +class TestIntersectionMetrics(MetricTester): + """Tester class for the different intersection metrics.""" + + @pytest.mark.parametrize( + ("preds", "target"), [(_preds_class, _target_class), (_add_noise(_preds_class), _add_noise(_target_class))] + ) + @pytest.mark.parametrize("respect_labels", [True, False]) + @pytest.mark.parametrize("iou_threshold", [None, 0.5, 0.7, 0.9]) + @pytest.mark.parametrize("class_metrics", [True, False]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_intersection_class( + self, + class_metric, + functional_metric, + reference_metric, + preds, + target, + respect_labels, + iou_threshold, + class_metrics, + ddp, + ): + """Test class implementation for correctness.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=class_metric, + reference_metric=partial( + _tv_wrapper_class, + base_fn=reference_metric, + respect_labels=respect_labels, + iou_threshold=iou_threshold, + class_metrics=class_metrics, + ), + metric_args={ + "respect_labels": respect_labels, + "iou_threshold": iou_threshold, + "class_metrics": class_metrics, + }, + check_batch=not class_metrics, + ) + + @pytest.mark.parametrize( + ("preds", "target"), + [ + (_preds_fn, _target_fn), + (_add_noise(_preds_fn), _add_noise(_target_fn)), + ], + ) + @pytest.mark.parametrize("aggregate", [True, False]) + @pytest.mark.parametrize("iou_threshold", [None, 0.5, 0.7, 0.9]) + def test_intersection_function( + self, class_metric, functional_metric, reference_metric, preds, target, aggregate, iou_threshold + ): + """Test functional implementation for correctness.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional_metric, + reference_metric=partial( + _tv_wrapper, base_fn=reference_metric, aggregate=aggregate, iou_threshold=iou_threshold + ), + metric_args={"aggregate": aggregate, "iou_threshold": iou_threshold}, + ) + + def test_error_on_wrong_input(self, class_metric, functional_metric, reference_metric): + """Test class input validation.""" + metric = class_metric() + + metric.update([], []) # no error + + with pytest.raises(ValueError, match="Expected argument `preds` to be of type Sequence"): + metric.update(Tensor(), []) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `target` to be of type Sequence"): + metric.update([], Tensor()) # type: ignore + + with pytest.raises(ValueError, match="Expected argument `preds` and `target` to have the same length"): + metric.update([{}], [{}, {}]) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `boxes` key"): + metric.update( + [{"scores": Tensor(), "labels": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `preds` to contain the `labels` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `boxes` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], + [{"labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all dicts in `target` to contain the `labels` key"): + metric.update( + [{"boxes": Tensor(), "scores": IntTensor, "labels": IntTensor}], + [{"boxes": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `preds` to be of type Tensor"): + metric.update( + [{"boxes": [], "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all labels in `preds` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": []}], + [{"boxes": Tensor(), "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all boxes in `target` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": [], "labels": IntTensor()}], + ) + + with pytest.raises(ValueError, match="Expected all labels in `target` to be of type Tensor"): + metric.update( + [{"boxes": Tensor(), "scores": Tensor(), "labels": IntTensor()}], + [{"boxes": Tensor(), "labels": []}], + ) + + +def test_corner_case(): + """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921.""" + preds = [ + { + "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00], [298.55, 98.96, 314.97, 151.79]]), + "scores": torch.tensor([0.236, 0.56]), + "labels": torch.tensor([4, 5]), + } + ] + + target = [ + { + "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00], [298.55, 98.96, 314.97, 151.79]]), + "labels": torch.tensor([4, 5]), + } + ] + + metric = IntersectionOverUnion(class_metrics=True, iou_threshold=0.75, respect_labels=True) + iou = metric(preds, target) + for val in iou.values(): + assert val == torch.tensor(1.0) diff --git a/tests/unittests/detection/test_iou.py b/tests/unittests/detection/test_iou.py deleted file mode 100644 index 020b8c8a221..00000000000 --- a/tests/unittests/detection/test_iou.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Callable, ClassVar, Dict - -import pytest -import torch -from torch import Tensor, tensor -from torchmetrics.detection.iou import IntersectionOverUnion -from torchmetrics.functional.detection.iou import intersection_over_union -from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8 - -from unittests.detection.base_iou_test import BaseTestIntersectionOverUnion, TestCaseData, _box_inputs, _inputs -from unittests.helpers.testers import MetricTester - -iou = torch.Tensor( - [ - [0.40733114], - ] -) -iou_dontrespect = torch.Tensor( - [ - [0.6165285], - ] -) -box_iou = torch.Tensor( - [ - [0.6898, 0.0000, 0.0000], - [0.0000, 0.5086, 0.0000], - [0.0000, 0.0000, 0.5654], - ] -) - -_pytest_condition = not (_TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8) - - -@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") -class TestIntersectionOverUnion(MetricTester, BaseTestIntersectionOverUnion): - """Test the Intersection over Union metric for object detection predictions.""" - - data: ClassVar[Dict[str, TestCaseData]] = { - "iou_variant": TestCaseData(data=_inputs, result={IntersectionOverUnion._iou_type: iou}), - "iou_variant_respect": TestCaseData(data=_inputs, result={IntersectionOverUnion._iou_type: iou_dontrespect}), - "fn_iou_variant": TestCaseData(data=_box_inputs, result=box_iou), - } - metric_class: ClassVar[Metric] = IntersectionOverUnion - metric_fn: ClassVar[Callable[[Tensor, Tensor, bool, float], Tensor]] = intersection_over_union - - -def test_corner_case(): - """Test corner case where preds is empty for a given target. - - See this issue: https://github.com/Lightning-AI/torchmetrics/issues/1889 - - """ - target = [{"boxes": tensor([[238.0000, 74.0000, 343.0000, 275.0000]]), "labels": tensor([6])}] - preds = [{"boxes": tensor([[], [], [], []]).T, "labels": tensor([], dtype=torch.int64), "scores": tensor([])}] - metric = IntersectionOverUnion() - metric.update(preds, target) - result = metric.compute() - assert result["iou"] == tensor(0.0)