diff --git a/testing/python/language/test_tilelang_language_vectorize.py b/testing/python/language/test_tilelang_language_vectorize.py new file mode 100644 index 000000000..cee8b5a63 --- /dev/null +++ b/testing/python/language/test_tilelang_language_vectorize.py @@ -0,0 +1,63 @@ +import torch +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True}) +def vectorize_test(N, M, stride_A, stride_B): + assert N % 128 == 0 and M % 128 == 0 + + @T.prim_func + def main( + A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 + B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 + ): + with T.Kernel(M // 128, threads=128) as (bx): + tx = T.get_thread_binding(0) + col = bx * 128 + tx + + for row in T.vectorized(N): + B[row, col] = A[row, col] + + return main + + +def run_vectorize(N, M, stride_A, stride_B): + assert stride_A >= N and stride_B >= N + + jit_kernel = vectorize_test(N, M, stride_A, stride_B) + + base_a = torch.randn(stride_A, M, device="cuda", dtype=torch.float32) + base_b = torch.zeros(stride_B, M, device="cuda", dtype=torch.float32) + a = torch.as_strided(base_a, size=(N, M), stride=(1, stride_A)) + b = torch.as_strided(base_b, size=(N, M), stride=(1, stride_B)) + + jit_kernel(a, b) + + torch.testing.assert_close(a, b, atol=1e-8, rtol=1e-8) + + code = jit_kernel.get_kernel_source() + + vectorize_size = 1 + while vectorize_size <= 2 and \ + stride_A % (vectorize_size * 2) == 0 and \ + stride_B % (vectorize_size * 2) == 0: + vectorize_size *= 2 + + if vectorize_size == 4: + assert "float4" in code + elif vectorize_size == 2: + assert "float2" in code + + +def test_vectorize(): + N, M = 512, 256 + + run_vectorize(N, M, N, N) + run_vectorize(N, M, N + 2, N + 4) + run_vectorize(N, M, N + 4, N + 8) + run_vectorize(N, M, N + 8, N + 16) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 4f854ba27..83513f7a1 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -178,9 +178,6 @@ def __call__(self, scope=None) -> tir.Buffer: if len(shape) != len(strides): raise ValueError("Invalid shape/strides' dimensions") - if not bool(strides[-1] == 1): - # TODO(chenggang): shall we support non-contiguous even for the last dimension? - raise ValueError("The stride of the last dimension must be 1 (contiguous)") return super().__call__(shape, dtype=dtype, strides=strides, scope=scope)