Skip to content

Commit

Permalink
doc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Sep 4, 2023
1 parent 0781213 commit 71c106f
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 63 deletions.
3 changes: 3 additions & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,6 @@
.. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
.. _Rand Score: https://link.springer.com/article/10.1007/BF01908075
.. _V-Measure Score: https://www.aclweb.org/anthology/D07-1043.pdf
.. _Homogeneity Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.homogeneity_score.html
.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html
109 changes: 54 additions & 55 deletions src/torchmetrics/clustering/homogeneity_completeness_v_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from torch import Tensor

from torchmetrics.functional.clustering.homogeneity_completeness_v_measure import (
homogeneity_score,
completeness_score,
homogeneity_score,
v_measure_score,
)
from torchmetrics.metric import Metric
Expand All @@ -26,23 +26,18 @@
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = [
"HomogeneityScore.plot",
"CompletenessScore.plot",
"VMeasureScore.plot"
]
__doctest_skip__ = ["HomogeneityScore.plot", "CompletenessScore.plot", "VMeasureScore.plot"]


class HomogeneityScore(Metric):
r"""Compute `Rand Score`_ (alternatively known as Rand Index).
.. math::
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}
r"""Compute `Homogeneity Score`_.
The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.
The homogeneity score is a metric to measure the homogeneity of a clustering. A clustering result satisfies
homogeneity if all of its clusters contain only data points which are members of a single class. The metric is not
symmetric, therefore swapping ``preds`` and ``target`` yields a different score.
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.
As input to ``forward`` and ``update`` the metric accepts the following input:
Expand All @@ -67,13 +62,13 @@ class HomogeneityScore(Metric):
"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -110,8 +105,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import HomogeneityScore
>>> metric = HomogeneityScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
Expand All @@ -120,8 +115,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import HomogeneityScore
>>> metric = HomogeneityScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
Expand All @@ -131,15 +126,13 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_


class CompletenessScore(Metric):
r"""Compute `Rand Score`_ (alternatively known as Rand Index).
r"""Compute `Completeness Score`_.
.. math::
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}
A clustering result satisfies completeness if all the data points that are members of a given class are elements of
the same cluster. The metric is not symmetric, therefore swapping ``preds`` and ``target`` yields a different
The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.
As input to ``forward`` and ``update`` the metric accepts the following input:
Expand All @@ -164,13 +157,13 @@ class CompletenessScore(Metric):
"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -207,8 +200,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import CompletenessScore
>>> metric = CompletenessScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
Expand All @@ -217,8 +210,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import CompletenessScore
>>> metric = CompletenessScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
Expand All @@ -228,15 +221,19 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_


class VMeasureScore(Metric):
r"""Compute `Rand Score`_ (alternatively known as Rand Index).
r"""Compute `V-Measure Score`_.
The V-measure is the harmonic mean between homogeneity and completeness:
.. math::
RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs}
..math::
v = \frac{(1 + \beta) * homogeneity * completeness}{\beta * homogeneity + completeness}
The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V`
(the predicted and true clusterings, respectively) that are in the same cluster for both clusterings.
where :math:`\beta` is a weight parameter that defines the weight of homogeneity in the harmonic mean, with the
default value :math:`\beta=1`. The V-measure is symmetric, which means that swapping ``preds`` and ``target`` does
not change the score.
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score.
This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not
be available in practice since clustering in generally is used for unsupervised learning.
As input to ``forward`` and ``update`` the metric accepts the following input:
Expand All @@ -252,25 +249,28 @@ class VMeasureScore(Metric):
Example:
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> from torchmetrics.clustering import VMeasureScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> metric = RandScore()
>>> metric = VMeasureScore()
>>> metric(preds, target)
tensor(0.6000)
"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
plot_upper_bound: float = 1.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
def __init__(self, beta: float = 1.0, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not (isinstance(beta, float) and beta > 0):
raise ValueError(f"Argument `beta` should be a positive float. Got {beta}.")
self.beta = beta

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
Expand All @@ -282,7 +282,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:

def compute(self) -> Tensor:
"""Compute rand score over state."""
return v_measure_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))
return v_measure_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), beta=self.beta)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand All @@ -304,8 +304,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import VMeasureScore
>>> metric = VMeasureScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
Expand All @@ -314,12 +314,11 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import RandScore
>>> metric = RandScore()
>>> from torchmetrics.clustering import VMeasureScore
>>> metric = VMeasureScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
"""
return self._plot(val, ax)

Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.functional.clustering.utils import calculate_entropy, check_cluster_labels


def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tensor:
def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Computes the homogeneity score of a clustering given the predicted and target cluster labels."""
check_cluster_labels(preds, target)

if len(target) == 0:
return torch.tensor(0.0, dtype=torch.float32, device=preds.device)
if len(target) == 0: # special case where no clustering is defined
zero = torch.tensor(0.0, dtype=torch.float32, device=preds.device)
return zero.clone(), zero.clone(), zero.clone(), zero.clone()

entropy_target = calculate_entropy(target)
entropy_preds = calculate_entropy(preds)
Expand All @@ -33,15 +36,15 @@ def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tensor:
return homogeneity, mutual_info, entropy_preds, entropy_target


def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tensor:
def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""Computes the completeness score of a clustering given the predicted and target cluster labels."""
homogeneity, mutual_info, entropy_preds, _ = _homogeneity_score_compute(preds, target)
completeness = mutual_info / entropy_preds if entropy_preds else torch.ones_like(entropy_preds)
return completeness, homogeneity


def homogeneity_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute the Rand score between two clusterings.
"""Compute the Homogeneity score between two clusterings.
Args:
preds: predicted cluster labels
Expand All @@ -64,7 +67,7 @@ def homogeneity_score(preds: Tensor, target: Tensor) -> Tensor:


def completeness_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute the Rand score between two clusterings.
"""Compute the Completeness score between two clusterings.
Args:
preds: predicted cluster labels
Expand All @@ -87,7 +90,7 @@ def completeness_score(preds: Tensor, target: Tensor) -> Tensor:


def v_measure_score(preds: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
"""Compute the Rand score between two clusterings.
"""Compute the V-measure score between two clusterings.
Args:
preds: predicted cluster labels
Expand All @@ -106,7 +109,7 @@ def v_measure_score(preds: Tensor, target: Tensor, beta: float = 1.0) -> Tensor:
tensor(0.8333)
"""
homogeneity, completeness = _completeness_score_compute(preds, target)
completeness, homogeneity = _completeness_score_compute(preds, target)
if homogeneity + completeness == 0.0:
v_measure = torch.ones_like(homogeneity)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

import pytest
from sklearn.metrics import completeness_score as sklearn_completeness_score
from sklearn.metrics import homogeneity_score as sklearn_homogeneity_score
Expand Down Expand Up @@ -39,6 +41,11 @@
(HomogeneityScore, homogeneity_score, sklearn_homogeneity_score),
(CompletenessScore, completeness_score, sklearn_completeness_score),
(VMeasureScore, v_measure_score, sklearn_v_measure_score),
(
partial(VMeasureScore, beta=2.0),
partial(v_measure_score, beta=2.0),
partial(sklearn_v_measure_score, beta=2.0),
),
],
)
@pytest.mark.parametrize(
Expand Down

0 comments on commit 71c106f

Please sign in to comment.