Skip to content

Commit

Permalink
Cooperative prefetch (#151)
Browse files Browse the repository at this point in the history
* hot fix: batch gemm & revert to previous config

* cooperative prefetch selector
perf 310tflops 4K gemm

---------

Co-authored-by: Alejandro Acosta <[email protected]>
  • Loading branch information
jiyang1011 and aacostadiaz authored Nov 7, 2024
1 parent 68e1449 commit 907d9f9
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 7 deletions.
86 changes: 86 additions & 0 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,24 @@ struct Copy_Traits<XE_2D_U8x16x64_LD_N, args_t...>
Stride<_16,Stride< _1,_256,_512>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
template <class... ArgT>
Copy_Traits(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x16x64_LD_N, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits<XE_2D_U8x16x64_LD_N::PREFETCH, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x16x64_LD_N::PREFETCH, args_t...> {
using Shape_MN = Shape<_16, _64>;
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_16, _2, _16>>,
Stride<_16,Stride< _1,_256,_512>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,Shape <_16, _2, _16>>,
Stride<_16,Stride< _1,_256,_512>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};

template <class... args_t>
Expand All @@ -504,6 +522,24 @@ struct Copy_Traits<XE_2D_U8x32x64_LD_N, args_t...>
Stride<_16,Stride< _1,_256,_512>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
template <class... ArgT>
Copy_Traits(ArgT... args)
: XE_2D_LD_Unpack<XE_2D_U8x32x64_LD_N, args_t...>(args...) {}
};

template <class... args_t>
struct Copy_Traits<XE_2D_U8x32x64_LD_N::PREFETCH, args_t...>
: XE_2D_LD_Unpack<XE_2D_U8x32x64_LD_N::PREFETCH, args_t...> {
using Shape_MN = Shape<_32, _64>;
using ThrID = Layout<_16>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape <_16,Shape <_16, _2, _32>>,
Stride<_16,Stride< _1,_256,_512>>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape <_16,Shape <_16, _2, _32>>,
Stride<_16,Stride< _1,_256,_512>>>;
// Reference map from (thr,val) to bit
using RefLayout = DstLayout;
};

template <class... args_t>
Expand Down Expand Up @@ -1933,4 +1969,54 @@ struct Copy_Traits<XE_1D_STORE_GLOBAL<S, D>> {
using RefLayout = SrcLayout;
};

namespace detail
{
template<class PrefetchTileSize, class dtype>
auto prefetch_selector(void* ptr = nullptr, int32_t width = 0, int32_t height = 0, int32_t pitch = 0) {
if constexpr (get<0>(PrefetchTileSize{}) == 1) {
using prefetch_trait = Copy_Traits<XE_2D_U8x1x64_LD_N>;
using prefetch_atom = Copy_Atom<prefetch_trait, dtype>;
return make_tiled_copy(prefetch_atom{}.with(static_cast<dtype const*>(ptr), width, height, pitch),
Layout<Shape<_1, _16>>{},
Layout<Shape<_2, _2>>{});
}
if constexpr (get<0>(PrefetchTileSize{}) == 2) {
using prefetch_trait = Copy_Traits<XE_2D_U8x2x64_LD_N>;
using prefetch_atom = Copy_Atom<prefetch_trait, dtype>;
return make_tiled_copy(prefetch_atom{}.with(static_cast<dtype const*>(ptr), width, height, pitch),
Layout<Shape<_1, _16>>{},
Layout<Shape<_2, _2, _2>>{});
}
if constexpr (get<0>(PrefetchTileSize{}) == 4) {
using prefetch_trait = Copy_Traits<XE_2D_U8x4x64_LD_N>;
using prefetch_atom = Copy_Atom<prefetch_trait, dtype>;
return make_tiled_copy(prefetch_atom{}.with(static_cast<dtype const*>(ptr), width, height, pitch),
Layout<Shape<_1, _16>>{},
Layout<Shape<_2, _2, _4>>{});
}
if constexpr (get<0>(PrefetchTileSize{}) == 8) {
using prefetch_trait = Copy_Traits<XE_2D_U8x8x64_LD_N>;
using prefetch_atom = Copy_Atom<prefetch_trait, dtype>;
return make_tiled_copy(prefetch_atom{}.with(static_cast<dtype const*>(ptr), width, height, pitch),
Layout<Shape<_1, _16>>{},
Layout<Shape<_2, _2, _8>>{});
}
if constexpr (get<0>(PrefetchTileSize{}) == 16) {
// static_assert(false);
using prefetch_trait = Copy_Traits<XE_2D_U8x16x64_LD_N>;
using prefetch_atom = Copy_Atom<prefetch_trait, dtype>;
return make_tiled_copy(prefetch_atom{}.with(static_cast<dtype const*>(ptr), width, height, pitch),
Layout<Shape<_1, _16>>{},
Layout<Shape<_2, _2, _16>>{});
}
if constexpr (get<0>(PrefetchTileSize{}) == 32) {
using prefetch_trait = Copy_Traits<XE_2D_U8x32x64_LD_N>;
using prefetch_atom = Copy_Atom<prefetch_trait, dtype>;
return make_tiled_copy(prefetch_atom{}.with(static_cast<dtype const*>(ptr), width, height, pitch),
Layout<Shape<_1, _16>>{},
Layout<Shape<_2, _2, _32>>{});
}
}
} // end namespace detail

} // end namespace cute
49 changes: 42 additions & 7 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ struct CollectiveMma<
static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N);
static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K);
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;

static constexpr size_t cacheline_bytes = 64;
static constexpr auto block_size_w_a = cute::min(SG_K, cacheline_bytes / sizeof(ElementA));
static constexpr auto block_size_w_b = cute::min(SG_N, cacheline_bytes / sizeof(ElementB));
static constexpr auto nums_block_w_a = ceil_div(SG_K, block_size_w_a);
static constexpr auto nums_block_w_b = ceil_div(SG_N, block_size_w_b);
using PrefetchAThrShape = Shape<Int<ATOM_N /cute::gcd(ATOM_N, nums_block_w_a)>, Int<cute::gcd(ATOM_N, nums_block_w_a)>>;
using PrefetchBThrShape = Shape<Int<ATOM_M /cute::gcd(ATOM_M, nums_block_w_b)>, Int<cute::gcd(ATOM_M, nums_block_w_b)>>;
using PrefetchATileSize = decltype(ceil_div(Shape<Int<SG_M>, Int<SG_K>>{},PrefetchAThrShape{}));
using PrefetchBTileSize = decltype(ceil_div(Shape<Int<SG_K>, Int<SG_N>>{},PrefetchBThrShape{}));

static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
using traits_load_A = Copy_Traits<GmemTiledCopyA>;
Expand All @@ -131,6 +141,8 @@ struct CollectiveMma<
make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}),
get<1>(typename traits_load_B::Shape_MN{}) / Int<SubgroupSize>{}))));

using XE_Prefetch_A = decltype(cute::detail::prefetch_selector<PrefetchATileSize, ElementA>());
using XE_Prefetch_B = decltype(cute::detail::prefetch_selector<PrefetchBTileSize, ElementB>());
// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A;
Expand All @@ -142,6 +154,8 @@ struct CollectiveMma<
struct Params {
XE_Copy_A gmem_tiled_copy_a;
XE_Copy_B gmem_tiled_copy_b;
XE_Prefetch_A gmem_prefetch_a;
XE_Prefetch_B gmem_prefetch_b;
};

//
Expand All @@ -166,7 +180,9 @@ struct CollectiveMma<
Layout<Shape<_1, Int<SubgroupSize>>>{},
make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}),
get<1>(typename traits_load_B::Shape_MN{}) / Int<SubgroupSize>{})));
return Params{copyA, copyB};
XE_Prefetch_A prefetchA = cute::detail::prefetch_selector<PrefetchATileSize,ElementA>((void *)args.ptr_A, K, M, K);
XE_Prefetch_B prefetchB = cute::detail::prefetch_selector<PrefetchBTileSize,ElementB>((void *)args.ptr_B, N, K, N);
return Params{copyA, copyB, prefetchA, prefetchB};
}

/// Perform a subgroup-scoped matrix multiply-accumulate
Expand Down Expand Up @@ -230,41 +246,60 @@ struct CollectiveMma<
print("===================== Config: \n");
print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n");
print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n");

print(" PrefetchAThrShape : ");print(PrefetchAThrShape{});print("\n");
print(" PrefetchBThrShape : ");print(PrefetchBThrShape{});print("\n");
print(" PrefetchATileSize : ");print(PrefetchATileSize{});print("\n");
print(" PrefetchBTileSize : ");print(PrefetchBTileSize{});print("\n");
}
#endif

//
// Mainloop
//
const int m_coord = BlockIdxY() * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
const int n_coord = BlockIdxX() * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
int sub_group_id = get_sub_group_id();
const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M;
const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N;
const int l_coord = BlockIdxZ();
Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count),
append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{});
Tensor iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor(
make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count),
append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{});

Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor(
make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{}) * get<1>(PrefetchATileSize{}), l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{});
Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor(
make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) * get<0>(PrefetchBTileSize{}),
n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{});

#pragma unroll
for (int i = 0; i < DispatchPolicy::Stages; i++) {
if constexpr(cute::detail::has_prefetch<GmemTiledCopyA>) {
prefetch(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,i));
prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,i));
}
if constexpr(cute::detail::has_prefetch<GmemTiledCopyB>) {
prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,i));
prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,i));
}
}
#pragma unroll
for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile), tCrA_copy_view);
copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile), tCrB_copy_view);

if(k_tile + DispatchPolicy::Stages < k_tile_count) {
if constexpr(cute::detail::has_prefetch<GmemTiledCopyA>) {
prefetch(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile + DispatchPolicy::Stages));
prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,k_tile + DispatchPolicy::Stages));
}
if constexpr(cute::detail::has_prefetch<GmemTiledCopyB>) {
prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile + DispatchPolicy::Stages));
prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,k_tile + DispatchPolicy::Stages));
}
}
for (int i = 0; i < SG_K / SubgroupSize; i++) {
Expand Down

0 comments on commit 907d9f9

Please sign in to comment.