Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Dec 4, 2024
1 parent cc82c83 commit 15600f1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
11 changes: 7 additions & 4 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from chronos.utils import left_pad_and_stack_1D


@pytest.mark.parametrize("tensors", [
list(map(torch.tensor, [[1, 2, 3], [5, 6]])),
list(map(torch.tensor, [[2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]])),
])
@pytest.mark.parametrize(
"tensors",
[
list(map(torch.tensor, [[1, 2, 3], [5, 6]])),
list(map(torch.tensor, [[2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]])),
],
)
def test_pad_and_stack(tensors: list):
stacked_and_padded = left_pad_and_stack_1D(tensors)
assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors))
Expand Down
2 changes: 1 addition & 1 deletion test/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def validate_tensor(
assert a.shape == shape

if dtype is not None:
assert a.dtype == dtype
assert a.dtype == dtype

0 comments on commit 15600f1

Please sign in to comment.