diff --git a/pyproject.toml b/pyproject.toml index 067d1a0007c..df8b8891776 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,7 +121,6 @@ ignore = [ "S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo "B905", # todo: `zip()` without an explicit `strict=` parameter - "PYI024", # todo: Use `typing.NamedTuple` instead of `collections.namedtuple` ] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index e6fa726f2ec..1c6e1b58906 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -24,7 +24,6 @@ # License under BSD 2-clause import inspect import os -from collections import namedtuple from typing import List, NamedTuple, Optional, Tuple, Union import torch @@ -86,13 +85,21 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None def forward(self, x: Tensor) -> NamedTuple: """Process input.""" - squeeze_output = namedtuple("squeeze_output", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) + + class _SqueezeOutput(NamedTuple): + relu1: Tensor + relu2: Tensor + relu3: Tensor + relu4: Tensor + relu5: Tensor + relu6: Tensor + relu7: Tensor relus = [] for slice_ in self.slices: x = slice_(x) relus.append(x) - return squeeze_output(*relus) + return _SqueezeOutput(*relus) class Alexnet(torch.nn.Module): @@ -134,8 +141,15 @@ def forward(self, x: Tensor) -> NamedTuple: h_relu4 = h h = self.slice5(h) h_relu5 = h - alexnet_outputs = namedtuple("alexnet_outputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) - return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + class _AlexnetOutputs(NamedTuple): + relu1: Tensor + relu2: Tensor + relu3: Tensor + relu4: Tensor + relu5: Tensor + + return _AlexnetOutputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) class Vgg16(torch.nn.Module): @@ -177,8 +191,15 @@ def forward(self, x: Tensor) -> NamedTuple: h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h - vgg_outputs = namedtuple("vgg_outputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) - return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + class _VGGOutputs(NamedTuple): + relu1_2: Tensor + relu2_2: Tensor + relu3_3: Tensor + relu4_3: Tensor + relu5_3: Tensor + + return _VGGOutputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) def _spatial_average(in_tens: Tensor, keep_dim: bool = True) -> Tensor: diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py index e77f74161bb..64a08805e34 100644 --- a/tests/unittests/__init__.py +++ b/tests/unittests/__init__.py @@ -1,7 +1,9 @@ import os.path +from typing import NamedTuple import numpy import torch +from torch import Tensor from unittests.conftest import ( BATCH_SIZE, @@ -26,9 +28,23 @@ torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + +class _Input(NamedTuple): + preds: Tensor + target: Tensor + + +class _GroupInput(NamedTuple): + preds: Tensor + target: Tensor + groups: Tensor + + __all__ = [ "BATCH_SIZE", "EXTRA_DIM", + "_Input", + "_GroupInput", "NUM_BATCHES", "NUM_CLASSES", "NUM_PROCESSES", diff --git a/tests/unittests/audio/test_c_si_snr.py b/tests/unittests/audio/test_c_si_snr.py index 9a5123b8309..aed96ea6285 100644 --- a/tests/unittests/audio/test_c_si_snr.py +++ b/tests/unittests/audio/test_c_si_snr.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import pytest import torch @@ -19,16 +18,15 @@ from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -inputs = Input( +inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2), ) diff --git a/tests/unittests/audio/test_pesq.py b/tests/unittests/audio/test_pesq.py index 5696822a81c..d847fc8f834 100644 --- a/tests/unittests/audio/test_pesq.py +++ b/tests/unittests/audio/test_pesq.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -22,21 +21,21 @@ from torchmetrics.audio import PerceptualEvaluationSpeechQuality from torchmetrics.functional.audio import perceptual_evaluation_speech_quality +from unittests import _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) # for 8k sample rate, need at least 8k/4=2000 samples -inputs_8k = Input( +inputs_8k = _Input( preds=torch.rand(2, 3, 2100), target=torch.rand(2, 3, 2100), ) # for 16k sample rate, need at least 16k/4=4000 samples -inputs_16k = Input( +inputs_16k = _Input( preds=torch.rand(2, 3, 4100), target=torch.rand(2, 3, 4100), ) diff --git a/tests/unittests/audio/test_pit.py b/tests/unittests/audio/test_pit.py index 8221b7f7a0d..ce79187e417 100644 --- a/tests/unittests/audio/test_pit.py +++ b/tests/unittests/audio/test_pit.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial from typing import Callable, Tuple @@ -31,7 +30,7 @@ _find_best_perm_by_linear_sum_assignment, ) -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -39,15 +38,14 @@ TIME = 10 -Input = namedtuple("Input", ["preds", "target"]) # three speaker examples to test _find_best_perm_by_linear_sum_assignment -inputs1 = Input( +inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME), ) # two speaker examples to test _find_best_perm_by_exhuastive_method -inputs2 = Input( +inputs2 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME), ) diff --git a/tests/unittests/audio/test_sa_sdr.py b/tests/unittests/audio/test_sa_sdr.py index 0206969bc68..b2d3c986213 100644 --- a/tests/unittests/audio/test_sa_sdr.py +++ b/tests/unittests/audio/test_sa_sdr.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -24,7 +23,7 @@ source_aggregated_signal_distortion_ratio, ) -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -32,9 +31,8 @@ NUM_SAMPLES = 100 # the number of samples -Input = namedtuple("Input", ["preds", "target"]) -inputs = Input( +inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, NUM_SAMPLES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, NUM_SAMPLES), ) diff --git a/tests/unittests/audio/test_sdr.py b/tests/unittests/audio/test_sdr.py index 2389be0d941..6f94d3c8efb 100644 --- a/tests/unittests/audio/test_sdr.py +++ b/tests/unittests/audio/test_sdr.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial from typing import Callable @@ -25,19 +24,20 @@ from torchmetrics.functional import signal_distortion_ratio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from unittests import _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _SAMPLE_NUMPY_ISSUE_895 from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -inputs_1spk = Input( +inputs_1spk = _Input( preds=torch.rand(2, 1, 1, 500), target=torch.rand(2, 1, 1, 500), ) -inputs_2spk = Input( + +inputs_2spk = _Input( preds=torch.rand(2, 1, 2, 500), target=torch.rand(2, 1, 2, 500), ) diff --git a/tests/unittests/audio/test_si_sdr.py b/tests/unittests/audio/test_si_sdr.py index 32d759e33f3..5abd76c2ea4 100644 --- a/tests/unittests/audio/test_si_sdr.py +++ b/tests/unittests/audio/test_si_sdr.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -21,7 +20,7 @@ from torchmetrics.audio import ScaleInvariantSignalDistortionRatio from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -29,9 +28,8 @@ NUM_SAMPLES = 100 -Input = namedtuple("Input", ["preds", "target"]) -inputs = Input( +inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES), ) diff --git a/tests/unittests/audio/test_si_snr.py b/tests/unittests/audio/test_si_snr.py index 6792ed4cbff..590458fa56c 100644 --- a/tests/unittests/audio/test_si_snr.py +++ b/tests/unittests/audio/test_si_snr.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -21,7 +20,7 @@ from torchmetrics.audio import ScaleInvariantSignalNoiseRatio from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -29,9 +28,8 @@ NUM_SAMPLES = 100 -Input = namedtuple("Input", ["preds", "target"]) -inputs = Input( +inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, NUM_SAMPLES), ) diff --git a/tests/unittests/audio/test_snr.py b/tests/unittests/audio/test_snr.py index 5003eddcde8..7b6e9e2a9d5 100644 --- a/tests/unittests/audio/test_snr.py +++ b/tests/unittests/audio/test_snr.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial from typing import Callable @@ -22,14 +21,14 @@ from torchmetrics.audio import SignalNoiseRatio from torchmetrics.functional.audio import signal_noise_ratio +from unittests import _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -inputs = Input( +inputs = _Input( preds=torch.rand(2, 1, 1, 25), target=torch.rand(2, 1, 1, 25), ) diff --git a/tests/unittests/audio/test_stoi.py b/tests/unittests/audio/test_stoi.py index c1545027b14..10c8a55685b 100644 --- a/tests/unittests/audio/test_stoi.py +++ b/tests/unittests/audio/test_stoi.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -22,19 +21,19 @@ from torchmetrics.audio import ShortTimeObjectiveIntelligibility from torchmetrics.functional.audio import short_time_objective_intelligibility +from unittests import _Input from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -inputs_8k = Input( +inputs_8k = _Input( preds=torch.rand(2, 3, 8000), target=torch.rand(2, 3, 8000), ) -inputs_16k = Input( +inputs_16k = _Input( preds=torch.rand(2, 3, 16000), target=torch.rand(2, 3, 16000), ) diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index 4897ce68687..c660e625214 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Any import pytest import torch from torch import Tensor -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _GroupInput, _Input from unittests.helpers import seed_all seed_all(1) @@ -32,82 +31,79 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: return torch.nn.functional.log_softmax(x, dim) -Input = namedtuple("Input", ["preds", "target"]) -GroupInput = namedtuple("GroupInput", ["preds", "target", "groups"]) - -_input_binary_prob = Input( +_input_binary_prob = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_binary = Input( +_input_binary = _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_binary_logits = Input( +_input_binary_logits = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_multilabel_prob = Input( +_input_multilabel_prob = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) -_input_multilabel_multidim_prob = Input( +_input_multilabel_multidim_prob = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ) -_input_multilabel_logits = Input( +_input_multilabel_logits = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) -_input_multilabel = Input( +_input_multilabel = _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ) -_input_multilabel_multidim = Input( +_input_multilabel_multidim = _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ) _binary_cases = ( pytest.param( - Input( + _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), id="input[single_dim-labels]", ), pytest.param( - Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), + _Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), id="input[single_dim-probs]", ), pytest.param( - Input( + _Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ), id="input[single_dim-logits]", ), pytest.param( - Input( + _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), id="input[multi_dim-labels]", ), pytest.param( - Input( + _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), id="input[multi_dim-probs]", ), pytest.param( - Input( + _Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), @@ -134,49 +130,49 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): _multiclass_cases = ( pytest.param( - Input( + _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), id="input[single_dim-labels]", ), pytest.param( - Input( + _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), id="input[single_dim-probs]", ), pytest.param( - Input( + _Input( preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), id="input[single_dim-logits]", ), pytest.param( - Input( + _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), id="input[multi_dim-labels]", ), pytest.param( - Input( + _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), id="input[multi_dim-probs]", ), pytest.param( - Input( + _Input( preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), id="input[multi_dim-logits]", ), pytest.param( - Input( + _Input( preds=_multiclass_with_missing_class(NUM_BATCHES, BATCH_SIZE, num_classes=NUM_CLASSES), target=_multiclass_with_missing_class(NUM_BATCHES, BATCH_SIZE, num_classes=NUM_CLASSES), ), @@ -187,42 +183,42 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): _multilabel_cases = ( pytest.param( - Input( + _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ), id="input[single_dim-labels]", ), pytest.param( - Input( + _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ), id="input[single_dim-probs]", ), pytest.param( - Input( + _Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), ), id="input[single_dim-logits]", ), pytest.param( - Input( + _Input( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ), id="input[multi_dim-labels]", ), pytest.param( - Input( + _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ), id="input[multi_dim-probs]", ), pytest.param( - Input( + _Input( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ), @@ -233,7 +229,7 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): _group_cases = ( pytest.param( - GroupInput( + _GroupInput( preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), @@ -241,7 +237,7 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): id="input[single_dim-labels]", ), pytest.param( - GroupInput( + _GroupInput( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), @@ -249,7 +245,7 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): id="input[single_dim-probs]", ), pytest.param( - GroupInput( + _GroupInput( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), @@ -262,20 +258,20 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) -_input_multilabel_no_match = Input(preds=__temp_preds, target=__temp_target) +_input_multilabel_no_match = _Input(preds=__temp_preds, target=__temp_target) __mc_prob_logits = 10 * torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) __mc_prob_preds = __mc_prob_logits.abs() / __mc_prob_logits.abs().sum(dim=2, keepdim=True) -_input_multiclass_prob = Input( +_input_multiclass_prob = _Input( preds=__mc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_multiclass_logits = Input( +_input_multiclass_logits = _Input( preds=__mc_prob_logits, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) ) -_input_multiclass = Input( +_input_multiclass = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) @@ -283,11 +279,11 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): __mdmc_prob_preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM) __mdmc_prob_preds = __mdmc_prob_preds / __mdmc_prob_preds.sum(dim=2, keepdim=True) -_input_multidim_multiclass_prob = Input( +_input_multidim_multiclass_prob = _Input( preds=__mdmc_prob_preds, target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) -_input_multidim_multiclass = Input( +_input_multidim_multiclass = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ) @@ -305,13 +301,13 @@ def _generate_plausible_inputs_multilabel(num_classes=NUM_CLASSES, num_batches=N preds = preds / preds.sum(dim=2, keepdim=True) - return Input(preds=preds, target=targets) + return _Input(preds=preds, target=targets) def _generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_SIZE): targets = torch.randint(high=2, size=(num_batches, batch_size)) preds = torch.rand(num_batches, batch_size) + torch.rand(num_batches, batch_size) * targets / 3 - return Input(preds=preds / (preds.max() + 0.01), target=targets) + return _Input(preds=preds / (preds.max() + 0.01), target=targets) _input_multilabel_prob_plausible = _generate_plausible_inputs_multilabel() @@ -323,7 +319,7 @@ def _generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_ _class_remove, _class_replace = torch.multinomial(torch.ones(NUM_CLASSES), num_samples=2, replacement=False) _temp[_temp == _class_remove] = _class_replace -_input_multiclass_with_missing_class = Input(_temp.clone(), _temp.clone()) +_input_multiclass_with_missing_class = _Input(_temp.clone(), _temp.clone()) _negmetric_noneavg = { diff --git a/tests/unittests/clustering/inputs.py b/tests/unittests/clustering/inputs.py index 15b24298f7d..b13bc0c0947 100644 --- a/tests/unittests/clustering/inputs.py +++ b/tests/unittests/clustering/inputs.py @@ -11,22 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import NamedTuple import torch from sklearn.datasets import make_blobs +from torch import Tensor -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _Input from unittests.helpers import seed_all seed_all(42) -# extrinsic input for clustering metrics that requires predicted clustering labels and target clustering labels -ExtrinsicInput = namedtuple("ExtrinsicInput", ["preds", "target"]) - # intrinsic input for clustering metrics that requires only predicted clustering labels and the cluster embeddings -IntrinsicInput = namedtuple("IntrinsicInput", ["data", "labels"]) +class _IntrinsicInput(NamedTuple): + data: Tensor + labels: Tensor def _batch_blobs(num_batches, num_samples, num_features, num_classes): @@ -36,20 +36,20 @@ def _batch_blobs(num_batches, num_samples, num_features, num_classes): data.append(torch.tensor(_data)) labels.append(torch.tensor(_labels)) - return IntrinsicInput(data=torch.stack(data), labels=torch.stack(labels)) + return _IntrinsicInput(data=torch.stack(data), labels=torch.stack(labels)) -_single_target_extrinsic1 = ExtrinsicInput( +_single_target_extrinsic1 = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -_single_target_extrinsic2 = ExtrinsicInput( +_single_target_extrinsic2 = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -_float_inputs_extrinsic = ExtrinsicInput( +_float_inputs_extrinsic = _Input( preds=torch.rand((NUM_BATCHES, BATCH_SIZE)), target=torch.rand((NUM_BATCHES, BATCH_SIZE)) ) diff --git a/tests/unittests/clustering/test_utils.py b/tests/unittests/clustering/test_utils.py index 6fc8577f15c..e6ffe222b46 100644 --- a/tests/unittests/clustering/test_utils.py +++ b/tests/unittests/clustering/test_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import numpy as np import pytest @@ -27,25 +26,25 @@ calculate_pair_cluster_confusion_matrix, ) -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) + NUM_CLASSES = 10 -_sklearn_inputs = Input( +_sklearn_inputs = _Input( preds=torch.tensor([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]), target=torch.tensor([1, 1, 1, 1, 2, 1, 2, 2, 2, 2, 3, 1, 3, 3, 3, 2, 2]), ) -_single_dim_inputs = Input( +_single_dim_inputs = _Input( preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)), ) -_multi_dim_inputs = Input( +_multi_dim_inputs = _Input( preds=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE, 2)), ) diff --git a/tests/unittests/detection/test_map.py b/tests/unittests/detection/test_map.py index a3b2b762b45..3027035c4da 100644 --- a/tests/unittests/detection/test_map.py +++ b/tests/unittests/detection/test_map.py @@ -14,7 +14,6 @@ import contextlib import io import json -from collections import namedtuple from copy import deepcopy from functools import partial from itertools import product @@ -221,11 +220,8 @@ def test_compare_both_same_time(tmpdir, backend): assert torch.allclose(res[f"segm_{k}"], v, atol=1e-2) -Input = namedtuple("Input", ["preds", "target"]) - - -_inputs = Input( - preds=[ +_inputs = { + "preds": [ [ { "boxes": Tensor([[258.15, 41.29, 606.41, 285.07]]), @@ -273,7 +269,7 @@ def test_compare_both_same_time(tmpdir, backend): }, # coco image id 987 category_id 49 ], ], - target=[ + "target": [ [ { "boxes": Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]), @@ -323,11 +319,11 @@ def test_compare_both_same_time(tmpdir, backend): }, # coco image id 987 category_id 49 ], ], -) +} # example from this issue https://github.com/Lightning-AI/torchmetrics/issues/943 -_inputs2 = Input( - preds=[ +_inputs2 = { + "preds": [ [ { "boxes": Tensor([[258.0, 41.0, 606.0, 285.0]]), @@ -343,7 +339,7 @@ def test_compare_both_same_time(tmpdir, backend): } ], ], - target=[ + "target": [ [ { "boxes": Tensor([[214.0, 41.0, 562.0, 285.0]]), @@ -357,13 +353,13 @@ def test_compare_both_same_time(tmpdir, backend): } ], ], -) +} # Test empty preds case, to ensure bool inputs are properly casted to uint8 # From https://github.com/Lightning-AI/torchmetrics/issues/981 # and https://github.com/Lightning-AI/torchmetrics/issues/1147 -_inputs3 = Input( - preds=[ +_inputs3 = { + "preds": [ [ { "boxes": Tensor([[258.0, 41.0, 606.0, 285.0]]), @@ -375,7 +371,7 @@ def test_compare_both_same_time(tmpdir, backend): {"boxes": Tensor([]), "scores": Tensor([]), "labels": Tensor([])}, ], ], - target=[ + "target": [ [ { "boxes": Tensor([[214.0, 41.0, 562.0, 285.0]]), @@ -390,7 +386,7 @@ def test_compare_both_same_time(tmpdir, backend): }, ], ], -) +} def _generate_random_segm_input(device, batch_size=2, num_preds_size=10, num_gt_size=10, random_size=True): @@ -519,7 +515,7 @@ def test_map_gpu(self, backend, inputs): """Test predictions on single gpu.""" metric = MeanAveragePrecision(backend=backend) metric = metric.to("cuda") - for preds, targets in zip(deepcopy(inputs.preds), deepcopy(inputs.target)): + for preds, targets in zip(deepcopy(inputs["preds"]), deepcopy(inputs["target"])): metric.update( apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), @@ -531,7 +527,7 @@ def test_map_with_custom_thresholds(self, backend): """Test that map works with custom iou thresholds.""" metric = MeanAveragePrecision(iou_thresholds=[0.1, 0.2], backend=backend) metric = metric.to("cuda") - for preds, targets in zip(deepcopy(_inputs.preds), deepcopy(_inputs.target)): + for preds, targets in zip(deepcopy(_inputs["preds"]), deepcopy(_inputs["target"])): metric.update( apply_to_collection(preds, Tensor, lambda x: x.to("cuda")), apply_to_collection(targets, Tensor, lambda x: x.to("cuda")), @@ -764,8 +760,8 @@ def test_warning_on_many_detections(self, iou_type, backend): (10, 1, 4, 3), ), ( - _inputs.preds, - _inputs.target, + _inputs["preds"], + _inputs["target"], 24, # 4 images x 6 classes = 24 list(product([0, 1, 2, 3], [0, 1, 2, 3, 4, 49])), (10, 101, 6, 4, 3), @@ -806,11 +802,11 @@ def test_average_argument(self, class_metrics, backend): """ if class_metrics: - _preds = _inputs.preds - _target = _inputs.target + _preds = _inputs["preds"] + _target = _inputs["target"] else: - _preds = apply_to_collection(deepcopy(_inputs.preds), IntTensor, lambda x: torch.ones_like(x)) - _target = apply_to_collection(deepcopy(_inputs.target), IntTensor, lambda x: torch.ones_like(x)) + _preds = apply_to_collection(deepcopy(_inputs["preds"]), IntTensor, lambda x: torch.ones_like(x)) + _target = apply_to_collection(deepcopy(_inputs["target"]), IntTensor, lambda x: torch.ones_like(x)) metric_macro = MeanAveragePrecision(average="macro", class_metrics=class_metrics, backend=backend) metric_macro.update(_preds[0], _target[0]) @@ -818,8 +814,8 @@ def test_average_argument(self, class_metrics, backend): result_macro = metric_macro.compute() metric_micro = MeanAveragePrecision(average="micro", class_metrics=class_metrics, backend=backend) - metric_micro.update(_inputs.preds[0], _inputs.target[0]) - metric_micro.update(_inputs.preds[1], _inputs.target[1]) + metric_micro.update(_inputs["preds"][0], _inputs["target"][0]) + metric_micro.update(_inputs["preds"][1], _inputs["target"][1]) result_micro = metric_micro.compute() if class_metrics: diff --git a/tests/unittests/detection/test_modified_panoptic_quality.py b/tests/unittests/detection/test_modified_panoptic_quality.py index 17dd7cd8f09..e771454c6ab 100644 --- a/tests/unittests/detection/test_modified_panoptic_quality.py +++ b/tests/unittests/detection/test_modified_panoptic_quality.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Any, Dict import numpy as np @@ -20,14 +19,14 @@ from torchmetrics.detection import ModifiedPanopticQuality from torchmetrics.functional.detection import modified_panoptic_quality +from unittests import _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -_INPUTS_0 = Input( +_INPUTS_0 = _Input( # Shape of input tensors is (num_batches, batch_size, height, width, 2). preds=torch.tensor( [ @@ -52,7 +51,7 @@ .reshape((1, 1, 5, 5, 2)) .repeat(2, 1, 1, 1, 1), ) -_INPUTS_1 = Input( +_INPUTS_1 = _Input( # Shape of input tensors is (num_batches, batch_size, num_points, 2). # NOTE: IoU for stuff category 6 is < 0.5, modified PQ behaves differently there. preds=torch.tensor([[0, 0], [0, 1], [6, 0], [7, 0], [0, 2], [1, 0]]).reshape((1, 1, 6, 2)).repeat(2, 1, 1, 1), @@ -185,7 +184,7 @@ def test_extreme_values(): (_INPUTS_1, _ARGS_2, 1), ], ) -def test_ignore_mask(inputs: Input, args: Dict[str, Any], cat_dim: int): +def test_ignore_mask(inputs: _Input, args: Dict[str, Any], cat_dim: int): """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" preds = inputs.preds[0] target = inputs.target[0] diff --git a/tests/unittests/detection/test_panoptic_quality.py b/tests/unittests/detection/test_panoptic_quality.py index 688af6a0c6f..2ef5f890b41 100644 --- a/tests/unittests/detection/test_panoptic_quality.py +++ b/tests/unittests/detection/test_panoptic_quality.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from typing import Any, Dict import numpy as np @@ -20,14 +19,14 @@ from torchmetrics.detection.panoptic_qualities import PanopticQuality from torchmetrics.functional.detection.panoptic_qualities import panoptic_quality +from unittests import _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -_INPUTS_0 = Input( +_INPUTS_0 = _Input( # Shape of input tensors is (num_batches, batch_size, height, width, 2). preds=torch.tensor( [ @@ -52,7 +51,7 @@ .reshape((1, 1, 5, 5, 2)) .repeat(2, 1, 1, 1, 1), ) -_INPUTS_1 = Input( +_INPUTS_1 = _Input( # Shape of input tensors is (num_batches, batch_size, num_points, 2). preds=torch.tensor( [[10, 0], [10, 123], [0, 1], [10, 0], [1, 2]], @@ -192,7 +191,7 @@ def test_extreme_values(): (_INPUTS_1, _ARGS_2, 1), ], ) -def test_ignore_mask(inputs: Input, args: Dict[str, Any], cat_dim: int): +def test_ignore_mask(inputs: _Input, args: Dict[str, Any], cat_dim: int): """Test that the metric correctly ignores regions of the inputs that do not map to a know category ID.""" preds = inputs.preds[0] target = inputs.target[0] diff --git a/tests/unittests/image/test_d_lambda.py b/tests/unittests/image/test_d_lambda.py index 32a9c277b43..8a7d93c47ee 100644 --- a/tests/unittests/image/test_d_lambda.py +++ b/tests/unittests/image/test_d_lambda.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import NamedTuple import numpy as np import pytest import torch +from torch import Tensor from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.image.d_lambda import SpectralDistortionIndex @@ -28,7 +29,11 @@ seed_all(42) -Input = namedtuple("Input", ["preds", "target", "p"]) +class _Input(NamedTuple): + preds: Tensor + target: Tensor + p: int + _inputs = [] for size, channel, p, dtype in [ @@ -40,7 +45,7 @@ preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) _inputs.append( - Input( + _Input( preds=preds, target=target, p=p, diff --git a/tests/unittests/image/test_ergas.py b/tests/unittests/image/test_ergas.py index e87b4b63a0f..4848bda4479 100644 --- a/tests/unittests/image/test_ergas.py +++ b/tests/unittests/image/test_ergas.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import NamedTuple import pytest import torch @@ -27,7 +27,12 @@ seed_all(42) -Input = namedtuple("Input", ["preds", "target", "ratio"]) + +class _Input(NamedTuple): + preds: Tensor + target: Tensor + ratio: int + _inputs = [] for size, channel, coef, ratio, dtype in [ @@ -37,7 +42,7 @@ (15, 3, 0.5, 4, torch.float64), ]: preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input(preds=preds, target=preds * coef, ratio=ratio)) + _inputs.append(_Input(preds=preds, target=preds * coef, ratio=ratio)) def _baseline_ergas( diff --git a/tests/unittests/image/test_lpips.py b/tests/unittests/image/test_lpips.py index f17aa0f59fb..8c7170626f5 100644 --- a/tests/unittests/image/test_lpips.py +++ b/tests/unittests/image/test_lpips.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import NamedTuple import pytest import torch @@ -26,9 +26,13 @@ seed_all(42) -Input = namedtuple("Input", ["img1", "img2"]) -_inputs = Input( +class _Input(NamedTuple): + img1: Tensor + img2: Tensor + + +_inputs = _Input( img1=torch.rand(4, 2, 3, 50, 50), img2=torch.rand(4, 2, 3, 50, 50), ) diff --git a/tests/unittests/image/test_ms_ssim.py b/tests/unittests/image/test_ms_ssim.py index 284a75d958e..882f304cb83 100644 --- a/tests/unittests/image/test_ms_ssim.py +++ b/tests/unittests/image/test_ms_ssim.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -20,20 +19,20 @@ from torchmetrics.functional.image.ssim import multiscale_structural_similarity_index_measure from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure -from unittests import NUM_BATCHES +from unittests import NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) + BATCH_SIZE = 1 _inputs = [] for size, coef in [(182, 0.9), (182, 0.7)]: preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 1, size, size) _inputs.append( - Input( + _Input( preds=preds, target=preds * coef, ) diff --git a/tests/unittests/image/test_psnr.py b/tests/unittests/image/test_psnr.py index f58a55c8225..4032e5fc775 100644 --- a/tests/unittests/image/test_psnr.py +++ b/tests/unittests/image/test_psnr.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import numpy as np @@ -23,17 +22,16 @@ from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) _input_size = (NUM_BATCHES, BATCH_SIZE, 32, 32) _inputs = [ - Input( + _Input( preds=torch.randint(n_cls_pred, _input_size, dtype=torch.float), target=torch.randint(n_cls_target, _input_size, dtype=torch.float), ) diff --git a/tests/unittests/image/test_rase.py b/tests/unittests/image/test_rase.py index 4a9b696e3a5..2d8aa812968 100644 --- a/tests/unittests/image/test_rase.py +++ b/tests/unittests/image/test_rase.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import namedtuple from functools import partial +from typing import NamedTuple import pytest import sewar import torch +from torch import Tensor from torchmetrics.functional import relative_average_spectral_error from torchmetrics.functional.image.helper import _uniform_filter from torchmetrics.image import RelativeAverageSpectralError @@ -25,7 +25,12 @@ from unittests import BATCH_SIZE from unittests.helpers.testers import MetricTester -Input = namedtuple("Input", ["preds", "target", "window_size"]) + +class _InputWindowSized(NamedTuple): + preds: Tensor + target: Tensor + window_size: int + _inputs = [] for size, channel, window_size, dtype in [ @@ -36,7 +41,7 @@ ]: preds = torch.rand(2, BATCH_SIZE, channel, size, size, dtype=dtype) target = torch.rand(2, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input(preds=preds, target=target, window_size=window_size)) + _inputs.append(_InputWindowSized(preds=preds, target=target, window_size=window_size)) def _sewar_rase(preds, target, window_size): diff --git a/tests/unittests/image/test_rmse_sw.py b/tests/unittests/image/test_rmse_sw.py index 0f38b8c0476..4bddd97bef9 100644 --- a/tests/unittests/image/test_rmse_sw.py +++ b/tests/unittests/image/test_rmse_sw.py @@ -11,20 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import namedtuple from functools import partial +from typing import NamedTuple import pytest import sewar import torch +from torch import Tensor from torchmetrics.functional import root_mean_squared_error_using_sliding_window from torchmetrics.image import RootMeanSquaredErrorUsingSlidingWindow from unittests import BATCH_SIZE, NUM_BATCHES from unittests.helpers.testers import MetricTester -Input = namedtuple("Input", ["preds", "target", "window_size"]) + +class _InputWindowSized(NamedTuple): + preds: Tensor + target: Tensor + window_size: int + _inputs = [] for size, channel, window_size, dtype in [ @@ -35,7 +40,7 @@ ]: preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input(preds=preds, target=target, window_size=window_size)) + _inputs.append(_InputWindowSized(preds=preds, target=target, window_size=window_size)) def _sewar_rmse_sw(preds, target, window_size): diff --git a/tests/unittests/image/test_sam.py b/tests/unittests/image/test_sam.py index 194f6f91b7a..6ab999cb02a 100644 --- a/tests/unittests/image/test_sam.py +++ b/tests/unittests/image/test_sam.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -22,13 +21,12 @@ from torchmetrics.image.sam import SpectralAngleMapper from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) _inputs = [] for size, channel, dtype in [ @@ -39,7 +37,7 @@ ]: preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) target = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) - _inputs.append(Input(preds=preds, target=target)) + _inputs.append(_Input(preds=preds, target=target)) def _baseline_sam( diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index e0d4eb848e6..b0af003871d 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import numpy as np @@ -23,13 +22,12 @@ from torchmetrics.functional import structural_similarity_index_measure from torchmetrics.image import StructuralSimilarityIndexMeasure -from unittests import NUM_BATCHES +from unittests import NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) BATCH_SIZE = 2 # custom batch size to prevent memory issues in CI _inputs = [] @@ -41,14 +39,14 @@ ]: preds2d = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) _inputs.append( - Input( + _Input( preds=preds2d, target=preds2d * coef, ) ) preds3d = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, size, dtype=dtype) _inputs.append( - Input( + _Input( preds=preds3d, target=preds3d * coef, ) diff --git a/tests/unittests/image/test_tv.py b/tests/unittests/image/test_tv.py index 34e704ad9ab..4ab9d2c96d0 100644 --- a/tests/unittests/image/test_tv.py +++ b/tests/unittests/image/test_tv.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial from typing import Any @@ -21,6 +20,7 @@ from torchmetrics.functional.image.tv import total_variation from torchmetrics.image.tv import TotalVariation +from unittests import _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -50,7 +50,7 @@ def _total_variation_kornia_tester(preds, target, reduction): # define inputs -Input = namedtuple("Input", ["preds", "target"]) + _inputs = [] for size, channel, dtype in [ @@ -61,7 +61,7 @@ def _total_variation_kornia_tester(preds, target, reduction): ]: preds = torch.rand(2, 4, channel, size, size, dtype=dtype) target = torch.rand(2, 4, channel, size, size, dtype=dtype) - _inputs.append(Input(preds=preds, target=target)) + _inputs.append(_Input(preds=preds, target=target)) @pytest.mark.parametrize( diff --git a/tests/unittests/image/test_uqi.py b/tests/unittests/image/test_uqi.py index 66597e6d7a8..f572e52e57b 100644 --- a/tests/unittests/image/test_uqi.py +++ b/tests/unittests/image/test_uqi.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import NamedTuple import pytest import torch from skimage.metrics import structural_similarity +from torch import Tensor from torchmetrics.functional.image.uqi import universal_image_quality_index from torchmetrics.image.uqi import UniversalImageQualityIndex @@ -26,10 +27,16 @@ seed_all(42) + +class _InputMultichannel(NamedTuple): + preds: Tensor + target: Tensor + multichannel: bool + + # UQI is SSIM with both constants k1 and k2 as 0 skimage_uqi = partial(structural_similarity, k1=0, k2=0) -Input = namedtuple("Input", ["preds", "target", "multichannel"]) _inputs = [] for size, channel, coef, multichannel, dtype in [ @@ -40,7 +47,7 @@ ]: preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype) _inputs.append( - Input( + _InputMultichannel( preds=preds, target=preds * coef, multichannel=multichannel, diff --git a/tests/unittests/image/test_vif.py b/tests/unittests/image/test_vif.py index 6842715842e..926fbc4fda6 100644 --- a/tests/unittests/image/test_vif.py +++ b/tests/unittests/image/test_vif.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import numpy as np import pytest @@ -20,15 +19,15 @@ from torchmetrics.functional.image.vif import visual_information_fidelity from torchmetrics.image.vif import VisualInformationFidelity -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) + _inputs = [ - Input( + _Input( preds=torch.randint(0, 255, size=(NUM_BATCHES, BATCH_SIZE, channels, 41, 41), dtype=torch.float), target=torch.randint(0, 255, size=(NUM_BATCHES, BATCH_SIZE, channels, 41, 41), dtype=torch.float), ) diff --git a/tests/unittests/multimodal/test_clip_score.py b/tests/unittests/multimodal/test_clip_score.py index 2cb65ee5a58..b350b2843d0 100644 --- a/tests/unittests/multimodal/test_clip_score.py +++ b/tests/unittests/multimodal/test_clip_score.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import List, NamedTuple import matplotlib import matplotlib.pyplot as plt import pytest import torch +from torch import Tensor from torchmetrics.functional.multimodal.clip_score import clip_score from torchmetrics.multimodal.clip_score import CLIPScore from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 @@ -31,7 +32,9 @@ seed_all(42) -Input = namedtuple("Input", ["images", "captions"]) +class _InputImagesCaptions(NamedTuple): + images: Tensor + captions: List[List[str]] captions = [ @@ -41,7 +44,9 @@ "A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto.", ] -_random_input = Input(images=torch.randint(255, (2, 2, 3, 64, 64)), captions=[captions[0:2], captions[2:]]) +_random_input = _InputImagesCaptions( + images=torch.randint(255, (2, 2, 3, 64, 64)), captions=[captions[0:2], captions[2:]] +) def _compare_fn(preds, target, model_name_or_path): diff --git a/tests/unittests/nominal/test_cramers.py b/tests/unittests/nominal/test_cramers.py index 373feecca92..1ae0c2575ca 100644 --- a/tests/unittests/nominal/test_cramers.py +++ b/tests/unittests/nominal/test_cramers.py @@ -13,7 +13,6 @@ # limitations under the License. import itertools import operator -from collections import namedtuple from functools import partial import pytest @@ -23,13 +22,12 @@ from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.nominal.cramers import CramersV -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers.testers import MetricTester -Input = namedtuple("Input", ["preds", "target"]) NUM_CLASSES = 4 -_input_default = Input( +_input_default = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) @@ -41,9 +39,9 @@ _target = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.float) _target[1, 0] = float("nan") _target[-1, 0] = float("nan") -_input_with_nans = Input(preds=_preds, target=_target) +_input_with_nans = _Input(preds=_preds, target=_target) -_input_logits = Input( +_input_logits = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) ) diff --git a/tests/unittests/nominal/test_pearson.py b/tests/unittests/nominal/test_pearson.py index adb3320c573..042f467e792 100644 --- a/tests/unittests/nominal/test_pearson.py +++ b/tests/unittests/nominal/test_pearson.py @@ -13,7 +13,6 @@ # limitations under the License. import itertools import operator -from collections import namedtuple import pandas as pd import pytest @@ -26,18 +25,17 @@ ) from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers.testers import MetricTester -Input = namedtuple("Input", ["preds", "target"]) NUM_CLASSES = 4 -_input_default = Input( +_input_default = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_logits = Input( +_input_logits = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) ) diff --git a/tests/unittests/nominal/test_theils_u.py b/tests/unittests/nominal/test_theils_u.py index cd6774d7917..02523144eca 100644 --- a/tests/unittests/nominal/test_theils_u.py +++ b/tests/unittests/nominal/test_theils_u.py @@ -13,7 +13,6 @@ # limitations under the License. import itertools import operator -from collections import namedtuple from functools import partial import pytest @@ -23,13 +22,12 @@ from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix from torchmetrics.nominal import TheilsU -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers.testers import MetricTester -Input = namedtuple("Input", ["preds", "target"]) NUM_CLASSES = 4 -_input_default = Input( +_input_default = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) @@ -41,9 +39,9 @@ _target = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.float) _target[1, 0] = float("nan") _target[-1, 0] = float("nan") -_input_with_nans = Input(preds=_preds, target=_target) +_input_with_nans = _Input(preds=_preds, target=_target) -_input_logits = Input( +_input_logits = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) ) diff --git a/tests/unittests/nominal/test_tschuprows.py b/tests/unittests/nominal/test_tschuprows.py index 23a4669a954..0076b34239e 100644 --- a/tests/unittests/nominal/test_tschuprows.py +++ b/tests/unittests/nominal/test_tschuprows.py @@ -13,7 +13,6 @@ # limitations under the License. import itertools import operator -from collections import namedtuple import pandas as pd import pytest @@ -23,18 +22,17 @@ from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix from torchmetrics.nominal.tschuprows import TschuprowsT -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers.testers import MetricTester -Input = namedtuple("Input", ["preds", "target"]) NUM_CLASSES = 4 -_input_default = Input( +_input_default = _Input( preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_logits = Input( +_input_logits = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), target=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) ) diff --git a/tests/unittests/pairwise/test_pairwise_distance.py b/tests/unittests/pairwise/test_pairwise_distance.py index d80377df432..0f6f8fdf6ea 100644 --- a/tests/unittests/pairwise/test_pairwise_distance.py +++ b/tests/unittests/pairwise/test_pairwise_distance.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import NamedTuple import pytest import torch @@ -23,6 +23,7 @@ manhattan_distances, pairwise_distances, ) +from torch import Tensor from torchmetrics.functional import ( pairwise_cosine_similarity, pairwise_euclidean_distance, @@ -40,16 +41,19 @@ extra_dim = 5 -Input = namedtuple("Input", ["x", "y"]) +class _Input(NamedTuple): + x: Tensor + y: Tensor -_inputs1 = Input( + +_inputs1 = _Input( x=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim), y=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim), ) -_inputs2 = Input( +_inputs2 = _Input( x=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim), y=torch.rand(NUM_BATCHES, BATCH_SIZE, extra_dim), ) diff --git a/tests/unittests/regression/test_concordance.py b/tests/unittests/regression/test_concordance.py index 135655db770..1583434fdd8 100644 --- a/tests/unittests/regression/test_concordance.py +++ b/tests/unittests/regression/test_concordance.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import numpy as np @@ -22,30 +21,29 @@ from torchmetrics.regression.concordance import ConcordanceCorrCoef from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs1 = Input( +_single_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_single_target_inputs2 = Input( +_single_target_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs1 = Input( +_multi_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) -_multi_target_inputs2 = Input( +_multi_target_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) diff --git a/tests/unittests/regression/test_cosine_similarity.py b/tests/unittests/regression/test_cosine_similarity.py index b99ab83ee47..37582d98978 100644 --- a/tests/unittests/regression/test_cosine_similarity.py +++ b/tests/unittests/regression/test_cosine_similarity.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import numpy as np @@ -21,7 +20,7 @@ from torchmetrics.functional.regression.cosine_similarity import cosine_similarity from torchmetrics.regression.cosine_similarity import CosineSimilarity -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -29,14 +28,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_explained_variance.py b/tests/unittests/regression/test_explained_variance.py index f601e215ced..96fd2897037 100644 --- a/tests/unittests/regression/test_explained_variance.py +++ b/tests/unittests/regression/test_explained_variance.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -20,7 +19,7 @@ from torchmetrics.functional import explained_variance from torchmetrics.regression import ExplainedVariance -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -28,14 +27,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_kendall.py b/tests/unittests/regression/test_kendall.py index 5b3a31ba482..04421cc6a36 100644 --- a/tests/unittests/regression/test_kendall.py +++ b/tests/unittests/regression/test_kendall.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import operator -from collections import namedtuple from functools import partial import pytest @@ -22,7 +21,7 @@ from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef from torchmetrics.regression.kendall import KendallRankCorrCoef -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -30,19 +29,19 @@ seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -_single_inputs1 = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE)) -_single_inputs2 = Input(preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE)) -_single_inputs3 = Input( + +_single_inputs1 = _Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE)) +_single_inputs2 = _Input(preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE)) +_single_inputs3 = _Input( preds=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE)), target=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE)) ) -_multi_inputs1 = Input( +_multi_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) ) -_multi_inputs2 = Input( +_multi_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM) ) -_multi_inputs3 = Input( +_multi_inputs3 = _Input( preds=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(-10, 10, (NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ) diff --git a/tests/unittests/regression/test_kl_divergence.py b/tests/unittests/regression/test_kl_divergence.py index 11794e738f6..9531ff5f337 100644 --- a/tests/unittests/regression/test_kl_divergence.py +++ b/tests/unittests/regression/test_kl_divergence.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial -from typing import Optional +from typing import NamedTuple, Optional import numpy as np import pytest @@ -30,14 +29,18 @@ seed_all(42) -Input = namedtuple("Input", ["p", "q"]) -_probs_inputs = Input( +class _Input(NamedTuple): + p: Tensor + q: Tensor + + +_probs_inputs = _Input( p=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), q=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) -_log_probs_inputs = Input( +_log_probs_inputs = _Input( p=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM).softmax(dim=-1).log(), q=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM).softmax(dim=-1).log(), ) diff --git a/tests/unittests/regression/test_log_cosh_error.py b/tests/unittests/regression/test_log_cosh_error.py index d288e37f177..7da6885eae7 100644 --- a/tests/unittests/regression/test_log_cosh_error.py +++ b/tests/unittests/regression/test_log_cosh_error.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections import namedtuple from functools import partial import numpy as np @@ -21,7 +19,7 @@ from torchmetrics.functional.regression.log_cosh import log_cosh_error from torchmetrics.regression.log_cosh import LogCoshError -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -29,14 +27,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 10671df138d..a56dfc4e373 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from collections import namedtuple from functools import partial from typing import Optional @@ -42,7 +41,7 @@ ) from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -50,14 +49,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_minkowski_distance.py b/tests/unittests/regression/test_minkowski_distance.py index 3ee9f047208..2feaceea199 100644 --- a/tests/unittests/regression/test_minkowski_distance.py +++ b/tests/unittests/regression/test_minkowski_distance.py @@ -1,4 +1,3 @@ -from collections import namedtuple from functools import partial import pytest @@ -9,7 +8,7 @@ from torchmetrics.utilities.exceptions import TorchMetricsUserError from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9 -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -17,14 +16,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 90c7df76b92..b76c030380f 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -20,31 +19,30 @@ from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs1 = Input( +_single_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_single_target_inputs2 = Input( +_single_target_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs1 = Input( +_multi_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) -_multi_target_inputs2 = Input( +_multi_target_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) diff --git a/tests/unittests/regression/test_r2.py b/tests/unittests/regression/test_r2.py index 1de25d328ff..9c87c6ad9fb 100644 --- a/tests/unittests/regression/test_r2.py +++ b/tests/unittests/regression/test_r2.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -20,7 +19,7 @@ from torchmetrics.functional import r2_score from torchmetrics.regression import R2Score -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -28,14 +27,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_rse.py b/tests/unittests/regression/test_rse.py index 4d2e9c4efd4..45647abfcf3 100644 --- a/tests/unittests/regression/test_rse.py +++ b/tests/unittests/regression/test_rse.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import numpy as np @@ -21,7 +20,7 @@ from torchmetrics.regression import RelativeSquaredError from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -29,14 +28,13 @@ num_targets = 5 -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs = Input( +_single_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) diff --git a/tests/unittests/regression/test_spearman.py b/tests/unittests/regression/test_spearman.py index e21835bbe75..d0d04aedf5b 100644 --- a/tests/unittests/regression/test_spearman.py +++ b/tests/unittests/regression/test_spearman.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -21,35 +20,34 @@ from torchmetrics.regression.spearman import SpearmanCorrCoef from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "target"]) -_single_target_inputs1 = Input( +_single_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_single_target_inputs2 = Input( +_single_target_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randn(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs1 = Input( +_multi_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) -_multi_target_inputs2 = Input( +_multi_target_inputs2 = _Input( preds=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) -_specific_input = Input( +_specific_input = _Input( preds=torch.stack([torch.tensor([1.0, 0.0, 4.0, 1.0, 0.0, 3.0, 0.0]) for _ in range(NUM_BATCHES)]), target=torch.stack([torch.tensor([4.0, 0.0, 3.0, 3.0, 3.0, 1.0, 1.0]) for _ in range(NUM_BATCHES)]), ) diff --git a/tests/unittests/regression/test_tweedie_deviance.py b/tests/unittests/regression/test_tweedie_deviance.py index ea0a7c67d7c..bc64bcb09a1 100644 --- a/tests/unittests/regression/test_tweedie_deviance.py +++ b/tests/unittests/regression/test_tweedie_deviance.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial import pytest @@ -22,27 +21,26 @@ from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9 -from unittests import BATCH_SIZE, NUM_BATCHES +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester seed_all(42) -Input = namedtuple("Input", ["preds", "targets"]) -_single_target_inputs1 = Input( +_single_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - targets=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_single_target_inputs2 = Input( +_single_target_inputs2 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - targets=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), ) -_multi_target_inputs = Input( +_multi_target_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 5), - targets=torch.rand(NUM_BATCHES, BATCH_SIZE, 5), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, 5), ) @@ -54,65 +52,65 @@ def _sklearn_deviance(preds: Tensor, targets: Tensor, power: float): @pytest.mark.parametrize("power", [-0.5, 0, 1, 1.5, 2, 3]) @pytest.mark.parametrize( - "preds, targets", + "preds, target", [ - (_single_target_inputs1.preds, _single_target_inputs1.targets), - (_single_target_inputs2.preds, _single_target_inputs2.targets), - (_multi_target_inputs.preds, _multi_target_inputs.targets), + (_single_target_inputs2.preds, _single_target_inputs2.target), + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_multi_target_inputs.preds, _multi_target_inputs.target), ], ) class TestDevianceScore(MetricTester): """Test class for `TweedieDevianceScore` metric.""" @pytest.mark.parametrize("ddp", [True, False]) - def test_deviance_scores_class(self, ddp, preds, targets, power): + def test_deviance_scores_class(self, ddp, preds, target, power): """Test class implementation of metric.""" self.run_class_metric_test( ddp, preds, - targets, + target, TweedieDevianceScore, partial(_sklearn_deviance, power=power), metric_args={"power": power}, ) - def test_deviance_scores_functional(self, preds, targets, power): + def test_deviance_scores_functional(self, preds, target, power): """Test functional implementation of metric.""" self.run_functional_metric_test( preds, - targets, + target, tweedie_deviance_score, partial(_sklearn_deviance, power=power), metric_args={"power": power}, ) - def test_deviance_scores_differentiability(self, preds, targets, power): + def test_deviance_scores_differentiability(self, preds, target, power): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test( - preds, targets, metric_module=TweedieDevianceScore, metric_functional=tweedie_deviance_score + preds, target, metric_module=TweedieDevianceScore, metric_functional=tweedie_deviance_score ) # Tweedie Deviance Score half + cpu does not work for power=[1,2] due to missing support in torch.log - def test_deviance_scores_half_cpu(self, preds, targets, power): + def test_deviance_scores_half_cpu(self, preds, target, power): """Test dtype support of the metric on CPU.""" if not _TORCH_GREATER_EQUAL_1_9 or power in [1, 2]: pytest.xfail(reason="TweedieDevianceScore metric does not support cpu + half precision for older Pytorch") metric_args = {"power": power} self.run_precision_test_cpu( preds, - targets, + target, metric_module=TweedieDevianceScore, metric_functional=tweedie_deviance_score, metric_args=metric_args, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_deviance_scores_half_gpu(self, preds, targets, power): + def test_deviance_scores_half_gpu(self, preds, target, power): """Test dtype support of the metric on GPU.""" metric_args = {"power": power} self.run_precision_test_gpu( preds, - targets, + target, metric_module=TweedieDevianceScore, metric_functional=tweedie_deviance_score, metric_args=metric_args, diff --git a/tests/unittests/retrieval/inputs.py b/tests/unittests/retrieval/inputs.py index 93c94feab63..cf6b9c40377 100644 --- a/tests/unittests/retrieval/inputs.py +++ b/tests/unittests/retrieval/inputs.py @@ -11,46 +11,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import NamedTuple import torch +from torch import Tensor from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES -Input = namedtuple("InputMultiple", ["indexes", "preds", "target"]) + +class _Input(NamedTuple): + indexes: Tensor + preds: Tensor + target: Tensor + # correct -_input_retrieval_scores = Input( +_input_retrieval_scores = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_retrieval_scores_for_adaptive_k = Input( +_input_retrieval_scores_for_adaptive_k = _Input( indexes=torch.randint(high=NUM_BATCHES * BATCH_SIZE // 2, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_retrieval_scores_extra = Input( +_input_retrieval_scores_extra = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ) -_input_retrieval_scores_int_target = Input( +_input_retrieval_scores_int_target = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, 2 * BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE), target=torch.randint(low=-1, high=4, size=(NUM_BATCHES, 2 * BATCH_SIZE)), ) -_input_retrieval_scores_float_target = Input( +_input_retrieval_scores_float_target = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, 2 * BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE), target=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE), ) -_input_retrieval_scores_with_ignore_index = Input( +_input_retrieval_scores_with_ignore_index = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)).masked_fill( @@ -59,37 +65,37 @@ ) # with errors -_input_retrieval_scores_no_target = Input( +_input_retrieval_scores_no_target = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=1, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_retrieval_scores_all_target = Input( +_input_retrieval_scores_all_target = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(low=1, high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_retrieval_scores_empty = Input( +_input_retrieval_scores_empty = _Input( indexes=torch.randint(high=10, size=[0]), preds=torch.rand(0), target=torch.randint(high=2, size=[0]), ) -_input_retrieval_scores_mismatching_sizes = Input( +_input_retrieval_scores_mismatching_sizes = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE - 2)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_retrieval_scores_mismatching_sizes_func = Input( +_input_retrieval_scores_mismatching_sizes_func = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE - 2), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), ) -_input_retrieval_scores_wrong_targets = Input( +_input_retrieval_scores_wrong_targets = _Input( indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)), preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(low=-(2**31), high=2**31, size=(NUM_BATCHES, BATCH_SIZE)), diff --git a/tests/unittests/text/inputs.py b/tests/unittests/text/inputs.py index eeb00ce7843..9d976864623 100644 --- a/tests/unittests/text/inputs.py +++ b/tests/unittests/text/inputs.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple +from typing import NamedTuple import torch +from torch import Tensor -from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES +from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES, _Input from unittests.helpers import seed_all seed_all(1) -Input = namedtuple("Input", ["preds", "targets"]) -SquadInput = namedtuple("SquadInput", ["preds", "targets", "exact_match", "f1"]) -LogitsInput = namedtuple("LogitsInput", ["preds", "target"]) + +class _SquadInput(NamedTuple): + preds: Tensor + target: Tensor + exact_match: Tensor + f1: Tensor + # example taken from # https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu and adjusted @@ -47,15 +52,15 @@ ) TUPLE_OF_HYPOTHESES = ((HYPOTHESIS_A, HYPOTHESIS_B), (HYPOTHESIS_B, HYPOTHESIS_C)) -_inputs_single_sentence_multiple_references = Input(preds=[HYPOTHESIS_B], targets=[[REFERENCE_1B, REFERENCE_2B]]) +_inputs_single_sentence_multiple_references = _Input(preds=[HYPOTHESIS_B], target=[[REFERENCE_1B, REFERENCE_2B]]) -_inputs_multiple_references = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_REFERENCES) +_inputs_multiple_references = _Input(preds=TUPLE_OF_HYPOTHESES, target=TUPLE_OF_REFERENCES) -_inputs_single_sentence_single_reference = Input(preds=HYPOTHESIS_B, targets=REFERENCE_1B) +_inputs_single_sentence_single_reference = _Input(preds=HYPOTHESIS_B, target=REFERENCE_1B) ERROR_RATES_BATCHES_1 = { "preds": [["hello world"], ["what a day"]], - "targets": [["hello world"], ["what a wonderful day"]], + "target": [["hello world"], ["what a wonderful day"]], } ERROR_RATES_BATCHES_2 = { @@ -63,28 +68,28 @@ ["i like python", "what you mean or swallow"], ["hello duck", "i like python"], ], - "targets": [ + "target": [ ["i like monthy python", "what do you mean, african or european swallow"], ["hello world", "i like monthy python"], ], } -_inputs_error_rate_batch_size_1 = Input(**ERROR_RATES_BATCHES_1) +_inputs_error_rate_batch_size_1 = _Input(**ERROR_RATES_BATCHES_1) -_inputs_error_rate_batch_size_2 = Input(**ERROR_RATES_BATCHES_2) +_inputs_error_rate_batch_size_2 = _Input(**ERROR_RATES_BATCHES_2) SAMPLE_1 = { "exact_match": 100.0, "f1": 100.0, "preds": {"prediction_text": "1976", "id": "id1"}, - "targets": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"}, + "target": {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"}, } SAMPLE_2 = { "exact_match": 0.0, "f1": 0.0, "preds": {"prediction_text": "Hello", "id": "id2"}, - "targets": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"}, + "target": {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"}, } BATCH = { @@ -94,34 +99,34 @@ {"prediction_text": "1976", "id": "id1"}, {"prediction_text": "Hello", "id": "id2"}, ], - "targets": [ + "target": [ {"answers": {"answer_start": [97], "text": ["1976"]}, "id": "id1"}, {"answers": {"answer_start": [97], "text": ["World"]}, "id": "id2"}, ], } -_inputs_squad_exact_match = SquadInput( - preds=SAMPLE_1["preds"], targets=SAMPLE_1["targets"], exact_match=SAMPLE_1["exact_match"], f1=SAMPLE_1["f1"] +_inputs_squad_exact_match = _SquadInput( + preds=SAMPLE_1["preds"], target=SAMPLE_1["target"], exact_match=SAMPLE_1["exact_match"], f1=SAMPLE_1["f1"] ) -_inputs_squad_exact_mismatch = SquadInput( - preds=SAMPLE_2["preds"], targets=SAMPLE_2["targets"], exact_match=SAMPLE_2["exact_match"], f1=SAMPLE_2["f1"] +_inputs_squad_exact_mismatch = _SquadInput( + preds=SAMPLE_2["preds"], target=SAMPLE_2["target"], exact_match=SAMPLE_2["exact_match"], f1=SAMPLE_2["f1"] ) -_inputs_squad_batch_match = SquadInput( - preds=BATCH["preds"], targets=BATCH["targets"], exact_match=BATCH["exact_match"], f1=BATCH["f1"] +_inputs_squad_batch_match = _SquadInput( + preds=BATCH["preds"], target=BATCH["target"], exact_match=BATCH["exact_match"], f1=BATCH["f1"] ) # single reference TUPLE_OF_SINGLE_REFERENCES = ((REFERENCE_1A, REFERENCE_1B), (REFERENCE_1B, REFERENCE_1C)) -_inputs_single_reference = Input(preds=TUPLE_OF_HYPOTHESES, targets=TUPLE_OF_SINGLE_REFERENCES) +_inputs_single_reference = _Input(preds=TUPLE_OF_HYPOTHESES, target=TUPLE_OF_SINGLE_REFERENCES) # Logits-based inputs for perplexity metrics -_logits_inputs_fp32 = LogitsInput( +_logits_inputs_fp32 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES, dtype=torch.float32), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ) -_logits_inputs_fp64 = LogitsInput( +_logits_inputs_fp64 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM, NUM_CLASSES, dtype=torch.float64), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ) @@ -130,5 +135,5 @@ _target_with_mask = _logits_inputs_fp32.target.clone() _target_with_mask[:, 0, 1:] = MASK_INDEX _target_with_mask[:, BATCH_SIZE - 1, :] = MASK_INDEX -_logits_inputs_fp32_with_mask = LogitsInput(preds=_logits_inputs_fp32.preds, target=_target_with_mask) -_logits_inputs_fp64_with_mask = LogitsInput(preds=_logits_inputs_fp64.preds, target=_target_with_mask) +_logits_inputs_fp32_with_mask = _Input(preds=_logits_inputs_fp32.preds, target=_target_with_mask) +_logits_inputs_fp64_with_mask = _Input(preds=_logits_inputs_fp64.preds, target=_target_with_mask) diff --git a/tests/unittests/text/test_bertscore.py b/tests/unittests/text/test_bertscore.py index cc873edefb8..117d909d9a3 100644 --- a/tests/unittests/text/test_bertscore.py +++ b/tests/unittests/text/test_bertscore.py @@ -84,7 +84,7 @@ def _reference_bert_score( ) @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_single_reference.preds, _inputs_single_reference.targets)], + [(_inputs_single_reference.preds, _inputs_single_reference.target)], ) @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") @pytest.mark.skipif(not _BERTSCORE_AVAILABLE, reason="test requires bert_score") diff --git a/tests/unittests/text/test_bleu.py b/tests/unittests/text/test_bleu.py index 472e499313e..978741110ac 100644 --- a/tests/unittests/text/test_bleu.py +++ b/tests/unittests/text/test_bleu.py @@ -46,7 +46,7 @@ def _compute_bleu_metric_nltk(preds, targets, weights, smoothing_function, **kwa ) @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_multiple_references.preds, _inputs_multiple_references.targets)], + [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) class TestBLEUScore(TextTester): """Test class for `BLEUScore` metric.""" diff --git a/tests/unittests/text/test_cer.py b/tests/unittests/text/test_cer.py index c193b55b90e..e8b47ebc1ef 100644 --- a/tests/unittests/text/test_cer.py +++ b/tests/unittests/text/test_cer.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, List, Union import pytest @@ -23,8 +36,8 @@ def _compare_fn(preds: Union[str, List[str]], target: Union[str, List[str]]): @pytest.mark.parametrize( ["preds", "targets"], [ - (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.targets), - (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.targets), + (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.target), + (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.target), ], ) class TestCharErrorRate(TextTester): diff --git a/tests/unittests/text/test_chrf.py b/tests/unittests/text/test_chrf.py index 137ae0fc6f9..35c056fa505 100644 --- a/tests/unittests/text/test_chrf.py +++ b/tests/unittests/text/test_chrf.py @@ -57,7 +57,7 @@ def _sacrebleu_chrf_fn( ) @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_multiple_references.preds, _inputs_multiple_references.targets)], + [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestCHRFScore(TextTester): @@ -141,7 +141,7 @@ def test_chrf_empty_class(): def test_chrf_return_sentence_level_score_functional(): """Test that chrf can return sentence level scores.""" preds = _inputs_single_sentence_multiple_references.preds - targets = _inputs_single_sentence_multiple_references.targets + targets = _inputs_single_sentence_multiple_references.target _, chrf_sentence_score = chrf_score(preds, targets, return_sentence_level_score=True) isinstance(chrf_sentence_score, Tensor) @@ -150,6 +150,6 @@ def test_chrf_return_sentence_level_class(): """Test that chrf can return sentence level scores.""" chrf = CHRFScore(return_sentence_level_score=True) preds = _inputs_single_sentence_multiple_references.preds - targets = _inputs_single_sentence_multiple_references.targets + targets = _inputs_single_sentence_multiple_references.target _, chrf_sentence_score = chrf(preds, targets) isinstance(chrf_sentence_score, Tensor) diff --git a/tests/unittests/text/test_edit.py b/tests/unittests/text/test_edit.py index c8ba21eb1ee..9af7b41c944 100644 --- a/tests/unittests/text/test_edit.py +++ b/tests/unittests/text/test_edit.py @@ -80,7 +80,7 @@ def _ref_implementation(preds, target, substitution_cost=1, reduction="mean"): @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_single_reference.preds, _inputs_single_reference.targets)], + [(_inputs_single_reference.preds, _inputs_single_reference.target)], ) class TestEditDistance(TextTester): """Test class for `EditDistance` metric.""" diff --git a/tests/unittests/text/test_eed.py b/tests/unittests/text/test_eed.py index 9db9a58d6f3..6aa5bbdb79d 100644 --- a/tests/unittests/text/test_eed.py +++ b/tests/unittests/text/test_eed.py @@ -46,7 +46,7 @@ def _rwth_manual_metric(preds, targets) -> Tensor: @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_single_reference.preds, _inputs_single_reference.targets)], + [(_inputs_single_reference.preds, _inputs_single_reference.target)], ) class TestExtendedEditDistance(TextTester): """Test class for `ExtendedEditDistance` metric.""" @@ -117,7 +117,7 @@ def test_eed_empty_with_non_empty_hyp_class(): def test_eed_return_sentence_level_score_functional(): """Test that eed can return sentence level scores.""" hyp = _inputs_single_sentence_multiple_references.preds - ref = _inputs_single_sentence_multiple_references.targets + ref = _inputs_single_sentence_multiple_references.target _, sentence_eed = extended_edit_distance(hyp, ref, return_sentence_level_score=True) isinstance(sentence_eed, Tensor) @@ -126,6 +126,6 @@ def test_eed_return_sentence_level_class(): """Test that eed can return sentence level scores.""" metric = ExtendedEditDistance(return_sentence_level_score=True) hyp = _inputs_single_sentence_multiple_references.preds - ref = _inputs_single_sentence_multiple_references.targets + ref = _inputs_single_sentence_multiple_references.target _, sentence_eed = metric(hyp, ref) isinstance(sentence_eed, Tensor) diff --git a/tests/unittests/text/test_infolm.py b/tests/unittests/text/test_infolm.py index 51e63116244..b557aaacd9f 100644 --- a/tests/unittests/text/test_infolm.py +++ b/tests/unittests/text/test_infolm.py @@ -97,7 +97,7 @@ def reference_infolm_score(preds, target, model_name, information_measure, idf, ) @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_single_reference.preds, _inputs_single_reference.targets)], + [(_inputs_single_reference.preds, _inputs_single_reference.target)], ) @pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>=4.4") class TestInfoLM(TextTester): diff --git a/tests/unittests/text/test_mer.py b/tests/unittests/text/test_mer.py index e6f75a50078..f0f0c49a9f7 100644 --- a/tests/unittests/text/test_mer.py +++ b/tests/unittests/text/test_mer.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, List, Union import pytest @@ -22,8 +35,8 @@ def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, L @pytest.mark.parametrize( ["preds", "targets"], [ - (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.targets), - (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.targets), + (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.target), + (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.target), ], ) class TestMatchErrorRate(TextTester): diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index fe1ba4cbfcd..0020a368203 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -25,7 +25,7 @@ from typing_extensions import Literal from unittests.text.helpers import TextTester, skip_on_connection_issues -from unittests.text.inputs import Input, _inputs_multiple_references, _inputs_single_sentence_single_reference +from unittests.text.inputs import _Input, _inputs_multiple_references, _inputs_single_sentence_single_reference if _ROUGE_SCORE_AVAILABLE: from rouge_score.rouge_scorer import RougeScorer @@ -39,7 +39,7 @@ # Some randomly adjusted input from CNN/DailyMail dataset which brakes the test _preds = "A lawyer says him .\nMoschetto, 54 and prosecutors say .\nAuthority abc Moschetto ." _target = "A trainer said her and Moschetto, 54s or weapons say . \nAuthorities Moschetto of ." -_inputs_summarization = Input(preds=_preds, targets=_target) +_inputs_summarization = _Input(preds=_preds, target=_target) def _compute_rouge_score( @@ -111,7 +111,7 @@ def _compute_rouge_score( @pytest.mark.parametrize( ["preds", "targets"], [ - (_inputs_multiple_references.preds, _inputs_multiple_references.targets), + (_inputs_multiple_references.preds, _inputs_multiple_references.target), ], ) @pytest.mark.parametrize("accumulate", ["avg", "best"]) @@ -177,7 +177,7 @@ def test_rouge_metric_wrong_key_value_error(): with pytest.raises(ValueError, match="Got unknown rouge key rouge. Expected to be one of"): rouge_score( _inputs_single_sentence_single_reference.preds, - _inputs_single_sentence_single_reference.targets, + _inputs_single_sentence_single_reference.target, rouge_keys=key, accumulate="best", ) @@ -209,7 +209,7 @@ def test_rouge_metric_normalizer_tokenizer(pl_rouge_metric_key): rouge_level, metric = pl_rouge_metric_key.split("_") original_score = _compute_rouge_score( preds=_inputs_single_sentence_single_reference.preds, - target=_inputs_single_sentence_single_reference.targets, + target=_inputs_single_sentence_single_reference.target, rouge_level=rouge_level, metric=metric, accumulate="best", @@ -221,7 +221,7 @@ def test_rouge_metric_normalizer_tokenizer(pl_rouge_metric_key): ) scorer.update( _inputs_single_sentence_single_reference.preds, - _inputs_single_sentence_single_reference.targets, + _inputs_single_sentence_single_reference.target, ) metrics_score = scorer.compute() @@ -246,7 +246,7 @@ def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer): rouge_level, metric = pl_rouge_metric_key.split("_") original_score = _compute_rouge_score( preds=_inputs_summarization.preds, - target=_inputs_summarization.targets, + target=_inputs_summarization.target, rouge_level=rouge_level, metric=metric, accumulate=None, @@ -255,7 +255,7 @@ def test_rouge_lsum_score(pl_rouge_metric_key, use_stemmer): metrics_score = rouge_score( _inputs_summarization.preds, - _inputs_summarization.targets, + _inputs_summarization.target, rouge_keys=rouge_level, use_stemmer=use_stemmer, ) diff --git a/tests/unittests/text/test_sacre_bleu.py b/tests/unittests/text/test_sacre_bleu.py index f601da7e85e..7555835161f 100644 --- a/tests/unittests/text/test_sacre_bleu.py +++ b/tests/unittests/text/test_sacre_bleu.py @@ -41,7 +41,7 @@ def _sacrebleu_fn(preds: Sequence[str], targets: Sequence[Sequence[str]], tokeni @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_multiple_references.preds, _inputs_multiple_references.targets)], + [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) @pytest.mark.parametrize(["lowercase"], [(False,), (True,)]) @pytest.mark.parametrize("tokenize", TOKENIZERS) diff --git a/tests/unittests/text/test_squad.py b/tests/unittests/text/test_squad.py index ffedbb4b5eb..51a2e24e38f 100644 --- a/tests/unittests/text/test_squad.py +++ b/tests/unittests/text/test_squad.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import pytest @@ -16,13 +29,13 @@ [ ( _inputs_squad_exact_match.preds, - _inputs_squad_exact_match.targets, + _inputs_squad_exact_match.target, _inputs_squad_exact_match.exact_match, _inputs_squad_exact_match.f1, ), ( _inputs_squad_exact_mismatch.preds, - _inputs_squad_exact_mismatch.targets, + _inputs_squad_exact_mismatch.target, _inputs_squad_exact_mismatch.exact_match, _inputs_squad_exact_mismatch.f1, ), @@ -42,7 +55,7 @@ def test_score_fn(preds, targets, exact_match, f1): [ ( _inputs_squad_batch_match.preds, - _inputs_squad_batch_match.targets, + _inputs_squad_batch_match.target, _inputs_squad_batch_match.exact_match, _inputs_squad_batch_match.f1, ) @@ -86,7 +99,7 @@ def _test_score_ddp_fn(rank, world_size, preds, targets, exact_match, f1): [ ( _inputs_squad_batch_match.preds, - _inputs_squad_batch_match.targets, + _inputs_squad_batch_match.target, _inputs_squad_batch_match.exact_match, _inputs_squad_batch_match.f1, ) diff --git a/tests/unittests/text/test_ter.py b/tests/unittests/text/test_ter.py index 05d8915f622..51faf546aea 100644 --- a/tests/unittests/text/test_ter.py +++ b/tests/unittests/text/test_ter.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from functools import partial from typing import Sequence @@ -44,7 +57,7 @@ def _sacrebleu_ter_fn( ) @pytest.mark.parametrize( ["preds", "targets"], - [(_inputs_multiple_references.preds, _inputs_multiple_references.targets)], + [(_inputs_multiple_references.preds, _inputs_multiple_references.target)], ) @pytest.mark.skipif(not _SACREBLEU_AVAILABLE, reason="test requires sacrebleu") class TestTER(TextTester): @@ -151,7 +164,7 @@ def test_ter_empty_with_non_empty_hyp_class(): def test_ter_return_sentence_level_score_functional(): """Test that functional metric can return sentence level scores.""" preds = _inputs_single_sentence_multiple_references.preds - targets = _inputs_single_sentence_multiple_references.targets + targets = _inputs_single_sentence_multiple_references.target _, sentence_ter = translation_edit_rate(preds, targets, return_sentence_level_score=True) isinstance(sentence_ter, Tensor) @@ -160,6 +173,6 @@ def test_ter_return_sentence_level_class(): """Test that modular metric can return sentence level scores.""" ter_metric = TranslationEditRate(return_sentence_level_score=True) preds = _inputs_single_sentence_multiple_references.preds - targets = _inputs_single_sentence_multiple_references.targets + targets = _inputs_single_sentence_multiple_references.target _, sentence_ter = ter_metric(preds, targets) isinstance(sentence_ter, Tensor) diff --git a/tests/unittests/text/test_wer.py b/tests/unittests/text/test_wer.py index ad7b590904a..8821f9101c4 100644 --- a/tests/unittests/text/test_wer.py +++ b/tests/unittests/text/test_wer.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, List, Union import pytest @@ -22,8 +35,8 @@ def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, L @pytest.mark.parametrize( ["preds", "targets"], [ - (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.targets), - (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.targets), + (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.target), + (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.target), ], ) class TestWER(TextTester): diff --git a/tests/unittests/text/test_wil.py b/tests/unittests/text/test_wil.py index f42fe6d2d00..2f8f063e407 100644 --- a/tests/unittests/text/test_wil.py +++ b/tests/unittests/text/test_wil.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Union import pytest @@ -18,8 +31,8 @@ def _compute_wil_metric_jiwer(preds: Union[str, List[str]], target: Union[str, L @pytest.mark.parametrize( ["preds", "targets"], [ - (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.targets), - (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.targets), + (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.target), + (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.target), ], ) class TestWordInfoLost(TextTester): diff --git a/tests/unittests/text/test_wip.py b/tests/unittests/text/test_wip.py index 7ce4bbe6314..4dd380a45cb 100644 --- a/tests/unittests/text/test_wip.py +++ b/tests/unittests/text/test_wip.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import List, Union import pytest @@ -18,8 +31,8 @@ def _compute_wip_metric_jiwer(preds: Union[str, List[str]], target: Union[str, L @pytest.mark.parametrize( ["preds", "targets"], [ - (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.targets), - (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.targets), + (_inputs_error_rate_batch_size_1.preds, _inputs_error_rate_batch_size_1.target), + (_inputs_error_rate_batch_size_2.preds, _inputs_error_rate_batch_size_2.target), ], ) class TestWordInfoPreserved(TextTester): diff --git a/tests/unittests/utilities/test_auc.py b/tests/unittests/utilities/test_auc.py index dd078847cc6..37d9c1105ee 100644 --- a/tests/unittests/utilities/test_auc.py +++ b/tests/unittests/utilities/test_auc.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from functools import partial +from typing import NamedTuple import numpy as np import pytest from sklearn.metrics import auc as _sk_auc -from torch import tensor +from torch import Tensor, tensor from torchmetrics.utilities.compute import auc from unittests import NUM_BATCHES from unittests.helpers import seed_all @@ -26,6 +26,11 @@ seed_all(42) +class _Input(NamedTuple): + x: Tensor + y: Tensor + + def sk_auc(x, y, reorder=False): """Comparison function for correctness of auc implementation.""" x = x.flatten() @@ -37,8 +42,6 @@ def sk_auc(x, y, reorder=False): return _sk_auc(x, y) -Input = namedtuple("Input", ["x", "y"]) - _examples = [] # generate already ordered samples, sorted in both directions for batch_size in (8, 4049): @@ -50,7 +53,7 @@ def sk_auc(x, y, reorder=False): y = y[idx] if i % 2 == 0 else x[idx[::-1]] x = x.reshape(NUM_BATCHES, batch_size) y = y.reshape(NUM_BATCHES, batch_size) - _examples.append(Input(x=tensor(x), y=tensor(y))) + _examples.append(_Input(x=tensor(x), y=tensor(y))) @pytest.mark.parametrize("x, y", _examples) diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index 90df27a3992..fe406537baf 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -1,3 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from copy import deepcopy from functools import partial from typing import Any diff --git a/tests/unittests/wrappers/test_multioutput.py b/tests/unittests/wrappers/test_multioutput.py index 443ac0a496e..b39e46cc39d 100644 --- a/tests/unittests/wrappers/test_multioutput.py +++ b/tests/unittests/wrappers/test_multioutput.py @@ -1,4 +1,16 @@ -from collections import namedtuple +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from functools import partial from typing import Any @@ -12,7 +24,7 @@ from torchmetrics.regression import R2Score from torchmetrics.wrappers.multioutput import MultioutputWrapper -from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES +from unittests import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, _Input from unittests.helpers import seed_all from unittests.helpers.testers import MetricTester @@ -54,13 +66,12 @@ def reset(self) -> None: num_targets = 2 -Input = namedtuple("Input", ["preds", "target"]) -_multi_target_regression_inputs = Input( +_multi_target_regression_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) -_multi_target_classification_inputs = Input( +_multi_target_classification_inputs = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets), target=torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, num_targets)), )