diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 542bafb346..731edfa15f 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -353,11 +353,12 @@ int main(int argc, const char** argv) using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; - using TileShape = Shape<_1, _1, _1>; + // Workgroup-level tile + using TileShape = Shape<_32, _256, _32>; using TiledMma = TiledMMA, Layout>, - Tile<_32,_64,_32>>; + Tile<_32,_64,_32>>; // Subgroup level-tile using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index d587fbcd9d..8f2e4cb34e 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -103,7 +103,7 @@ struct CollectiveMma< using DpasShape = typename TiledMma::Shape_MNK; using TileDpasShape = decltype(tile_shape(TiledMma())); - static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape()); + static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize; static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 1a91854374..db9aa03514 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -101,14 +101,10 @@ class GemmUniversal< static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; - - static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group using DpasShape = typename CollectiveMainloop::DpasShape; using TileDpasShape = typename CollectiveMainloop::TileDpasShape; - static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -182,9 +178,9 @@ class GemmUniversal< const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments return dim3( - sg_m, - cute::ceil_div(sg_n, num_sg), - batch_count + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + batch_count ); } @@ -218,9 +214,10 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) const int m_coord = BlockIdxX() * get<0>(subgroup_shape); - const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape); + const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0),