Skip to content

Commit

Permalink
Merge branch 'master' into feature/map_average
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Aug 28, 2023
2 parents 0595357 + 6c54478 commit 75d6146
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)


- Added `RandScore` metric to cluster package ([#2025](https://github.com/Lightning-AI/torchmetrics/pull/2025)


### Changed

-
Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/rand_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Rand Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg
:tags: Clustering

.. include:: ../links.rst

##########
Rand Score
##########

Module Interface
________________

.. autoclass:: torchmetrics.clustering.RandScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.rand_score
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,4 @@
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
2 changes: 2 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.mutual_info_score import MutualInfoScore
from torchmetrics.clustering.rand_score import RandScore

__all__ = [
"MutualInfoScore",
"RandScore",
]
4 changes: 2 additions & 2 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class MutualInfoScore(Metric):
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)``
- ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)``
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels
As output of ``forward`` and ``compute`` the metric returns the following output:
Expand Down
122 changes: 122 additions & 0 deletions src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.clustering.rand_score import rand_score
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["RandScore.plot"]


class RandScore(Metric):
r"""Compute `Rand Score`_ (alternatively known as Rand Index).
.. math::
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}
The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels
- ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score
Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> metric = RandScore()
>>> metric(preds, target)
tensor(0.6000)
"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Compute rand score over state."""
return rand_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
"""
return self._plot(val, ax)
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.rand_score import rand_score

__all__ = ["mutual_info_score"]
__all__ = ["mutual_info_score", "rand_score"]
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute mutual information between two clusterings.
Args:
preds: predicted classes
target: ground truth classes
preds: predicted cluster labels
target: ground truth cluster labels
Example:
>>> from torchmetrics.functional.clustering import mutual_info_score
Expand Down
82 changes: 82 additions & 0 deletions src/torchmetrics/functional/clustering/rand_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor

from torchmetrics.functional.clustering.utils import (
calcualte_pair_cluster_confusion_matrix,
calculate_contingency_matrix,
check_cluster_labels,
)


def _rand_score_update(preds: Tensor, target: Tensor) -> Tensor:
"""Update and return variables required to compute the rand score.
Args:
preds: predicted cluster labels
target: ground truth cluster labels
Returns:
contingency: contingency matrix
"""
check_cluster_labels(preds, target)
return calculate_contingency_matrix(preds, target)


def _rand_score_compute(contingency: Tensor) -> Tensor:
"""Compute the rand score based on the contingency matrix.
Args:
contingency: contingency matrix
Returns:
rand_score: rand score
"""
pair_matrix = calcualte_pair_cluster_confusion_matrix(contingency=contingency)

numerator = pair_matrix.diagonal().sum()
denominator = pair_matrix.sum()
if numerator == denominator or denominator == 0:
# Special limit cases: no clustering since the data is not split;
# or trivial clustering where each document is assigned a unique
# cluster. These are perfect matches hence return 1.0.
return torch.ones_like(numerator, dtype=torch.float32)

return numerator / denominator


def rand_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute the Rand score between two clusterings.
Args:
preds: predicted cluster labels
target: ground truth cluster labels
Returns:
scalar tensor with the rand score
Example:
>>> from torchmetrics.functional.clustering import rand_score
>>> import torch
>>> rand_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0]))
tensor(1.)
>>> rand_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1]))
tensor(0.8333)
"""
contingency = _rand_score_update(preds, target)
return _rand_score_compute(contingency)
70 changes: 70 additions & 0 deletions src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,73 @@ def check_cluster_labels(preds: Tensor, target: Tensor) -> None:
f"Expected real, discrete values but received {preds.dtype} for"
f"predictions and {target.dtype} for target labels instead."
)


def calcualte_pair_cluster_confusion_matrix(
preds: Optional[Tensor] = None,
target: Optional[Tensor] = None,
contingency: Optional[Tensor] = None,
) -> Tensor:
"""Calculates the pair cluster confusion matrix.
Can either be calculated from predicted cluster labels and target cluster labels or from a pre-computed
contingency matrix. The pair cluster confusion matrix is a 2x2 matrix where that defines the similarity between
two clustering by considering all pairs of samples and counting pairs that are assigned into same or different
clusters in the predicted and target clusterings.
Note that the matrix is not symmetric.
Inspired by:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cluster.pair_confusion_matrix.html
Args:
preds: predicted cluster labels
target: ground truth cluster labels
contingency: contingency matrix
Returns:
A 2x2 tensor containing the pair cluster confusion matrix.
Raises:
ValueError:
If neither `preds` and `target` nor `contingency` are provided.
ValueError:
If both `preds` and `target` and `contingency` are provided.
Example:
>>> import torch
>>> from torchmetrics.functional.clustering.utils import calcualte_pair_cluster_confusion_matrix
>>> preds = torch.tensor([0, 0, 1, 1])
>>> target = torch.tensor([1, 1, 0, 0])
>>> calcualte_pair_cluster_confusion_matrix(preds, target)
tensor([[8, 0],
[0, 4]])
>>> preds = torch.tensor([0, 0, 1, 2])
>>> target = torch.tensor([0, 0, 1, 1])
>>> calcualte_pair_cluster_confusion_matrix(preds, target)
tensor([[8, 2],
[0, 2]])
"""
if preds is None and target is None and contingency is None:
raise ValueError("Must provide either `preds` and `target` or `contingency`.")
if preds is not None and target is not None and contingency is not None:
raise ValueError("Must provide either `preds` and `target` or `contingency`, not both.")

if preds is not None and target is not None:
contingency = calculate_contingency_matrix(preds, target)

if contingency is None:
raise ValueError("Must provide `contingency` if `preds` and `target` are not provided.")

n_samples = contingency.sum()
n_c = contingency.sum(dim=1)
n_k = contingency.sum(dim=0)
sum_squared = (contingency**2).sum()

pair_matrix = torch.zeros(2, 2, dtype=contingency.dtype, device=contingency.device)
pair_matrix[1, 1] = sum_squared - n_samples
pair_matrix[1, 0] = (contingency * n_k).sum() - sum_squared
pair_matrix[0, 1] = (contingency.T * n_c).sum() - sum_squared
pair_matrix[0, 0] = n_samples**2 - pair_matrix[0, 1] - pair_matrix[1, 0] - sum_squared
return pair_matrix
12 changes: 11 additions & 1 deletion tests/unittests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import numpy
import torch

from unittests.conftest import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, NUM_PROCESSES, THRESHOLD, setup_ddp
from unittests.conftest import (
BATCH_SIZE,
EXTRA_DIM,
NUM_BATCHES,
NUM_CLASSES,
NUM_PROCESSES,
THRESHOLD,
setup_ddp,
skip_on_running_out_of_memory,
)

# adding compatibility for numpy >= 1.24
for tp_name, tp_ins in [("object", object), ("bool", bool), ("int", int), ("float", float)]:
Expand All @@ -25,4 +34,5 @@
"NUM_PROCESSES",
"THRESHOLD",
"setup_ddp",
"skip_on_running_out_of_memory",
]
Loading

0 comments on commit 75d6146

Please sign in to comment.