Skip to content

Commit

Permalink
Add Ampere bfloat-float example
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed May 27, 2024
1 parent fa70215 commit fc7e60b
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 709 deletions.
5 changes: 5 additions & 0 deletions benchmarks/ampere/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ cutlass_benchmark_add_executable(
bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32
bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp
)

cutlass_benchmark_add_executable(
bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32
bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*
**************************************************************************************************/

#include "../common/example_runner.hpp"
#include "../common/benchmark_runner.hpp"
#include "gemm_configuration.hpp"

int main(int argc, const char** argv)
Expand Down Expand Up @@ -68,8 +68,8 @@ int main(int argc, const char** argv)
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = half_t; // <- data type of elements in input matrix A
using ElementInputB = half_t; // <- data type of elements in input matrix B
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D

using LayoutA = cutlass::layout::ColumnMajor;
Expand All @@ -80,7 +80,7 @@ int main(int argc, const char** argv)
using TileShape = Shape<_128, _128, _32>;

using TiledMma = TiledMMA<
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>,
Layout<Shape<_2,_2,_1>>, // 2x2x1 thread group
Tile<_32,_32,_16>>; // 32x32x8 MMA for LDSM, 1x2x1 value group

Expand Down
87 changes: 81 additions & 6 deletions benchmarks/ampere/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ struct DefaultGemm_TensorOpSm80_OperandA<cutlass::half_t, cutlass::layout::RowMa
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape < _8,_64>,
Stride<_64, _1>>{}));
Stride<_64, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>{},
Layout<Shape <_16,_8>,
Stride< _8,_1>>{},
Stride< _8,_1>>{},
Layout<Shape < _1,_8>>{}));
};

Expand All @@ -77,14 +77,14 @@ struct DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::ColumnMajor, 8
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape <_64, _8>,
Stride< _1,_64>>{}));
Stride< _1,_64>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, half_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>{},
Layout<Shape <_16, _8>,
Stride< _1,_16>>{},
Stride< _1,_16>>{},
Layout<Shape < _8, _1>>{}));
};

Expand All @@ -96,14 +96,14 @@ struct DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::RowMajor, 8, 3
using SmemLayoutAtom = decltype(
composition(Swizzle<2,3,3>{},
Layout<Shape < _8,_32>,
Stride<_32, _1>>{}));
Stride<_32, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>{},
Layout<Shape <_32,_4>,
Stride< _4,_1>>{},
Stride< _4,_1>>{},
Layout<Shape < _1,_8>>{}));
};

Expand All @@ -120,3 +120,78 @@ template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};

/////////////////////////////////////////////////////////////////////////

// Bfloat

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape < _8,_64>,
Stride<_64, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_16,_8>,
Stride< _8,_1>>{},
Layout<Shape < _1,_8>>{}));
};

/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape <_64, _8>,
Stride< _1,_64>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_16, _8>,
Stride< _1,_16>>{},
Layout<Shape < _8, _1>>{}));
};

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<2,3,3>{},
Layout<Shape < _8,_32>,
Stride<_32, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_32,_4>,
Stride< _4,_1>>{},
Layout<Shape < _1,_8>>{}));
};

// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands

// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{};

// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};
2 changes: 0 additions & 2 deletions examples/sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@

if(SYCL_INTEL_TARGET)
add_subdirectory(pvc)
else(SYCL_NVIDIA_TARGET)
add_subdirectory(ampere)
endif()
33 changes: 0 additions & 33 deletions examples/sycl/ampere/CMakeLists.txt

This file was deleted.

155 changes: 0 additions & 155 deletions examples/sycl/ampere/ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cu

This file was deleted.

Loading

0 comments on commit fc7e60b

Please sign in to comment.