diff --git a/test/__init__.py b/test/__init__.py index 03f633a..04f8b7b 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1,2 +1,2 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 \ No newline at end of file +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/test_utils.py b/test/test_utils.py index f36186d..c02cfb6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,10 +4,13 @@ from chronos.utils import left_pad_and_stack_1D -@pytest.mark.parametrize("tensors", [ - list(map(torch.tensor, [[1, 2, 3], [5, 6]])), - list(map(torch.tensor, [[2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]])), -]) +@pytest.mark.parametrize( + "tensors", + [ + list(map(torch.tensor, [[1, 2, 3], [5, 6]])), + list(map(torch.tensor, [[2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]])), + ], +) def test_pad_and_stack(tensors: list): stacked_and_padded = left_pad_and_stack_1D(tensors) assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors)) diff --git a/test/util.py b/test/util.py index 37a2c3b..78c2e93 100644 --- a/test/util.py +++ b/test/util.py @@ -10,4 +10,4 @@ def validate_tensor( assert a.shape == shape if dtype is not None: - assert a.dtype == dtype \ No newline at end of file + assert a.dtype == dtype