Skip to content

Commit

Permalink
Make atom type a make_2d_copy argument
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandschulz committed Apr 18, 2024
1 parent b4c3eb1 commit 037006a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
6 changes: 3 additions & 3 deletions examples/cute/tutorial/pvc_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ static void go_dpas_blockread_vnni_tiled(
Tensor tBr = make_tensor<uint>(Shape<_8, Int<NN>>{});
Tensor tCr = make_tensor<float>(Shape<_8, Int<MM>, Int<NN>>{});

auto A_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(A), make_shape(M, K)));
auto B_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(B), make_shape(K, N)));
auto C_copy = make_xe_2d_copy(make_tensor(make_gmem_ptr(C), make_shape(M, N)));
auto A_copy = make_xe_2d_copy<XE_2D_LOAD>(make_tensor(make_gmem_ptr(A), make_shape(M, K)));
auto B_copy = make_xe_2d_copy<XE_2D_LOAD>(make_tensor(make_gmem_ptr(B), make_shape(K, N)));
auto C_copy = make_xe_2d_copy<XE_2D_SAVE>(make_tensor(make_gmem_ptr(C), make_shape(M, N)));
//TODO: - decide on how to deal with vector types
// - create layouts with tiling/partitioning

Expand Down
21 changes: 19 additions & 2 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ namespace cute
auto [y, x] = src.data().coord_;
XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data());
}
};


template <class GTensor>
struct Copy_Traits<XE_2D_SAVE, GTensor>
{
// using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails
using ThrID = Layout<_1>;
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>; // hacky: does vec of 8
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBits>>; // TODO: is _1 correct?
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;

GTensor tensor;

template <class TS, class SLayout,
class TD, class DLayout>
Expand All @@ -52,11 +69,11 @@ namespace cute
}
};

template <class GEngine, class GLayout>
template <class Copy, class GEngine, class GLayout>
auto make_xe_2d_copy(Tensor<GEngine, GLayout> gtensor)
{
using GTensor = Tensor<GEngine, GLayout>;
using Traits = Copy_Traits<XE_2D_LOAD, GTensor>;
using Traits = Copy_Traits<Copy, GTensor>;
Traits traits{gtensor};
return Copy_Atom<Traits, typename GEngine::value_type>{traits};
}
Expand Down

0 comments on commit 037006a

Please sign in to comment.