Skip to content

Commit

Permalink
Add Ampere bfloat-float example
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed May 31, 2024
1 parent 3b8108c commit 74cbb6d
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 332 deletions.
5 changes: 0 additions & 5 deletions benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,6 @@ endfunction()

if(SYCL_INTEL_TARGET)
add_subdirectory(pvc)
else(SYCL_NVIDIA_TARGET)
add_subdirectory(ampere)
endif()
if (SYCL_NVIDIA_TARGET)
add_subdirectory(ampere)
endif()
if (SYCL_NVIDIA_TARGET)
add_subdirectory(ampere)
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/ampere/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ cutlass_benchmark_add_executable(
bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32
bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp
)

cutlass_benchmark_add_executable(
bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32
bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*
**************************************************************************************************/

#include "../common/example_runner.hpp"
#include "../common/benchmark_runner.hpp"
#include "gemm_configuration.hpp"

int main(int argc, const char** argv)
Expand All @@ -53,7 +53,7 @@ int main(int argc, const char** argv)
}

//
// Run examples
// Run benchmark
//

// The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This
Expand All @@ -68,19 +68,19 @@ int main(int argc, const char** argv)
// elements in input matrices.
using ElementAccumulator = float; // <- data type of accumulator
using ElementComputeEpilogue = float; // <- data type of epilogue operations
using ElementInputA = half_t; // <- data type of elements in input matrix A
using ElementInputB = half_t; // <- data type of elements in input matrix B
using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A
using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D

using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;

using TileShape = Shape<_128, _128, _32>;

using TiledMma = TiledMMA<
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>,
Layout<Shape<_2,_2,_1>>, // 2x2x1 thread group
Tile<_32,_32,_16>>; // 32x32x8 MMA for LDSM, 1x2x1 value group

Expand Down Expand Up @@ -145,7 +145,7 @@ int main(int argc, const char** argv)

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

ExampleRunner<Gemm> runner;
BenchmarkRunner<Gemm> runner;

runner.run(options, hw_info);

Expand Down
87 changes: 81 additions & 6 deletions benchmarks/ampere/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ struct DefaultGemm_TensorOpSm80_OperandA<cutlass::half_t, cutlass::layout::RowMa
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape < _8,_64>,
Stride<_64, _1>>{}));
Stride<_64, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>{},
Layout<Shape <_16,_8>,
Stride< _8,_1>>{},
Stride< _8,_1>>{},
Layout<Shape < _1,_8>>{}));
};

Expand All @@ -77,14 +77,14 @@ struct DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::ColumnMajor, 8
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape <_64, _8>,
Stride< _1,_64>>{}));
Stride< _1,_64>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, half_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>{},
Layout<Shape <_16, _8>,
Stride< _1,_16>>{},
Stride< _1,_16>>{},
Layout<Shape < _8, _1>>{}));
};

Expand All @@ -96,14 +96,14 @@ struct DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::RowMajor, 8, 3
using SmemLayoutAtom = decltype(
composition(Swizzle<2,3,3>{},
Layout<Shape < _8,_32>,
Stride<_32, _1>>{}));
Stride<_32, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, half_t>{},
Layout<Shape <_32,_4>,
Stride< _4,_1>>{},
Stride< _4,_1>>{},
Layout<Shape < _1,_8>>{}));
};

Expand All @@ -120,3 +120,78 @@ template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<half_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<half_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};

/////////////////////////////////////////////////////////////////////////

// Bfloat

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<cutlass::bfloat16_t, cutlass::layout::RowMajor, 8, 64>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape < _8,_64>,
Stride<_64, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_16,_8>,
Stride< _8,_1>>{},
Layout<Shape < _1,_8>>{}));
};

/// Operand A - Column-major (M-major)
template <int SizeK>
struct DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, 8, SizeK>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<3,3,3>{},
Layout<Shape <_64, _8>,
Stride< _1,_64>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_16, _8>,
Stride< _1,_16>>{},
Layout<Shape < _8, _1>>{}));
};

/// Operand A - Row-major (K-Major)
template <>
struct DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::RowMajor, 8, 32>
{
// Smem
using SmemLayoutAtom = decltype(
composition(Swizzle<2,3,3>{},
Layout<Shape < _8,_32>,
Stride<_32, _1>>{}));
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, bfloat16_t>;

// Gmem
using GmemTiledCopy = decltype(
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, bfloat16_t>{},
Layout<Shape <_32,_4>,
Stride< _4,_1>>{},
Layout<Shape < _1,_8>>{}));
};

// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands

// Operand B - Column-Major (K-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
{};

// Operand B - Row-Major (N-major)
template <int Alignment, int SizeK>
struct DefaultGemm_TensorOpSm80_OperandB<bfloat16_t, cutlass::layout::RowMajor, Alignment, SizeK>
: DefaultGemm_TensorOpSm80_OperandA<bfloat16_t, cutlass::layout::ColumnMajor, Alignment, SizeK>
{};
2 changes: 1 addition & 1 deletion benchmarks/common/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {

out << "PVC GEMM Example\n\n"
out << "PVC GEMM Benchmark\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ int main(int argc, const char** argv)
}

//
// Run examples
// Run benchmark
//

// The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This
Expand Down
2 changes: 0 additions & 2 deletions examples/sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@

if(SYCL_INTEL_TARGET)
add_subdirectory(pvc)
else(SYCL_NVIDIA_TARGET)
add_subdirectory(ampere)
endif()
33 changes: 0 additions & 33 deletions examples/sycl/ampere/CMakeLists.txt

This file was deleted.

Loading

0 comments on commit 74cbb6d

Please sign in to comment.