From 27ea9aa9aa483624ad1ebc28ec74c27e92081fa1 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Fri, 20 Dec 2024 16:34:52 +0100 Subject: [PATCH] add test cases --- tests/test_tensor.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 33163077..e00efe04 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -1,3 +1,5 @@ +from typing import Sequence + import numpy as np import pytest import xarray as xr @@ -8,9 +10,19 @@ @pytest.mark.parametrize( "axes", - ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], + [ + "yx", + "xy", + "cyx", + "yxc", + "bczyx", + "xyz", + "xyzc", + "bzyxc", + ("batch", "channel", "x", "y"), + ], ) -def test_transpose_tensor_2d(axes: str): +def test_transpose_tensor_2d(axes: Sequence[str]): tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) @@ -19,9 +31,18 @@ def test_transpose_tensor_2d(axes: str): @pytest.mark.parametrize( "axes", - ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], + [ + "zyx", + "cyzx", + "yzixc", + "bczyx", + "xyz", + "xyzc", + "bzyxtc", + ("batch", "channel", "x", "y", "z"), + ], ) -def test_transpose_tensor_3d(axes: str): +def test_transpose_tensor_3d(axes: Sequence[str]): tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None) transposed = tensor.transpose([AxisId(a) for a in axes]) assert transposed.ndim == len(axes)