From 907d9f95bc26bb7957606397c72d04ac12c91948 Mon Sep 17 00:00:00 2001 From: jiyang1011 <110882834+jiyang1011@users.noreply.github.com> Date: Fri, 8 Nov 2024 00:21:25 +0800 Subject: [PATCH] Cooperative prefetch (#151) * hot fix: batch gemm & revert to previous config * cooperative prefetch selector perf 310tflops 4K gemm --------- Co-authored-by: Alejandro Acosta --- include/cute/atom/copy_traits_xe.hpp | 86 ++++++++++++++++++++++ include/cutlass/gemm/collective/xe_mma.hpp | 49 ++++++++++-- 2 files changed, 128 insertions(+), 7 deletions(-) diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index d645cb91b..4d9eaf6b4 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -489,6 +489,24 @@ struct Copy_Traits Stride<_16,Stride< _1,_256,_512>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; }; template @@ -504,6 +522,24 @@ struct Copy_Traits Stride<_16,Stride< _1,_256,_512>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; }; template @@ -1933,4 +1969,54 @@ struct Copy_Traits> { using RefLayout = SrcLayout; }; +namespace detail +{ + template + auto prefetch_selector(void* ptr = nullptr, int32_t width = 0, int32_t height = 0, int32_t pitch = 0) { + if constexpr (get<0>(PrefetchTileSize{}) == 1) { + using prefetch_trait = Copy_Traits; + using prefetch_atom = Copy_Atom; + return make_tiled_copy(prefetch_atom{}.with(static_cast(ptr), width, height, pitch), + Layout>{}, + Layout>{}); + } + if constexpr (get<0>(PrefetchTileSize{}) == 2) { + using prefetch_trait = Copy_Traits; + using prefetch_atom = Copy_Atom; + return make_tiled_copy(prefetch_atom{}.with(static_cast(ptr), width, height, pitch), + Layout>{}, + Layout>{}); + } + if constexpr (get<0>(PrefetchTileSize{}) == 4) { + using prefetch_trait = Copy_Traits; + using prefetch_atom = Copy_Atom; + return make_tiled_copy(prefetch_atom{}.with(static_cast(ptr), width, height, pitch), + Layout>{}, + Layout>{}); + } + if constexpr (get<0>(PrefetchTileSize{}) == 8) { + using prefetch_trait = Copy_Traits; + using prefetch_atom = Copy_Atom; + return make_tiled_copy(prefetch_atom{}.with(static_cast(ptr), width, height, pitch), + Layout>{}, + Layout>{}); + } + if constexpr (get<0>(PrefetchTileSize{}) == 16) { + // static_assert(false); + using prefetch_trait = Copy_Traits; + using prefetch_atom = Copy_Atom; + return make_tiled_copy(prefetch_atom{}.with(static_cast(ptr), width, height, pitch), + Layout>{}, + Layout>{}); + } + if constexpr (get<0>(PrefetchTileSize{}) == 32) { + using prefetch_trait = Copy_Traits; + using prefetch_atom = Copy_Atom; + return make_tiled_copy(prefetch_atom{}.with(static_cast(ptr), width, height, pitch), + Layout>{}, + Layout>{}); + } + } +} // end namespace detail + } // end namespace cute diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 06fa815e6..b08b79c2f 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -114,6 +114,16 @@ struct CollectiveMma< static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); using SubgroupTileShape = Shape; + + static constexpr size_t cacheline_bytes = 64; + static constexpr auto block_size_w_a = cute::min(SG_K, cacheline_bytes / sizeof(ElementA)); + static constexpr auto block_size_w_b = cute::min(SG_N, cacheline_bytes / sizeof(ElementB)); + static constexpr auto nums_block_w_a = ceil_div(SG_K, block_size_w_a); + static constexpr auto nums_block_w_b = ceil_div(SG_N, block_size_w_b); + using PrefetchAThrShape = Shape, Int>; + using PrefetchBThrShape = Shape, Int>; + using PrefetchATileSize = decltype(ceil_div(Shape, Int>{},PrefetchAThrShape{})); + using PrefetchBTileSize = decltype(ceil_div(Shape, Int>{},PrefetchBThrShape{})); static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); using traits_load_A = Copy_Traits; @@ -131,6 +141,8 @@ struct CollectiveMma< make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), get<1>(typename traits_load_B::Shape_MN{}) / Int{})))); + using XE_Prefetch_A = decltype(cute::detail::prefetch_selector()); + using XE_Prefetch_B = decltype(cute::detail::prefetch_selector()); // Host side kernel arguments struct Arguments { ElementA const* ptr_A; @@ -142,6 +154,8 @@ struct CollectiveMma< struct Params { XE_Copy_A gmem_tiled_copy_a; XE_Copy_B gmem_tiled_copy_b; + XE_Prefetch_A gmem_prefetch_a; + XE_Prefetch_B gmem_prefetch_b; }; // @@ -166,7 +180,9 @@ struct CollectiveMma< Layout>>{}, make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); - return Params{copyA, copyB}; + XE_Prefetch_A prefetchA = cute::detail::prefetch_selector((void *)args.ptr_A, K, M, K); + XE_Prefetch_B prefetchB = cute::detail::prefetch_selector((void *)args.ptr_B, N, K, N); + return Params{copyA, copyB, prefetchA, prefetchB}; } /// Perform a subgroup-scoped matrix multiply-accumulate @@ -230,14 +246,20 @@ struct CollectiveMma< print("===================== Config: \n"); print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + + print(" PrefetchAThrShape : ");print(PrefetchAThrShape{});print("\n"); + print(" PrefetchBThrShape : ");print(PrefetchBThrShape{});print("\n"); + print(" PrefetchATileSize : ");print(PrefetchATileSize{});print("\n"); + print(" PrefetchBTileSize : ");print(PrefetchBTileSize{});print("\n"); } #endif // // Mainloop // - const int m_coord = BlockIdxY() * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; - const int n_coord = BlockIdxX() * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + int sub_group_id = get_sub_group_id(); + const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M; + const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N; const int l_coord = BlockIdxZ(); Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), @@ -245,13 +267,25 @@ struct CollectiveMma< Tensor iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor( make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{}); + + Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( + make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), + (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{}) * get<1>(PrefetchATileSize{}), l_coord), + append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), + append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); + Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( + make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) * get<0>(PrefetchBTileSize{}), + n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), + append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), + append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); + #pragma unroll for (int i = 0; i < DispatchPolicy::Stages; i++) { if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,i)); + prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,i)); } if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,i)); + prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,i)); } } #pragma unroll @@ -259,12 +293,13 @@ struct CollectiveMma< // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile), tCrA_copy_view); copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile), tCrB_copy_view); + if(k_tile + DispatchPolicy::Stages < k_tile_count) { if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile + DispatchPolicy::Stages)); + prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,k_tile + DispatchPolicy::Stages)); } if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); + prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); } } for (int i = 0; i < SG_K / SubgroupSize; i++) {