diff --git a/benchmarks/benchmark_cosine_similarity.py b/benchmarks/benchmark_cosine_similarity.py index 84700b04..e5372d34 100644 --- a/benchmarks/benchmark_cosine_similarity.py +++ b/benchmarks/benchmark_cosine_similarity.py @@ -23,41 +23,33 @@ "cosine similarity forward", ["x_size"], [256 * i for i in range(1, 21)], - {"num_batches": 16, "y_size": 16}, + {"z_size": 16, "y_size": 16}, ) -def bench_cosine_similarity_forward(num_batches, y_size, x_size, backend): - factory_kwargs = {"device": "cuda"} - - input = torch.randn(num_batches, y_size, x_size, **factory_kwargs) - other = torch.randn(num_batches, y_size, x_size, **factory_kwargs) +def bench_cosine_similarity_forward(z_size, y_size, x_size, backend): + x1 = torch.randn(z_size, y_size, x_size, device="cuda") + x2 = torch.randn(z_size, y_size, x_size, device="cuda") if backend == "torch": - return triton.testing.do_bench_cudagraph(lambda: torch.nn.functional.cosine_similarity(input, other, 2)) + return triton.testing.do_bench_cudagraph(lambda: torch.nn.functional.cosine_similarity(x1, x2, 2)) else: - return triton.testing.do_bench_cudagraph(lambda: trident.function.cosine_similarity(input, other, 2)) + return triton.testing.do_bench_cudagraph(lambda: trident.function.cosine_similarity(x1, x2, 2)) @util.report( "cosine similarity backward", ["x_size"], [256 * i for i in range(1, 21)], - {"num_batches": 16, "y_size": 16}, + {"z_size": 16, "y_size": 16}, ) -def bench_cosine_similarity_backward(num_batches, y_size, x_size, backend): - factory_kwargs = {"device": "cuda"} - - input = torch.randn(num_batches, y_size, x_size, **factory_kwargs) - other = torch.randn(num_batches, y_size, x_size, **factory_kwargs) - - input.requires_grad = other.requires_grad = True +def bench_cosine_similarity_backward(z_size, y_size, x_size, backend): + x1 = torch.randn(z_size, y_size, x_size, device="cuda", requires_grad=True) + x2 = torch.randn(z_size, y_size, x_size, device="cuda", requires_grad=True) + grad_output = torch.randn(z_size, y_size, device="cuda") if backend == "torch": - operation = torch.nn.CosineSimilarity(2) + output = torch.nn.functional.cosine_similarity(x1, x2, 2) else: - operation = trident.CosineSimilarity(2) - - output = operation.forward(input, other) - grad_output = torch.ones_like(output) + output = trident.function.cosine_similarity(x1, x2, 2) return triton.testing.do_bench_cudagraph(lambda: output.backward(grad_output, retain_graph=True)) diff --git a/tests/test_cosine_similarity.py b/tests/test_cosine_similarity.py index 7f823c03..43fe2fc1 100644 --- a/tests/test_cosine_similarity.py +++ b/tests/test_cosine_similarity.py @@ -6,49 +6,49 @@ @pytest.mark.parametrize( - "num_batches, y_size, x_size, dim", + "z_size, y_size, x_size, dim", [(1431, 500, 200, 0), (221, 1250, 200, 1), (21, 6400, 86, 2)], ) -def test_forward(num_batches, y_size, x_size, dim, device): +def test_forward(z_size, y_size, x_size, dim, device): factory_kwargs = {"device": device} - input = torch.randn(num_batches, y_size, x_size, **factory_kwargs) - other = torch.randn(num_batches, y_size, x_size, **factory_kwargs) + x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs) + x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs) assert util.equal( - torch.nn.functional.cosine_similarity(input, other, dim=dim), - trident.function.cosine_similarity(input, other, dim=dim), + torch.nn.functional.cosine_similarity(x1, x2, dim=dim), + trident.function.cosine_similarity(x1, x2, dim=dim), ) @pytest.mark.parametrize( - "num_batches, y_size, x_size, dim", + "z_size, y_size, x_size, dim", [(1280, 1000, 200, 0), (200, 1280, 200, 1), (640, 21, 86, 2)], ) -def test_backward(num_batches, y_size, x_size, dim, device): +def test_backward(z_size, y_size, x_size, dim, device): factory_kwargs = {"device": device} - input = torch.randn(num_batches, y_size, x_size, **factory_kwargs) - other = torch.randn(num_batches, y_size, x_size, **factory_kwargs) + x1 = torch.randn(z_size, y_size, x_size, **factory_kwargs) + x2 = torch.randn(z_size, y_size, x_size, **factory_kwargs) if dim == 0: target_dim = (y_size, x_size) elif dim == 1: - target_dim = (num_batches, x_size) + target_dim = (z_size, x_size) else: - target_dim = (num_batches, y_size) + target_dim = (z_size, y_size) - target = torch.randn(target_dim, **factory_kwargs) + grad_ouput = torch.randn(target_dim, **factory_kwargs) def train(func): - x1 = input.clone() - x2 = other.clone() - x1.requires_grad = x2.requires_grad = True - func(x1, x2).backward(target, retain_graph=True) - return [x1.grad, x2.grad] + i = x1.clone() + j = x2.clone() + i.requires_grad = j.requires_grad = True + func(i, j).backward(grad_ouput, retain_graph=True) + return i.grad, j.grad - grad_a = train(torch.nn.CosineSimilarity(dim=dim)) - grad_b = train(trident.CosineSimilarity(dim=dim)) + (x, y) = train(torch.nn.CosineSimilarity(dim)) + (a, b) = train(trident.CosineSimilarity(dim)) - assert util.equal(grad_a[0], grad_b[0]) - assert util.equal(grad_a[1], grad_b[1]) + assert util.equal(x, a) + assert util.equal(y, b) diff --git a/trident/kernel/cosine_similarity.py b/trident/kernel/cosine_similarity.py index 2d5a5e9a..e68c7c55 100644 --- a/trident/kernel/cosine_similarity.py +++ b/trident/kernel/cosine_similarity.py @@ -32,148 +32,69 @@ class CosineSimilarity: @triton.autotune(cosine_similarity_configs(), ["y_size", "x_size"]) @triton.jit def forward( - output_ptr, - denominator_ptr, - numerator_ptr, - x1_ptr, - x2_ptr, - num_batches, - y_size, - x_size, - eps, - size_along_dim, - dim: tl.constexpr, + output_ptr: tl.tensor, + denominator_ptr: tl.tensor, + numerator_ptr: tl.tensor, + x1_ptr: tl.tensor, + x2_ptr: tl.tensor, + z_size: tl.int32, + y_size: tl.int32, + x_size: tl.int32, + z_stride: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, + eps: tl.float32, + size_along_dim: tl.int32, + output_y_size: tl.int32, + output_x_size: tl.int32, dtype: tl.constexpr, block_size: tl.constexpr, ): pid = tl.program_id(0) - projected_x_size = y_size if dim == 2 else x_size - i = pid // projected_x_size - j = pid % projected_x_size - if dim == language.dim[0]: - x1_block_ptr = tl.make_block_ptr( - x1_ptr, - shape=(num_batches, y_size, x_size), - strides=(y_size * x_size, x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(2, 1, 0), - ) - x2_block_ptr = tl.make_block_ptr( - x2_ptr, - shape=(num_batches, y_size, x_size), - strides=(y_size * x_size, x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(2, 1, 0), - ) - output_block_ptr = tl.make_block_ptr( - output_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - denominator_block_ptr = tl.make_block_ptr( - denominator_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - numerator_block_ptr = tl.make_block_ptr( - numerator_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - elif dim == language.dim[1]: - x1_block_ptr = tl.make_block_ptr( - x1_ptr, - shape=(y_size, num_batches, x_size), - strides=(x_size, y_size * x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(1, 2, 0), - ) - x2_block_ptr = tl.make_block_ptr( - x2_ptr, - shape=(y_size, num_batches, x_size), - strides=(x_size, y_size * x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(1, 2, 0), - ) - output_block_ptr = tl.make_block_ptr( - output_ptr, - shape=(num_batches, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - denominator_block_ptr = tl.make_block_ptr( - denominator_ptr, - shape=(num_batches, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - numerator_block_ptr = tl.make_block_ptr( - numerator_ptr, - shape=(num_batches, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - else: - x1_block_ptr = tl.make_block_ptr( - x1_ptr, - shape=(x_size, num_batches, y_size), - strides=(1, y_size * x_size, x_size), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(0, 2, 1), - ) - x2_block_ptr = tl.make_block_ptr( - x2_ptr, - shape=(x_size, num_batches, y_size), - strides=(1, y_size * x_size, x_size), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(0, 2, 1), - ) - output_block_ptr = tl.make_block_ptr( - output_ptr, - shape=(num_batches, y_size), - strides=(y_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - denominator_block_ptr = tl.make_block_ptr( - denominator_ptr, - shape=(num_batches, y_size), - strides=(y_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - numerator_block_ptr = tl.make_block_ptr( - numerator_ptr, - shape=(num_batches, y_size), - strides=(y_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) + num_output_y = pid // output_x_size + num_output_x = pid % output_x_size + + x1_block_ptr = tl.make_block_ptr( + x1_ptr, + shape=(z_size, y_size, x_size), + strides=(z_stride, y_stride, x_stride), + offsets=(0, num_output_y, num_output_x), + block_shape=(block_size, 1, 1), + order=(2, 1, 0), + ) + x2_block_ptr = tl.make_block_ptr( + x2_ptr, + shape=(z_size, y_size, x_size), + strides=(z_stride, y_stride, x_stride), + offsets=(0, num_output_y, num_output_x), + block_shape=(block_size, 1, 1), + order=(2, 1, 0), + ) + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(output_y_size, output_x_size), + strides=(output_x_size, 1), + offsets=(num_output_y, num_output_x), + block_shape=(1, 1), + order=(1, 0), + ) + denominator_block_ptr = tl.make_block_ptr( + denominator_ptr, + shape=(output_y_size, output_x_size), + strides=(output_x_size, 1), + offsets=(num_output_y, num_output_x), + block_shape=(1, 1), + order=(1, 0), + ) + numerator_block_ptr = tl.make_block_ptr( + numerator_ptr, + shape=(output_y_size, output_x_size), + strides=(output_x_size, 1), + offsets=(num_output_y, num_output_x), + block_shape=(1, 1), + order=(1, 0), + ) denominator_accumulation1 = tl.zeros((block_size, 1, 1), tl.float32) denominator_accumulation2 = tl.zeros((block_size, 1, 1), tl.float32) @@ -205,201 +126,88 @@ def forward( @triton.autotune(cosine_similarity_configs(), ["y_size", "x_size"]) @triton.jit def backward( - grad_x1_ptr, - grad_x2_ptr, - grad_output_ptr, - denominator_ptr, - numerator_ptr, - x1_ptr, - x2_ptr, - num_batches, - y_size, - x_size, - size_along_dim, - dim: tl.constexpr, + grad_x1_ptr: tl.tensor, + grad_x2_ptr: tl.tensor, + grad_output_ptr: tl.tensor, + denominator_ptr: tl.tensor, + numerator_ptr: tl.tensor, + x1_ptr: tl.tensor, + x2_ptr: tl.tensor, + z_size: tl.int32, + y_size: tl.int32, + x_size: tl.int32, + z_stride: tl.int32, + y_stride: tl.int32, + x_stride: tl.int32, + size_along_dim: tl.int32, + output_y_size: tl.int32, + output_x_size: tl.int32, dtype: tl.constexpr, block_size: tl.constexpr, ): pid = tl.program_id(0) - projected_x_size = y_size if dim == 2 else x_size - i = pid // projected_x_size - j = pid % projected_x_size - - if dim == language.dim[0]: - grad_x1_block_ptr = tl.make_block_ptr( - grad_x1_ptr, - shape=(num_batches, y_size, x_size), - strides=(y_size * x_size, x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(2, 1, 0), - ) - grad_x2_block_ptr = tl.make_block_ptr( - grad_x2_ptr, - shape=(num_batches, y_size, x_size), - strides=(y_size * x_size, x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(2, 1, 0), - ) - grad_output_block_ptr = tl.make_block_ptr( - grad_output_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - x1_block_ptr = tl.make_block_ptr( - x1_ptr, - shape=(num_batches, y_size, x_size), - strides=(y_size * x_size, x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(2, 1, 0), - ) - x2_block_ptr = tl.make_block_ptr( - x2_ptr, - shape=(num_batches, y_size, x_size), - strides=(y_size * x_size, x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(2, 1, 0), - ) - denominator_block_ptr = tl.make_block_ptr( - denominator_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - numerator_block_ptr = tl.make_block_ptr( - numerator_ptr, - shape=(y_size, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - elif dim == language.dim[1]: - grad_x1_block_ptr = tl.make_block_ptr( - grad_x1_ptr, - shape=(y_size, num_batches, x_size), - strides=(x_size, y_size * x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(1, 2, 0), - ) - grad_x2_block_ptr = tl.make_block_ptr( - grad_x2_ptr, - shape=(y_size, num_batches, x_size), - strides=(x_size, y_size * x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(1, 2, 0), - ) - grad_output_block_ptr = tl.make_block_ptr( - grad_output_ptr, - shape=(num_batches, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - x1_block_ptr = tl.make_block_ptr( - x1_ptr, - shape=(y_size, num_batches, x_size), - strides=(x_size, y_size * x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(1, 2, 0), - ) - x2_block_ptr = tl.make_block_ptr( - x2_ptr, - shape=(y_size, num_batches, x_size), - strides=(x_size, y_size * x_size, 1), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(1, 2, 0), - ) - denominator_block_ptr = tl.make_block_ptr( - denominator_ptr, - shape=(num_batches, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - numerator_block_ptr = tl.make_block_ptr( - numerator_ptr, - shape=(num_batches, x_size), - strides=(x_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - else: - grad_x1_block_ptr = tl.make_block_ptr( - grad_x1_ptr, - shape=(x_size, num_batches, y_size), - strides=(1, y_size * x_size, x_size), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(0, 2, 1), - ) - grad_x2_block_ptr = tl.make_block_ptr( - grad_x2_ptr, - shape=(x_size, num_batches, y_size), - strides=(1, y_size * x_size, x_size), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(0, 2, 1), - ) - grad_output_block_ptr = tl.make_block_ptr( - grad_output_ptr, - shape=(num_batches, y_size), - strides=(y_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - x1_block_ptr = tl.make_block_ptr( - x1_ptr, - shape=(x_size, num_batches, y_size), - strides=(1, y_size * x_size, x_size), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(0, 2, 1), - ) - x2_block_ptr = tl.make_block_ptr( - x2_ptr, - shape=(x_size, num_batches, y_size), - strides=(1, y_size * x_size, x_size), - offsets=(0, i, j), - block_shape=(block_size, 1, 1), - order=(0, 2, 1), - ) - denominator_block_ptr = tl.make_block_ptr( - denominator_ptr, - shape=(num_batches, y_size), - strides=(y_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) - numerator_block_ptr = tl.make_block_ptr( - numerator_ptr, - shape=(num_batches, y_size), - strides=(y_size, 1), - offsets=(i, j), - block_shape=(1, 1), - order=(1, 0), - ) + num_output_y = pid // output_x_size + num_output_x = pid % output_x_size + + grad_x1_block_ptr = tl.make_block_ptr( + grad_x1_ptr, + shape=(z_size, y_size, x_size), + strides=(z_stride, y_stride, x_stride), + offsets=(0, num_output_y, num_output_x), + block_shape=(block_size, 1, 1), + order=(2, 1, 0), + ) + grad_x2_block_ptr = tl.make_block_ptr( + grad_x2_ptr, + shape=(z_size, y_size, x_size), + strides=(z_stride, y_stride, x_stride), + offsets=(0, num_output_y, num_output_x), + block_shape=(block_size, 1, 1), + order=(2, 1, 0), + ) + grad_output_block_ptr = tl.make_block_ptr( + grad_output_ptr, + shape=(output_y_size, output_x_size), + strides=(output_x_size, 1), + offsets=(num_output_y, num_output_x), + block_shape=(1, 1), + order=(1, 0), + ) + x1_block_ptr = tl.make_block_ptr( + x1_ptr, + shape=(z_size, y_size, x_size), + strides=(z_stride, y_stride, x_stride), + offsets=(0, num_output_y, num_output_x), + block_shape=(block_size, 1, 1), + order=(2, 1, 0), + ) + x2_block_ptr = tl.make_block_ptr( + x2_ptr, + shape=(z_size, y_size, x_size), + strides=(z_stride, y_stride, x_stride), + offsets=(0, num_output_y, num_output_x), + block_shape=(block_size, 1, 1), + order=(2, 1, 0), + ) + denominator_block_ptr = tl.make_block_ptr( + denominator_ptr, + shape=(output_y_size, output_x_size), + strides=(output_x_size, 1), + offsets=(num_output_y, num_output_x), + block_shape=(1, 1), + order=(1, 0), + ) + numerator_block_ptr = tl.make_block_ptr( + numerator_ptr, + shape=(output_y_size, output_x_size), + strides=(output_x_size, 1), + offsets=(num_output_y, num_output_x), + block_shape=(1, 1), + order=(1, 0), + ) for _ in range(0, size_along_dim, block_size): x1 = tl.load(x1_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) - x2 = tl.load(x2_block_ptr, boundary_check=(0,), padding_option="zero").to(tl.float32) denominator = tl.load(denominator_block_ptr) diff --git a/trident/operation/cosine_similarity.py b/trident/operation/cosine_similarity.py index 16931cff..f59eb0af 100644 --- a/trident/operation/cosine_similarity.py +++ b/trident/operation/cosine_similarity.py @@ -41,26 +41,15 @@ def __forward(x1: torch.Tensor, x2: torch.Tensor, dim: torch.int32, eps: torch.f assert x1.is_contiguous() and x2.is_contiguous() and x1.shape == x2.shape factory_kwargs = {"device": x1.device, "dtype": x1.dtype} - num_batches, y_size, x_size = x1.shape - if dim == 0: - grid_size = y_size * x_size - size_along_dim = num_batches - output = torch.empty(y_size, x_size, **factory_kwargs) - elif dim == 1: - grid_size = num_batches * x_size - size_along_dim = y_size - output = torch.empty(num_batches, x_size, **factory_kwargs) - else: - grid_size = num_batches * y_size - size_along_dim = x_size - output = torch.empty(num_batches, y_size, **factory_kwargs) - - denominator = output.clone() - numerator = output.clone() + z_size, y_size, x_size, z_stride, y_stride, x_stride = util.size_and_stride(x1, dim) + output_y_size, output_x_size, size_along_dim = CosineSimilarity.__output_size_and_size_along_dim(x1, dim) + output = torch.empty(output_y_size, output_x_size, **factory_kwargs) + denominator = torch.empty_like(output) + numerator = torch.empty_like(output) def grid(meta): - return (grid_size,) + return (output_y_size * output_x_size,) kernel.CosineSimilarity.forward[grid]( output, @@ -68,12 +57,16 @@ def grid(meta): numerator, x1, x2, - num_batches, + z_size, y_size, x_size, + z_stride, + y_stride, + x_stride, eps, size_along_dim, - dim, + output_y_size, + output_x_size, util.dtype(x1.dtype), ) @@ -81,23 +74,14 @@ def grid(meta): @staticmethod def __backward(grad_output, x1, x2, denominator, numerator, dim): - num_batches, y_size, x_size = x1.shape - grad_x1 = torch.empty_like(x1) grad_x2 = torch.empty_like(x2) - if dim == 0: - grid_size = y_size * x_size - size_along_dim = num_batches - elif dim == 1: - grid_size = num_batches * x_size - size_along_dim = y_size - else: - grid_size = num_batches * y_size - size_along_dim = x_size + z_size, y_size, x_size, z_stride, y_stride, x_stride = util.size_and_stride(x1, dim) + output_y_size, output_x_size, size_along_dim = CosineSimilarity.__output_size_and_size_along_dim(x1, dim) def grid(meta): - return (grid_size,) + return (output_y_size * output_x_size,) kernel.CosineSimilarity.backward[grid]( grad_x1, @@ -107,12 +91,32 @@ def grid(meta): numerator, x1, x2, - num_batches, + z_size, y_size, x_size, + z_stride, + y_stride, + x_stride, size_along_dim, - dim, + output_y_size, + output_x_size, util.dtype(x1.dtype), ) return grad_x1, grad_x2, None, None + + @staticmethod + def __output_size_and_size_along_dim(input: torch.Tensor, dim: int): + z_size, y_size, x_size = input.shape + + if dim == 0: + output_y_size, output_x_size = y_size, x_size + size_along_dim = z_size + elif dim == 1: + output_y_size, output_x_size = z_size, x_size + size_along_dim = y_size + else: + output_y_size, output_x_size = z_size, y_size + size_along_dim = x_size + + return output_y_size, output_x_size, size_along_dim diff --git a/trident/util/util.py b/trident/util/util.py index b5c810b1..a5d4eb97 100644 --- a/trident/util/util.py +++ b/trident/util/util.py @@ -37,16 +37,31 @@ def dtype(input): def size_and_stride(input: torch.Tensor, dim: int): - if dim == 0: - x_size, y_size = input.shape - y_stride = input.stride(1) - x_stride = input.stride(0) + if input.dim() == 2: + if dim == 0: + x_size, y_size = input.shape + y_stride = input.stride(1) + x_stride = input.stride(0) + else: + y_size, x_size = input.shape + y_stride = input.stride(0) + x_stride = input.stride(1) + + return y_size, x_size, y_stride, x_stride + elif input.dim() == 3: + if dim == 0: + z_size, y_size, x_size = input.shape[0], input.shape[1], input.shape[2] + z_stride, y_stride, x_stride = input.stride(0), input.stride(1), input.stride(2) + elif dim == 1: + z_size, y_size, x_size = input.shape[1], input.shape[0], input.shape[2] + z_stride, y_stride, x_stride = input.stride(1), input.stride(0), input.stride(2) + else: + z_size, y_size, x_size = input.shape[2], input.shape[0], input.shape[1] + z_stride, y_stride, x_stride = input.stride(2), input.stride(0), input.stride(1) + + return z_size, y_size, x_size, z_stride, y_stride, x_stride else: - y_size, x_size = input.shape - y_stride = input.stride(0) - x_stride = input.stride(1) - - return y_size, x_size, y_stride, x_stride + raise ValueError(f"{dim} is not supported.") def optimize_module(mod):