diff --git a/CHANGELOG.md b/CHANGELOG.md index f8c45a53d4c..38b04aa4438 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `DunnIndex` metric to cluster package ([#2049](https://github.com/Lightning-AI/torchmetrics/pull/2049)) +- Added `FowlkesMallowsIndex` metric to cluster package ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066)) + ### Changed - diff --git a/docs/source/links.rst b/docs/source/links.rst index 2d0000115d2..0bb83459071 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -158,3 +158,4 @@ .. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi .. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index .. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index +.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index dfa89738e1f..97b34861778 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.clustering.dunn_index import DunnIndex +from torchmetrics.clustering.fowlkes_mallows_index import FowlkesMallowsIndex from torchmetrics.clustering.mutual_info_score import MutualInfoScore from torchmetrics.clustering.normalized_mutual_info_score import NormalizedMutualInfoScore from torchmetrics.clustering.rand_score import RandScore @@ -22,6 +23,7 @@ "AdjustedRandScore", "CalinskiHarabaszScore", "DunnIndex", + "FowlkesMallowsIndex", "MutualInfoScore", "NormalizedMutualInfoScore", "RandScore", diff --git a/src/torchmetrics/clustering/fowlkes_mallows_index.py b/src/torchmetrics/clustering/fowlkes_mallows_index.py new file mode 100644 index 00000000000..a3b32b1586e --- /dev/null +++ b/src/torchmetrics/clustering/fowlkes_mallows_index.py @@ -0,0 +1,121 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering import fowlkes_mallows_index +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["FowlkesMallowsIndex.plot"] + + +class FowlkesMallowsIndex(Metric): + r"""Compute `Fowlkes-Mallows Index`_. + + .. math:: + FMI(U,V) = \frac{TP}{\sqrt{(TP + FP) * (TP + FN)}} + + Where :math:`TP` is the number of true positives, :math:`FP` is the number of false positives, and :math:`FN` is + the number of false negatives. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``fmi`` (:class:`~torch.Tensor`): A tensor with the Fowlkes-Mallows index. + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import FowlkesMallowsIndex + >>> preds = torch.tensor([2, 2, 0, 1, 0]) + >>> target = torch.tensor([2, 2, 1, 1, 0]) + >>> fmi = FowlkesMallowsIndex() + >>> fmi(preds, target) + tensor(0.5000) + + """ + + is_differentiable: bool = True + higher_is_better: Optional[bool] = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + contingency: Tensor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute Fowlkes-Mallows index over state.""" + return fowlkes_mallows_index(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import FowlkesMallowsIndex + >>> metric = FowlkesMallowsIndex() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import FowlkesMallowsIndex + >>> metric = FowlkesMallowsIndex() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index bde34f546d6..f0d03be6deb 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -14,6 +14,7 @@ from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score from torchmetrics.functional.clustering.dunn_index import dunn_index +from torchmetrics.functional.clustering.fowlkes_mallows_index import fowlkes_mallows_index from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.functional.clustering.normalized_mutual_info_score import normalized_mutual_info_score from torchmetrics.functional.clustering.rand_score import rand_score @@ -22,6 +23,7 @@ "adjusted_rand_score", "calinski_harabasz_score", "dunn_index", + "fowlkes_mallows_index", "mutual_info_score", "normalized_mutual_info_score", "rand_score", diff --git a/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py new file mode 100644 index 00000000000..2369e49c565 --- /dev/null +++ b/src/torchmetrics/functional/clustering/fowlkes_mallows_index.py @@ -0,0 +1,78 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels + + +def _fowlkes_mallows_index_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: + """Return contingency matrix required to compute the Fowlkes-Mallows index. + + Args: + preds: predicted class labels + target: ground truth class labels + + Returns: + contingency: contingency matrix + + """ + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target), preds.size(0) + + +def _fowlkes_mallows_index_compute(contingency: Tensor, n: int) -> Tensor: + """Compute the Fowlkes-Mallows index based on the contingency matrix. + + Args: + contingency: contingency matrix + n: number of samples + + Returns: + fowlkes_mallows: Fowlkes-Mallows index + + """ + tk = torch.sum(contingency**2) - n + if torch.allclose(tk, tensor(0)): + return torch.tensor(0.0, device=contingency.device) + + pk = torch.sum(contingency.sum(dim=0) ** 2) - n + qk = torch.sum(contingency.sum(dim=1) ** 2) - n + + return torch.sqrt(tk / pk) * torch.sqrt(tk / qk) + + +def fowlkes_mallows_index(preds: Tensor, target: Tensor) -> Tensor: + """Compute Fowlkes-Mallows index between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + fowlkes_mallows: Fowlkes-Mallows index + + Example: + >>> import torch + >>> from torchmetrics.functional.clustering import fowlkes_mallows_index + >>> preds = torch.tensor([2, 2, 0, 1, 0]) + >>> target = torch.tensor([2, 2, 1, 1, 0]) + >>> fowlkes_mallows_index(preds, target) + tensor(0.5000) + + """ + contingency, n = _fowlkes_mallows_index_update(preds, target) + return _fowlkes_mallows_index_compute(contingency, n) diff --git a/tests/unittests/clustering/test_fowlkes_mallows_index.py b/tests/unittests/clustering/test_fowlkes_mallows_index.py new file mode 100644 index 00000000000..1516a517c73 --- /dev/null +++ b/tests/unittests/clustering/test_fowlkes_mallows_index.py @@ -0,0 +1,56 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from sklearn.metrics import fowlkes_mallows_score as sklearn_fowlkes_mallows_score +from torchmetrics.clustering import FowlkesMallowsIndex +from torchmetrics.functional.clustering import fowlkes_mallows_index + +from unittests.clustering.inputs import _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), + ], +) +class TestFowlkesMallowsIndex(MetricTester): + """Test class for `FowlkesMallowsIndex` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_fowlkes_mallows_index(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=FowlkesMallowsIndex, + reference_metric=sklearn_fowlkes_mallows_score, + ) + + def test_fowlkes_mallows_index_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=fowlkes_mallows_index, + reference_metric=sklearn_fowlkes_mallows_score, + )