Skip to content

Commit

Permalink
fix docstring + add input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 12, 2024
1 parent 3345393 commit a4f129f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
13 changes: 13 additions & 0 deletions src/torchmetrics/functional/segmentation/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@
from torchmetrics.utilities.checks import _check_same_shape


def _hausdorff_distance_validate_args(
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, list[float]]] = None,
) -> None:
"""Validate the arguments of `hausdorff_distance` function."""
if distance_metric not in ["euclidean", "chessboard", "taxicab"]:
raise ValueError(
f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}."
)
if spacing is not None and not isinstance(spacing, (list, Tensor)):
raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.")


def _hausdorff_distance_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Update and returns variables required to compute `Hausdorff Distance`_.
Expand Down
29 changes: 20 additions & 9 deletions src/torchmetrics/segmentation/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

from torch import Tensor

from torchmetrics.functional.segmentation import hausdorff_distance
from torchmetrics.functional.segmentation.hausdorff_distance import (
_hausdorff_distance_validate_args,
hausdorff_distance,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat, dim_zero_mean
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
Expand All @@ -25,24 +28,33 @@


class HausdorffDistance(Metric):
r"""Compute the Hausdorff distance between two subsets of a metric space.
r"""Compute the `Hausdorff Distance`_ between two subsets of a metric space for semantic segmentation.
.. math::
d_{\Pi}(X,Y) = \max{/sup_{x\in X} {d(x,Y)}, /sup_{y\in Y} {d(X,y)}}
where :math:`\X, \Y` are ________________, :math:`\X, \Y` ______.
where :math:`\X, \Y` are two subsets of a metric space with distance metric :math:`d`. The Hausdorff distance is
the maximum distance from a point in one set to the closest point in the other set. The Hausdorff distance is a
measure of the degree of mismatch between two sets.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`):
- ``target`` (:class:`~torch.Tensor`):
- ``preds`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
- ``target`` (:class:`~torch.Tensor`): An one-hot boolean tensor of shape ``(N, C, ...)`` with ``N`` being
the number of samples and ``C`` the number of classes. Alternatively, an integer tensor of shape ``(N, ...)``
can be provided, where the integer values correspond to the class index. The input type can be controlled
with the ``input_format`` argument.
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``hausdorff_distance`` (:class:`~torch.Tensor`): A scalar float tensor with the Hausdorff distance.
Args:
p: p-norm used for distance metric
distance_metric: distance metric to calculate surface distance. Choose between "euclidean", "chessboard" or
"taxicab".
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
Expand All @@ -59,8 +71,7 @@ class HausdorffDistance(Metric):
... [1, 0, 0, 1, 0],
... [1, 1, 1, 1, 0]], dtype=torch.bool)
>>> hausdorff_distance = HausdorffDistance(distance_metric="euclidean")
>>> hausdorff_distance.update(preds, target)
>>> hausdorff_distance.compute()
>>> hausdorff_distance(preds, target)
tensor(1.)
"""
Expand All @@ -80,9 +91,9 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_hausdorff_distance_validate_args(distance_metric, spacing)
self.distance_metric = distance_metric
self.spacing = spacing

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

Expand Down

0 comments on commit a4f129f

Please sign in to comment.