diff --git a/src/torchmetrics/segmentation/hausdorff_distance.py b/src/torchmetrics/segmentation/hausdorff_distance.py index 8b52d4bcb8c..d38b5c3b35f 100644 --- a/src/torchmetrics/segmentation/hausdorff_distance.py +++ b/src/torchmetrics/segmentation/hausdorff_distance.py @@ -16,6 +16,7 @@ from torchmetrics.functional.segmentation import hausdorff_distance from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -89,7 +90,9 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute final Hausdorff distance over states.""" - return hausdorff_distance(self.preds, self.target, self.distance_metric, self.spacing) + return hausdorff_distance( + dim_zero_cat(self.preds), dim_zero_cat(self.target), self.distance_metric, self.spacing + ) def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/tests/unittests/segmentation/inputs.py b/tests/unittests/segmentation/inputs.py index a276876a06d..80d0a240d29 100644 --- a/tests/unittests/segmentation/inputs.py +++ b/tests/unittests/segmentation/inputs.py @@ -23,7 +23,6 @@ # extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels Input = namedtuple("Input", ["preds", "target"]) - preds = torch.tensor( [ [[1, 1, 1, 1, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]], diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py index 03eb6b20517..ee390a95c57 100644 --- a/tests/unittests/segmentation/test_hausdorff_distance.py +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial - import pytest +import torch from skimage.metrics import hausdorff_distance as skimage_hausdorff_distance from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance @@ -48,7 +47,7 @@ def test_hausdorff_distance(self, preds, target, distance_metric, ddp): preds=preds, target=target, metric_class=HausdorffDistance, - reference_metric=partial(skimage_hausdorff_distance, method="standard"), + reference_metric=skimage_hausdorff_distance, metric_args={"distance_metric": distance_metric, "spacing": None}, ) @@ -58,9 +57,8 @@ def test_hausdorff_distance_functional(self, preds, target, distance_metric): preds=preds, target=target, metric_functional=hausdorff_distance, - reference_metric=partial(skimage_hausdorff_distance, method="standard"), - distance_metric=distance_metric, - spacing=None, + reference_metric=skimage_hausdorff_distance, + metric_args={"distance_metric": distance_metric, "spacing": None}, ) @@ -69,3 +67,16 @@ def test_hausdorff_distance_functional_raises_invalid_task(): preds, target = _inputs with pytest.raises(ValueError, match=r"Expected *"): hausdorff_distance(preds, target) + + +@pytest.mark.parametrize( + "distance_metric", + ["euclidean", "chessboard", "taxicab"], +) +def test_hausdorff_distance_is_symmetric(distance_metric): + """Check that the metric functional is symmetric.""" + for p, t in zip(_inputs.preds, _inputs.target): + assert torch.allclose( + hausdorff_distance(p, t, distance_metric), + hausdorff_distance(t, p, distance_metric), + )