diff --git a/CHANGELOG.md b/CHANGELOG.md index 72332b4ce3d..2317ac43604 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DunnIndex` ([#2049](https://github.com/Lightning-AI/torchmetrics/pull/2049)) + - `HomogeneityScore` ([#2053](https://github.com/Lightning-AI/torchmetrics/pull/2053)) + + - `CompletenessScore` ([#2053](https://github.com/Lightning-AI/torchmetrics/pull/2053)) + + - `VMeasureScore` ([#2053](https://github.com/Lightning-AI/torchmetrics/pull/2053)) + - `FowlkesMallowsIndex` ([#2066](https://github.com/Lightning-AI/torchmetrics/pull/2066)) - `DaviesBouldinScore` ([#2071](https://github.com/Lightning-AI/torchmetrics/pull/2071)) diff --git a/docs/source/clustering/completeness_score.rst b/docs/source/clustering/completeness_score.rst new file mode 100644 index 00000000000..83f572f5489 --- /dev/null +++ b/docs/source/clustering/completeness_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Completeness Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :tags: Clustering + +.. include:: ../links.rst + +################## +Completeness Score +################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.CompletenessScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.completeness_score diff --git a/docs/source/clustering/homogeneity_score.rst b/docs/source/clustering/homogeneity_score.rst new file mode 100644 index 00000000000..c12511aea64 --- /dev/null +++ b/docs/source/clustering/homogeneity_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Homogeneity Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :tags: Clustering + +.. include:: ../links.rst + +################# +Homogeneity Score +################# + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.HomogeneityScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.homogeneity_score diff --git a/docs/source/clustering/v_measure_score.rst b/docs/source/clustering/v_measure_score.rst new file mode 100644 index 00000000000..a60a51dfa91 --- /dev/null +++ b/docs/source/clustering/v_measure_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: V-Measure Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/default.svg + :tags: Clustering + +.. include:: ../links.rst + +############### +V-Measure Score +############### + +Module Interface +________________ + +.. autoclass:: torchmetrics.clustering.VMeasureScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.clustering.v_measure_score diff --git a/docs/source/links.rst b/docs/source/links.rst index aeb2c5698fe..9c63d351b43 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -158,5 +158,8 @@ .. _fork of pycocotools: https://github.com/ppwwyyxx/cocoapi .. _Adjusted Rand Score: https://en.wikipedia.org/wiki/Rand_index#Adjusted_Rand_index .. _Dunn Index: https://en.wikipedia.org/wiki/Dunn_index +.. _V-Measure Score: https://www.aclweb.org/anthology/D07-1043.pdf +.. _Homogeneity Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.homogeneity_score.html +.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html .. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index .. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score diff --git a/src/torchmetrics/clustering/__init__.py b/src/torchmetrics/clustering/__init__.py index f2a123f9f0a..ca7f8ed29af 100644 --- a/src/torchmetrics/clustering/__init__.py +++ b/src/torchmetrics/clustering/__init__.py @@ -16,6 +16,11 @@ from torchmetrics.clustering.davies_bouldin_score import DaviesBouldinScore from torchmetrics.clustering.dunn_index import DunnIndex from torchmetrics.clustering.fowlkes_mallows_index import FowlkesMallowsIndex +from torchmetrics.clustering.homogeneity_completeness_v_measure import ( + CompletenessScore, + HomogeneityScore, + VMeasureScore, +) from torchmetrics.clustering.mutual_info_score import MutualInfoScore from torchmetrics.clustering.normalized_mutual_info_score import NormalizedMutualInfoScore from torchmetrics.clustering.rand_score import RandScore @@ -23,10 +28,13 @@ __all__ = [ "AdjustedRandScore", "CalinskiHarabaszScore", + "CompletenessScore", "DaviesBouldinScore", "DunnIndex", "FowlkesMallowsIndex", + "HomogeneityScore", "MutualInfoScore", "NormalizedMutualInfoScore", "RandScore", + "VMeasureScore", ] diff --git a/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py new file mode 100644 index 00000000000..ccef518e4a3 --- /dev/null +++ b/src/torchmetrics/clustering/homogeneity_completeness_v_measure.py @@ -0,0 +1,325 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.clustering.homogeneity_completeness_v_measure import ( + completeness_score, + homogeneity_score, + v_measure_score, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["HomogeneityScore.plot", "CompletenessScore.plot", "VMeasureScore.plot"] + + +class HomogeneityScore(Metric): + r"""Compute `Homogeneity Score`_. + + The homogeneity score is a metric to measure the homogeneity of a clustering. A clustering result satisfies + homogeneity if all of its clusters contain only data points which are members of a single class. The metric is not + symmetric, therefore swapping ``preds`` and ``target`` yields a different score. + + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import HomogeneityScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> metric = HomogeneityScore() + >>> metric(preds, target) + tensor(0.4744) + + """ + + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute rand score over state.""" + return homogeneity_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import HomogeneityScore + >>> metric = HomogeneityScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import HomogeneityScore + >>> metric = HomogeneityScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) + + +class CompletenessScore(Metric): + r"""Compute `Completeness Score`_. + + A clustering result satisfies completeness if all the data points that are members of a given class are elements of + the same cluster. The metric is not symmetric, therefore swapping ``preds`` and ``target`` yields a different + + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score + + Args: + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import CompletenessScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> metric = CompletenessScore() + >>> metric(preds, target) + tensor(0.4744) + + """ + + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute rand score over state.""" + return completeness_score(dim_zero_cat(self.preds), dim_zero_cat(self.target)) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import CompletenessScore + >>> metric = CompletenessScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import CompletenessScore + >>> metric = CompletenessScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) + + +class VMeasureScore(Metric): + r"""Compute `V-Measure Score`_. + + The V-measure is the harmonic mean between homogeneity and completeness: + + ..math:: + v = \frac{(1 + \beta) * homogeneity * completeness}{\beta * homogeneity + completeness} + + where :math:`\beta` is a weight parameter that defines the weight of homogeneity in the harmonic mean, with the + default value :math:`\beta=1`. The V-measure is symmetric, which means that swapping ``preds`` and ``target`` does + not change the score. + + This clustering metric is an extrinsic measure, because it requires ground truth clustering labels, which may not + be available in practice since clustering in generally is used for unsupervised learning. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with predicted cluster labels + - ``target`` (:class:`~torch.Tensor`): single integer tensor with shape ``(N,)`` with ground truth cluster labels + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rand_score`` (:class:`~torch.Tensor`): A tensor with the Rand Score + + Args: + beta: Weight parameter that defines the weight of homogeneity in the harmonic mean + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torchmetrics.clustering import VMeasureScore + >>> preds = torch.tensor([2, 1, 0, 1, 0]) + >>> target = torch.tensor([0, 2, 1, 1, 0]) + >>> metric = VMeasureScore(beta=2.0) + >>> metric(preds, target) + tensor(0.4744) + + """ + + is_differentiable: bool = True + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + + def __init__(self, beta: float = 1.0, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not (isinstance(beta, float) and beta > 0): + raise ValueError(f"Argument `beta` should be a positive float. Got {beta}.") + self.beta = beta + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Tensor: + """Compute rand score over state.""" + return v_measure_score(dim_zero_cat(self.preds), dim_zero_cat(self.target), beta=self.beta) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.clustering import VMeasureScore + >>> metric = VMeasureScore() + >>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.clustering import VMeasureScore + >>> metric = VMeasureScore() + >>> for _ in range(10): + ... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,))) + >>> fig_, ax_ = metric.plot(metric.compute()) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/clustering/__init__.py b/src/torchmetrics/functional/clustering/__init__.py index 8942789a87c..aa83d386f28 100644 --- a/src/torchmetrics/functional/clustering/__init__.py +++ b/src/torchmetrics/functional/clustering/__init__.py @@ -16,6 +16,11 @@ from torchmetrics.functional.clustering.davies_bouldin_score import davies_bouldin_score from torchmetrics.functional.clustering.dunn_index import dunn_index from torchmetrics.functional.clustering.fowlkes_mallows_index import fowlkes_mallows_index +from torchmetrics.functional.clustering.homogeneity_completeness_v_measure import ( + completeness_score, + homogeneity_score, + v_measure_score, +) from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score from torchmetrics.functional.clustering.normalized_mutual_info_score import normalized_mutual_info_score from torchmetrics.functional.clustering.rand_score import rand_score @@ -23,10 +28,13 @@ __all__ = [ "adjusted_rand_score", "calinski_harabasz_score", + "completeness_score", "davies_bouldin_score", "dunn_index", "fowlkes_mallows_index", + "homogeneity_score", "mutual_info_score", "normalized_mutual_info_score", "rand_score", + "v_measure_score", ] diff --git a/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py new file mode 100644 index 00000000000..e98f1e26b5b --- /dev/null +++ b/src/torchmetrics/functional/clustering/homogeneity_completeness_v_measure.py @@ -0,0 +1,115 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +from torch import Tensor + +from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score +from torchmetrics.functional.clustering.utils import calculate_entropy, check_cluster_labels + + +def _homogeneity_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the homogeneity score of a clustering given the predicted and target cluster labels.""" + check_cluster_labels(preds, target) + + if len(target) == 0: # special case where no clustering is defined + zero = torch.tensor(0.0, dtype=torch.float32, device=preds.device) + return zero.clone(), zero.clone(), zero.clone(), zero.clone() + + entropy_target = calculate_entropy(target) + entropy_preds = calculate_entropy(preds) + mutual_info = mutual_info_score(preds, target) + + homogeneity = mutual_info / entropy_target if entropy_target else torch.ones_like(entropy_target) + return homogeneity, mutual_info, entropy_preds, entropy_target + + +def _completeness_score_compute(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """Computes the completeness score of a clustering given the predicted and target cluster labels.""" + homogeneity, mutual_info, entropy_preds, _ = _homogeneity_score_compute(preds, target) + completeness = mutual_info / entropy_preds if entropy_preds else torch.ones_like(entropy_preds) + return completeness, homogeneity + + +def homogeneity_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute the Homogeneity score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + scalar tensor with the rand score + + Example: + >>> from torchmetrics.functional.clustering import homogeneity_score + >>> import torch + >>> homogeneity_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0])) + tensor(1.) + >>> homogeneity_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(1.) + + """ + homogeneity, _, _, _ = _homogeneity_score_compute(preds, target) + return homogeneity + + +def completeness_score(preds: Tensor, target: Tensor) -> Tensor: + """Compute the Completeness score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + + Returns: + scalar tensor with the rand score + + Example: + >>> from torchmetrics.functional.clustering import completeness_score + >>> import torch + >>> completeness_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0])) + tensor(1.) + >>> completeness_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(0.6667) + + """ + completeness, _ = _completeness_score_compute(preds, target) + return completeness + + +def v_measure_score(preds: Tensor, target: Tensor, beta: float = 1.0) -> Tensor: + """Compute the V-measure score between two clusterings. + + Args: + preds: predicted cluster labels + target: ground truth cluster labels + beta: weight of the harmonic mean between homogeneity and completeness + + Returns: + scalar tensor with the rand score + + Example: + >>> from torchmetrics.functional.clustering import v_measure_score + >>> import torch + >>> v_measure_score(torch.tensor([0, 0, 1, 1]), torch.tensor([1, 1, 0, 0])) + tensor(1.) + >>> v_measure_score(torch.tensor([0, 0, 1, 2]), torch.tensor([0, 0, 1, 1])) + tensor(0.8000) + + """ + completeness, homogeneity = _completeness_score_compute(preds, target) + if homogeneity + completeness == 0.0: + return torch.ones_like(homogeneity) + return (1 + beta) * homogeneity * completeness / (beta * homogeneity + completeness) diff --git a/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py new file mode 100644 index 00000000000..673271c26fc --- /dev/null +++ b/tests/unittests/clustering/test_homogeneity_completeness_v_measure.py @@ -0,0 +1,98 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pytest +from sklearn.metrics import completeness_score as sklearn_completeness_score +from sklearn.metrics import homogeneity_score as sklearn_homogeneity_score +from sklearn.metrics import v_measure_score as sklearn_v_measure_score +from torchmetrics.clustering.homogeneity_completeness_v_measure import ( + CompletenessScore, + HomogeneityScore, + VMeasureScore, +) +from torchmetrics.functional.clustering.homogeneity_completeness_v_measure import ( + completeness_score, + homogeneity_score, + v_measure_score, +) + +from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + + +def _sk_reference(preds, target, fn): + """Compute reference values using sklearn.""" + return fn(target, preds) + + +@pytest.mark.parametrize( + "modular_metric, functional_metric, reference_metric", + [ + (HomogeneityScore, homogeneity_score, sklearn_homogeneity_score), + (CompletenessScore, completeness_score, sklearn_completeness_score), + (VMeasureScore, v_measure_score, sklearn_v_measure_score), + ( + partial(VMeasureScore, beta=2.0), + partial(v_measure_score, beta=2.0), + partial(sklearn_v_measure_score, beta=2.0), + ), + ], +) +@pytest.mark.parametrize( + "preds, target", + [ + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), + ], +) +class TestHomogeneityCompletenessVmeasur(MetricTester): + """Test class for testing homogeneity, completeness and v-measure metrics.""" + + atol = 1e-5 + + @pytest.mark.parametrize("ddp", [True, False]) + def test_homogeneity_completeness_vmeasure( + self, modular_metric, functional_metric, reference_metric, preds, target, ddp + ): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=modular_metric, + reference_metric=partial(_sk_reference, fn=reference_metric), + ) + + def test_homogeneity_completeness_vmeasure_functional( + self, modular_metric, functional_metric, reference_metric, preds, target + ): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=functional_metric, + reference_metric=partial(_sk_reference, fn=reference_metric), + ) + + +@pytest.mark.parametrize("functional_metric", [homogeneity_score, completeness_score, v_measure_score]) +def test_homogeneity_completeness_vmeasure_functional_raises_invalid_task(functional_metric): + """Check that metric rejects continuous-valued inputs.""" + preds, target = _float_inputs_extrinsic + with pytest.raises(ValueError, match=r"Expected *"): + functional_metric(preds, target)