Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New segmentation metric: Hausdorff Distance #2122

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5ab879d
add docs
matsumotosan Sep 30, 2023
b6519c3
initial commit
matsumotosan Oct 1, 2023
4f5d606
fix hausdorff metric args
matsumotosan Oct 4, 2023
80cbb1a
ci: switch to custom docker images (#2123)
matsumotosan Oct 14, 2023
05c154a
Add `average` to curve metrics (#2084)
matsumotosan Oct 14, 2023
ea58776
symmetric test
matsumotosan Oct 14, 2023
efc972f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2024
3147712
Merge branch 'master' into 1990-hausdorff-distance
matsumotosan May 27, 2024
0c9b1a8
fix merge error
matsumotosan May 27, 2024
a3dcc86
fix imports
matsumotosan May 27, 2024
bfe0a3b
tests running
matsumotosan May 27, 2024
62b7a4c
Merge branch 'master' into 1990-hausdorff-distance
Borda Jul 22, 2024
abe3069
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 6, 2024
5e0b253
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 9, 2024
550770b
Add Torch-to-numpy wrapper for skimage metric
baskrahmer Aug 10, 2024
78da660
Return average over states
baskrahmer Aug 10, 2024
011722d
Fix docs for doctests
baskrahmer Aug 16, 2024
dd837e5
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 16, 2024
c0091e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
1a53a22
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Aug 29, 2024
e2e86d9
Merge branch 'master' into 1990-hausdorff-distance
Borda Sep 2, 2024
873a8ca
Merge branch 'master' into 1990-hausdorff-distance
Borda Sep 16, 2024
223404a
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Sep 24, 2024
80829e8
Add pytest param for ddp
baskrahmer Sep 24, 2024
0e96276
Fix type hints
baskrahmer Sep 24, 2024
7b84f03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2024
cc0d239
Refactor lambda to function definition
baskrahmer Sep 24, 2024
a07e021
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2024
c7bdce9
Fix docs
baskrahmer Sep 24, 2024
8541d97
Output a tensor for reference metric
baskrahmer Sep 24, 2024
90ad414
Set dtype to float32 in reference metric
baskrahmer Sep 24, 2024
f61e7ac
Add links back
baskrahmer Sep 24, 2024
aaee609
Merge branch 'master' into 1990-hausdorff-distance
baskrahmer Oct 7, 2024
708ce27
Merge branch 'master' into 1990-hausdorff-distance
SkafteNicki Oct 12, 2024
3345393
changelog
SkafteNicki Oct 12, 2024
a4f129f
fix docstring + add input validation
SkafteNicki Oct 12, 2024
3f4b68e
add edge_surface_distance utility
SkafteNicki Oct 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `truncation` argument to `BERTScore` ([#2776](https://github.com/Lightning-AI/torchmetrics/pull/2776))


- Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122))


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
2 changes: 2 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,6 @@
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _SCC: https://www.ingentaconnect.com/content/tandf/tres/1998/00000019/00000004/art00013
.. _Generalized Dice Score: https://arxiv.org/abs/1707.03237
.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
21 changes: 21 additions & 0 deletions docs/source/segmentation/hausdorff_distance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Hausdorff Distance
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg
:tags: segmentation

.. include:: ../links.rst

##################
Hausdorff Distance
##################

Module Interface
________________

.. autoclass:: torchmetrics.segmentation.HausdorffDistance
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.hausdorff_distance
Empty file added requirements/integrate.txt
Empty file.
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score
from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance
from torchmetrics.functional.segmentation.mean_iou import mean_iou

__all__ = ["generalized_dice_score", "mean_iou"]
__all__ = ["generalized_dice_score", "mean_iou", "hausdorff_distance"]
128 changes: 128 additions & 0 deletions src/torchmetrics/functional/segmentation/hausdorff_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal, Optional, Tuple, Union

import torch
from torch import Tensor

from torchmetrics.functional.segmentation.utils import check_if_binarized, surface_distance
from torchmetrics.utilities.checks import _check_same_shape


def _hausdorff_distance_validate_args(
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, list[float]]] = None,
) -> None:
"""Validate the arguments of `hausdorff_distance` function."""
if distance_metric not in ["euclidean", "chessboard", "taxicab"]:
raise ValueError(
f"Arg `distance_metric` must be one of 'euclidean', 'chessboard', 'taxicab', but got {distance_metric}."
)
if spacing is not None and not isinstance(spacing, (list, Tensor)):
raise ValueError(f"Arg `spacing` must be a list or tensor, but got {type(spacing)}.")


def _hausdorff_distance_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Update and returns variables required to compute `Hausdorff Distance`_.

Args:
preds: predicted binarized segmentation map
target: target binarized segmentation map

Returns:
preds: predicted binarized segmentation map
target: target binarized segmentation map

"""
check_if_binarized(preds)
check_if_binarized(target)
_check_same_shape(preds, target)
return preds, target


def _hausdorff_distance_compute(
preds: Tensor,
target: Tensor,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, list[float]]] = None,
) -> Tensor:
"""Compute `Hausdorff Distance`_.

Args:
preds: predicted binarized segmentation map
target: target binarized segmentation map
distance_metric: distance metric to calculate surface distance. One of `["euclidean", "chessboard", "taxicab"]`.
spacing: spacing between pixels along each spatial dimension

Returns:
Hausdorff distance

Example:
>>> preds = torch.tensor([[1, 1, 1, 1, 1],
... [1, 0, 0, 0, 1],
... [1, 0, 0, 0, 1],
... [1, 0, 0, 0, 1],
... [1, 1, 1, 1, 1]], dtype=torch.bool)
>>> target = torch.tensor([[1, 1, 1, 1, 0],
... [1, 0, 0, 1, 0],
... [1, 0, 0, 1, 0],
... [1, 0, 0, 1, 0],
... [1, 1, 1, 1, 0]], dtype=torch.bool)
>>> hausdorff_distance(preds, target, distance_metric="euclidean")
tensor(1.)

"""
fwd = surface_distance(preds, target, distance_metric=distance_metric, spacing=spacing)
bwd = surface_distance(target, preds, distance_metric=distance_metric, spacing=spacing)
return torch.max(torch.tensor([fwd.max(), bwd.max()]))


def hausdorff_distance(
preds: Tensor,
target: Tensor,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, list[float]]] = None,
) -> Tensor:
"""Calculate `Hausdorff Distance`_.

Args:
preds: predicted binarized segmentation map
target: target binarized segmentation map
distance_metric: distance metric to calculate surface distance. One of `["euclidean", "chessboard", "taxicab"]`.
spacing: spacing between pixels along each spatial dimension

Returns:
Hausdorff Distance

Example:
>>> import torch
>>> from torchmetrics.functional.segmentation import hausdorff_distance
>>> preds = torch.tensor([[1, 1, 1, 1, 1],
... [1, 0, 0, 0, 1],
... [1, 0, 0, 0, 1],
... [1, 0, 0, 0, 1],
... [1, 1, 1, 1, 1]], dtype=torch.bool)
>>> target = torch.tensor([[1, 1, 1, 1, 0],
... [1, 0, 0, 1, 0],
... [1, 0, 0, 1, 0],
... [1, 0, 0, 1, 0],
... [1, 1, 1, 1, 0]], dtype=torch.bool)
>>> hausdorff_distance(preds, target, distance_metric="euclidean")
tensor(1.)

"""
_hausdorff_distance_validate_args(distance_metric, spacing)
preds, target = _hausdorff_distance_update(preds, target)
return _hausdorff_distance_compute(preds, target, distance_metric=distance_metric, spacing=spacing)
54 changes: 43 additions & 11 deletions src/torchmetrics/functional/segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _ignore_background(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:


def check_if_binarized(x: Tensor) -> None:
"""Check if the input is binarized.
"""Check if tensor is binarized.

Example:
>>> from torchmetrics.functional.segmentation.utils import check_if_binarized
Expand Down Expand Up @@ -249,25 +249,25 @@ def distance_transform(
raise ValueError(f"Expected argument `sampling` to have length 2 but got length `{len(sampling)}`.")

if engine == "pytorch":
x = x.float()
# calculate distance from every foreground pixel to every background pixel
i0, j0 = torch.where(x == 0)
i1, j1 = torch.where(x == 1)
dis_row = (i1.unsqueeze(1) - i0.unsqueeze(0)).abs_().mul_(sampling[0])
dis_col = (j1.unsqueeze(1) - j0.unsqueeze(0)).abs_().mul_(sampling[1])
dis_row = (i1.view(-1, 1) - i0.view(1, -1)).abs()
dis_col = (j1.view(-1, 1) - j0.view(1, -1)).abs()

# # calculate distance
h, _ = x.shape
if metric == "euclidean":
dis_row = dis_row.float()
dis_row.pow_(2).add_(dis_col.pow_(2)).sqrt_()
dis = ((sampling[0] * dis_row) ** 2 + (sampling[1] * dis_col) ** 2).sqrt()
if metric == "chessboard":
dis_row = dis_row.max(dis_col)
dis = torch.max(sampling[0] * dis_row, sampling[1] * dis_col).float()
if metric == "taxicab":
dis_row.add_(dis_col)
dis = (sampling[0] * dis_row + sampling[1] * dis_col).float()

# select only the closest distance
mindis, _ = torch.min(dis_row, dim=1)
z = torch.zeros_like(x, dtype=mindis.dtype).view(-1)
mindis, _ = torch.min(dis, dim=1)
z = torch.zeros_like(x).view(-1)
z[i1 * h + j1] = mindis
return z.view(x.shape)

Expand All @@ -279,7 +279,7 @@ def distance_transform(

if metric == "euclidean":
return ndimage.distance_transform_edt(x.cpu().numpy(), sampling)
return ndimage.distance_transform_cdt(x.cpu().numpy(), metric=metric)
return ndimage.distance_transform_cdt(x.cpu().numpy(), sampling, metric=metric)


def mask_edges(
Expand Down Expand Up @@ -345,7 +345,7 @@ def surface_distance(
target: Tensor,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
) -> Tensor:
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Calculate the surface distance between two binary edge masks.

May return infinity if the predicted mask is empty and the target mask is not, or vice versa.
Expand Down Expand Up @@ -390,6 +390,38 @@ def surface_distance(
return dis[preds]


def edge_surface_distance(
preds: Tensor,
target: Tensor,
distance_metric: Literal["euclidean", "chessboard", "taxicab"] = "euclidean",
spacing: Optional[Union[Tensor, List[float]]] = None,
symmetric: bool = False,
) -> Tensor:
"""Extracts the edges from the input masks and calculates the surface distance between them.

Args:
preds: The predicted binary edge mask.
target: The target binary edge mask.
distance_metric: The distance metric to use. One of `["euclidean", "chessboard", "taxicab"]`.
spacing: The spacing between pixels along each spatial dimension.
symmetric: Whether to calculate the symmetric distance between the edges.

Returns:
A tensor with length equal to the number of edges in predictions e.g. `preds.sum()`. Each element is the
distance from the corresponding edge in `preds` to the closest edge in `target`. If `symmetric` is `True`, the
function returns a tuple containing the distances from the predicted edges to the target edges and vice versa.

"""
output = mask_edges(preds, target)
edges_preds, edges_target = output[0].bool(), output[1].bool()
if symmetric:
return (
surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing),
surface_distance(edges_target, edges_preds, distance_metric=distance_metric, spacing=spacing),
)
return surface_distance(edges_preds, edges_target, distance_metric=distance_metric, spacing=spacing)


@functools.lru_cache
def get_neighbour_tables(
spacing: Union[Tuple[int, int], Tuple[int, int, int]], device: Optional[torch.device] = None
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/segmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore
from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance
from torchmetrics.segmentation.mean_iou import MeanIoU

__all__ = ["GeneralizedDiceScore", "MeanIoU"]
__all__ = ["GeneralizedDiceScore", "MeanIoU", "HausdorffDistance"]
Loading
Loading