diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index b4023c0b40..8e92e80b84 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -36,6 +36,23 @@ namespace cute auto [y, x] = src.data().coord_; XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data()); } + }; + + + template + struct Copy_Traits + { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails + using ThrID = Layout<_1>; + using NumBits = Int; // hacky: does vec of 8 + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; // TODO: is _1 correct? + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + GTensor tensor; template @@ -52,11 +69,11 @@ namespace cute } }; - template + template auto make_xe_2d_copy(Tensor gtensor) { using GTensor = Tensor; - using Traits = Copy_Traits; + using Traits = Copy_Traits; Traits traits{gtensor}; return Copy_Atom{traits}; }