From 2219e193601ec8e97e65eb634fea974f7843b73d Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Fri, 31 May 2024 14:10:03 +0200 Subject: [PATCH] Add Ampere bfloat-float benchmark (#67) * Add generic example runner * Init d and ref_d with different values * Move runner to benchmark folder * Add generic example runner * Add Ampere half-float example * Update benchmarks/CMakeLists.txt Co-authored-by: Mehdi Goli * Add Ampere half-float example * Add Ampere half-float example * Add Ampere half-float example * Add Ampere bfloat-float example --------- Co-authored-by: Mehdi Goli --- benchmarks/ampere/CMakeLists.txt | 5 + ...ere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp | 153 ++++++++++++++++++ ...ere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp | 2 +- benchmarks/ampere/gemm_configuration.hpp | 87 +++++++++- benchmarks/common/benchmark_runner.hpp | 2 +- ...ench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp | 2 +- 6 files changed, 242 insertions(+), 9 deletions(-) create mode 100644 benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp diff --git a/benchmarks/ampere/CMakeLists.txt b/benchmarks/ampere/CMakeLists.txt index 70c2bdc990..666d9cac60 100644 --- a/benchmarks/ampere/CMakeLists.txt +++ b/benchmarks/ampere/CMakeLists.txt @@ -31,3 +31,8 @@ cutlass_benchmark_add_executable( bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32 bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp ) + +cutlass_benchmark_add_executable( + bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32 + bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp +) diff --git a/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp b/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp new file mode 100644 index 0000000000..8dad127417 --- /dev/null +++ b/benchmarks/ampere/bench_ampere_gemm_bf16_bf16_fp32_tensor_op_fp32.cpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "../common/benchmark_runner.hpp" +#include "gemm_configuration.hpp" + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run benchmark + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using TileShape = Shape<_128, _128, _32>; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_16>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + static constexpr int kAlignmentA = 8; + using DefaultOperandA = DefaultGemm_TensorOpSm80_OperandA< + ElementInputA, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + static constexpr int kAlignmentB = 8; + using DefaultOperandB = DefaultGemm_TensorOpSm80_OperandB< + ElementInputB, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + using Stages = Int<3>; + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; + + // Define strides (mixed) + using StrideA = cutlass::detail::TagToStrideA_t; + using StrideB = cutlass::detail::TagToStrideB_t; + using StrideC = cutlass::detail::TagToStrideC_t; + using StrideD = cutlass::detail::TagToStrideC_t; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + StrideC, + StrideD, + EpilogueOp, + cutlass::gemm::EpilogueDefault>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape, + ElementInputA, + StrideA, + ElementInputB, + StrideB, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + BenchmarkRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp b/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp index 69b6159f76..69bc482f12 100644 --- a/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp +++ b/benchmarks/ampere/bench_ampere_gemm_fp16_fp16_fp32_tensor_op_fp32.cpp @@ -53,7 +53,7 @@ int main(int argc, const char** argv) } // - // Run Benchmark + // Run Benchmark // // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This diff --git a/benchmarks/ampere/gemm_configuration.hpp b/benchmarks/ampere/gemm_configuration.hpp index 8a32e77e7f..484786567f 100644 --- a/benchmarks/ampere/gemm_configuration.hpp +++ b/benchmarks/ampere/gemm_configuration.hpp @@ -58,14 +58,14 @@ struct DefaultGemm_TensorOpSm80_OperandA{}, Layout, - Stride<_64, _1>>{})); + Stride<_64, _1>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, half_t>{}, Layout, - Stride< _8,_1>>{}, + Stride< _8,_1>>{}, Layout>{})); }; @@ -77,14 +77,14 @@ struct DefaultGemm_TensorOpSm80_OperandA{}, Layout, - Stride< _1,_64>>{})); + Stride< _1,_64>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, half_t>{}, Layout, - Stride< _1,_16>>{}, + Stride< _1,_16>>{}, Layout>{})); }; @@ -96,14 +96,14 @@ struct DefaultGemm_TensorOpSm80_OperandA{}, Layout, - Stride<_32, _1>>{})); + Stride<_32, _1>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, half_t>{}, Layout, - Stride< _4,_1>>{}, + Stride< _4,_1>>{}, Layout>{})); }; @@ -120,3 +120,78 @@ template struct DefaultGemm_TensorOpSm80_OperandB : DefaultGemm_TensorOpSm80_OperandA {}; + +///////////////////////////////////////////////////////////////////////// + +// Bfloat + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,3,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, bfloat16_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; diff --git a/benchmarks/common/benchmark_runner.hpp b/benchmarks/common/benchmark_runner.hpp index e3d5d8f3a5..5eb2ade3eb 100644 --- a/benchmarks/common/benchmark_runner.hpp +++ b/benchmarks/common/benchmark_runner.hpp @@ -97,7 +97,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "PVC GEMM Example\n\n" + out << "PVC GEMM Benchmark\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement\n\n" << " --m= Sets the M extent of the GEMM\n" diff --git a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp index 2ff30d7a79..67b76929db 100644 --- a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp +++ b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp @@ -56,7 +56,7 @@ int main(int argc, const char** argv) } // - // Run examples + // Run benchmark // // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This