From a4f129f7d7f832134fa4cc6900dfa076c3669865 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 12 Oct 2024 15:57:14 +0200 Subject: [PATCH] fix docstring + add input validation --- .../segmentation/hausdorff_distance.py | 13 +++++++++ .../segmentation/hausdorff_distance.py | 29 +++++++++++++------ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/hausdorff_distance.py b/src/torchmetrics/functional/segmentation/hausdorff_distance.py index 4655b76aeab..171bf15e248 100644 --- a/src/torchmetrics/functional/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/functional/segmentation/hausdorff_distance.py @@ -21,6 +21,19 @@ from torchmetrics.utilities.checks import _check_same_shape +def _hausdorff_distance_validate_args( + distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean", + spacing: Optional[Union[Tensor, list[float]]] = None, +) -> None: + """Validate the arguments of `hausdorff_distance` function.""" + if distance_metric not in ["euclidean", "chessboard", "taxicab"]: + raise ValueError( + f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}." + ) + if spacing is not None and not isinstance(spacing, (list, Tensor)): + raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.") + + def _hausdorff_distance_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Update and returns variables required to compute `Hausdorff Distance`_. diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py index b7542845a0a..d4c01e3808a 100644 --- a/src/torchmetrics/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -14,7 +14,10 @@ from torch import Tensor -from torchmetrics.functional.segmentation import hausdorff_distance +from torchmetrics.functional.segmentation.hausdorff_distance import ( + _hausdorff_distance_validate_args, + hausdorff_distance, +) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat, dim_zero_mean from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE @@ -25,24 +28,33 @@ class HausdorffDistance(Metric): - r"""Compute the Hausdorff distance between two subsets of a metric space. + r"""Compute the `Hausdorff Distance`_ between two subsets of a metric space for semantic segmentation. .. math:: d_{\Pi}(X,Y) = \max{/sup_{x\in X} {d(x,Y)}, /sup_{y\in Y} {d(X,y)}} - where :math:`\X, \Y` are ________________, :math:`\X, \Y` ______. + where :math:`\X, \Y` are two subsets of a metric space with distance metric :math:`d`. The Hausdorff distance is + the maximum distance from a point in one set to the closest point in the other set. The Hausdorff distance is a + measure of the degree of mismatch between two sets. As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): - - ``target`` (:class:`~torch.Tensor`): + - ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. + - ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being + the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)`` + can be provided, where the integer values correspond to the class index. The input type can be controlled + with the ``input_format`` argument. As output of ``forward`` and ``compute`` the metric returns the following output: - ``hausdorff_distance`` (:class:`~torch.Tensor`): A scalar float tensor with the Hausdorff distance. Args: - p: p-norm used for distance metric + distance_metric: distance metric to calculate surface distance. Choose between "euclidean", "chessboard" or + "taxicab". kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -59,8 +71,7 @@ class HausdorffDistance(Metric): ... [1, 0, 0, 1, 0], ... [1, 1, 1, 1, 0]], dtype=torch.bool) >>> hausdorff_distance = HausdorffDistance(distance_metric="euclidean") - >>> hausdorff_distance.update(preds, target) - >>> hausdorff_distance.compute() + >>> hausdorff_distance(preds, target) tensor(1.) """ @@ -80,9 +91,9 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) + _hausdorff_distance_validate_args(distance_metric, spacing) self.distance_metric = distance_metric self.spacing = spacing - self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat")