From e7f96730676b28f748a6ebb90833fbc123e638d9 Mon Sep 17 00:00:00 2001 From: jiyang1011 Date: Sun, 1 Dec 2024 23:15:11 -0800 Subject: [PATCH] add Habana UTs to benchmark --- CMakeLists.txt | 4 ++ benchmarks/pvc/benchmarks.hpp | 44 ++++++++++++++++++++-- benchmarks/pvc/gemm_configuration.hpp | 40 ++++++++++++++------ benchmarks/pvc/input.in | 38 +++++++++++-------- include/cutlass/gemm/collective/xe_mma.hpp | 5 +++ include/cutlass/gemm/kernel/xe_gemm.hpp | 10 +++++ 6 files changed, 109 insertions(+), 32 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 28e2c2b4c..03a45a17f 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,10 @@ endif() set(CUTLASS_ENABLE_SYCL OFF CACHE BOOL "Enable SYCL") set(CUTLASS_SYCL_PROFILING_ENABLED OFF CACHE BOOL "Use SYCL events to calculate device execution time") +set(SWITCH_WG OFF CACHE BOOL "Enable SWITCH WG") +if(SWITCH_WG) + add_compile_definitions(SWITCH_WG) +endif() if (CUTLASS_ENABLE_SYCL) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index a8ebc0b67..a224b9231 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -34,15 +34,51 @@ #include "../benchmark_runner.hpp" #include "gemm_configuration.hpp" -using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration< +using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelPVC, cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, - float>; + float,256,256,32,64,32>; -CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR); +using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float,128,512,32,64,32>; + +using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float,256,128,32,32,32>; + +using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float,128,256,32,32,16>; + +using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float,8,128,8,32,32>; + +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); static void register_benchmarks() { - CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); } diff --git a/benchmarks/pvc/gemm_configuration.hpp b/benchmarks/pvc/gemm_configuration.hpp index 3a07857f5..84a65d22d 100644 --- a/benchmarks/pvc/gemm_configuration.hpp +++ b/benchmarks/pvc/gemm_configuration.hpp @@ -57,7 +57,8 @@ template< class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, class LayoutC, - class ElementAccumulator> + class ElementAccumulator, + int wg_m, int wg_n, int sg_m, int sg_n, int sg_k> struct GemmConfiguration { static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists."); }; @@ -68,44 +69,59 @@ struct GemmConfiguration { namespace detail { -template +template struct Gemm_OperandA; -template +template struct Gemm_OperandB; template<> -struct Gemm_OperandA { +struct Gemm_OperandA { using GmemTiledCopy = XE_2D_U16x32x32_LD_N; }; template<> -struct Gemm_OperandB { +struct Gemm_OperandA { + using GmemTiledCopy = XE_2D_U16x32x16_LD_N; +}; + +template<> +struct Gemm_OperandA { + using GmemTiledCopy = XE_2D_U16x8x32_LD_N; +}; + +template<> +struct Gemm_OperandB { using GmemTiledCopy = XE_2D_U16x32x32_LD_V; }; +template<> +struct Gemm_OperandB { + using GmemTiledCopy = XE_2D_U16x16x32_LD_V; +}; } // namespace details -template +template struct GemmConfiguration< arch::IntelPVC, bfloat16_t, LayoutA, bfloat16_t, LayoutB, float, LayoutC, - float> { - using TileShape = Shape<_256, _256, _32>; + float, wg_m, wg_n, sg_m, sg_n, sg_k> { + using TileShape = Shape, Int, Int>; using DispatchPolicy = MainloopIntelPVC<3>;; using TiledMma = TiledMMA< MMA_Atom, - Layout>, - Tile<_64,_64,_32>>; + Layout,Int,_1>>, + Tile,Int<16*wg_n/sg_n>,Int>>; // A - using OperandA = detail::Gemm_OperandA; + using OperandA = detail::Gemm_OperandA; using GmemTiledCopyA = typename OperandA::GmemTiledCopy; // B - using OperandB = detail::Gemm_OperandB; + using OperandB = detail::Gemm_OperandB; using GmemTiledCopyB = typename OperandB::GmemTiledCopy; // Mainloop diff --git a/benchmarks/pvc/input.in b/benchmarks/pvc/input.in index 8e68fcd56..4f5d47648 100644 --- a/benchmarks/pvc/input.in +++ b/benchmarks/pvc/input.in @@ -1,17 +1,23 @@ # BFloat16 benchmarks -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824 +PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 +PvcGemmBF16BF16FP32_RRR_4 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 +PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index b08b79c2f..fd634ac03 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -258,8 +258,13 @@ struct CollectiveMma< // Mainloop // int sub_group_id = get_sub_group_id(); + #ifdef SWITCH_WG + const int m_coord = BlockIdxX() * BLK_M + (sub_group_id / ATOM_N) * SG_M; + const int n_coord = BlockIdxY() * BLK_N + (sub_group_id % ATOM_N) * SG_N; + #else const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M; const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N; + #endif const int l_coord = BlockIdxZ(); Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 9e4968aca..0d951a788 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -172,8 +172,13 @@ class GemmUniversal< batch_count = cute::size<3>(params.problem_shape); } return dim3( + #ifdef SWITCH_WG + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), + #else cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + #endif batch_count ); } @@ -207,8 +212,13 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); auto blk_shape = TileShape{}; + #ifdef SWITCH_WG + auto m_coord = BlockIdxX(); + auto n_coord = BlockIdxY(); + #else auto m_coord = BlockIdxY(); auto n_coord = BlockIdxX(); + #endif auto l_coord = BlockIdxZ(); auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); int sub_group_id = thread_idx / SubgroupSize;