From 794630cc3aee308ffbd1524ef4d56f5d54f9c2e5 Mon Sep 17 00:00:00 2001 From: Denis Prokopenko <22414094+denproc@users.noreply.github.com> Date: Sat, 1 Jul 2023 14:40:08 +0100 Subject: [PATCH] fix: tests for clipiqa --- tests/test_clip_iqa.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/tests/test_clip_iqa.py b/tests/test_clip_iqa.py index cc72324..b936b94 100644 --- a/tests/test_clip_iqa.py +++ b/tests/test_clip_iqa.py @@ -78,38 +78,32 @@ def test_clip_iqa_input_dtype_does_not_change(clipiqa: _Loss, x_rgb: torch.Tenso def test_clip_iqa_dims_work(clipiqa: _Loss, device: str) -> None: clipiqa = clipiqa.to(device) - x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))] - for x in x_3dims: - clipiqa(x.to(device)) x_4dims = [torch.rand((3, 3, 96, 96)), torch.rand((4, 3, 128, 128)), torch.rand((5, 3, 160, 160))] for x in x_4dims: clipiqa(x.to(device)) -def test_clip_iqa_results_equal_for_3_and_4_dims(clipiqa: _Loss, device: str) -> None: - clipiqa = clipiqa.to(device) - x = torch.rand((3, 128, 128)) - x_copy = x[None] - x_result = clipiqa(x.to(device)) - x_copy_result = clipiqa(x_copy.to(device)) - assert torch.isclose(x_result, x_copy_result, rtol=1e-2), \ - f'Expected values to be equal, got {x_result} and {x_copy_result}' - - def test_clip_iqa_dims_does_not_work(clipiqa: _Loss, device: str) -> None: clipiqa = clipiqa.to(device) x_2dims = [torch.rand((96, 96)), torch.rand((128, 128)), torch.rand((160, 160))] - with pytest.raises(AssertionError): - for x in x_2dims: + for x in x_2dims: + with pytest.raises(AssertionError): clipiqa(x.to(device)) x_1dims = [torch.rand((96)), torch.rand((128)), torch.rand((160))] - with pytest.raises(AssertionError): - for x in x_1dims: + + for x in x_1dims: + with pytest.raises(AssertionError): + clipiqa(x.to(device)) + + x_3dims = [torch.rand((3, 96, 96)), torch.rand((3, 128, 128)), torch.rand((3, 160, 160))] + for x in x_3dims: + with pytest.raises(AssertionError): clipiqa(x.to(device)) x_5dims = [torch.rand((1, 3, 3, 96, 96)), torch.rand((2, 4, 3, 128, 128)), torch.rand((1, 5, 3, 160, 160))] - with pytest.raises(AssertionError): - for x in x_5dims: + + for x in x_5dims: + with pytest.raises(AssertionError): clipiqa(x.to(device))