Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Dec 20, 2024
1 parent 84f24fe commit 27ea9aa
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Sequence

import numpy as np
import pytest
import xarray as xr
Expand All @@ -8,9 +10,19 @@

@pytest.mark.parametrize(
"axes",
["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"],
[
"yx",
"xy",
"cyx",
"yxc",
"bczyx",
"xyz",
"xyzc",
"bzyxc",
("batch", "channel", "x", "y"),
],
)
def test_transpose_tensor_2d(axes: str):
def test_transpose_tensor_2d(axes: Sequence[str]):

tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None)
transposed = tensor.transpose([AxisId(a) for a in axes])
Expand All @@ -19,9 +31,18 @@ def test_transpose_tensor_2d(axes: str):

@pytest.mark.parametrize(
"axes",
["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"],
[
"zyx",
"cyzx",
"yzixc",
"bczyx",
"xyz",
"xyzc",
"bzyxtc",
("batch", "channel", "x", "y", "z"),
],
)
def test_transpose_tensor_3d(axes: str):
def test_transpose_tensor_3d(axes: Sequence[str]):
tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None)
transposed = tensor.transpose([AxisId(a) for a in axes])
assert transposed.ndim == len(axes)
Expand Down

0 comments on commit 27ea9aa

Please sign in to comment.