From c139a96ece4ebf49c6fc917422407d967b45dee4 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Mon, 4 Sep 2023 15:39:58 +0900 Subject: [PATCH] Create inputs.py for clustering tests (#2045) Create inputs.py for clustering tests --- tests/unittests/clustering/inputs.py | 51 +++++++++++++++++++ .../clustering/test_mutual_info_score.py | 39 +++----------- .../test_normalized_mutual_info_score.py | 30 +++-------- tests/unittests/clustering/test_rand_score.py | 36 +++---------- 4 files changed, 72 insertions(+), 84 deletions(-) create mode 100644 tests/unittests/clustering/inputs.py diff --git a/tests/unittests/clustering/inputs.py b/tests/unittests/clustering/inputs.py new file mode 100644 index 00000000000..61fff9eed31 --- /dev/null +++ b/tests/unittests/clustering/inputs.py @@ -0,0 +1,51 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple + +import torch + +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES +from unittests.helpers import seed_all + +seed_all(42) + + +Input = namedtuple("Input", ["preds", "target"]) +NUM_CLASSES = 10 + +# extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels +_single_target_extrinsic1 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_extrinsic2 = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_float_inputs_extrinsic = Input( + preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), target=torch.rand((NUM_BATCHES, BATCH_SIZE)) +) + +# intrinsic input for clustering metrics that requires only predicted clustering labels and the cluster embeddings +_single_target_intrinsic1 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) + +_single_target_intrinsic2 = Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +) diff --git a/tests/unittests/clustering/test_mutual_info_score.py b/tests/unittests/clustering/test_mutual_info_score.py index c4e0e56f38e..49522d50ce9 100644 --- a/tests/unittests/clustering/test_mutual_info_score.py +++ b/tests/unittests/clustering/test_mutual_info_score.py @@ -11,44 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple - import pytest import torch from sklearn.metrics import mutual_info_score as sklearn_mutual_info_score from torchmetrics.clustering.mutual_info_score import MutualInfoScore from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_CLASSES +from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -NUM_CLASSES = 10 - -_single_target_inputs1 = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_single_target_inputs2 = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_float_inputs = Input( - preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), - target=torch.rand((NUM_BATCHES, BATCH_SIZE)), -) - @pytest.mark.parametrize( "preds, target", [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - (_single_target_inputs2.preds, _single_target_inputs2.target), + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), ], ) class TestMutualInfoScore(MetricTester): @@ -87,18 +68,14 @@ def test_mutual_info_score_functional_single_cluster(): def test_mutual_info_score_functional_raises_invalid_task(): """Check that metric rejects continuous-valued inputs.""" - preds, target = _float_inputs + preds, target = _float_inputs_extrinsic with pytest.raises(ValueError, match=r"Expected *"): mutual_info_score(preds, target) -@pytest.mark.parametrize( - ("preds", "target"), - [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - ], -) -def test_mutual_info_score_functional_is_symmetric(preds, target): +def test_mutual_info_score_functional_is_symmetric( + preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target +): """Check that the metric funtional is symmetric.""" for p, t in zip(preds, target): assert torch.allclose(mutual_info_score(p, t), mutual_info_score(t, p)) diff --git a/tests/unittests/clustering/test_normalized_mutual_info_score.py b/tests/unittests/clustering/test_normalized_mutual_info_score.py index 97d40d2e66a..095bc5963d2 100644 --- a/tests/unittests/clustering/test_normalized_mutual_info_score.py +++ b/tests/unittests/clustering/test_normalized_mutual_info_score.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -20,36 +19,19 @@ from torchmetrics.clustering import NormalizedMutualInfoScore from torchmetrics.functional.clustering import normalized_mutual_info_score -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_CLASSES +from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -NUM_CLASSES = 10 - -_single_target_inputs1 = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_single_target_inputs2 = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_float_inputs = Input( - preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), - target=torch.rand((NUM_BATCHES, BATCH_SIZE)), -) - @pytest.mark.parametrize( "preds, target", [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - (_single_target_inputs2.preds, _single_target_inputs2.target), + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), ], ) @pytest.mark.parametrize( @@ -96,7 +78,7 @@ def test_normalized_mutual_info_score_functional_single_cluster(average_method): @pytest.mark.parametrize("average_method", ["min", "geometric", "arithmetic", "max"]) def test_normalized_mutual_info_score_functional_raises_invalid_task(average_method): """Check that metric rejects continuous-valued inputs.""" - preds, target = _float_inputs + preds, target = _float_inputs_extrinsic with pytest.raises(ValueError, match=r"Expected *"): normalized_mutual_info_score(preds, target, average_method) @@ -106,7 +88,7 @@ def test_normalized_mutual_info_score_functional_raises_invalid_task(average_met ["min", "geometric", "arithmetic", "max"], ) def test_normalized_mutual_info_score_functional_is_symmetric( - average_method, preds=_single_target_inputs1.preds, target=_single_target_inputs1.target + average_method, preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target ): """Check that the metric funtional is symmetric.""" for p, t in zip(preds, target): diff --git a/tests/unittests/clustering/test_rand_score.py b/tests/unittests/clustering/test_rand_score.py index d00fd421d34..08df4ff5e5e 100644 --- a/tests/unittests/clustering/test_rand_score.py +++ b/tests/unittests/clustering/test_rand_score.py @@ -11,44 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple - import pytest import torch from sklearn.metrics import rand_score as sklearn_rand_score from torchmetrics.clustering.rand_score import RandScore from torchmetrics.functional.clustering.rand_score import rand_score -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.clustering.inputs import _float_inputs_extrinsic, _single_target_extrinsic1, _single_target_extrinsic2 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -NUM_CLASSES = 10 - -_single_target_inputs1 = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_single_target_inputs2 = Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - -_float_inputs = Input( - preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), - target=torch.rand((NUM_BATCHES, BATCH_SIZE)), -) - @pytest.mark.parametrize( "preds, target", [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - (_single_target_inputs2.preds, _single_target_inputs2.target), + (_single_target_extrinsic1.preds, _single_target_extrinsic1.target), + (_single_target_extrinsic2.preds, _single_target_extrinsic2.target), ], ) class TestRandScore(MetricTester): @@ -79,16 +59,14 @@ def test_rand_score_functional(self, preds, target): def test_rand_score_functional_raises_invalid_task(): """Check that metric rejects continuous-valued inputs.""" - preds, target = _float_inputs + preds, target = _float_inputs_extrinsic with pytest.raises(ValueError, match=r"Expected *"): rand_score(preds, target) -@pytest.mark.parametrize( - ("preds", "target"), - [(_single_target_inputs1.preds, _single_target_inputs1.target)], -) -def test_rand_score_functional_is_symmetric(preds, target): +def test_rand_score_functional_is_symmetric( + preds=_single_target_extrinsic1.preds, target=_single_target_extrinsic1.target +): """Check that the metric funtional is symmetric.""" for p, t in zip(preds, target): assert torch.allclose(rand_score(p, t), rand_score(t, p))