diff --git a/benchmarks/pvc/gemm_configuration.hpp b/benchmarks/pvc/gemm_configuration.hpp index cde3c70d9..3a07857f5 100644 --- a/benchmarks/pvc/gemm_configuration.hpp +++ b/benchmarks/pvc/gemm_configuration.hpp @@ -76,12 +76,12 @@ struct Gemm_OperandB; template<> struct Gemm_OperandA { - using GmemTiledCopy = XE_2D_U16x8x16_LD_N; + using GmemTiledCopy = XE_2D_U16x32x32_LD_N; }; template<> struct Gemm_OperandB { - using GmemTiledCopy = XE_2D_U16x16x16_LD_V; + using GmemTiledCopy = XE_2D_U16x32x32_LD_V; }; } // namespace details @@ -93,12 +93,12 @@ struct GemmConfiguration< bfloat16_t, LayoutB, float, LayoutC, float> { - using TileShape = Shape<_256, _256, _16>; + using TileShape = Shape<_256, _256, _32>; using DispatchPolicy = MainloopIntelPVC<3>;; using TiledMma = TiledMMA< MMA_Atom, - Layout>, - Tile<_64,_128,_16>>; + Layout>, + Tile<_64,_64,_32>>; // A using OperandA = detail::Gemm_OperandA; diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 0c8ffd916..8c64fdb04 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -306,15 +306,15 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // Workgroup-level tile - using TileShape = Shape<_256, _128, _16>; + using TileShape = Shape<_256, _256, _32>; using TiledMma = TiledMMA, - Layout>, - Tile<_64,_32,_16>>; // Subgroup level-tile + Layout>, + Tile<_64,_64,_32>>; // Subgroup level-tile constexpr int PipelineStages = 3; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index b8fe97406..a5d00a0c0 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -296,7 +296,7 @@ class CollectiveEpilogue< Tensor trC = make_tensor(Shape>{}); Tensor trD = make_tensor(Shape>{}); Tensor tOuti = params.xe_store_d.get_pvc_tensor( - make_coord(m_offset, n_offset, l_offset), + make_coord(m_offset, n_offset, 0), make_shape(_, Int{}, Int{}, L), make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 61fc7867f..06fa815e6 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -200,8 +200,14 @@ struct CollectiveMma< // Instantiate the MMA object TiledMma tiled_mma; auto thread_mma = tiled_mma.get_slice(thread_idx); - Tensor tCrA = thread_mma.partition_fragment_A(gA(_, _, 0)); - Tensor tCrB = thread_mma.partition_fragment_B(gB(_, _, 0)); + Tensor tCrA_partition = thread_mma.partition_fragment_A(gA(_, _, 0)); + Tensor tCrA = make_tensor(static_cast(tCrA_partition).data(), + tCrA_partition.shape()); + Tensor tCrB_partition = thread_mma.partition_fragment_B(gB(_, _, 0)); + Tensor tCrB = make_tensor(static_cast(tCrB_partition).data(), + make_shape(size<0>(tCrB_partition.shape()), + size<2>(tCrB_partition.shape()), + size<1>(tCrB_partition.shape()))); // Partition the copying of A and B tiles across the threads auto gmem_thr_copy_A = mainloop.gmem_tiled_copy_a.get_slice(thread_idx); auto gmem_thr_copy_B = mainloop.gmem_tiled_copy_b.get_slice(thread_idx); @@ -234,10 +240,10 @@ struct CollectiveMma< const int n_coord = BlockIdxX() * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; const int l_coord = BlockIdxZ(); Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, l_coord), make_shape(_, size<1>(tCrA_copy_view.shape()), size<2>(tCrA_copy_view.shape()), k_tile_count), + make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{}); Tensor iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(0, n_coord, l_coord), make_shape(_, size<2>(tCrB_copy_view.shape()), size<1>(tCrB_copy_view.shape()), k_tile_count), + make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{}); #pragma unroll for (int i = 0; i < DispatchPolicy::Stages; i++) { @@ -261,7 +267,9 @@ struct CollectiveMma< prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); } } - cute::gemm(tiled_mma, accum, tCrA, tCrB, src_accum); + for (int i = 0; i < SG_K / SubgroupSize; i++) { + cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); + } } } };