From 8273148758bf7b42b9c64b5d179bb4af3806e6c0 Mon Sep 17 00:00:00 2001 From: taozha2 Date: Mon, 16 Dec 2024 09:32:14 +0800 Subject: [PATCH] rebase --- examples/sycl/pvc/pvc_gemm.cpp | 34 +++++------ include/cutlass/gemm/collective/xe_mma.hpp | 70 +++++++++++++--------- 2 files changed, 56 insertions(+), 48 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 2a27529d9..5ecfea4a3 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -45,7 +45,7 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -#include "common.h" +#include "common.hpp" using namespace cute; @@ -227,7 +227,10 @@ struct ExampleRunner { size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - gemm_op.can_implement(arguments); + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } gemm_op.initialize(arguments, workspace.get()); @@ -286,16 +289,13 @@ static constexpr auto gemm_run(Options const& options) { using ElementInputB = b_type; // <- data type of elements in input matrix B using ElementOutput = c_type; // <- data type of elements in output matrix D - // using LayoutA = cutlass::layout::ColumnMajor; - // using LayoutB = cutlass::layout::RowMajor; + using LayoutA = std::conditional_t; + using LayoutB = std::conditional_t; using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - // using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; // row major load - // using GmemTiledCopyA = XE_2D_U16x16x16_LD_T; // column major load - - // using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // row major load - // using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; // column major load + using copy_traits_a = std::conditional_t; + using copy_traits_b = std::conditional_t; // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; @@ -326,23 +326,17 @@ static constexpr auto gemm_run(Options const& options) { XE_2D_U32x8x16_ST_N, void, void>; -// Mainloop + // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< GEMMDispatchPolicy, TileShape, ElementInputA, - cutlass::gemm::TagToStrideA_t>, + cutlass::gemm::TagToStrideA_t, ElementInputB, - cutlass::gemm::TagToStrideB_t>, + cutlass::gemm::TagToStrideB_t, TiledMma, - std::conditional_t, // A - void, void, cute::identity, // A - std::conditional_t, // B - void, void, cute::identity // B + copy_traits_a, void, void, cute::identity, // A + copy_traits_b, void, void, cute::identity // B >; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 4d9862f3e..adba580df 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -214,12 +214,15 @@ struct CollectiveMma< /// Perform a subgroup-scoped matrix multiply-accumulate template < + int PrefetchStrideA, + int PrefetchStrideB, class FrgTensorD, class TensorA, class TensorB, class FrgTensorC, class KTileIterator, - class ResidueMNK + class ResidueMNK, + class BlkCoord > CUTLASS_DEVICE void operator() ( @@ -229,6 +232,8 @@ struct CollectiveMma< FrgTensorC const &src_accum, KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, + BlkCoord const &blk_coord, + int const &K, int thread_idx, char *smem_buf, Params const& mainloop) @@ -284,11 +289,15 @@ struct CollectiveMma< // // Mainloop // - int sub_group_id = get_sub_group_id(); - const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M; - const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N; - const int l_coord = BlockIdxZ(); - + auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; + #ifdef CUTLASS_SYCL_SWITCH_WG + const int m_coord = n_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = m_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + #else + const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + #endif + const int l_coord = l_idx; Tensor iter_a = get_pvc_tensor_a(mainloop.gmem_tiled_copy_a, m_coord, 0, l_coord, append<4>(tCrA_copy_view.shape(), k_tile_count), @@ -298,44 +307,49 @@ struct CollectiveMma< append<4>(tCrB_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K)); + const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K)); + int prefetch_k = 0; + Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( - make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), - (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{}) * get<1>(PrefetchATileSize{}), l_coord), + make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), + (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( - make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) * get<0>(PrefetchBTileSize{}), - n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), + make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, + n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); -#pragma unroll - for (int i = 0; i < DispatchPolicy::Stages; i++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,i)); + prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); } if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,i)); + prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); } } -#pragma unroll - for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { - // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile), tCrA_copy_view); - copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile), tCrB_copy_view); - if(k_tile + DispatchPolicy::Stages < k_tile_count) { - if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,k_tile + DispatchPolicy::Stages)); + CUTLASS_PRAGMA_UNROLL + for (int k_tile = 0, k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) { + // Copy gmem to rmem for the first k_tile + copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k), tCrA_copy_view); + copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k), tCrB_copy_view); + + if(prefetch_k < k_tile_count) { + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); + } + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); + } } - if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); + + for (int i = 0; i < SG_K / SubgroupSize; i++) { + cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); } } - for (int i = 0; i < SG_K / SubgroupSize; i++) { - cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); - } - } } };