diff --git a/CHANGELOG.md b/CHANGELOG.md index 2805105ee6c..3b28cad77ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,11 +12,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)) + + - Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025)) + + - Added `NormalizedMutualInfoScore` metric to cluster package ([#2029](https://github.com/Lightning-AI/torchmetrics/pull/2029)) + + +- Added `AdjustedRandScore` metric to cluster package ([#2032](https://github.com/Lightning-AI/torchmetrics/pull/2032)) + + - Added `CalinskiHarabaszScore` metric to cluster package ([#2036](https://github.com/Lightning-AI/torchmetrics/pull/2036)) + + - Added `DunnIndex` metric to cluster package ([#2049](https://github.com/Lightning-AI/torchmetrics/pull/2049)) + ### Changed - diff --git a/docs/source/clustering/adjusted_rand_score.rst b/docs/source/clustering/adjusted_rand_score.rst new file mode 100644 index 00000000000..b2e0bbdab10 --- /dev/null +++ b/docs/source/clustering/adjusted_rand_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Adjusted Rand Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :tags: Clustering + +.. include:: ../links.rst + +################### +Adjusted Rand Score +################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.AdjustedRandScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.adjusted_rand_score diff --git a/docs/source/links.rst b/docs/source/links.rst index e6c85b2994a..88009ba5ff1 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -154,4 +154,5 @@ .. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools .. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 +.. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index .. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index 6f4e67e1197..dfa89738e1f 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.clustering.mutual_info_score import MutualInfoScore @@ -18,6 +19,7 @@ from torchmetrics.clustering.rand_score import RandScore __all__ = [ + "AdjustedRandScore", "CalinskiHarabaszScore", "DunnIndex", "MutualInfoScore", diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py new file mode 100644 index 00000000000..ece3d3aeef7 --- /dev/null +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -0,0 +1,125 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["AdjustedRandScore.plot"] + + +class AdjustedRandScore(Metric): + r"""Compute `Adjusted Rand Score`_ (also known as Adjusted Rand Index). + + .. math:: + ARS(U, V) = (\text{RS} - \text{Expected RS}) / (\text{Max RS} - \text{Expected RS}) + + The adjusted rand score :math:`\text{ARS}` is in essence the :math:`\text{RS}` (rand score) adjusted for chance. + The score ensures that completly randomly cluster labels have a score close to zero and only a perfect match will + have a score of 1 (up to a permutation of the labels). The adjusted rand score is symmetric, therefore swapping + :math:`U` and :math:`V` yields the same adjusted rand score. + + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering is generally used for unsupervised learning. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``adj_rand_score`` (:class:`~torch.Tensor`): Scalar tensor with the adjusted rand score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import AdjustedRandScore + >>> metric = AdjustedRandScore() + >>> metric(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1])) + tensor(1.) + >>> metric(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 1, 0, 1])) + tensor(-0.5000) + + """ + + is_differentiable = True + higher_is_better = None + full_state_update: bool = True + plot_lower_bound: float = -0.5 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute mutual information over state.""" + return adjusted_rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import AdjustedRandScore + >>> metric = AdjustedRandScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import AdjustedRandScore + >>> metric = AdjustedRandScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/clustering/mutual_info_score.py b/src/torchmetrics/clustering/mutual_info_score.py index e92c7c5baa6..6918393af46 100644 --- a/src/torchmetrics/clustering/mutual_info_score.py +++ b/src/torchmetrics/clustering/mutual_info_score.py @@ -33,11 +33,12 @@ class MutualInfoScore(Metric): \log\frac{N|U_i\cap V_j|}{|U_i||V_j|} Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions, - :math:`|U_i|` is the number of samples in cluster :math:`U_i`, and - :math:`|V_i|` is the number of samples in cluster :math:`V_i`. + :math:`|U_i|` is the number of samples in cluster :math:`U_i`, and :math:`|V_i|` is the number of samples in + cluster :math:`V_i`. The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same mutual + information score. - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields - the same mutual information score. + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. As input to ``forward`` and ``update`` the metric accepts the following input: diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index a7fa5bb83f8..29cc5351c70 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -32,9 +32,11 @@ class RandScore(Metric): RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` - (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. + (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. The metric is + symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. As input to ``forward`` and ``update`` the metric accepts the following input: diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 08656e9e5e4..bde34f546d6 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score from torchmetrics.functional.clustering.dunn_index import dunn_index from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score @@ -18,6 +19,7 @@ from torchmetrics.functional.clustering.rand_score import rand_score __all__ = [ + "adjusted_rand_score", "calinski_harabasz_score", "dunn_index", "mutual_info_score", diff --git a/src/torchmetrics/functional/clustering/adjusted_rand_score.py b/src/torchmetrics/functional/clustering/adjusted_rand_score.py new file mode 100644 index 00000000000..59d742862f8 --- /dev/null +++ b/src/torchmetrics/functional/clustering/adjusted_rand_score.py @@ -0,0 +1,75 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch import Tensor + +from torchmetrics.functional.clustering.utils import ( + calcualte_pair_cluster_confusion_matrix, + calculate_contingency_matrix, + check_cluster_labels, +) + + +def _adjusted_rand_score_update(preds: Tensor, target: Tensor) -> Tensor: + """Update and return variables required to compute the rand score. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + contingency: contingency matrix + + """ + check_cluster_labels(preds, target) + return calculate_contingency_matrix(preds, target) + + +def _adjusted_rand_score_compute(contingency: Tensor) -> Tensor: + """Compute the rand score based on the contingency matrix. + + Args: + contingency: contingency matrix + + Returns: + rand_score: rand score + + """ + (tn, fp), (fn, tp) = calcualte_pair_cluster_confusion_matrix(contingency=contingency) + if fn == 0 and fp == 0: + return torch.ones_like(tn, dtype=torch.float32) + return 2.0 * (tp * tn - fn * fp) / ((tp + fn) * (fn + tn) + (tp + fp) * (fp + tn)) + + +def adjusted_rand_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute the Adjusted Rand score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + Scalar tensor with adjusted rand score + + Example: + >>> from torchmetrics.functional.clustering import adjusted_rand_score + >>> import torch + >>> adjusted_rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1])) + tensor(1.) + >>> adjusted_rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(0.5714) + + """ + contingency = _adjusted_rand_score_update(preds, target) + return _adjusted_rand_score_compute(contingency) diff --git a/tests/unittests/clustering/test_adjusted_rand_score.py b/tests/unittests/clustering/test_adjusted_rand_score.py new file mode 100644 index 00000000000..fcd2939bbe3 --- /dev/null +++ b/tests/unittests/clustering/test_adjusted_rand_score.py @@ -0,0 +1,69 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from sklearn.metrics import adjusted_rand_score as sklearn_adjusted_rand_score +from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore +from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score + +from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.helpers.testers import MetricTester + + +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), + ], +) +class TestAdjustedRandScore(MetricTester): + """Test class for `AdjustedRandScore` metric.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_adjusted_rand_score(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=AdjustedRandScore, + reference_metric=sklearn_adjusted_rand_score, + ) + + def test_rand_score_functional(self, preds, target): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=adjusted_rand_score, + reference_metric=sklearn_adjusted_rand_score, + ) + + +def test_rand_score_functional_raises_invalid_task(): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs_extrinsic + with pytest.raises(ValueError, match=r"Expected *"): + adjusted_rand_score(preds, target) + + +def test_rand_score_functional_is_symmetric( + preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target +): + """Check that the metric funtional is symmetric.""" + for p, t in zip(preds, target): + assert torch.allclose(adjusted_rand_score(p, t), adjusted_rand_score(t, p)) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index ff80729cc28..b900b8d292d 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -92,6 +92,7 @@ MultilabelSpecificity, ) from torchmetrics.clustering import ( + AdjustedRandScore, CalinskiHarabaszScore, DunnIndex, MutualInfoScore, @@ -624,6 +625,7 @@ pytest.param(TranslationEditRate, _text_input_3, _text_input_4, id="translation edit rate"), pytest.param(MutualInfoScore, _nominal_input, _nominal_input, id="mutual info score"), pytest.param(RandScore, _nominal_input, _nominal_input, id="rand score"), + pytest.param(AdjustedRandScore, _nominal_input, _nominal_input, id="adjusted rand score"), pytest.param(CalinskiHarabaszScore, lambda: torch.randn(100, 3), _nominal_input, id="calinski harabasz score"), pytest.param(NormalizedMutualInfoScore, _nominal_input, _nominal_input, id="normalized mutual info score"), pytest.param(DunnIndex, lambda: torch.randn(100, 3), _nominal_input, id="dunn index"),