Skip to content

Commit

Permalink
add Habana UTs to benchmark (#163)
Browse files Browse the repository at this point in the history
Co-authored-by: Alejandro Acosta <[email protected]>
  • Loading branch information
jiyang1011 and aacostadiaz authored Dec 11, 2024
1 parent f14d683 commit 8398ba8
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 63 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ endif()
set(CUTLASS_ENABLE_SYCL OFF CACHE BOOL "Enable SYCL")
set(CUTLASS_SYCL_PROFILING_ENABLED OFF CACHE BOOL "Use SYCL events to calculate device execution time")

set(CUTLASS_SYCL_SWITCH_WG OFF CACHE BOOL "Enable SWITCH WG and for GEMM on Intel PVC during benchmarking")
if(CUTLASS_SYCL_SWITCH_WG)
add_compile_definitions(CUTLASS_SYCL_SWITCH_WG)
endif()
if (CUTLASS_ENABLE_SYCL)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

Expand Down
55 changes: 51 additions & 4 deletions benchmarks/pvc/benchmarks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,62 @@
#include "../benchmark_runner.hpp"
#include "gemm_configuration.hpp"

using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration<
using MMAAtom = MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>;
using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float>;
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR);
using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_128, _512, _32>,
TiledMMA<MMAAtom, Layout<Shape<_4,_8,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_256, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_128, _256, _16>,
TiledMMA<MMAAtom, Layout<Shape<_4,_8,_1>>>,
XE_2D_U16x32x16_LD_N, XE_2D_U16x16x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_8, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);

static void register_benchmarks() {
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
}
45 changes: 8 additions & 37 deletions benchmarks/pvc/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ template<
class ElementA, class LayoutA,
class ElementB, class LayoutB,
class ElementC, class LayoutC,
class ElementAccumulator>
class ElementAccumulator,
class TileShape, class TiledMma,
class GmemTiledCopyA, class GmemTiledCopyB>
struct GemmConfiguration {
static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists.");
};
Expand All @@ -66,47 +68,16 @@ struct GemmConfiguration {

// bfloat16

namespace detail {

template<typename Element, typename Layout>
struct Gemm_OperandA;

template<typename Element, typename Layout>
struct Gemm_OperandB;

template<>
struct Gemm_OperandA<bfloat16_t, layout::RowMajor> {
using GmemTiledCopy = XE_2D_U16x32x32_LD_N;
};

template<>
struct Gemm_OperandB<bfloat16_t, layout::RowMajor> {
using GmemTiledCopy = XE_2D_U16x32x32_LD_V;
};

} // namespace details

template<typename LayoutA, typename LayoutB, typename LayoutC>
template<typename LayoutA, typename LayoutB, typename LayoutC,
class TileShape, class TiledMma, class GmemTiledCopyA, class GmemTiledCopyB>
struct GemmConfiguration<
arch::IntelPVC,
bfloat16_t, LayoutA,
bfloat16_t, LayoutB,
float, LayoutC,
float> {
using TileShape = Shape<_256, _256, _32>;
using DispatchPolicy = MainloopIntelPVC<3>;;
using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8,_4,_1>>,
Tile<_64,_64,_32>>;

// A
using OperandA = detail::Gemm_OperandA<bfloat16_t, LayoutA>;
using GmemTiledCopyA = typename OperandA::GmemTiledCopy;

// B
using OperandB = detail::Gemm_OperandB<bfloat16_t, LayoutB>;
using GmemTiledCopyB = typename OperandB::GmemTiledCopy;
float, TileShape, TiledMma,
GmemTiledCopyA, GmemTiledCopyB> {
using DispatchPolicy = MainloopIntelPVC<3>;

// Mainloop
using CollectiveMainloop = collective::CollectiveMma<
Expand Down
38 changes: 22 additions & 16 deletions benchmarks/pvc/input.in
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# BFloat16 benchmarks
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR_4 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
15 changes: 9 additions & 6 deletions include/cutlass/gemm/collective/xe_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,14 @@ struct CollectiveMma<
// Mainloop
//
auto [m_idx, n_idx, k_idx, l_idx] = blk_coord;
#ifdef CUTLASS_SYCL_SWITCH_WG
const int m_coord = n_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
const int n_coord = m_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
#else
const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
#endif
const int l_coord = l_idx;

int sub_group_id = get_sub_group_id();
Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor(
make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count),
append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{});
Expand All @@ -279,13 +282,13 @@ struct CollectiveMma<
int prefetch_k = 0;

Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor(
make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(k_start_idx + (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord),
make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}),
(k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{});
Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor(
make_coord(((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB,
n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord),
make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB,
n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord),
append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count),
append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{});

Expand Down
10 changes: 10 additions & 0 deletions include/cutlass/gemm/kernel/xe_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,13 @@ class GemmUniversal<
batch_count = cute::size<3>(params.problem_shape);
}
return dim3(
#ifdef CUTLASS_SYCL_SWITCH_WG
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
#else
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))),
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))),
#endif
batch_count
);
}
Expand Down Expand Up @@ -221,8 +226,13 @@ class GemmUniversal<
// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto blk_shape = TileShape{};
#ifdef CUTLASS_SYCL_SWITCH_WG
auto m_coord = BlockIdxX();
auto n_coord = BlockIdxY();
#else
auto m_coord = BlockIdxY();
auto n_coord = BlockIdxX();
#endif
auto l_coord = BlockIdxZ();
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord);
int sub_group_id = thread_idx / SubgroupSize;
Expand Down

0 comments on commit 8398ba8

Please sign in to comment.