Skip to content

Commit

Permalink
Merge branch 'sycl-develop' into shuffle_64b
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 authored Dec 12, 2024
2 parents 637e8b8 + 8398ba8 commit 451d107
Show file tree
Hide file tree
Showing 36 changed files with 3,548 additions and 82 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/cuda_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ on:

permissions: {}

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
run-tests:
name: Run cuda tests
Expand Down
15 changes: 5 additions & 10 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ on:

permissions: {}

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
run-tests:
name: Run tests
Expand Down Expand Up @@ -71,13 +75,4 @@ jobs:
shell: bash
run: |
export LD_LIBRARY_PATH=~/dpcpp/lib/:$LD_LIBRARY_PATH
echo Run sgemm_1
./examples/cute/tutorial/sgemm_1
echo Run sgemm_2
./examples/cute/tutorial/sgemm_2
echo Run sgemm_sm70
./examples/cute/tutorial/sgemm_sm70
echo Run sgemm_sm80
./examples/cute/tutorial/sgemm_sm80
echo Run tiled_copy
./examples/cute/tutorial/tiled_copy
cmake --build . --target test_examples -j 24
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ endif()
set(CUTLASS_ENABLE_SYCL OFF CACHE BOOL "Enable SYCL")
set(CUTLASS_SYCL_PROFILING_ENABLED OFF CACHE BOOL "Use SYCL events to calculate device execution time")

set(CUTLASS_SYCL_SWITCH_WG OFF CACHE BOOL "Enable SWITCH WG and for GEMM on Intel PVC during benchmarking")
if(CUTLASS_SYCL_SWITCH_WG)
add_compile_definitions(CUTLASS_SYCL_SWITCH_WG)
endif()
if (CUTLASS_ENABLE_SYCL)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

Expand Down
55 changes: 51 additions & 4 deletions benchmarks/pvc/benchmarks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,62 @@
#include "../benchmark_runner.hpp"
#include "gemm_configuration.hpp"

using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration<
using MMAAtom = MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>;
using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float>;
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR);
using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_128, _512, _32>,
TiledMMA<MMAAtom, Layout<Shape<_4,_8,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_256, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_128, _256, _16>,
TiledMMA<MMAAtom, Layout<Shape<_4,_8,_1>>>,
XE_2D_U16x32x16_LD_N, XE_2D_U16x16x32_LD_V>;

using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_8, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);

static void register_benchmarks() {
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
}
45 changes: 8 additions & 37 deletions benchmarks/pvc/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ template<
class ElementA, class LayoutA,
class ElementB, class LayoutB,
class ElementC, class LayoutC,
class ElementAccumulator>
class ElementAccumulator,
class TileShape, class TiledMma,
class GmemTiledCopyA, class GmemTiledCopyB>
struct GemmConfiguration {
static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists.");
};
Expand All @@ -66,47 +68,16 @@ struct GemmConfiguration {

// bfloat16

namespace detail {

template<typename Element, typename Layout>
struct Gemm_OperandA;

template<typename Element, typename Layout>
struct Gemm_OperandB;

template<>
struct Gemm_OperandA<bfloat16_t, layout::RowMajor> {
using GmemTiledCopy = XE_2D_U16x32x32_LD_N;
};

template<>
struct Gemm_OperandB<bfloat16_t, layout::RowMajor> {
using GmemTiledCopy = XE_2D_U16x32x32_LD_V;
};

} // namespace details

template<typename LayoutA, typename LayoutB, typename LayoutC>
template<typename LayoutA, typename LayoutB, typename LayoutC,
class TileShape, class TiledMma, class GmemTiledCopyA, class GmemTiledCopyB>
struct GemmConfiguration<
arch::IntelPVC,
bfloat16_t, LayoutA,
bfloat16_t, LayoutB,
float, LayoutC,
float> {
using TileShape = Shape<_256, _256, _32>;
using DispatchPolicy = MainloopIntelPVC<3>;;
using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_8,_4,_1>>,
Tile<_64,_64,_32>>;

// A
using OperandA = detail::Gemm_OperandA<bfloat16_t, LayoutA>;
using GmemTiledCopyA = typename OperandA::GmemTiledCopy;

// B
using OperandB = detail::Gemm_OperandB<bfloat16_t, LayoutB>;
using GmemTiledCopyB = typename OperandB::GmemTiledCopy;
float, TileShape, TiledMma,
GmemTiledCopyA, GmemTiledCopyB> {
using DispatchPolicy = MainloopIntelPVC<3>;

// Mainloop
using CollectiveMainloop = collective::CollectiveMma<
Expand Down
38 changes: 22 additions & 16 deletions benchmarks/pvc/input.in
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
# BFloat16 benchmarks
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824
PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
PvcGemmBF16BF16FP32_RRR_4 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128
8 changes: 7 additions & 1 deletion examples/35_gemm_softmax/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@



if (NOT CUTLASS_ENABLE_SYCL)
cutlass_example_add_executable(
35_gemm_softmax
gemm_softmax.cu
)

else()
cutlass_example_add_executable(
35_gemm_online_softmax
gemm_online_softmax.cpp
)
endif()
Loading

0 comments on commit 451d107

Please sign in to comment.