From 8c72fd5da61baa19a88d4dd948256fdb03fcb886 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Thu, 13 Jun 2024 13:45:55 +0100 Subject: [PATCH] Rename tile shapes --- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 18 +++++++++------- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 21 ++++++++++--------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 8f2e4cb34e..0ee08a0abc 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -81,7 +81,7 @@ struct CollectiveMma< // Type Aliases // using DispatchPolicy = MainloopIntelPVCUnpredicated; - using TileShape = TileShape_; + using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; using ElementB = ElementB_; @@ -101,13 +101,14 @@ struct CollectiveMma< static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using DpasShape = typename TiledMma::Shape_MNK; - using TileDpasShape = decltype(tile_shape(TiledMma())); + using SubgroupTileShape = decltype(tile_shape(TiledMma())); - static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize; + static constexpr uint32_t MaxThreadsPerBlock = + cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; - static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group - static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape()); + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(DpasShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(DpasShape()); // B frags per sub_group + static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(DpasShape()); // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by @@ -186,8 +187,9 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(TileDpasShape{}) * FragsK>, Int<1>>{}); - Tensor tBr = make_tensor(Shape(TileDpasShape{}) / FragsN>, Int>{}); + Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor( + Shape(SubgroupTileShape{}) / FragsN>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index db9aa03514..1f823297de 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -65,7 +65,8 @@ class GemmUniversal< // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; using TiledMma = typename CollectiveMainloop::TiledMma; using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; @@ -81,7 +82,7 @@ class GemmUniversal< "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, + TileScheduler_, ArchTag, WorkgroupTileShape, cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; @@ -103,7 +104,7 @@ class GemmUniversal< static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; using DpasShape = typename CollectiveMainloop::DpasShape; - using TileDpasShape = typename CollectiveMainloop::TileDpasShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -174,12 +175,12 @@ class GemmUniversal< auto M = get<0>(params.problem_shape); auto N = get<1>(params.problem_shape); - const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments - const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments + const int sg_m = (M - 1) / get<0>(SubgroupTileShape{}) + 1; // sub_groups required to process A fragments + const int sg_n = (N - 1) / get<1>(SubgroupTileShape{}) + 1; // sub_groups required to process B fragments return dim3( - cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), - cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), batch_count ); } @@ -196,7 +197,7 @@ class GemmUniversal< (void)smem_buf; // Preconditions - CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) @@ -214,8 +215,8 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) - constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) const int m_coord = BlockIdxX() * get<0>(subgroup_shape); const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ();