diff --git a/docs/source/links.rst b/docs/source/links.rst index 78a2b34d764..3bfc12dcffa 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -154,3 +154,6 @@ .. _Normalized Mutual Information Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.normalized_mutual_info_score.html .. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools .. _Rand Score: https://link.springer.com/article/10.1007/BF01908075 +.. _V-Measure Score: https://www.aclweb.org/anthology/D07-1043.pdf +.. _Homogeneity Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.homogeneity_score.html +.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py index 0648322d005..e01eb132883 100644 --- a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -16,8 +16,8 @@ from torch import Tensor from torchmetrics.functional.clustering.homogeneity_completeness_v_measure import ( - homogeneity_score, completeness_score, + homogeneity_score, v_measure_score, ) from torchmetrics.metric import Metric @@ -26,23 +26,18 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = [ - "HomogeneityScore.plot", - "CompletenessScore.plot", - "VMeasureScore.plot" - ] + __doctest_skip__ = ["HomogeneityScore.plot", "CompletenessScore.plot", "VMeasureScore.plot"] class HomogeneityScore(Metric): - r"""Compute `Rand Score`_ (alternatively known as Rand Index). - - .. math:: - RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} + r"""Compute `Homogeneity Score`_. - The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` - (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. + The homogeneity score is a metric to measure the homogeneity of a clustering. A clustering result satisfies + homogeneity if all of its clusters contain only data points which are members of a single class. The metric is not + symmetric, therefore swapping ``preds`` and ``target`` yields a different score. - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -67,13 +62,13 @@ class HomogeneityScore(Metric): """ - is_differentiable = True - higher_is_better = None - full_state_update: bool = True + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 preds: List[Tensor] target: List[Tensor] - contingency: Tensor def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -110,8 +105,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.clustering import RandScore - >>> metric = RandScore() + >>> from torchmetrics.clustering import HomogeneityScore + >>> metric = HomogeneityScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -120,8 +115,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.clustering import RandScore - >>> metric = RandScore() + >>> from torchmetrics.clustering import HomogeneityScore + >>> metric = HomogeneityScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -131,15 +126,13 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ class CompletenessScore(Metric): - r"""Compute `Rand Score`_ (alternatively known as Rand Index). + r"""Compute `Completeness Score`_. - .. math:: - RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} + A clustering result satisfies completeness if all the data points that are members of a given class are elements of + the same cluster. The metric is not symmetric, therefore swapping ``preds`` and ``target`` yields a different - The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` - (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. - - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -164,13 +157,13 @@ class CompletenessScore(Metric): """ - is_differentiable = True - higher_is_better = None - full_state_update: bool = True + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 preds: List[Tensor] target: List[Tensor] - contingency: Tensor def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -207,8 +200,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.clustering import RandScore - >>> metric = RandScore() + >>> from torchmetrics.clustering import CompletenessScore + >>> metric = CompletenessScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -217,8 +210,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.clustering import RandScore - >>> metric = RandScore() + >>> from torchmetrics.clustering import CompletenessScore + >>> metric = CompletenessScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -228,15 +221,19 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ class VMeasureScore(Metric): - r"""Compute `Rand Score`_ (alternatively known as Rand Index). + r"""Compute `V-Measure Score`_. + + The V-measure is the harmonic mean between homogeneity and completeness: - .. math:: - RS(U, V) = \text{number of agreeing pairs} / \text{number of pairs} + ..math:: + v = \frac{(1 + \beta) * homogeneity * completeness}{\beta * homogeneity + completeness} - The number of agreeing pairs is every :math:`(i, j)` pair of samples where :math:`i \in U` and :math:`j \in V` - (the predicted and true clusterings, respectively) that are in the same cluster for both clusterings. + where :math:`\beta` is a weight parameter that defines the weight of homogeneity in the harmonic mean, with the + default value :math:`\beta=1`. The V-measure is symmetric, which means that swapping ``preds`` and ``target`` does + not change the score. - The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields the same rand score. + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -252,25 +249,28 @@ class VMeasureScore(Metric): Example: >>> import torch - >>> from torchmetrics.clustering import RandScore + >>> from torchmetrics.clustering import VMeasureScore >>> preds = torch.tensor([2, 1, 0, 1, 0]) >>> target = torch.tensor([0, 2, 1, 1, 0]) - >>> metric = RandScore() + >>> metric = VMeasureScore() >>> metric(preds, target) tensor(0.6000) """ - is_differentiable = True - higher_is_better = None - full_state_update: bool = True + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 preds: List[Tensor] target: List[Tensor] - contingency: Tensor - def __init__(self, **kwargs: Any) -> None: + def __init__(self, beta: float = 1.0, **kwargs: Any) -> None: super().__init__(**kwargs) + if not (isinstance(beta, float) and beta > 0): + raise ValueError(f"Argument `beta` should be a positive float. Got {beta}.") + self.beta = beta self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") @@ -282,7 +282,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute rand score over state.""" - return v_measure_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + return v_measure_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), beta=self.beta) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. @@ -304,8 +304,8 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.clustering import RandScore - >>> metric = RandScore() + >>> from torchmetrics.clustering import VMeasureScore + >>> metric = VMeasureScore() >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) @@ -314,12 +314,11 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_ >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.clustering import RandScore - >>> metric = RandScore() + >>> from torchmetrics.clustering import VMeasureScore + >>> metric = VMeasureScore() >>> for _ in range(10): ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) >>> fig_, ax_ = metric.plot(metric.compute()) """ return self._plot(val, ax) - diff --git a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py index 3399252b9f1..d412abdf611 100644 --- a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py +++ b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import torch from torch import Tensor @@ -18,12 +20,13 @@ from torchmetrics.functional.clustering.utils import calculate_entropy, check_cluster_labels -def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tensor: +def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Computes the homogeneity score of a clustering given the predicted and target cluster labels.""" check_cluster_labels(preds, target) - if len(target) == 0: - return torch.tensor(0.0, dtype=torch.float32, device=preds.device) + if len(target) == 0: # special case where no clustering is defined + zero = torch.tensor(0.0, dtype=torch.float32, device=preds.device) + return zero.clone(), zero.clone(), zero.clone(), zero.clone() entropy_target = calculate_entropy(target) entropy_preds = calculate_entropy(preds) @@ -33,7 +36,7 @@ def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tensor: return homogeneity, mutual_info, entropy_preds, entropy_target -def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tensor: +def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Computes the completeness score of a clustering given the predicted and target cluster labels.""" homogeneity, mutual_info, entropy_preds, _ = _homogeneity_score_compute(preds, target) completeness = mutual_info / entropy_preds if entropy_preds else torch.ones_like(entropy_preds) @@ -41,7 +44,7 @@ def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tensor: def homogeneity_score(preds: Tensor, target: Tensor) -> Tensor: - """Compute the Rand score between two clusterings. + """Compute the Homogeneity score between two clusterings. Args: preds: predicted cluster labels @@ -64,7 +67,7 @@ def homogeneity_score(preds: Tensor, target: Tensor) -> Tensor: def completeness_score(preds: Tensor, target: Tensor) -> Tensor: - """Compute the Rand score between two clusterings. + """Compute the Completeness score between two clusterings. Args: preds: predicted cluster labels @@ -87,7 +90,7 @@ def completeness_score(preds: Tensor, target: Tensor) -> Tensor: def v_measure_score(preds: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: - """Compute the Rand score between two clusterings. + """Compute the V-measure score between two clusterings. Args: preds: predicted cluster labels @@ -106,7 +109,7 @@ def v_measure_score(preds: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: tensor(0.8333) """ - homogeneity, completeness = _completeness_score_compute(preds, target) + completeness, homogeneity = _completeness_score_compute(preds, target) if homogeneity + completeness == 0.0: v_measure = torch.ones_like(homogeneity) else: diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py index c716ea77daf..31590f956cc 100644 --- a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import pytest from sklearn.metrics import completeness_score as sklearn_completeness_score from sklearn.metrics import homogeneity_score as sklearn_homogeneity_score @@ -39,6 +41,11 @@ (HomogeneityScore, homogeneity_score, sklearn_homogeneity_score), (CompletenessScore, completeness_score, sklearn_completeness_score), (VMeasureScore, v_measure_score, sklearn_v_measure_score), + ( + partial(VMeasureScore, beta=2.0), + partial(v_measure_score, beta=2.0), + partial(sklearn_v_measure_score, beta=2.0), + ), ], ) @pytest.mark.parametrize(