diff --git a/test/test_utils.py b/test/test_utils.py index e6842c1..999991b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,4 +14,5 @@ def test_pad_and_stack(tensors: list): ref = torch.concat(tensors).to(dtype=stacked_and_padded.dtype) + assert torch.sum(stacked_and_padded) == torch.sum(ref) assert torch.nanmean(stacked_and_padded) == torch.nanmean(ref)