diff --git a/CMakeLists.txt b/CMakeLists.txt index 9187927b1..a5e0faa67 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,10 @@ endif() set(CUTLASS_ENABLE_SYCL OFF CACHE BOOL "Enable SYCL") set(CUTLASS_SYCL_PROFILING_ENABLED OFF CACHE BOOL "Use SYCL events to calculate device execution time") +set(CUTLASS_SYCL_SWITCH_WG OFF CACHE BOOL "Enable SWITCH WG") +if(CUTLASS_SYCL_SWITCH_WG) + add_compile_definitions(CUTLASS_SYCL_SWITCH_WG) +endif() if (CUTLASS_ENABLE_SYCL) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index a8ebc0b67..7682d010e 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -34,15 +34,57 @@ #include "../benchmark_runner.hpp" #include "gemm_configuration.hpp" -using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration< +using MMAAtom = MMA_Atom; +using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelPVC, cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, - float>; + float, Shape<_256, _256, _32>, + TiledMMA>>>; -CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR); +using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_128, _512, _32>, + TiledMMA>>>; + +using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_256, _128, _32>, + TiledMMA>>>; + +using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_128, _256, _16>, + TiledMMA>>>; + +using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_8, _128, _32>, + TiledMMA>>>; + +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); static void register_benchmarks() { - CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); } diff --git a/benchmarks/pvc/gemm_configuration.hpp b/benchmarks/pvc/gemm_configuration.hpp index 3a07857f5..0fc0422e9 100644 --- a/benchmarks/pvc/gemm_configuration.hpp +++ b/benchmarks/pvc/gemm_configuration.hpp @@ -57,7 +57,8 @@ template< class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, class LayoutC, - class ElementAccumulator> + class ElementAccumulator, + class TileShape, class TiledMma> struct GemmConfiguration { static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists."); }; @@ -68,44 +69,60 @@ struct GemmConfiguration { namespace detail { -template +template struct Gemm_OperandA; -template +template struct Gemm_OperandB; template<> -struct Gemm_OperandA { +struct Gemm_OperandA { using GmemTiledCopy = XE_2D_U16x32x32_LD_N; }; template<> -struct Gemm_OperandB { +struct Gemm_OperandA { + using GmemTiledCopy = XE_2D_U16x32x16_LD_N; +}; + +template<> +struct Gemm_OperandA { + using GmemTiledCopy = XE_2D_U16x8x32_LD_N; +}; + +template<> +struct Gemm_OperandB { using GmemTiledCopy = XE_2D_U16x32x32_LD_V; }; +template<> +struct Gemm_OperandB { + using GmemTiledCopy = XE_2D_U16x16x32_LD_V; +}; } // namespace details -template +template struct GemmConfiguration< arch::IntelPVC, bfloat16_t, LayoutA, bfloat16_t, LayoutB, float, LayoutC, - float> { - using TileShape = Shape<_256, _256, _32>; - using DispatchPolicy = MainloopIntelPVC<3>;; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, - Tile<_64,_64,_32>>; - + float, TileShape, TiledMma> { + using DispatchPolicy = MainloopIntelPVC<3>; + static constexpr auto BLK_M = get<0>(TileShape{}); + static constexpr auto BLK_N = get<1>(TileShape{}); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto sg_m = ceil_div(BLK_M, ATOM_M); + static constexpr auto sg_n = ceil_div(BLK_N, ATOM_N); + static constexpr auto sg_k = get<2>(TileShape{}); // A - using OperandA = detail::Gemm_OperandA; + using OperandA = detail::Gemm_OperandA; using GmemTiledCopyA = typename OperandA::GmemTiledCopy; // B - using OperandB = detail::Gemm_OperandB; + using OperandB = detail::Gemm_OperandB; using GmemTiledCopyB = typename OperandB::GmemTiledCopy; // Mainloop diff --git a/benchmarks/pvc/input.in b/benchmarks/pvc/input.in index 8e68fcd56..4f5d47648 100644 --- a/benchmarks/pvc/input.in +++ b/benchmarks/pvc/input.in @@ -1,17 +1,23 @@ # BFloat16 benchmarks -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824 +PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 +PvcGemmBF16BF16FP32_RRR_4 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 +PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index a34083398..77cddf7df 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -263,11 +263,14 @@ struct CollectiveMma< // Mainloop // auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; + #ifdef CUTLASS_SYCL_SWITCH_WG + const int m_coord = n_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = m_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + #else const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + #endif const int l_coord = l_idx; - - int sub_group_id = get_sub_group_id(); Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{}); @@ -279,13 +282,13 @@ struct CollectiveMma< int prefetch_k = 0; Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( - make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), - (k_start_idx + (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), + make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), + (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( - make_coord(((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, - n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), + make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, + n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 2b41f3d31..d7b49666f 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -186,8 +186,13 @@ class GemmUniversal< batch_count = cute::size<3>(params.problem_shape); } return dim3( + #ifdef CUTLASS_SYCL_SWITCH_WG + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), + #else cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + #endif batch_count ); } @@ -221,8 +226,13 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); auto blk_shape = TileShape{}; + #ifdef CUTLASS_SYCL_SWITCH_WG + auto m_coord = BlockIdxX(); + auto n_coord = BlockIdxY(); + #else auto m_coord = BlockIdxY(); auto n_coord = BlockIdxX(); + #endif auto l_coord = BlockIdxZ(); auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); int sub_group_id = thread_idx / SubgroupSize;