From b42305f42a0e12901ea2512b59101e683160e82f Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Tue, 18 Jun 2024 15:27:53 +0100 Subject: [PATCH] Add workgroup level TileShape (#84) * Add workgroup-level tile * Rename tile shapes * Rename mma shape * Remove unused code * Update benchmark --- ...ench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp | 20 ++++---- .../sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 5 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 28 ++++++----- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 48 ++++++++----------- 4 files changed, 50 insertions(+), 51 deletions(-) diff --git a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp index 67b76929d..6d36bb4d4 100644 --- a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp +++ b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp @@ -67,10 +67,8 @@ int main(int argc, const char** argv) // to use a GPU other than that with device ID 0. hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - bool passed; - - // The code section below describes datatype for input, output matrices and computation between - // elements in input matrices. +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = float; // <- data type of epilogue operations using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A @@ -82,16 +80,20 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + // Workgroup-level tile + using TileShape = Shape<_32, _256, _32>; - using TileShape = Shape<_32, _64, _32>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, + Tile<_32,_64,_32>>; // Subgroup level-tile - using TiledMma = TiledMMA, - Layout>>; + using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // <- data type of output matrix 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 542bafb34..731edfa15 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -353,11 +353,12 @@ int main(int argc, const char** argv) using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; - using TileShape = Shape<_1, _1, _1>; + // Workgroup-level tile + using TileShape = Shape<_32, _256, _32>; using TiledMma = TiledMMA, Layout>, - Tile<_32,_64,_32>>; + Tile<_32,_64,_32>>; // Subgroup level-tile using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index d587fbcd9..c552ee861 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -81,7 +81,7 @@ struct CollectiveMma< // Type Aliases // using DispatchPolicy = MainloopIntelPVCUnpredicated; - using TileShape = TileShape_; + using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; using ElementB = ElementB_; @@ -100,21 +100,22 @@ struct CollectiveMma< static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - using DpasShape = typename TiledMma::Shape_MNK; - using TileDpasShape = decltype(tile_shape(TiledMma())); + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + using SubgroupTileShape = decltype(tile_shape(TiledMma())); - static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape()); + static constexpr uint32_t MaxThreadsPerBlock = + cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; - static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group - static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape()); + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape()); // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by // the sub_group size. - static constexpr int VecC = (get<1>(DpasShape()) * get<0>(DpasShape())) / SubgroupSize; - static constexpr int VecA = (get<0>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize; - static constexpr int VecB = (get<1>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize; + static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; // Host side kernel arguments struct Arguments { @@ -186,8 +187,9 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(TileDpasShape{}) * FragsK>, Int<1>>{}); - Tensor tBr = make_tensor(Shape(TileDpasShape{}) / FragsN>, Int>{}); + Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor( + Shape(SubgroupTileShape{}) / FragsN>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); @@ -200,7 +202,7 @@ struct CollectiveMma< // // Mainloop // - for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(DpasShape()) * FragsK) + for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(MmaAtomShape()) * FragsK) { // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 1a9185437..5c9b6d019 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -65,7 +65,8 @@ class GemmUniversal< // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; using TiledMma = typename CollectiveMainloop::TiledMma; using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; @@ -81,7 +82,7 @@ class GemmUniversal< "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, + TileScheduler_, ArchTag, WorkgroupTileShape, cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; @@ -101,13 +102,9 @@ class GemmUniversal< static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; - - static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group - - using DpasShape = typename CollectiveMainloop::DpasShape; - using TileDpasShape = typename CollectiveMainloop::TileDpasShape; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -175,16 +172,10 @@ class GemmUniversal< batch_count = cute::size<3>(params.problem_shape); } - auto M = get<0>(params.problem_shape); - auto N = get<1>(params.problem_shape); - - const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments - const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments - return dim3( - sg_m, - cute::ceil_div(sg_n, num_sg), - batch_count + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), + batch_count ); } @@ -200,7 +191,7 @@ class GemmUniversal< (void)smem_buf; // Preconditions - CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) @@ -218,18 +209,21 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) const int m_coord = BlockIdxX() * get<0>(subgroup_shape); - const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape); + const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); - Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0), - make_shape(_1{}, K, L), - make_stride(Int{} * get<0>(DpasShape()), _1{})); + Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( + make_coord(m_coord, 0, 0), + make_shape(_1{}, K, L), + make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); - Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, 0), - make_shape(K, Int{}, L), - make_stride(_1{}, get<1>(DpasShape()))); + Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( + make_coord(0, n_coord, 0), + make_shape(K, Int{}, L), + make_stride(_1{}, get<1>(MmaAtomShape()))); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord @@ -263,7 +257,7 @@ class GemmUniversal< Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0), make_shape(Int{}, Int{}, L), - make_stride(get<0>(DpasShape()), get<1>(DpasShape()))); + make_stride(get<0>(MmaAtomShape()), get<1>(MmaAtomShape()))); copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); }