diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e49f86d671..1a9d9fc6764 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed flakiness in tests related to `torch.unique` with `dim=None` ([#2650](https://github.com/Lightning-AI/torchmetrics/pull/2650)) +- Fixed corner case in `MatthewsCorrCoef` ([#2743](https://github.com/Lightning-AI/torchmetrics/pull/2743)) + + ## [1.4.1] - 2024-08-02 ### Changed diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 544414ee4a8..45e0238dae5 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -64,12 +64,14 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: denom = cov_ypyp * cov_ytyt if denom == 0 and confmat.numel() == 4: - if tp == 0 or tn == 0: - a = tp + tn - - if fp == 0 or fn == 0: - b = fp + fn - + if fn == 0 and tn == 0: + a, b = tp, fp + elif fp == 0 and tn == 0: + a, b = tp, fn + elif tp == 0 and fn == 0: + a, b = tn, fp + elif tp == 0 and fp == 0: + a, b = tn, fn eps = torch.tensor(torch.finfo(torch.float32).eps, dtype=torch.float32, device=confmat.device) numerator = torch.sqrt(eps) * (a - b) denom = (tp + fp + eps) * (tp + fn + eps) * (tn + fp + eps) * (tn + fn + eps) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 03f649bc0ac..2f881604d09 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -331,6 +331,12 @@ def test_zero_case_in_multiclass(): torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]), 0.0, ), + ( + binary_matthews_corrcoef, + torch.tensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), + torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), + 0.0, + ), (binary_matthews_corrcoef, torch.zeros(10), torch.ones(10), -1.0), (binary_matthews_corrcoef, torch.ones(10), torch.zeros(10), -1.0), (