Skip to content

Commit

Permalink
add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Dec 4, 2024
1 parent 2af93f0 commit 357cef4
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch

from chronos.utils import left_pad_and_stack_1D


@pytest.mark.parametrize("tensors", [
list(map(torch.tensor, [[1, 2, 3], [5, 6]])),
list(map(torch.tensor, [[2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]])),
])
def test_pad_and_stack(tensors: list):
stacked_and_padded = left_pad_and_stack_1D(tensors)
assert stacked_and_padded.shape == (len(tensors), max(len(t) for t in tensors))

ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype)

assert torch.nanmean(stacked_and_padded) == torch.nanmean(ref)

0 comments on commit 357cef4

Please sign in to comment.