Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Adjusted Rand Score #2032

Merged
merged 27 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0994dfc
initial implementation
SkafteNicki Aug 25, 2023
00ec141
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Aug 29, 2023
682f1b1
add init files
SkafteNicki Aug 29, 2023
87f3814
add tests
SkafteNicki Aug 29, 2023
404ed03
docs
SkafteNicki Aug 29, 2023
78153b6
fix doc tests
SkafteNicki Aug 29, 2023
52f2c52
changelog
SkafteNicki Aug 29, 2023
6045755
Merge branch 'master' into newmetric/adjusted_rand_score
Borda Aug 29, 2023
8ac011f
fix
SkafteNicki Aug 29, 2023
1753225
Merge branch 'master' into newmetric/adjusted_rand_score
Borda Aug 29, 2023
1051ac0
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 1, 2023
21e6692
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 1, 2023
4b6a43e
change image
SkafteNicki Sep 1, 2023
18fcdaf
fix
SkafteNicki Sep 4, 2023
49ffcaa
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 4, 2023
abfd6ad
use new inputs
SkafteNicki Sep 4, 2023
07c4e33
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 4, 2023
b4f02f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2023
b0ea203
Update src/torchmetrics/clustering/adjusted_rand_score.py
SkafteNicki Sep 4, 2023
5af8243
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 5, 2023
449d565
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 6, 2023
ef374c6
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 6, 2023
fc135cc
Merge branch 'master' into newmetric/adjusted_rand_score
mergify[bot] Sep 6, 2023
9fb69c8
Merge branch 'master' into newmetric/adjusted_rand_score
mergify[bot] Sep 6, 2023
922053f
Update src/torchmetrics/clustering/adjusted_rand_score.py
SkafteNicki Sep 7, 2023
52dcee3
Update src/torchmetrics/clustering/adjusted_rand_score.py
SkafteNicki Sep 7, 2023
0975c26
Merge branch 'master' into newmetric/adjusted_rand_score
SkafteNicki Sep 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008))


- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025))


- Added `NormalizedMutualInfoScore` metric to cluster package ([#2029](https://github.com/Lightning-AI/torchmetrics/pull/2029))


- Added `AdjustedRandScore` metric to cluster package ([#2032](https://github.com/Lightning-AI/torchmetrics/pull/2032))


- Added `CalinskiHarabaszScore` metric to cluster package ([#2036](https://github.com/Lightning-AI/torchmetrics/pull/2036))


- Added `DunnIndex` metric to cluster package ([#2049](https://github.com/Lightning-AI/torchmetrics/pull/2049))


### Changed

-
Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/adjusted_rand_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Adjusted Rand Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg
:tags: Clustering

.. include:: ../links.rst

###################
Adjusted Rand Score
###################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.AdjustedRandScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.adjusted_rand_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,5 @@
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
.. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index
.. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
from torchmetrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
from torchmetrics.clustering.dunn_index import DunnIndex
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.clustering.normalized_mutual_info_score import NormalizedMutualInfoScore
from torchmetrics.clustering.rand_score import RandScore

__all__ = [
"AdjustedRandScore",
"CalinskiHarabaszScore",
"DunnIndex",
"MutualInfoScore",
Expand Down
125 changes: 125 additions & 0 deletions src/torchmetrics/clustering/adjusted_rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["AdjustedRandScore.plot"]


class AdjustedRandScore(Metric):
r"""Compute `Adjusted Rand Score`_ (also known as Adjusted Rand Index).

.. math::
ARS(U, V) = (\text{RS} - \text{Expected RS}) / (\text{Max RS} - \text{Expected RS})

The adjusted rand score :math:`\text{ARS}` is in essence the :math:`\text{RS}` (rand score) adjusted for chance.
The score ensures that completly randomly cluster labels have a score close to zero and only a perfect match will
have a score of 1 (up to a permutation of the labels). The adjusted rand score is symmetric, therefore swapping
:math:`U` and :math:`V` yields the same adjusted rand score.

This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering is generally used for unsupervised learning.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels

As output of ``forward`` and ``compute`` the metric returns the following output:

- ``adj_rand_score`` (:class:`~torch.Tensor`): Scalar tensor with the adjusted rand score

Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.clustering import AdjustedRandScore
>>> metric = AdjustedRandScore()
>>> metric(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1]))
tensor(1.)
>>> metric(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 1, 0, 1]))
tensor(-0.5000)

"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
plot_lower_bound: float = -0.5
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Compute mutual information over state."""
return adjusted_rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import AdjustedRandScore
>>> metric = AdjustedRandScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import AdjustedRandScore
>>> metric = AdjustedRandScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())

"""
return self._plot(val, ax)
9 changes: 5 additions & 4 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class MutualInfoScore(Metric):
\log\frac{N|U_i\cap V_j|}{|U_i||V_j|}

Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions,
:math:`|U_i|` is the number of samples in cluster :math:`U_i`, and
:math:`|V_i|` is the number of samples in cluster :math:`V_i`.
:math:`|U_i|` is the number of samples in cluster :math:`U_i`, and :math:`|V_i|` is the number of samples in
cluster :math:`V_i`. The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same mutual
information score.

The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields
the same mutual information score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class RandScore(Metric):
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}

The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. The metric is
symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.

The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.

As input to ``forward`` and ``update`` the metric accepts the following input:

Expand Down
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score
from torchmetrics.functional.clustering.calinski_harabasz_score import calinski_harabasz_score
from torchmetrics.functional.clustering.dunn_index import dunn_index
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.normalized_mutual_info_score import normalized_mutual_info_score
from torchmetrics.functional.clustering.rand_score import rand_score

__all__ = [
"adjusted_rand_score",
"calinski_harabasz_score",
"dunn_index",
"mutual_info_score",
Expand Down
75 changes: 75 additions & 0 deletions src/torchmetrics/functional/clustering/adjusted_rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.functional.clustering.utils import (
calcualte_pair_cluster_confusion_matrix,
calculate_contingency_matrix,
check_cluster_labels,
)


def _adjusted_rand_score_update(preds: Tensor, target: Tensor) -> Tensor:
"""Update and return variables required to compute the rand score.

Args:
preds: predicted cluster labels
target: ground truth cluster labels

Returns:
contingency: contingency matrix

"""
check_cluster_labels(preds, target)
return calculate_contingency_matrix(preds, target)


def _adjusted_rand_score_compute(contingency: Tensor) -> Tensor:
"""Compute the rand score based on the contingency matrix.

Args:
contingency: contingency matrix

Returns:
rand_score: rand score

"""
(tn, fp), (fn, tp) = calcualte_pair_cluster_confusion_matrix(contingency=contingency)
if fn == 0 and fp == 0:
return torch.ones_like(tn, dtype=torch.float32)
return 2.0 * (tp * tn - fn * fp) / ((tp + fn) * (fn + tn) + (tp + fp) * (fp + tn))


def adjusted_rand_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute the Adjusted Rand score between two clusterings.

Args:
preds: predicted cluster labels
target: ground truth cluster labels

Returns:
Scalar tensor with adjusted rand score

Example:
>>> from torchmetrics.functional.clustering import adjusted_rand_score
>>> import torch
>>> adjusted_rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([0, 0, 1, 1]))
tensor(1.)
>>> adjusted_rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1]))
tensor(0.5714)

"""
contingency = _adjusted_rand_score_update(preds, target)
return _adjusted_rand_score_compute(contingency)
69 changes: 69 additions & 0 deletions tests/unittests/clustering/test_adjusted_rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from sklearn.metrics import adjusted_rand_score as sklearn_adjusted_rand_score
from torchmetrics.clustering.adjusted_rand_score import AdjustedRandScore
from torchmetrics.functional.clustering.adjusted_rand_score import adjusted_rand_score

from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2
from unittests.helpers.testers import MetricTester


@pytest.mark.parametrize(
"preds, target",
[
(_single_target_extrinsic1.preds, _single_target_extrinsic1.target),
(_single_target_extrinsic2.preds, _single_target_extrinsic2.target),
],
)
class TestAdjustedRandScore(MetricTester):
"""Test class for `AdjustedRandScore` metric."""

atol = 1e-5

@pytest.mark.parametrize("ddp", [True, False])
def test_adjusted_rand_score(self, preds, target, ddp):
"""Test class implementation of metric."""
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=AdjustedRandScore,
reference_metric=sklearn_adjusted_rand_score,
)

def test_rand_score_functional(self, preds, target):
"""Test functional implementation of metric."""
self.run_functional_metric_test(
preds=preds,
target=target,
metric_functional=adjusted_rand_score,
reference_metric=sklearn_adjusted_rand_score,
)


def test_rand_score_functional_raises_invalid_task():
"""Check that metric rejects continuous-valued inputs."""
preds, target = _float_inputs_extrinsic
with pytest.raises(ValueError, match=r"Expected *"):
adjusted_rand_score(preds, target)


def test_rand_score_functional_is_symmetric(
preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target
):
"""Check that the metric funtional is symmetric."""
for p, t in zip(preds, target):
assert torch.allclose(adjusted_rand_score(p, t), adjusted_rand_score(t, p))
Loading
Loading