Skip to content

Commit

Permalink
add Habana UTs to benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
jiyang1011 committed Dec 2, 2024
1 parent cbf0bfa commit e7f9673
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 32 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ endif()
set(CUTLASS_ENABLE_SYCL OFF CACHE BOOL "Enable SYCL")
set(CUTLASS_SYCL_PROFILING_ENABLED OFF CACHE BOOL "Use SYCL events to calculate device execution time")

set(SWITCH_WG OFF CACHE BOOL "Enable SWITCH WG")
if(SWITCH_WG)
add_compile_definitions(SWITCH_WG)
endif()
if (CUTLASS_ENABLE_SYCL)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

Expand Down
44 changes: 40 additions & 4 deletions benchmarks/pvc/benchmarks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,51 @@
#include "../benchmark_runner.hpp"
#include "gemm_configuration.hpp"

using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration<
using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float>;
float,256,256,32,64,32>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR);
using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float,128,512,32,64,32>;

using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float,256,128,32,32,32>;

using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float,128,256,32,32,16>;

using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float,8,128,8,32,32>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);

static void register_benchmarks() {
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
}
40 changes: 28 additions & 12 deletions benchmarks/pvc/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ template<
class ElementA, class LayoutA,
class ElementB, class LayoutB,
class ElementC, class LayoutC,
class ElementAccumulator>
class ElementAccumulator,
int wg_m, int wg_n, int sg_m, int sg_n, int sg_k>
struct GemmConfiguration {
static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists.");
};
Expand All @@ -68,44 +69,59 @@ struct GemmConfiguration {

namespace detail {

template<typename Element, typename Layout>
template<typename Element, typename Layout, int M, int K>
struct Gemm_OperandA;

template<typename Element, typename Layout>
template<typename Element, typename Layout, int K, int N>
struct Gemm_OperandB;

template<>
struct Gemm_OperandA<bfloat16_t, layout::RowMajor> {
struct Gemm_OperandA<bfloat16_t, layout::RowMajor, 32, 32> {
using GmemTiledCopy = XE_2D_U16x32x32_LD_N;
};

template<>
struct Gemm_OperandB<bfloat16_t, layout::RowMajor> {
struct Gemm_OperandA<bfloat16_t, layout::RowMajor, 32, 16> {
using GmemTiledCopy = XE_2D_U16x32x16_LD_N;
};

template<>
struct Gemm_OperandA<bfloat16_t, layout::RowMajor, 8, 32> {
using GmemTiledCopy = XE_2D_U16x8x32_LD_N;
};

template<>
struct Gemm_OperandB<bfloat16_t, layout::RowMajor, 32, 32> {
using GmemTiledCopy = XE_2D_U16x32x32_LD_V;
};

template<>
struct Gemm_OperandB<bfloat16_t, layout::RowMajor, 16, 32> {
using GmemTiledCopy = XE_2D_U16x16x32_LD_V;
};
} // namespace details

template<typename LayoutA, typename LayoutB, typename LayoutC>
template<typename LayoutA, typename LayoutB, typename LayoutC,
int wg_m, int wg_n, int sg_m, int sg_n, int sg_k>
struct GemmConfiguration<
arch::IntelPVC,
bfloat16_t, LayoutA,
bfloat16_t, LayoutB,
float, LayoutC,
float> {
using TileShape = Shape<_256, _256, _32>;
float, wg_m, wg_n, sg_m, sg_n, sg_k> {
using TileShape = Shape<Int<wg_m>, Int<wg_n>, Int<sg_k>>;
using DispatchPolicy = MainloopIntelPVC<3>;;
using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8,_4,_1>>,
Tile<_64,_64,_32>>;
Layout<Shape<Int<wg_m/sg_m>,Int<wg_n/sg_n>,_1>>,
Tile<Int<8*wg_m/sg_m>,Int<16*wg_n/sg_n>,Int<sg_k>>>;

// A
using OperandA = detail::Gemm_OperandA<bfloat16_t, LayoutA>;
using OperandA = detail::Gemm_OperandA<bfloat16_t, LayoutA, sg_m, sg_k>;
using GmemTiledCopyA = typename OperandA::GmemTiledCopy;

// B
using OperandB = detail::Gemm_OperandB<bfloat16_t, LayoutB>;
using OperandB = detail::Gemm_OperandB<bfloat16_t, LayoutB, sg_k, 32>;
using GmemTiledCopyB = typename OperandB::GmemTiledCopy;

// Mainloop
Expand Down
38 changes: 22 additions & 16 deletions benchmarks/pvc/input.in
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# BFloat16 benchmarks
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR_4 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
5 changes: 5 additions & 0 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,13 @@ struct CollectiveMma<
// Mainloop
//
int sub_group_id = get_sub_group_id();
#ifdef SWITCH_WG
const int m_coord = BlockIdxX() * BLK_M + (sub_group_id / ATOM_N) * SG_M;
const int n_coord = BlockIdxY() * BLK_N + (sub_group_id % ATOM_N) * SG_N;
#else
const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M;
const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N;
#endif
const int l_coord = BlockIdxZ();
Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count),
Expand Down
10 changes: 10 additions & 0 deletions include/cutlass/gemm/kernel/xe_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,13 @@ class GemmUniversal<
batch_count = cute::size<3>(params.problem_shape);
}
return dim3(
#ifdef SWITCH_WG
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
#else
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
#endif
batch_count
);
}
Expand Down Expand Up @@ -207,8 +212,13 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto blk_shape = TileShape{};
#ifdef SWITCH_WG
auto m_coord = BlockIdxX();
auto n_coord = BlockIdxY();
#else
auto m_coord = BlockIdxY();
auto n_coord = BlockIdxX();
#endif
auto l_coord = BlockIdxZ();
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
int sub_group_id = thread_idx / SubgroupSize;
Expand Down

0 comments on commit e7f9673

Please sign in to comment.