From b083ad0afeb51c82a2900f7982c3c768a624bd76 Mon Sep 17 00:00:00 2001 From: rolandschulz Date: Wed, 17 Apr 2024 18:58:59 -0700 Subject: [PATCH] Make atom type a make_2d_copy argument --- examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp | 9 +++----- include/cute/atom/copy_traits_xe.hpp | 23 +++++++++++++++++--- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp b/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp index fd8943477d..56f60b8cd4 100644 --- a/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp +++ b/examples/cute/tutorial/pvc_sycl/pvc_sycl.cpp @@ -186,12 +186,9 @@ go_dpas_blockread_vnni_tiled(sycl::queue queue, std::vector &c_vec, Tensor tCr = make_tensor(Shape<_8, Int, Int>{}); - auto A_copy = make_xe_2d_copy( - make_tensor(make_gmem_ptr(A), make_shape(M, K))); - auto B_copy = make_xe_2d_copy( - make_tensor(make_gmem_ptr(B), make_shape(K, N))); - auto C_copy = make_xe_2d_copy( - make_tensor(make_gmem_ptr(C), make_shape(M, N))); + auto A_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(A), make_shape(M, K))); + auto B_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(B), make_shape(K, N))); + auto C_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(C), make_shape(M, N))); // TODO: - decide on how to deal with vector types // - create layouts with tiling/partitioning diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index b4023c0b40..25f797306d 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -36,6 +36,23 @@ namespace cute auto [y, x] = src.data().coord_; XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data()); } + }; + + + template + struct Copy_Traits + { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails + using ThrID = Layout<_1>; + using NumBits = Int; // hacky: does vec of 8 + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; // TODO: is _1 correct? + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + GTensor tensor; template @@ -52,12 +69,12 @@ namespace cute } }; - template + template auto make_xe_2d_copy(Tensor gtensor) { using GTensor = Tensor; - using Traits = Copy_Traits; + using Traits = Copy_Traits; Traits traits{gtensor}; return Copy_Atom{traits}; } -} +} \ No newline at end of file