diff --git a/examples/35_gemm_softmax/CMakeLists.txt b/examples/35_gemm_softmax/CMakeLists.txt index b7ecd99fc..824453a65 100644 --- a/examples/35_gemm_softmax/CMakeLists.txt +++ b/examples/35_gemm_softmax/CMakeLists.txt @@ -29,8 +29,14 @@ +if (NOT CUTLASS_ENABLE_SYCL) cutlass_example_add_executable( 35_gemm_softmax gemm_softmax.cu ) - +else() +cutlass_example_add_executable( + 35_gemm_online_softmax + gemm_online_softmax.cpp + ) +endif() diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp new file mode 100644 index 000000000..b8844286f --- /dev/null +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -0,0 +1,604 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief GEMM + Softmax example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture + + This example demonstrate how to instantiate and run a TF32 GEMM using the Cute and + CUTLASS 3.x APIs on NVIDIA Ampere architecture. Please check example 07 and 08 for + the basics of tensor op gemm kernels. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#else +#include "cutlass/util/reference/device/tensor_fill.h" +#endif +#include "helper.h" +#include "softmax_epilogue.hpp" +#include "gemm_softmax_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +/// Result structure +struct Result { + + double avg_runtime_ms; + double gflops; + bool passed; + + // + // Methods + // + + Result( + double avg_runtime_ms = 0, + double gflops = 0) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), passed(false) + {} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + int m, n, k, l; + float alpha, beta; + int iterations; + float tolerance; + + Options(): + help(false), + m(5120), n(4096), k(4096), l(1), + alpha(1), beta(0), + iterations(100), + tolerance(1e-5f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tolerance", tolerance); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "35_gemm_online_softmax example\n\n" + << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --tolerance Error tolerance\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Wrapper to run and verify a GEMM. +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StridePartials = typename Gemm::CollectiveEpilogue::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + using LayoutPartials = typename Gemm::LayoutPartials; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + StridePartials stride_partials; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_max; + cutlass::DeviceAllocation block_sum; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + template + bool verify_tensor(std::vector vector_Input, + std::vector vector_Input_Ref, const Options& options) { + + auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); + float abs_tol = options.tolerance; + float rel_tol = options.tolerance; + + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i)); + float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + printf("i = %d diff = %f, {%f, %f}.\n", i, abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + return false; + } + + } + + return true; + } + + /// Verifies the reference matches + bool verify(const Options& options) { + using ElementSoftmax = ElementD; + + cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{options.m, options.n, options.k}; + + int64_t total_elements_A_per_batch = options.m * options.k; + int64_t total_elements_B_per_batch = options.k * options.n; + int64_t total_elements_C_per_batch = options.m * options.n; + int64_t total_elements_D_per_batch = total_elements_C_per_batch; + + int64_t lda = LayoutA::packed({options.m, options.k}).stride(0); + int64_t ldb = LayoutB::packed({options.k, options.n}).stride(0); + int64_t ldc = LayoutC::packed({options.m, options.n}).stride(0); + + int64_t ldn = options.m; + int64_t lds = ldn; + + LayoutA layout_A(lda); + LayoutB layout_B(ldb); + LayoutC layout_C(ldc); + LayoutPartials Layout_N(ldn); + LayoutPartials Layout_S(lds); + + cutlass::MatrixCoord extent_A{options.m, options.k}; + cutlass::MatrixCoord extent_B{options.k, options.n}; + cutlass::MatrixCoord extent_C{options.m, options.n}; + + cutlass::HostTensor reference_N; + reference_N.reset({options.m, 1}, false); + + for (int batch_idx = 0; batch_idx < options.l; batch_idx++) { + cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); + cutlass::TensorView view_Ref_device(block_ref_D.get(), layout_C, extent_C); + + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementCompute, ElementD + >( + problem_size, + options.alpha, + view_A, + cutlass::ComplexTransform::kNone, + view_B, + cutlass::ComplexTransform::kNone, + options.beta, + view_C, + view_Ref_device, + ElementCompute(0) + ); + + // Copy reference results to host memory for verification + std::vector matrix_D_Ref(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_ref_D.get(), matrix_D_Ref.size()); + cutlass::TensorView view_D_Ref(matrix_D_Ref.data(), layout_C, extent_C); + + std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); + cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); + + // Copy computed results to host memory + std::vector matrix_D(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); + + // Compute the norm + for (int m = 0; m < options.m; ++m) { + reference_N.at({m, 0}) = view_D_Ref.ref().at({m, 0}); + for (int n = 1; n < options.n; ++n) { + reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); + } + } + + // Compute softmax + for (int m = 0; m < options.m; ++m) { + float sum = 0; + for (int n = 0; n < options.n; ++n) { + sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); + } + float inv_sum = float(1.0f / sum); + + for (int n = 0; n < options.n; ++n) { + view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( + std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum + ); + } + } + + bool verified_Softmax = verify_tensor(matrix_D, matrix_Softmax_Ref, options); + if (!verified_Softmax) { + std::cerr << "Verification of Softmax tensor failed\n"; + return false; + } + } + return true; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto partials_N = cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})); + auto partials_size = M * partials_N * L; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_partials = cutlass::make_cute_packed_stride(StridePartials{}, cute::make_shape(M, partials_N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + block_sum.reset(partials_size); + block_max.reset(partials_size); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, + options.beta}, + block_C.get(), stride_C, + block_D.get(), stride_D, + block_max.get(), block_sum.get(), stride_partials}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' + << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. +#if !defined(CUTLASS_ENABLE_SYCL) + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + return 0; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + return 0; + } +#endif + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. + // This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // Problem configuration + using ElementA = float; + using ElementB = float; + using ElementAcc = float; + using ElementOutput = float; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + using LayoutPartials = cutlass::layout::ColumnMajor; + + // Tiling configuration selection + using TileShape = Shape<_128,_128,_32>; + + // + // Assembling the CollectiveMainloop type + // + + // Number of pipelines you want to use + constexpr int PipelineStages = 4; + + using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; + + // This code section describes the MMA op and the tile size a warp will compute + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2,_1,_1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // Define the copy layout and atom for device memory copy. + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, float>{}, + Layout, Stride<_1,_16>>{}, + Layout>{})); + + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, float>{}, + Layout, Stride<_8,_1>>{}, + Layout>{})); + + // Define the copy layout and atom for shared memory copy. + using SmemLayoutAtomA = decltype(composition(Swizzle<2,3,2>{}, Layout, Stride< _1,_32>>{})); + using SmemCopyAtomA = Copy_Atom, float>; + + using SmemLayoutAtomB = decltype(composition(Swizzle<3,2,3>{}, Layout, Stride<_32, _1>>{})); + using SmemCopyAtomB = Copy_Atom; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape, + ElementA, + cutlass::detail::TagToStrideA_t, + ElementB, + cutlass::detail::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // + // Assembling the Collective Epilogue Type + // + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAcc, // <- data type of accumulator + ElementAcc>; // <- data type for alpha/beta in linear combination function + + using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + TileShape, + EpilogueOp, + cutlass::gemm::EpilogueDefault>; + + // + // Assembling the GemmKernel + // + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmSoftmaxAdapter; + + ExampleRunner runner; + runner.run(options, hw_info); + + return 0; +} diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp new file mode 100644 index 000000000..13e0369e4 --- /dev/null +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -0,0 +1,543 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_event_manager.hpp" +#endif + +#include "softmax_finalize.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmSoftmaxAdapter +{ +public: + using GemmKernel = GemmKernel_; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< + ElementD, typename GemmKernel::StrideD, + ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StridePartials, + ElementD, typename GemmKernel::StrideD>; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + using LayoutPartials = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + + struct Params{ + typename GemmKernel::Params gemm_params; + typename SoftmaxFinalizeKernel::Params softmax_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params.gemm_params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + void initialize_softmax_params(Arguments const& args, typename SoftmaxFinalizeKernel::Arguments& softmax_args){ + softmax_args.M = get<0>(args.problem_shape); + softmax_args.dataN = get<1>(args.problem_shape); + softmax_args.partialN = cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{})); + softmax_args.batch_count = get<3>(args.problem_shape); + softmax_args.dInput = args.epilogue.dD; + softmax_args.dPartial = args.epilogue.dPartials; + softmax_args.dOutput = args.epilogue.dD; + softmax_args.ptr_in = args.epilogue.ptr_D; + softmax_args.ptr_partial_max = args.epilogue.ptr_max; + softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; + softmax_args.ptr_out = args.epilogue.ptr_D; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); + initialize_softmax_params(args, params_.softmax_params.args); + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if !defined(CUTLASS_ENABLE_SYCL) + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } +#endif + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); + initialize_softmax_params(args, params_.softmax_params.args); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + dim3 const block_finalize = syclcompat::dim3(NumThreadsPerWarp, + std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, + params.softmax_params.args.M), + 1); + dim3 const grid_finalize = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, block_finalize.x), + params.softmax_params.args.batch_count, + 1); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + int smem_size_finalize = SoftmaxFinalizeKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if !defined(CUTLASS_ENABLE_SYCL) + constexpr bool is_static_1x1x1 = cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + dim3 cluster_finalize(1,1,1); + void* kernel_params[] = {¶ms.gemm_params}; + void* kernel_params_finalize[] = {¶ms.softmax_params}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); + launch_result = cuda_adapter->launch(grid_finalize, + cluster_finalize, + block_finalize, + smem_size_finalize, + stream, + kernel_params_finalize, + 1); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + void const* kernel_finalize = (void const*) device_kernel; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { + if (is_static_1x1x1 && not launch_with_pdl) { + device_kernel<<>>(params.gemm_params); + device_kernel<<>>(params.softmax_params); + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + launch_result = ClusterLauncher::launch( + grid_finalize, cluster_finalize, block_finalize, smem_size_finalize, stream, kernel_finalize, kernel_params, launch_with_pdl); + } + } + } +#endif + } + else { + launch_result = Status::kSuccess; + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms.gemm_params}; + void* kernel_params_finalize[] = {¶ms.softmax_params}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + launch_result = cuda_adapter->launch( + grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 1 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if defined(CUTLASS_ENABLE_SYCL) + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + using namespace syclcompat::experimental; +#if defined (SYCL_INTEL_TARGET) + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size} + }, params.gemm_params); +#else + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, + params.gemm_params); +#endif + const auto sycl_block_finalize = syclcompat::dim3(block_finalize.x, block_finalize.y, block_finalize.z); + const auto sycl_grid_finalize = syclcompat::dim3(grid_finalize.x, grid_finalize.y, grid_finalize.z); + auto event2 = launch>(launch_policy{ + sycl_grid_finalize, sycl_block_finalize, local_mem_size{static_cast(smem_size_finalize)}}, + params.softmax_params); + EventManager::getInstance().addEvent(event2); +#else + device_kernel<<>>(params.gemm_params); + device_kernel<<>>(params.softmax_params); +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp new file mode 100644 index 000000000..438552a7f --- /dev/null +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class StridePartials_, + class BlockShapeMNK, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class SoftmaxEpilogue { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using StridePartials = StridePartials_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { + cute::array_aligned(BlockShapeMNK{}) * get<1>(BlockShapeMNK{})> smem_c; + }; + + using TensorStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementAccumulator* ptr_max; + ElementAccumulator* ptr_sum; + StridePartials dPartials{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + SoftmaxEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor & accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem and tile shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto M_tile = get<0>(blk_shape_MNK); + auto N_tile = get<1>(blk_shape_MNK); + auto K_tile = get<2>(blk_shape_MNK); + + auto N_partials = cute::ceil_div(N, N_tile); + + cute::packed_tuple partial_block(M_tile, K_tile); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensors + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_partials,L), params.dPartials); + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_partials,L), params.dPartials); + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); + Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gMax = gMax_mnl(_,m_coord,n_coord,l_coord); + Tensor gSum = gSum_mnl(_,m_coord,n_coord,l_coord); + + //Represent the shared tensor + Tensor sC = make_tensor(make_smem_ptr(reinterpret_cast(smem_buf)), make_layout(make_shape(M_tile, N_tile))); + + // Partition the tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + Tensor tCsC = thr_mma.partition_C(sC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + if(is_source_needed()){ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); + tCgD(i,j,k) = accumulators(i,j,k); + tCsC(i,j,k) = accumulators(i,j,k); + } + } + } + } + } else{ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); + tCgD(i,j,k) = accumulators(i,j,k); + tCsC(i,j,k) = accumulators(i,j,k); + } + } + } + } + } + + syncthreads(); + + // assumption for reductions: size<0>(sC) == block size + assert(size<0>(sC) == BlockDimX() * BlockDimy() * BlockDimZ()); + + ElementAccumulator max = std::numeric_limits::lowest(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(sC); ++i) { + if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i) = sC(thread_idx, i); + max = cutlass::fast_max(max, accumulators(i)); + } + } + gMax(thread_idx) = max; + + ElementAccumulator sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(sC); ++i) { + if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + sum += cutlass::fast_exp(accumulators(i) - max); + } + } + gSum(thread_idx) = sum; + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp new file mode 100644 index 000000000..ca6e6ac93 --- /dev/null +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a final calculation of softmax +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +template < + typename ElementInput_, + typename StrideInput_, + typename ElementPartial_, + typename StridePartial_, + typename ElementOutput_, + typename StrideOutput_ +> +class SoftmaxFinalize { +public: + + using ElementInput = ElementInput_; + using StrideInput = StrideInput_; + using ElementPartial = ElementPartial_; + using StridePartial = StridePartial_; + using ElementOutput = ElementOutput_; + using StrideOutput = StrideOutput_; + + // + // Arguments + // + + struct Arguments { + int M; // dimension M of input, output and partially reduced tensors + int dataN; // dimension N of the input and output + int partialN; // dimension N of the partially reduced tensors + int batch_count; // batch count + StrideInput dInput; // stride of the input + StridePartial dPartial; // stride of the partially reduced tensors + StrideOutput dOutput; // stride of the output + ElementInput* ptr_in; // pointer to start of input data + ElementPartial* ptr_partial_max; // pointer to start of partially reduced max data + ElementPartial* ptr_partial_sum; // pointer to start of partially reduced sum data + ElementOutput* ptr_out; // pointer to start of output data + }; + + struct SharedStorage { + cute::array_aligned s_mem; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +public: + + CUTLASS_DEVICE + SoftmaxFinalize() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char* shared_storage) { + apply(params, shared_storage); + } + +private: + + CUTLASS_DEVICE + void apply(Params const ¶ms, char* shared_storage) { + using ConvertInput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + const int idx_x = ThreadIdxX(); + const int m = idx_x + BlockDimX() * BlockIdxX(); + const int idx_y = ThreadIdxY(); + const int y_size = BlockDimY(); + const int batch_id = BlockIdxY(); + + if (m >= params.args.M) { + return; + } + + // Represent the full tensors + auto IOTensorShape = make_shape(params.args.M, params.args.dataN, params.args.batch_count); + auto PartialTensorShape = make_shape(params.args.M, params.args.partialN, params.args.batch_count); + Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); + Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); + Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); + Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); + + //Represent the shared tensor + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), + make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp))); + + ElementPartial max_val = std::numeric_limits::lowest(); + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + max_val = cutlass::fast_max(max_val, partial_max); + } + sPartial(idx_x, idx_y) = max_val; + syncthreads(); + // tree-reduction could be better, although it does not seem to be a bottleneck + for (int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_max = sPartial(idx_x, idx_y2); + max_val = cutlass::fast_max(max_val, partial_max); + } + + ElementPartial sum_val = 0; + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); + sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); + } + syncthreads(); + sPartial(idx_x, idx_y) = sum_val; + syncthreads(); + sum_val = 0; + // tree-reduction could be better, although it does not seem to be a bottleneck + for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_sum = sPartial(idx_x, idx_y2); + sum_val += partial_sum; + } + + ElementPartial norm = 1 / sum_val; + + for (int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ + auto inVal = mIn(m, n, batch_id); + auto inVal2 = mIn(m, n+1, batch_id); + mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; + mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; + } + if (params.args.dataN % 2 == 1){ + int n = params.args.dataN - 1; + auto inVal = mIn(m, n, batch_id); + mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b40a2ff59..515a5617c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -161,6 +161,7 @@ if (NOT CUTLASS_ENABLE_SYCL) else() foreach(EXAMPLE 14_ampere_tf32_tensorop_gemm + 35_gemm_softmax cute sycl ) diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index fa3873c5e..c856afa76 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -859,6 +859,8 @@ CUTLASS_HOST_DEVICE float fast_exp(float x) { #if defined(__CUDA_ARCH__) return ::expf(x); + #elif defined(__SYCL_CUDA_ARCH__) + return ::sycl::native::exp(x); #else return std::exp(x); #endif diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index c476a2fc6..7e3857e57 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -49,6 +49,7 @@ static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWa static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; static const int NumThreadsPerQuad = 4; static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static constexpr int MaxNumThreadsPerBlock = 1024; ////////////////////////////////////////////////////////////////////////////////////////////////////