Skip to content

Commit

Permalink
symmetric test
Browse files Browse the repository at this point in the history
  • Loading branch information
matsumotosan committed Oct 14, 2023
1 parent 936141f commit 3fc0d37
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/torchmetrics/segmentation/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torchmetrics.functional.segmentation import hausdorff_distance
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -89,7 +90,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tensor:
"""Compute final Hausdorff distance over states."""
return hausdorff_distance(self.preds, self.target, self.distance_metric, self.spacing)
return hausdorff_distance(
dim_zero_cat(self.preds), dim_zero_cat(self.target), self.distance_metric, self.spacing
)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
1 change: 0 additions & 1 deletion tests/unittests/segmentation/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
# extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels
Input = namedtuple("Input", ["preds", "target"])


preds = torch.tensor(
[
[[1, 1, 1, 1, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 0, 0, 0, 1], [1, 1, 1, 1, 1]],
Expand Down
23 changes: 17 additions & 6 deletions tests/unittests/segmentation/test_hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import pytest
import torch
from skimage.metrics import hausdorff_distance as skimage_hausdorff_distance
from torchmetrics.functional.segmentation.hausdorff_distance import hausdorff_distance
from torchmetrics.segmentation.hausdorff_distance import HausdorffDistance
Expand Down Expand Up @@ -48,7 +47,7 @@ def test_hausdorff_distance(self, preds, target, distance_metric, ddp):
preds=preds,
target=target,
metric_class=HausdorffDistance,
reference_metric=partial(skimage_hausdorff_distance, method="standard"),
reference_metric=skimage_hausdorff_distance,
metric_args={"distance_metric": distance_metric, "spacing": None},
)

Expand All @@ -58,9 +57,8 @@ def test_hausdorff_distance_functional(self, preds, target, distance_metric):
preds=preds,
target=target,
metric_functional=hausdorff_distance,
reference_metric=partial(skimage_hausdorff_distance, method="standard"),
distance_metric=distance_metric,
spacing=None,
reference_metric=skimage_hausdorff_distance,
metric_args={"distance_metric": distance_metric, "spacing": None},
)


Expand All @@ -69,3 +67,16 @@ def test_hausdorff_distance_functional_raises_invalid_task():
preds, target = _inputs
with pytest.raises(ValueError, match=r"Expected *"):
hausdorff_distance(preds, target)


@pytest.mark.parametrize(
"distance_metric",
["euclidean", "chessboard", "taxicab"],
)
def test_hausdorff_distance_is_symmetric(distance_metric):
"""Check that the metric functional is symmetric."""
for p, t in zip(_inputs.preds, _inputs.target):
assert torch.allclose(
hausdorff_distance(p, t, distance_metric),
hausdorff_distance(t, p, distance_metric),
)

0 comments on commit 3fc0d37

Please sign in to comment.