From 430f9ebc566b863c89d22d644aa4a41dfdd1aed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 7 Nov 2024 10:41:05 +0100 Subject: [PATCH 01/14] pre-state - WIP --- examples/35_gemm_softmax/CMakeLists.txt | 8 +- .../35_gemm_softmax/gemm_online_softmax.cpp | 519 ++++++++++++++++++ examples/35_gemm_softmax/softmax_epilogue.hpp | 319 +++++++++++ examples/CMakeLists.txt | 1 + 4 files changed, 846 insertions(+), 1 deletion(-) create mode 100644 examples/35_gemm_softmax/gemm_online_softmax.cpp create mode 100644 examples/35_gemm_softmax/softmax_epilogue.hpp diff --git a/examples/35_gemm_softmax/CMakeLists.txt b/examples/35_gemm_softmax/CMakeLists.txt index b7ecd99fcc..d7f2cd574b 100644 --- a/examples/35_gemm_softmax/CMakeLists.txt +++ b/examples/35_gemm_softmax/CMakeLists.txt @@ -29,8 +29,14 @@ +if (NOT CUTLASS_ENABLE_SYCL) cutlass_example_add_executable( 35_gemm_softmax gemm_softmax.cu ) - +else() +cutlass_example_add_executable( + 35_gemm_online_softmax + gemm_online_softmax.cpp + ) +endif() \ No newline at end of file diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp new file mode 100644 index 0000000000..2a7d21fd33 --- /dev/null +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -0,0 +1,519 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple GEMM example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture + + This example demonstrate how to instantiate and run a TF32 GEMM using the Cute and + CUTLASS 3.x APIs on NVIDIA Ampere architecture. Please check example 07 and 08 for + the basics of tensor op gemm kernels. On NVIDIA Ampere architecture, most concept + still holds. The two main differences are: + + (1) NVIDIA Ampere architecture introduces a new series of tensor core instructions + (see include/cute/arch/mma_sm80.hpp) which are more efficient on Ampere. + (2) NVIDIA Ampere architecture uses CP_ASYNC (see include/cute/arch/copy_sm80.hpp) + to build a multistage software pipeline to better hide latency (see + include/cutlass/gemm/collective/sm80_mma_multistage.hpp). + + Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) + data types in tensor cores. One big advantage is that we can load in fp32 data and convert + them implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate + traditional fp32 data by using NVIDIA Ampere architecture. + + Examples: + + $ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm_cute + +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#else +#include "cutlass/util/reference/device/tensor_fill.h" +#endif +#include "helper.h" +#include "softmax_epilogue.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +/// Result structure +struct Result { + + double avg_runtime_ms; + double gflops; + bool passed; + + // + // Methods + // + + Result( + double avg_runtime_ms = 0, + double gflops = 0) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), passed(false) + {} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + int m, n, k, l; + float alpha, beta; + int iterations; + + Options(): + help(false), + m(5120), n(4096), k(4096), l(1), + alpha(1), beta(0), + iterations(0) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + //cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "14_ampere_tf32_tensorop_gemm_cute example\n\n" + << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Wrapper to run and verify a GEMM. +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementCompute(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementCompute(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + +#if defined(CUTLASS_ENABLE_SYCL) + syclcompat::wait_and_throw(); +#else + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } +#endif + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(problem_size, options.alpha, options.beta); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' + << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. +#if !defined(CUTLASS_ENABLE_SYCL) + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + return 0; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + return 0; + } +#endif + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. + // This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // Problem configuration + using ElementA = float; + using ElementB = float; + using ElementAcc = float; + using ElementOutput = float; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + // Tiling configuration selection + using TileShape = Shape<_128,_128,_32>; + + // + // Assembling the CollectiveMainloop type + // + + // Number of pipelines you want to use + constexpr int PipelineStages = 4; + + using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; + + // This code section describes the MMA op and the tile size a warp will compute + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2,_1,_1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // Define the copy layout and atom for device memory copy. + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, float>{}, + Layout, Stride<_1,_16>>{}, + Layout>{})); + + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, float>{}, + Layout, Stride<_8,_1>>{}, + Layout>{})); + + // Define the copy layout and atom for shared memory copy. + using SmemLayoutAtomA = decltype(composition(Swizzle<2,3,2>{}, Layout, Stride< _1,_32>>{})); + using SmemCopyAtomA = Copy_Atom, float>; + + using SmemLayoutAtomB = decltype(composition(Swizzle<3,2,3>{}, Layout, Stride<_32, _1>>{})); + using SmemCopyAtomB = Copy_Atom; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape, + ElementA, + cutlass::detail::TagToStrideA_t, + ElementB, + cutlass::detail::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // + // Assembling the Collective Epilogue Type + // + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAcc, // <- data type of accumulator + ElementOutput>; // <- data type for alpha/beta in linear combination function + + using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + EpilogueOp, + cutlass::gemm::EpilogueDefault>; + + // + // Assembling the GemmKernel + // + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + runner.run(options, hw_info); + + return 0; +} diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp new file mode 100644 index 0000000000..7f7a055f09 --- /dev/null +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class SoftmaxEpilogue { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + SoftmaxEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragDst, + class Op + > + CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { + // reduce across all the N tiles in shape + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + dst(x, y) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + CUTLASS_PRAGMA_UNROLL + for(int z = 1; z < SizeC; z++) { + dst(x, y) = op(dst(x, y), src(x, y, z)); + } + } + } + + // reduce across the sub_group to get the final output + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + CUTLASS_PRAGMA_UNROLL + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); + } + } + } + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragDst, + class Op + > + CUTLASS_DEVICE static void reduceWg(FragSrc const &src, FragDst &dst, char* smem_buf, Op op, SharedStorage const& shared_storage) { + reduceSg(src, dst, op); + for(int i=ThreadIdxX() % NumThreadsPerWarp; i + CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { + reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragSum + > + CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { + reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + //auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + //Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + //Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + //Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + //Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + //CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + // "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + if(ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0){ + print("thr_mma: "); print(thr_mma); print("\n"); + print("tiled_mma: "); print(tiled_mma); print("\n"); + //print("tiled_mma L: "); print_latex(tiled_mma); print("\n"); + print("acc: "); print(accumulators); print("\n"); + print("tCgD: "); print(tCgD); print("\n"); + print("gD: "); print(gD); print("\n"); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + for (int j = 0; j < size<1>(accumulators); ++j) { + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i,j,k) = epilogue_op(accumulators(i,j,k)); + } + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 26c69b310a..c1081d7417 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -155,6 +155,7 @@ if (NOT CUTLASS_ENABLE_SYCL) else() foreach(EXAMPLE 14_ampere_tf32_tensorop_gemm + 35_gemm_softmax cute sycl ) From f191b9a2d01fc0b22ef5d3f65cd61f7eabd0c8a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 11 Nov 2024 13:29:14 +0100 Subject: [PATCH 02/14] WIP second kernel --- .../35_gemm_softmax/gemm_online_softmax.cpp | 21 ++- examples/35_gemm_softmax/softmax_epilogue.hpp | 141 +++++++++++++----- 2 files changed, 123 insertions(+), 39 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 2a7d21fd33..6ad3c95333 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -77,6 +77,7 @@ #endif #include "helper.h" #include "softmax_epilogue.hpp" +#include "gemm_softmax_adapter.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -116,7 +117,7 @@ struct Options { help(false), m(5120), n(4096), k(4096), l(1), alpha(1), beta(0), - iterations(0) + iterations(100) { } // Parses the command line @@ -205,6 +206,7 @@ struct ExampleRunner { using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideTmp = typename Gemm::CollectiveEpilogue::StrideD; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; @@ -232,11 +234,14 @@ struct ExampleRunner { StrideB stride_B; StrideC stride_C; StrideD stride_D; + StrideTmp stride_tmp; uint64_t seed = 0; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_max; + cutlass::DeviceAllocation block_sum; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; @@ -292,16 +297,24 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; + // 1 element per warp. + auto tmp_size = cute::ceil_div(M * K * L, cute::shape<0>(typename Gemm::TileShape{}) * cute::shape<1>(typename Gemm::TileShape{})) * NumWarpsPerWarpGroup; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(cute::ceil_div(M, cute::shape<0>(typename Gemm::TileShape{})), + cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})), + L)); block_A.reset(M * K * L); block_B.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); + block_sum.reset(tmp_size); + block_max.reset(tmp_size); initialize_block(block_A, seed + 2023); initialize_block(block_B, seed + 2022); @@ -317,7 +330,7 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, block_max.get(), block_sum.get(), stride_tmp}, hw_info }; @@ -431,6 +444,7 @@ int main(int argc, char const **args) { using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::ColumnMajor; + using LayoutTmp = cutlass::layout::ColumnMajor; // Tiling configuration selection using TileShape = Shape<_128,_128,_32>; @@ -497,6 +511,7 @@ int main(int argc, char const **args) { using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, EpilogueOp, cutlass::gemm::EpilogueDefault>; @@ -510,7 +525,7 @@ int main(int argc, char const **args) { CollectiveEpilogue >; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = cutlass::gemm::device::GemmSoftmaxAdapter; ExampleRunner runner; runner.run(options, hw_info); diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 7f7a055f09..e3e558da8d 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -55,6 +55,7 @@ namespace collective { template < class StrideC_, class StrideD_, + class StrideTmp_, class ThreadEpilogueOp_, class EpilogueSchedule_ > @@ -76,6 +77,7 @@ class SoftmaxEpilogue { using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; + using StrideTmp = StrideTmp_; using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -97,6 +99,9 @@ class SoftmaxEpilogue { StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; + ElementAccumulator* ptr_max; + ElementAccumulator* ptr_sum; + StrideTmp dTmp{}; }; // Device side epilogue params @@ -148,9 +153,6 @@ class SoftmaxEpilogue { template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragDst, class Op @@ -158,12 +160,12 @@ class SoftmaxEpilogue { CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { // reduce across all the N tiles in shape CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < SizeA; x++) { + for(int x = 0; x < size<0>(src); x++) { CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < SizeB; y++) { - dst(x, y) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + for(int y = 0; y < size<1>(src); y++) { + dst(0, 0) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < SizeC; z++) { + for(int z = 1; z < size<2>(src); z++) { dst(x, y) = op(dst(x, y), src(x, y, z)); } } @@ -172,9 +174,9 @@ class SoftmaxEpilogue { // reduce across the sub_group to get the final output auto sg = syclcompat::get_nd_item<1>().get_sub_group(); CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < SizeA; x++) { + for(int x = 0; x < size<0>(src); x++) { CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < SizeB; y++) { + for(int y = 0; y < size<1>(src); y++) { CUTLASS_PRAGMA_UNROLL for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); @@ -185,42 +187,45 @@ class SoftmaxEpilogue { template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragDst, class Op > CUTLASS_DEVICE static void reduceWg(FragSrc const &src, FragDst &dst, char* smem_buf, Op op, SharedStorage const& shared_storage) { - reduceSg(src, dst, op); + reduceSg(src, dst, op); for(int i=ThreadIdxX() % NumThreadsPerWarp; i y ? x : y; } + };*/ + template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragMax > CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { - reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + //reduceSg(src, max, MaxOp()); } + /*struct SumOp { + CUTLASS_DEVICE ElementAccumulator + operator()(ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; } + };*/ + template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragSum > CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { - reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); + reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); + //reduceSg(src, sum, SumOp()); } template< @@ -236,7 +241,7 @@ class SoftmaxEpilogue { ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, + cute::Tensor & accumulators, TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, @@ -255,28 +260,28 @@ class SoftmaxEpilogue { auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); - //auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); // Represent the full output tensor - //Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - //Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - //Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - //Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) static_assert(is_static::value, "Accumulator layout must be static"); - //CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), - // "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); @@ -284,25 +289,89 @@ class SoftmaxEpilogue { auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); + //Tensor acc_max = make_tensor(Shape(accumulators)>, Int(accumulators)>>{}); + //Tensor acc_max = make_tensor(size<0>(accumulators)); + Tensor acc_max = make_tensor_like(take<0,2>(accumulators)); + Tensor acc_sum = make_tensor_like(take<0,2>(accumulators)); //TODO can reuse prev? + if(ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0){ - print("thr_mma: "); print(thr_mma); print("\n"); - print("tiled_mma: "); print(tiled_mma); print("\n"); - //print("tiled_mma L: "); print_latex(tiled_mma); print("\n"); - print("acc: "); print(accumulators); print("\n"); - print("tCgD: "); print(tCgD); print("\n"); - print("gD: "); print(gD); print("\n"); + //print("thr_mma: "); print(thr_mma); print("\n"); + //print("tiled_mma: "); print(tiled_mma); print("\n"); + //print("acc: "); print(accumulators); print("\n"); + //print("tCgD: "); print(tCgD); print("\n"); + //print("acc_max: "); print(acc_max); print("\n"); + //print("take<0,2>(accumulators): "); print(take<0,2>(accumulators)); print("\n"); + //print("gD: "); print(gD); print("\n"); + } + + if(is_source_needed()){ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); + tCgD(i,j,k) = accumulators(i,j,k); + } + } + } + } + } else{ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); + tCgD(i,j,k) = accumulators(i,j,k); + } + } + } + } } + reduce_max(accumulators, acc_max); + //reduceSg(accumulators, acc_max, MaxOp()); + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(accumulators); ++k) { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCgD(i,j,k) = epilogue_op(accumulators(i,j,k)); + accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(i,j)); } } } } + + reduce_sum(accumulators, acc_sum); + + //TODO write out reductions + + //second kernel: + // - finalize max reduction: mN = sum(mj) + // - finalize sum reduction: sN = sum(sj * exp(mj-mN)) + // - finalize softmax: yi = exp(xi-mN)/sN + + /*CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i,j,k) = accumulators(i,j,k); + } + } + } + }*/ + } private: From e0fe666b48faa0c21bb8943d104cd266ff0dbf3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 18 Nov 2024 14:36:46 +0100 Subject: [PATCH 03/14] working softmax --- .../35_gemm_softmax/gemm_online_softmax.cpp | 198 ++++++- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 527 ++++++++++++++++++ examples/35_gemm_softmax/softmax_epilogue.hpp | 254 +++++++-- examples/35_gemm_softmax/softmax_finalize.hpp | 305 ++++++++++ include/cutlass/gpu_generics.h | 2 +- 5 files changed, 1232 insertions(+), 54 deletions(-) create mode 100644 examples/35_gemm_softmax/gemm_softmax_adapter.hpp create mode 100644 examples/35_gemm_softmax/softmax_finalize.hpp diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 6ad3c95333..eb2533947b 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -112,12 +112,14 @@ struct Options { int m, n, k, l; float alpha, beta; int iterations; + float tolerance; Options(): help(false), m(5120), n(4096), k(4096), l(1), alpha(1), beta(0), - iterations(100) + iterations(100), + tolerance(1e-5f) { } // Parses the command line @@ -134,8 +136,9 @@ struct Options { cmd.get_cmd_line_argument("k", k, 4096); cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); - //cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tolerance", tolerance); } @@ -152,7 +155,8 @@ struct Options { << " --l= Sets the L extent (batch count) of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Number of profiling iterations to perform.\n\n"; + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --tolerance Error tolerance\n"; return out; } @@ -212,13 +216,15 @@ struct ExampleRunner { using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; + using LayoutTmp = typename Gemm::LayoutTmp; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; using ElementAcc = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; @@ -249,7 +255,7 @@ struct ExampleRunner { // Methods // - bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { + /*bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { auto [M, N, K, L] = problem_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); @@ -290,6 +296,177 @@ struct ExampleRunner { bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); return passed; + }*/ + + template + bool verify_tensor(std::vector vector_Input, \ + std::vector vector_Input_Ref, const Options& options) { + + auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); + float abs_tol = options.tolerance; + float rel_tol = options.tolerance; + + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i)); + float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { + printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + return false; + } + + } + + return true; + } + + /// Verifies the reference matches + bool verify(const Options& options) { + using ElementSoftmax = ElementD; + + cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{options.m, options.n, options.k}; + + int64_t total_elements_A_per_batch = options.m * options.k; + int64_t total_elements_B_per_batch = options.k * options.n; + int64_t total_elements_C_per_batch = options.m * options.n; + int64_t total_elements_D_per_batch = total_elements_C_per_batch; + + int64_t lda = LayoutA::packed({options.m, options.k}).stride(0); + int64_t ldb = LayoutB::packed({options.k, options.n}).stride(0); + int64_t ldc = LayoutC::packed({options.m, options.n}).stride(0); + + int64_t ldn = options.m; + int64_t lds = ldn; + + LayoutA layout_A(lda); + LayoutB layout_B(ldb); + LayoutC layout_C(ldc); + LayoutTmp Layout_N(ldn); + LayoutTmp Layout_S(lds); + + cutlass::MatrixCoord extent_A{options.m, options.k}; + cutlass::MatrixCoord extent_B{options.k, options.n}; + cutlass::MatrixCoord extent_C{options.m, options.n}; + + cutlass::HostTensor reference_N; + reference_N.reset({options.m, 1}, false); + + for (int batch_idx = 0; batch_idx < options.l; batch_idx++) { + cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); + cutlass::TensorView view_Ref_device(block_ref_D.get(), layout_C, extent_C); + + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementCompute + >( + problem_size, + options.alpha, + view_A, + cutlass::ComplexTransform::kNone, + view_B, + cutlass::ComplexTransform::kNone, + options.beta, + view_C, + view_Ref_device, + ElementCompute(0) + ); + + // Copy reference results to host memory for verification + std::vector matrix_D_Ref(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_ref_D.get(), matrix_D_Ref.size()); + cutlass::TensorView view_D_Ref(matrix_D_Ref.data(), layout_C, extent_C); + + std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); + cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); + + // Copy computed results to host memory + std::vector matrix_D(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); + + auto& matrix_Softmax = matrix_D; + //std::vector matrix_Softmax(layout_C.capacity(extent_C)); + //cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); + + // Compute the norm + for (int m = 0; m < options.m; ++m) { + reference_N.at({m, 0}) = view_D_Ref.ref().at({m, 0}); + if(batch_idx == 0 && m < 3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ + std::cout << "ref tmp " << m << " " << 0 << ": " << view_D_Ref.ref().at({m, 0}) << std::endl; + } + for (int n = 1; n < options.n; ++n) { + //std::cout << "val: " << view_D_Ref.ref().at({m, n}) << std::endl; + reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); + + if(batch_idx == 0 && m < 3 && n<3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ + std::cout << "ref tmp " << m << " " << n << ": " << view_D_Ref.ref().at({m, n}) << std::endl; + } + if(batch_idx == 0 && m==0 && n==127 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ + std::cout << "ref max tmp " << m << " " << n << ": " << reference_N.at({m, 0}) << std::endl; + } + } + if(batch_idx == 0 && m == 0){ + std::cout << "ref max: " << reference_N.at({m, 0}) << std::endl; + } + } + + // Compute softmax + for (int m = 0; m < options.m; ++m) { + float sum = float(); + + for (int n = 0; n < options.n; ++n) { + sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); + } + if(batch_idx == 0 && m == 0){ + std::cout << "ref sum: " << sum << std::endl; + } + + float inv_sum = float(1.0f / sum); + + for (int n = 0; n < options.n; ++n) { + view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( + std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum + ); + } + } + + // Verification checks - set any of these to 'true' to override the verification checks. + bool verified_D = false; + bool verified_Softmax = false; + + // Verify softmax output + if (!verified_D) { + verified_D = verify_tensor(matrix_D, matrix_D_Ref, options); + } + + if (!verified_Softmax) { + verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref, options); + } + //TODO(Tadej): just softmax + if (!verified_D && !verified_Softmax) { + std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; + + // Summarize which checks failed + if (!verified_D) { + std::cerr << "Verification of D tensor failed\n"; + } else{ + std::cerr << "Verification of D tensor passed\n"; + } + + if (!verified_Softmax) { + std::cerr << "Verification of Softmax tensor failed\n"; + } else{ + std::cerr << "Verification of Softmax tensor passed\n"; + } + + return false; + } + } + return true; } /// Initialize operands to be used in the GEMM and reference GEMM @@ -297,16 +474,14 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - // 1 element per warp. - auto tmp_size = cute::ceil_div(M * K * L, cute::shape<0>(typename Gemm::TileShape{}) * cute::shape<1>(typename Gemm::TileShape{})) * NumWarpsPerWarpGroup; + auto partials_N = cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})); + auto tmp_size = M * partials_N * L; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(cute::ceil_div(M, cute::shape<0>(typename Gemm::TileShape{})), - cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})), - L)); + stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(M, partials_N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); @@ -348,7 +523,7 @@ struct ExampleRunner { // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; - result.passed = verify(problem_size, options.alpha, options.beta); + result.passed = verify(options); std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; @@ -512,6 +687,7 @@ int main(int argc, char const **args) { cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, + TileShape, EpilogueOp, cutlass::gemm::EpilogueDefault>; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp new file mode 100644 index 0000000000..08fa798684 --- /dev/null +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -0,0 +1,527 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +//#include "cutlass/gemm/device/gemm_universal_base.h" +//#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +//#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +//#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_event_manager.hpp" +#endif + +#include "softmax_finalize.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmSoftmaxAdapter +{ +public: + using GemmKernel = GemmKernel_; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< + ElementD, typename GemmKernel::StrideD, + ElementD, typename GemmKernel::StrideD, + ElementD, typename GemmKernel::StrideD>; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + using LayoutTmp = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + //using Params = typename GemmKernel::Params; + + struct Params{ + typename GemmKernel::Params gemm_params; + typename SoftmaxFinalizeKernel::Params softmax_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params.gemm_params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); + //TODO(Tadej) move to finalize kernel class? + auto& softmax_args = params_.softmax_params.args; + softmax_args.IOSize = {get<0>(args.problem_shape), get<1>(args.problem_shape)}; + softmax_args.partialSize = {get<0>(args.problem_shape), + cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{}))}; + softmax_args.batch_count = get<3>(args.problem_shape); + softmax_args.dInput = args.epilogue.dD; + softmax_args.dPartial = args.epilogue.dTmp; + softmax_args.dOutput = args.epilogue.dD; + softmax_args.ptr_in = args.epilogue.ptr_D; + softmax_args.ptr_partial_max = args.epilogue.ptr_max; + softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; + softmax_args.ptr_out = args.epilogue.ptr_D; + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if !defined(CUTLASS_ENABLE_SYCL) + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } +#endif + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); + //TODO(Tadej) update softmax args + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if !defined(CUTLASS_ENABLE_SYCL) + constexpr bool is_static_1x1x1 = cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { + if (is_static_1x1x1 && not launch_with_pdl) { + device_kernel<<>>(params); + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + } +#endif + } + else { + launch_result = Status::kSuccess; + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms.gemm_params}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if defined(CUTLASS_ENABLE_SYCL) + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + using namespace syclcompat::experimental; +#if defined (SYCL_INTEL_TARGET) + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size} + }, params.gemm_params); +#else + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, + params.gemm_params); +#endif + //EventManager::getInstance().addEvent(event); + + const auto sycl_block2 = syclcompat::dim3(128, 1, 1); + const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), + params.softmax_params.args.batch_count, + 1); + auto event2 = launch>(launch_policy{ + sycl_grid2, sycl_block2, local_mem_size{0}}, + params.softmax_params); + EventManager::getInstance().addEvent(event2); +#else + device_kernel<<>>(params.gemm_params); +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index e3e558da8d..96c902e373 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -56,6 +56,7 @@ template < class StrideC_, class StrideD_, class StrideTmp_, + class BlockShapeMNK, class ThreadEpilogueOp_, class EpilogueSchedule_ > @@ -88,7 +89,10 @@ class SoftmaxEpilogue { static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - struct SharedStorage { }; + struct SharedStorage { + //cute::array_aligned(BlockShapeMNK{})>,C(BlockShapeMNK{})>>>>> smem_c; + cute::array_aligned(BlockShapeMNK{}) * get<1>(BlockShapeMNK{})> smem_c; + }; using TensorStorage = SharedStorage; @@ -158,15 +162,15 @@ class SoftmaxEpilogue { class Op > CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { - // reduce across all the N tiles in shape + // reduce across all the -N- M tiles in shape CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < size<0>(src); x++) { + for(int z = 1; z < size<2>(src); z++) { + dst(z) = zero_init ? src(0, 0, z) : op(dst(z), src(0, 0, z)); CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < size<1>(src); y++) { - dst(0, 0) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + for(int x = 0; x < size<0>(src); x++) { CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < size<2>(src); z++) { - dst(x, y) = op(dst(x, y), src(x, y, z)); + for(int y = 0; y < size<1>(src); y++) { + dst(z) = op(dst(z), src(x, y, z)); } } } @@ -174,35 +178,73 @@ class SoftmaxEpilogue { // reduce across the sub_group to get the final output auto sg = syclcompat::get_nd_item<1>().get_sub_group(); CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < size<0>(src); x++) { + for(int z = 1; z < size<2>(src); z++) { CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < size<1>(src); y++) { - CUTLASS_PRAGMA_UNROLL - for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { - dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); - } + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + dst(z) = op(dst(z), syclcompat::permute_sub_group_by_xor(sg, dst(z), laneMask, 16)); } } } template < - bool zero_init, class FragSrc, class FragDst, + class SharedThreadTens, + class SharedTens, + class ResidueMap, + class Residue, class Op > - CUTLASS_DEVICE static void reduceWg(FragSrc const &src, FragDst &dst, char* smem_buf, Op op, SharedStorage const& shared_storage) { - reduceSg(src, dst, op); + CUTLASS_DEVICE static ElementAccumulator reduceWg(FragSrc const &src, FragDst &dst, + SharedThreadTens& tCsC, SharedTens& sC, + ResidueMap tCcD, Residue residue_mnk, int thread_idx, + ElementAccumulator init, Op op) { + //TODO(Tadej): single loop over all dims + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(src); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(src); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(src); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCsC(i,j,k) = src(i,j,k); + } else{ + tCsC(i,j,k) = init; + } + } + } + } + + syncthreads(); + + ElementAccumulator acc = sC(0, thread_idx); + for (int i = 1; i < size(src); ++i) { + acc = op(acc, sC(i, thread_idx)); + } + + syncthreads(); + + //broadcast it back to threads + //TODO(Tadej): optimize + for (int i = 0; i < size(src); ++i) { + sC(i, thread_idx) = acc; + } + + syncthreads(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < size<2>(src); k++) { + dst(k) = tCsC(0,0,k); + } + + return acc; + + /*reduceSg(src, dst, op); for(int i=ThreadIdxX() % NumThreadsPerWarp; i y ? x : y; } - };*/ - template < bool zero_init, class FragSrc, @@ -210,13 +252,24 @@ class SoftmaxEpilogue { > CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); - //reduceSg(src, max, MaxOp()); } - /*struct SumOp { - CUTLASS_DEVICE ElementAccumulator - operator()(ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; } - };*/ + template < + class FragSrc, + class FragDst, + class SharedThreadTens, + class SharedTens, + class ResidueMap, + class Residue + > + CUTLASS_DEVICE static ElementAccumulator reduce_max_wg(FragSrc const &src, FragDst &dst, + SharedThreadTens& tCsC, SharedTens& sC, + ResidueMap tCcD, Residue residue_mnk, int thread_idx) { + + return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, + std::numeric_limits::min(), + [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + } template < bool zero_init, @@ -225,12 +278,27 @@ class SoftmaxEpilogue { > CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); - //reduceSg(src, sum, SumOp()); + } + + template < + class FragSrc, + class FragDst, + class SharedThreadTens, + class SharedTens, + class Residue, + class ResidueMap + > + CUTLASS_DEVICE static ElementAccumulator reduce_sum_wg(FragSrc const &src, FragDst &dst, + SharedThreadTens& tCsC, SharedTens& sC, + ResidueMap tCcD, Residue residue_mnk, int thread_idx) { + + return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, + 0, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x+y; }); } template< class ProblemShapeMNKL, - class BlockShapeMNK, + //class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, class TiledMma, @@ -255,29 +323,53 @@ class SoftmaxEpilogue { static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - // Separate out problem shape for convenience + //auto wlid = thread_idx % NumWarpsPerWarpGroup; // warp local id + //auto wid = thread_idx / NumWarpsPerWarpGroup; // warp id in tile + + // Separate out problem and tile shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); + auto M_tile = get<0>(blk_shape_MNK); + auto N_tile = get<1>(blk_shape_MNK); + auto K_tile = get<2>(blk_shape_MNK); + + auto N_tmp = cute::ceil_div(N, N_tile); + + cute::packed_tuple partial_block(M_tile, C<1>(), K_tile); + auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); - // Represent the full output tensor + // Represent the full output tensors Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gMax = gMax_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gSum = gSum_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - // Partition source and destination tiles to match the accumulator partitioning + //Represent the shared tensor + Tensor sC = make_tensor(make_smem_ptr(reinterpret_cast(smem_buf)), make_layout(make_shape(M_tile, N_tile))); + + // Partition the tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + //Tensor tCgMax = thr_mma.partition_C(gMax); // (VEC,THR_M,THR_N) + //Tensor tCgSum = thr_mma.partition_C(gSum); // (VEC,THR_M,THR_N) + Tensor tCsC = thr_mma.partition_C(sC); // (VEC,THR_M,THR_N) + static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), @@ -291,17 +383,35 @@ class SoftmaxEpilogue { //Tensor acc_max = make_tensor(Shape(accumulators)>, Int(accumulators)>>{}); //Tensor acc_max = make_tensor(size<0>(accumulators)); - Tensor acc_max = make_tensor_like(take<0,2>(accumulators)); - Tensor acc_sum = make_tensor_like(take<0,2>(accumulators)); //TODO can reuse prev? - - if(ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0){ + //Tensor acc_max = make_tensor_like(take<2,3>(accumulators)); + //Tensor acc_sum = make_tensor_like(acc_max); //TODO can reuse prev? + + //Tensor acc_max = make_tensor(shape<2>(accumulators), LayoutLeft{}); + //Tensor acc_sum = make_tensor(shape<2>(accumulators), LayoutLeft{}); //TODO can reuse prev? + + bool is_first = ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0; + if(is_first){ + print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); + //print("blk_shape_MNK: "); print(blk_shape_MNK); print("\n"); + //print("partial_block: "); print(partial_block); print("\n"); //print("thr_mma: "); print(thr_mma); print("\n"); //print("tiled_mma: "); print(tiled_mma); print("\n"); //print("acc: "); print(accumulators); print("\n"); + //print("mD_mnl: "); print(mD_mnl); print("\n"); + print("mMax_mnl: "); print(mMax_mnl); print("\n"); + //print("gD_mnl: "); print(gD_mnl); print("\n"); + print("gMax_mnl: "); print(gMax_mnl); print("\n"); + //print("gD: "); print(gD); print("\n"); + print("gMax: "); print(gMax); print("\n"); //print("tCgD: "); print(tCgD); print("\n"); + //print("sC: "); print(sC); print("\n"); + //print("tCsC: "); print(tCsC); print("\n"); + //print("sC.data: "); print(&sC(0)); print("\n"); + //print("tCsC.data: "); print(&tCsC(0)); print("\n"); + //decltype(tCsC(0)) a = "asd"; + //print("tCgMax: "); print(tCgMax); print("\n"); //print("acc_max: "); print(acc_max); print("\n"); - //print("take<0,2>(accumulators): "); print(take<0,2>(accumulators)); print("\n"); - //print("gD: "); print(gD); print("\n"); + //print("accumulators: "); print(accumulators); print("\n"); } if(is_source_needed()){ @@ -314,6 +424,10 @@ class SoftmaxEpilogue { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); + tCsC(i,j,k) = accumulators(i,j,k); + /*if(is_first){ + print("acc1.1:"); print(tCsC(i,j,k)); print("\n"); + }*/ } } } @@ -328,14 +442,65 @@ class SoftmaxEpilogue { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); - } + tCsC(i,j,k) = accumulators(i,j,k); + /*if(is_first){ + print("acc1.2:"); print(accumulators(i,j,k)); print(".\n"); + print("idx:"); print(tCsC.layout()(i,j,k)); print(".\n"); + }*/ + } } } } } - reduce_max(accumulators, acc_max); - //reduceSg(accumulators, acc_max, MaxOp()); + syncthreads(); + + // assumption size<0>(sC) == wg size + ElementAccumulator max = std::numeric_limits::min(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(sC); ++i) { + if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i) = sC(thread_idx, i); + max = cutlass::fast_max(max, accumulators(i)); + /*if(is_first && i < 3){ + print("acc2 :"); print(accumulators(i)); print("\n"); + //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); + for (int j = 0; j < 3; ++j) { + print("shared :"); print(j); print(" "); print(i); print(": "); print(sC(j, i)); print("\n"); + } + }*/ + } + } + /*if(m_coord == 0 && n_coord == 1 && ThreadIdxX()==0){ + print("max epilogue val:"); print(max); print("\n"); + print("idx:"); print(n_coord); print("\n"); + }*/ + + gMax(thread_idx,0) = max; + + ElementAccumulator sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(sC); ++i) { + if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + sum += cutlass::fast_exp(accumulators(i) - max); + if(is_first){ + //print("acc3 :"); print(accumulators(i)); print("\n"); + //print("diff :"); print(accumulators(i) - max); print("\n"); + //print("add :"); print(cutlass::fast_exp(accumulators(i) - max)); print("\n"); + //print("sum :"); print(sum); print("\n"); + //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); + } + } + } + if(is_first){ + print("sum epilogue val:"); print(sum); print("\n"); + } + + gSum(thread_idx,0) = sum; + + + /*//reduce_max(accumulators, acc_max); + reduce_max_wg(accumulators, acc_max, tCsC, sC, tCcD, residue_mnk, thread_idx); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(accumulators); ++i) { @@ -344,13 +509,18 @@ class SoftmaxEpilogue { CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(accumulators); ++k) { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(i,j)); + accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(k)); } } } } reduce_sum(accumulators, acc_sum); + if(wlid == 0){ + for (int k = 0; k < size<2>(accumulators); ++k) { + gSum(wid,k) = acc_sum(k); + } + }*/ //TODO write out reductions diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp new file mode 100644 index 0000000000..bc91f85dd9 --- /dev/null +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a final calculation of softmax +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +template < + typename ElementInput_, + typename StrideInput_, + typename ElementPartial_, + typename StridePartial_, + typename ElementOutput_, + typename StrideOutput_ +> +class SoftmaxFinalize { +public: + + using ElementInput = ElementInput_; + using StrideInput = StrideInput_; + using ElementPartial = ElementPartial_; + using StridePartial = StridePartial_; + using ElementOutput = ElementOutput_; + using StrideOutput = StrideOutput_; + + // + // Arguments + // + + struct Arguments { + //TODO(Tadej): duplicated part of sizes + cutlass::MatrixCoord IOSize; ///< Extent of input and output matrices + cutlass::MatrixCoord partialSize; ///< Extent of partial max and sum matrices + int batch_count; ///< Batch count + StrideInput dInput; + StridePartial dPartial; + StrideOutput dOutput; + ElementInput* ptr_in; + ElementPartial* ptr_partial_max; + ElementPartial* ptr_partial_sum; + ElementOutput* ptr_out; + +/* + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size, + ElementNorm* block_Norm, + ElementSum* block_Sum + ): + problem_size(problem_size), + block_Norm(block_Norm), + block_Sum(block_Sum), + problem_sizes(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr), + batch_stride_Max(0), + batch_stride_Sum(0) + { + + }*/ + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + SoftmaxFinalize() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char* shared_storage) { + + apply(params, shared_storage); + } + +private: + + template + CUTLASS_DEVICE static ElementPartial reduceSg(ElementPartial val, Op op) { + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + val = op(val, syclcompat::permute_sub_group_by_xor(sg, val, laneMask, 16)); + } + return val; + } + + CUTLASS_DEVICE static ElementPartial reduce_max(ElementPartial val) { + return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x > y ? x : y; }); + } + + CUTLASS_DEVICE static ElementPartial reduce_sum(ElementPartial val) { + return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x + y; }); + } + + /// Full reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, char* shared_storage) { + using ConvertInput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + /*int tid = ThreadIdxX(); + //int bid = BlockIdxX(); + int bdim = BlockDimX(); + + int warps_in_block = bdim / NumThreadsPerWarp; + + int wlid = tid % warps_in_block; // local id of thread in warp + int gid = tid;// + bid * GridDimX(); + int bsize = GridDimX(); + int m_batch_id = BlockIdxY(); + int batch_id = m_batch_id / params.args.IOSize[1]; + int m = m_batch_id % params.args.IOSize[1];*/ + + int m = ThreadIdxX() + BlockDimX() * BlockIdxX(); + int batch_id = BlockIdxY(); + + if(m>=params.args.IOSize[0]){ + return; + } + + + // Represent the full tensors + auto IOTensorShape = make_shape(params.args.IOSize[0], params.args.IOSize[1], params.args.batch_count); + auto PartialTensorShape = make_shape(params.args.partialSize[0], params.args.partialSize[1], params.args.batch_count); + Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); // (m,n,l) + Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); // (m,n,l) + Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); // (m,n,l) + Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) + + if(m==0 && batch_id==0){ + print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); + } + + ElementPartial max_val = std::numeric_limits::min(); + for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + /*if(m==0 && batch_id==0){ + print("partial_max: "); print(partial_max); print("\n"); + }*/ + max_val = max_val > partial_max ? max_val : partial_max; + } + //max_val = reduce_max(max_val); + + //mOut(0,0,0) = max_val; return; + + ElementPartial sum_val = 0; + for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); + sum_val = sum_val + partial_sum * cutlass::fast_exp(partial_max - max_val); + } + //sum_val = reduce_sum(sum_val); + + if(m==0 && batch_id==0){ + print("max_val: "); print(max_val); print("\n"); + print("sum_val: "); print(sum_val); print("\n"); + } + + ElementPartial norm = 1 / sum_val; + + for(int n = 0; n < params.args.IOSize[1]; n += 1){ + mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; + } + } + /* + // defining three vars for a general reduction module + cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; + int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; + int access_offset = isGroupedProblem ? 0 : bid * bdim; + + if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; + + ElementNorm *curr_ptr_Max = isGroupedProblem ? \ + params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ + params.args.block_Norm + block_batch * params.args.batch_stride_Max; + ElementSum *curr_ptr_Sum = isGroupedProblem ? \ + params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ + params.args.block_Sum + block_batch * params.args.batch_stride_Sum; + + int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + uint32_t float_max_bits = 0xff7fffff; + float min_float = reinterpret_cast(float_max_bits); + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { + ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; + ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; + ElementNorm *access_n_bak = access_n; + ElementSum *access_s_bak = access_s; + ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); + ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); + ElementNorm fetch_n; + ElementSum fetch_s; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); + access_n += problem_size.m(); + } + + access_n = access_n_bak; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + cutlass::arch::global_load(fetch_s, access_s, true); + sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); + access_n += problem_size.m(); + access_s += problem_size.m(); + } + + ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; + + access_n = access_n_bak; + access_s = access_s_bak; + + access_n[0] = convert_norm_output(max_val); + access_s[0] = convert_sum_output(inv_sum); + } + + }*/ +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 22d82e9d5d..7dab25b1dc 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -43,7 +43,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -static const int NumThreadsPerWarp = 32; +static constexpr int NumThreadsPerWarp = 32; static const int NumThreadsPerWarpGroup = 128; static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; From abafbabd2722f8dd1415cf3efe41c2bf67ce8d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 19 Nov 2024 13:06:27 +0100 Subject: [PATCH 04/14] partial cleanup --- .../35_gemm_softmax/gemm_online_softmax.cpp | 139 +++--------------- examples/35_gemm_softmax/gemm_softmax.cu | 10 +- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 2 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 57 +------ examples/35_gemm_softmax/softmax_finalize.hpp | 71 +-------- 5 files changed, 35 insertions(+), 244 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index eb2533947b..720de5d824 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -31,28 +31,11 @@ **************************************************************************************************/ /*! \file - \brief Simple GEMM example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture + \brief GEMM + Softmax example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture This example demonstrate how to instantiate and run a TF32 GEMM using the Cute and CUTLASS 3.x APIs on NVIDIA Ampere architecture. Please check example 07 and 08 for - the basics of tensor op gemm kernels. On NVIDIA Ampere architecture, most concept - still holds. The two main differences are: - - (1) NVIDIA Ampere architecture introduces a new series of tensor core instructions - (see include/cute/arch/mma_sm80.hpp) which are more efficient on Ampere. - (2) NVIDIA Ampere architecture uses CP_ASYNC (see include/cute/arch/copy_sm80.hpp) - to build a multistage software pipeline to better hide latency (see - include/cutlass/gemm/collective/sm80_mma_multistage.hpp). - - Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) - data types in tensor cores. One big advantage is that we can load in fp32 data and convert - them implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate - traditional fp32 data by using NVIDIA Ampere architecture. - - Examples: - - $ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm_cute - + the basics of tensor op gemm kernels. */ #include @@ -246,59 +229,15 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_max; - cutlass::DeviceAllocation block_sum; + cutlass::DeviceAllocation block_max; + cutlass::DeviceAllocation block_sum; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; // // Methods // - - /*bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { - auto [M, N, K, L] = problem_size; - - cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); - cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); - cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); - - cutlass::reference::device::GemmComplex( - {M, N, K}, - ElementCompute(alpha), - ref_A, - cutlass::ComplexTransform::kNone, - ref_B, - cutlass::ComplexTransform::kNone, - ElementCompute(beta), - ref_C, - ref_D, - ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); - -#if defined(CUTLASS_ENABLE_SYCL) - syclcompat::wait_and_throw(); -#else - cudaError_t result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - std::cerr << "Reference kernel failed. Last CUDA error: " - << cudaGetErrorString(result) << std::endl; - return false; - } -#endif - - // Check if output from CUTLASS kernel and reference kernel are equal or not - bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); - - return passed; - }*/ - - template + template bool verify_tensor(std::vector vector_Input, \ std::vector vector_Input_Ref, const Options& options) { @@ -362,7 +301,7 @@ struct ExampleRunner { ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - ElementCompute, ElementCompute + ElementCompute, ElementCompute, ElementD >( problem_size, options.alpha, @@ -388,42 +327,20 @@ struct ExampleRunner { std::vector matrix_D(layout_C.capacity(extent_C)); cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); - auto& matrix_Softmax = matrix_D; - //std::vector matrix_Softmax(layout_C.capacity(extent_C)); - //cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); - // Compute the norm for (int m = 0; m < options.m; ++m) { reference_N.at({m, 0}) = view_D_Ref.ref().at({m, 0}); - if(batch_idx == 0 && m < 3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ - std::cout << "ref tmp " << m << " " << 0 << ": " << view_D_Ref.ref().at({m, 0}) << std::endl; - } for (int n = 1; n < options.n; ++n) { - //std::cout << "val: " << view_D_Ref.ref().at({m, n}) << std::endl; reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); - - if(batch_idx == 0 && m < 3 && n<3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ - std::cout << "ref tmp " << m << " " << n << ": " << view_D_Ref.ref().at({m, n}) << std::endl; - } - if(batch_idx == 0 && m==0 && n==127 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ - std::cout << "ref max tmp " << m << " " << n << ": " << reference_N.at({m, 0}) << std::endl; - } - } - if(batch_idx == 0 && m == 0){ - std::cout << "ref max: " << reference_N.at({m, 0}) << std::endl; } } // Compute softmax for (int m = 0; m < options.m; ++m) { - float sum = float(); - + float sum = 0; for (int n = 0; n < options.n; ++n) { sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } - if(batch_idx == 0 && m == 0){ - std::cout << "ref sum: " << sum << std::endl; - } float inv_sum = float(1.0f / sum); @@ -434,35 +351,9 @@ struct ExampleRunner { } } - // Verification checks - set any of these to 'true' to override the verification checks. - bool verified_D = false; - bool verified_Softmax = false; - - // Verify softmax output - if (!verified_D) { - verified_D = verify_tensor(matrix_D, matrix_D_Ref, options); - } - + bool verified_Softmax = verify_tensor(matrix_D, matrix_Softmax_Ref, options); if (!verified_Softmax) { - verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref, options); - } - //TODO(Tadej): just softmax - if (!verified_D && !verified_Softmax) { - std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; - - // Summarize which checks failed - if (!verified_D) { - std::cerr << "Verification of D tensor failed\n"; - } else{ - std::cerr << "Verification of D tensor passed\n"; - } - - if (!verified_Softmax) { - std::cerr << "Verification of Softmax tensor failed\n"; - } else{ - std::cerr << "Verification of Softmax tensor passed\n"; - } - + std::cerr << "Verification of Softmax tensor failed\n"; return false; } } @@ -505,7 +396,11 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, block_max.get(), block_sum.get(), stride_tmp}, + {{options.alpha,//static_cast(options.alpha), + options.beta},//static_cast(options.beta)}, + block_C.get(), stride_C, + block_D.get(), stride_D, + block_max.get(), block_sum.get(), stride_tmp}, hw_info }; @@ -610,6 +505,10 @@ int main(int argc, char const **args) { hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); // Problem configuration + /*using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementAcc = float; + using ElementOutput = cutlass::half_t;*/ using ElementA = float; using ElementB = float; using ElementAcc = float; @@ -681,7 +580,7 @@ int main(int argc, char const **args) { // elements. This becomes the vector width of // math instructions in the epilogue too ElementAcc, // <- data type of accumulator - ElementOutput>; // <- data type for alpha/beta in linear combination function + ElementAcc>; // <- data type for alpha/beta in linear combination function using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< cutlass::detail::TagToStrideC_t, diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 27156ea02d..7e663679b6 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -201,19 +201,23 @@ struct Testbed { // - using ElementA = cutlass::half_t; + /*using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; + using ElementCompute = float;*/ + using ElementA = float; + using ElementB = float; + using ElementC = float; using ElementCompute = float; using ElementD = ElementC; using ElementSoftmax = ElementC; - using LayoutA = cutlass::layout::RowMajor; + using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 08fa798684..bdc4f6052a 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -91,7 +91,7 @@ class GemmSoftmaxAdapter using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< ElementD, typename GemmKernel::StrideD, - ElementD, typename GemmKernel::StrideD, + ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StrideTmp, ElementD, typename GemmKernel::StrideD>; // Map back to 2.x type as best as possible diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 96c902e373..f5f807a84d 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -298,7 +298,6 @@ class SoftmaxEpilogue { template< class ProblemShapeMNKL, - //class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, class TiledMma, @@ -391,18 +390,18 @@ class SoftmaxEpilogue { bool is_first = ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0; if(is_first){ - print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); + //print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); //print("blk_shape_MNK: "); print(blk_shape_MNK); print("\n"); //print("partial_block: "); print(partial_block); print("\n"); //print("thr_mma: "); print(thr_mma); print("\n"); //print("tiled_mma: "); print(tiled_mma); print("\n"); //print("acc: "); print(accumulators); print("\n"); //print("mD_mnl: "); print(mD_mnl); print("\n"); - print("mMax_mnl: "); print(mMax_mnl); print("\n"); + //print("mMax_mnl: "); print(mMax_mnl); print("\n"); //print("gD_mnl: "); print(gD_mnl); print("\n"); - print("gMax_mnl: "); print(gMax_mnl); print("\n"); + //print("gMax_mnl: "); print(gMax_mnl); print("\n"); //print("gD: "); print(gD); print("\n"); - print("gMax: "); print(gMax); print("\n"); + //print("gMax: "); print(gMax); print("\n"); //print("tCgD: "); print(tCgD); print("\n"); //print("sC: "); print(sC); print("\n"); //print("tCsC: "); print(tCsC); print("\n"); @@ -493,55 +492,9 @@ class SoftmaxEpilogue { } } if(is_first){ - print("sum epilogue val:"); print(sum); print("\n"); + //print("sum epilogue val:"); print(sum); print("\n"); } - gSum(thread_idx,0) = sum; - - - /*//reduce_max(accumulators, acc_max); - reduce_max_wg(accumulators, acc_max, tCsC, sC, tCcD, residue_mnk, thread_idx); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(accumulators); ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<1>(accumulators); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(accumulators); ++k) { - if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(k)); - } - } - } - } - - reduce_sum(accumulators, acc_sum); - if(wlid == 0){ - for (int k = 0; k < size<2>(accumulators); ++k) { - gSum(wid,k) = acc_sum(k); - } - }*/ - - //TODO write out reductions - - //second kernel: - // - finalize max reduction: mN = sum(mj) - // - finalize sum reduction: sN = sum(sj * exp(mj-mN)) - // - finalize softmax: yi = exp(xi-mN)/sN - - /*CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(accumulators); ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<1>(accumulators); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(accumulators); ++k) { - if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCgD(i,j,k) = accumulators(i,j,k); - } - } - } - }*/ - } private: diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index bc91f85dd9..72a7c93f8c 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -197,7 +197,7 @@ class SoftmaxFinalize { Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) if(m==0 && batch_id==0){ - print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); + //print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); } ElementPartial max_val = std::numeric_limits::min(); @@ -221,8 +221,8 @@ class SoftmaxFinalize { //sum_val = reduce_sum(sum_val); if(m==0 && batch_id==0){ - print("max_val: "); print(max_val); print("\n"); - print("sum_val: "); print(sum_val); print("\n"); + //print("max_val: "); print(max_val); print("\n"); + //print("sum_val: "); print(sum_val); print("\n"); } ElementPartial norm = 1 / sum_val; @@ -231,71 +231,6 @@ class SoftmaxFinalize { mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; } } - /* - // defining three vars for a general reduction module - cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; - int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; - int access_offset = isGroupedProblem ? 0 : bid * bdim; - - if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; - - ElementNorm *curr_ptr_Max = isGroupedProblem ? \ - params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ - params.args.block_Norm + block_batch * params.args.batch_stride_Max; - ElementSum *curr_ptr_Sum = isGroupedProblem ? \ - params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ - params.args.block_Sum + block_batch * params.args.batch_stride_Sum; - - int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; - - ConvertSum convert_sum; - ConvertNorm convert_norm; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - uint32_t float_max_bits = 0xff7fffff; - float min_float = reinterpret_cast(float_max_bits); - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { - ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; - ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; - ElementNorm *access_n_bak = access_n; - ElementSum *access_s_bak = access_s; - ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); - ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); - ElementNorm fetch_n; - ElementSum fetch_s; - - CUTLASS_PRAGMA_UNROLL - for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { - cutlass::arch::global_load(fetch_n, access_n, true); - max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); - access_n += problem_size.m(); - } - - access_n = access_n_bak; - - CUTLASS_PRAGMA_UNROLL - for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { - cutlass::arch::global_load(fetch_n, access_n, true); - cutlass::arch::global_load(fetch_s, access_s, true); - sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); - access_n += problem_size.m(); - access_s += problem_size.m(); - } - - ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; - - access_n = access_n_bak; - access_s = access_s_bak; - - access_n[0] = convert_norm_output(max_val); - access_s[0] = convert_sum_output(inv_sum); - } - - }*/ }; ///////////////////////////////////////////////////////////////////////////////////////////////// From 8d8c1bcbd7626e0e6521d634c1912e927486bcf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 25 Nov 2024 14:20:32 +0100 Subject: [PATCH 05/14] 2nd kernel work distribution fix --- .../35_gemm_softmax/gemm_online_softmax.cpp | 11 ++++- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 4 +- examples/35_gemm_softmax/softmax_finalize.hpp | 42 ++++++++++++++++--- 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 720de5d824..f8ebdc3221 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -250,8 +250,8 @@ struct ExampleRunner { float abs_diff = fabs(diff); float abs_ref = fabs((float)vector_Input_Ref.at(i)); float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; - if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { - printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + printf("i = %d diff = %f, {%f, %f}.\n", i, abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } @@ -333,6 +333,9 @@ struct ExampleRunner { for (int n = 1; n < options.n; ++n) { reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); } + /*if(m == 3516 && batch_idx == 0){ + std:: cout << "max0: " << reference_N.at({m, 0}) << std::endl; + }*/ } // Compute softmax @@ -341,6 +344,10 @@ struct ExampleRunner { for (int n = 0; n < options.n; ++n) { sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } + + /*if(m == 3516 && batch_idx == 0){ + std:: cout << "sum0: " << sum << std::endl; + }*/ float inv_sum = float(1.0f / sum); diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index bdc4f6052a..18c5b4ec15 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -448,12 +448,12 @@ class GemmSoftmaxAdapter #endif //EventManager::getInstance().addEvent(event); - const auto sycl_block2 = syclcompat::dim3(128, 1, 1); + const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.IOSize[0]), 1); const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), params.softmax_params.args.batch_count, 1); auto event2 = launch>(launch_policy{ - sycl_grid2, sycl_block2, local_mem_size{0}}, + sycl_grid2, sycl_block2, local_mem_size{SoftmaxFinalizeKernel::SharedStorageSize}}, params.softmax_params); EventManager::getInstance().addEvent(event2); #else diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 72a7c93f8c..ca6c3c458c 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -109,10 +109,11 @@ class SoftmaxFinalize { }; struct SharedStorage { - - + cute::array_aligned s_mem; }; + static constexpr int SharedStorageSize = sizeof(SharedStorage); + // // Params struct // @@ -180,7 +181,10 @@ class SoftmaxFinalize { int batch_id = m_batch_id / params.args.IOSize[1]; int m = m_batch_id % params.args.IOSize[1];*/ - int m = ThreadIdxX() + BlockDimX() * BlockIdxX(); + int x = ThreadIdxX(); + int m = x + BlockDimX() * BlockIdxX(); + int y = ThreadIdxY(); + int y_size = BlockDimY(); int batch_id = BlockIdxY(); if(m>=params.args.IOSize[0]){ @@ -196,28 +200,54 @@ class SoftmaxFinalize { Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); // (m,n,l) Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) + //Represent the shared tensor + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32, 2))); + if(m==0 && batch_id==0){ //print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); } ElementPartial max_val = std::numeric_limits::min(); - for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); /*if(m==0 && batch_id==0){ print("partial_max: "); print(partial_max); print("\n"); }*/ max_val = max_val > partial_max ? max_val : partial_max; } + sPartial(x,y,0) = max_val; + syncthreads(); + //TODO(Tadej): improve reduction + for(int y2 = 0; y2 < y_size; y2++){ + ElementPartial partial_max = sPartial(x,y2,0); + max_val = max_val > partial_max ? max_val : partial_max; + } + /*if(m == 3516 && y == 0){ + print("kernel max"); print(max_val); print("\n"); + }*/ //max_val = reduce_max(max_val); //mOut(0,0,0) = max_val; return; ElementPartial sum_val = 0; - for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val = sum_val + partial_sum * cutlass::fast_exp(partial_max - max_val); } + syncthreads(); + sPartial(x,y,1) = sum_val; + syncthreads(); + sum_val = 0; + //TODO(Tadej): improve reduction + for(int y2 = 0; y2 < y_size; y2++){ + ElementPartial partial_max = sPartial(x,y2,0); + ElementPartial partial_sum = sPartial(x,y2,1); + sum_val = sum_val + partial_sum /** cutlass::fast_exp(partial_max - max_val)*/; + } + /*if(m == 3516 && y == 0){ + print("kernel sum"); print(sum_val); print("\n"); + }*/ //sum_val = reduce_sum(sum_val); if(m==0 && batch_id==0){ @@ -227,7 +257,7 @@ class SoftmaxFinalize { ElementPartial norm = 1 / sum_val; - for(int n = 0; n < params.args.IOSize[1]; n += 1){ + for(int n = y; n < params.args.IOSize[1]; n += y_size){ mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; } } From 1ad0f8ec9f0c5f8e16ba7d0dfc1348953e7d0816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 2 Dec 2024 10:29:14 +0100 Subject: [PATCH 06/14] native exp + optimizations --- .../35_gemm_softmax/gemm_online_softmax.cpp | 2 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 19 +------------------ examples/35_gemm_softmax/softmax_finalize.hpp | 19 +++++++++++++++++-- include/cutlass/fast_math.h | 2 ++ 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index f8ebdc3221..cdafc65a50 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -430,7 +430,7 @@ struct ExampleRunner { std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; if (!result.passed) { - exit(-1); + //exit(-1); } // Run profiling loop diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index f5f807a84d..0e0505bf67 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -424,9 +424,6 @@ class SoftmaxEpilogue { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); tCsC(i,j,k) = accumulators(i,j,k); - /*if(is_first){ - print("acc1.1:"); print(tCsC(i,j,k)); print("\n"); - }*/ } } } @@ -442,10 +439,6 @@ class SoftmaxEpilogue { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); tCsC(i,j,k) = accumulators(i,j,k); - /*if(is_first){ - print("acc1.2:"); print(accumulators(i,j,k)); print(".\n"); - print("idx:"); print(tCsC.layout()(i,j,k)); print(".\n"); - }*/ } } } @@ -461,19 +454,8 @@ class SoftmaxEpilogue { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i) = sC(thread_idx, i); max = cutlass::fast_max(max, accumulators(i)); - /*if(is_first && i < 3){ - print("acc2 :"); print(accumulators(i)); print("\n"); - //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); - for (int j = 0; j < 3; ++j) { - print("shared :"); print(j); print(" "); print(i); print(": "); print(sC(j, i)); print("\n"); - } - }*/ } } - /*if(m_coord == 0 && n_coord == 1 && ThreadIdxX()==0){ - print("max epilogue val:"); print(max); print("\n"); - print("idx:"); print(n_coord); print("\n"); - }*/ gMax(thread_idx,0) = max; @@ -482,6 +464,7 @@ class SoftmaxEpilogue { for (int i = 0; i < size<0>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { sum += cutlass::fast_exp(accumulators(i) - max); + //sum += sycl::native::exp(accumulators(i) - max); if(is_first){ //print("acc3 :"); print(accumulators(i)); print("\n"); //print("diff :"); print(accumulators(i) - max); print("\n"); diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index ca6c3c458c..56a88e8acd 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -257,8 +257,23 @@ class SoftmaxFinalize { ElementPartial norm = 1 / sum_val; - for(int n = y; n < params.args.IOSize[1]; n += y_size){ - mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; + int unroll = 2; + //_Pragma("unroll 2") + //for(int n = y; n < params.args.IOSize[1]; n += y_size){ + for(int n = y * unroll; n < params.args.IOSize[1]; n += y_size * unroll){ + auto inVal = mIn(m, n, batch_id); + auto inVal2 = mIn(m, n+1, batch_id); + //auto inVal3 = mIn(m, n+2, batch_id); + //auto inVal4 = mIn(m, n+3, batch_id); + mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; + mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; + //mOut(m, n+2, batch_id) = cutlass::fast_exp(inVal3 - max_val) * norm; + //mOut(m, n+3, batch_id) = cutlass::fast_exp(inVal4 - max_val) * norm; + } + if(params.args.IOSize[1]%2==1){ + int n = params.args.IOSize[1] - 1; + auto inVal = mIn(m, n, batch_id); + mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; } } }; diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index fa3873c5e7..c856afa76e 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -859,6 +859,8 @@ CUTLASS_HOST_DEVICE float fast_exp(float x) { #if defined(__CUDA_ARCH__) return ::expf(x); + #elif defined(__SYCL_CUDA_ARCH__) + return ::sycl::native::exp(x); #else return std::exp(x); #endif From cf34fdad589ec19f670709704f8a80ead58e5a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 2 Dec 2024 12:51:58 +0100 Subject: [PATCH 07/14] basic cleanup --- .../35_gemm_softmax/gemm_online_softmax.cpp | 18 +---- examples/35_gemm_softmax/gemm_softmax.cu | 5 -- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 2 - examples/35_gemm_softmax/softmax_epilogue.hpp | 53 +------------- examples/35_gemm_softmax/softmax_finalize.hpp | 69 +------------------ 5 files changed, 6 insertions(+), 141 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index cdafc65a50..67dff1460a 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -333,9 +333,6 @@ struct ExampleRunner { for (int n = 1; n < options.n; ++n) { reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); } - /*if(m == 3516 && batch_idx == 0){ - std:: cout << "max0: " << reference_N.at({m, 0}) << std::endl; - }*/ } // Compute softmax @@ -344,11 +341,6 @@ struct ExampleRunner { for (int n = 0; n < options.n; ++n) { sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } - - /*if(m == 3516 && batch_idx == 0){ - std:: cout << "sum0: " << sum << std::endl; - }*/ - float inv_sum = float(1.0f / sum); for (int n = 0; n < options.n; ++n) { @@ -403,8 +395,8 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha,//static_cast(options.alpha), - options.beta},//static_cast(options.beta)}, + {{options.alpha, + options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, block_max.get(), block_sum.get(), stride_tmp}, @@ -430,7 +422,7 @@ struct ExampleRunner { std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; if (!result.passed) { - //exit(-1); + exit(-1); } // Run profiling loop @@ -512,10 +504,6 @@ int main(int argc, char const **args) { hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); // Problem configuration - /*using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementAcc = float; - using ElementOutput = cutlass::half_t;*/ using ElementA = float; using ElementB = float; using ElementAcc = float; diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 7e663679b6..47673501bb 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -200,11 +200,6 @@ struct Testbed { // Type definitions // - - /*using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementC = cutlass::half_t; - using ElementCompute = float;*/ using ElementA = float; using ElementB = float; using ElementC = float; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 18c5b4ec15..c56e509af0 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -446,8 +446,6 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - //EventManager::getInstance().addEvent(event); - const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.IOSize[0]), 1); const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), params.softmax_params.args.batch_count, diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 0e0505bf67..2582784d29 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -322,9 +322,6 @@ class SoftmaxEpilogue { static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - //auto wlid = thread_idx % NumWarpsPerWarpGroup; // warp local id - //auto wid = thread_idx / NumWarpsPerWarpGroup; // warp id in tile - // Separate out problem and tile shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); @@ -365,11 +362,8 @@ class SoftmaxEpilogue { auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) - //Tensor tCgMax = thr_mma.partition_C(gMax); // (VEC,THR_M,THR_N) - //Tensor tCgSum = thr_mma.partition_C(gSum); // (VEC,THR_M,THR_N) Tensor tCsC = thr_mma.partition_C(sC); // (VEC,THR_M,THR_N) - static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), "Source and destination must have the same number of elements."); @@ -380,39 +374,6 @@ class SoftmaxEpilogue { auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); - //Tensor acc_max = make_tensor(Shape(accumulators)>, Int(accumulators)>>{}); - //Tensor acc_max = make_tensor(size<0>(accumulators)); - //Tensor acc_max = make_tensor_like(take<2,3>(accumulators)); - //Tensor acc_sum = make_tensor_like(acc_max); //TODO can reuse prev? - - //Tensor acc_max = make_tensor(shape<2>(accumulators), LayoutLeft{}); - //Tensor acc_sum = make_tensor(shape<2>(accumulators), LayoutLeft{}); //TODO can reuse prev? - - bool is_first = ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0; - if(is_first){ - //print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); - //print("blk_shape_MNK: "); print(blk_shape_MNK); print("\n"); - //print("partial_block: "); print(partial_block); print("\n"); - //print("thr_mma: "); print(thr_mma); print("\n"); - //print("tiled_mma: "); print(tiled_mma); print("\n"); - //print("acc: "); print(accumulators); print("\n"); - //print("mD_mnl: "); print(mD_mnl); print("\n"); - //print("mMax_mnl: "); print(mMax_mnl); print("\n"); - //print("gD_mnl: "); print(gD_mnl); print("\n"); - //print("gMax_mnl: "); print(gMax_mnl); print("\n"); - //print("gD: "); print(gD); print("\n"); - //print("gMax: "); print(gMax); print("\n"); - //print("tCgD: "); print(tCgD); print("\n"); - //print("sC: "); print(sC); print("\n"); - //print("tCsC: "); print(tCsC); print("\n"); - //print("sC.data: "); print(&sC(0)); print("\n"); - //print("tCsC.data: "); print(&tCsC(0)); print("\n"); - //decltype(tCsC(0)) a = "asd"; - //print("tCgMax: "); print(tCgMax); print("\n"); - //print("acc_max: "); print(acc_max); print("\n"); - //print("accumulators: "); print(accumulators); print("\n"); - } - if(is_source_needed()){ CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(accumulators); ++i) { @@ -447,7 +408,7 @@ class SoftmaxEpilogue { syncthreads(); - // assumption size<0>(sC) == wg size + // assumption: size<0>(sC) == wg size ElementAccumulator max = std::numeric_limits::min(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(sC); ++i) { @@ -456,7 +417,6 @@ class SoftmaxEpilogue { max = cutlass::fast_max(max, accumulators(i)); } } - gMax(thread_idx,0) = max; ElementAccumulator sum = 0; @@ -464,19 +424,8 @@ class SoftmaxEpilogue { for (int i = 0; i < size<0>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { sum += cutlass::fast_exp(accumulators(i) - max); - //sum += sycl::native::exp(accumulators(i) - max); - if(is_first){ - //print("acc3 :"); print(accumulators(i)); print("\n"); - //print("diff :"); print(accumulators(i) - max); print("\n"); - //print("add :"); print(cutlass::fast_exp(accumulators(i) - max)); print("\n"); - //print("sum :"); print(sum); print("\n"); - //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); - } } } - if(is_first){ - //print("sum epilogue val:"); print(sum); print("\n"); - } gSum(thread_idx,0) = sum; } diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 56a88e8acd..ceb4de1813 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -83,29 +83,6 @@ class SoftmaxFinalize { ElementPartial* ptr_partial_max; ElementPartial* ptr_partial_sum; ElementOutput* ptr_out; - -/* - // - // Methods - // - Arguments() { } - - Arguments( - cutlass::gemm::GemmCoord problem_size, - ElementNorm* block_Norm, - ElementSum* block_Sum - ): - problem_size(problem_size), - block_Norm(block_Norm), - block_Sum(block_Sum), - problem_sizes(nullptr), - offset_Norm_Device(nullptr), - offset_Sum_Device(nullptr), - batch_stride_Max(0), - batch_stride_Sum(0) - { - - }*/ }; struct SharedStorage { @@ -168,19 +145,6 @@ class SoftmaxFinalize { using ConvertInput = cutlass::NumericConverter; using ConvertNormOutput = cutlass::NumericConverter; - /*int tid = ThreadIdxX(); - //int bid = BlockIdxX(); - int bdim = BlockDimX(); - - int warps_in_block = bdim / NumThreadsPerWarp; - - int wlid = tid % warps_in_block; // local id of thread in warp - int gid = tid;// + bid * GridDimX(); - int bsize = GridDimX(); - int m_batch_id = BlockIdxY(); - int batch_id = m_batch_id / params.args.IOSize[1]; - int m = m_batch_id % params.args.IOSize[1];*/ - int x = ThreadIdxX(); int m = x + BlockDimX() * BlockIdxX(); int y = ThreadIdxY(); @@ -203,16 +167,9 @@ class SoftmaxFinalize { //Represent the shared tensor Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32, 2))); - if(m==0 && batch_id==0){ - //print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); - } - ElementPartial max_val = std::numeric_limits::min(); for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); - /*if(m==0 && batch_id==0){ - print("partial_max: "); print(partial_max); print("\n"); - }*/ max_val = max_val > partial_max ? max_val : partial_max; } sPartial(x,y,0) = max_val; @@ -222,12 +179,6 @@ class SoftmaxFinalize { ElementPartial partial_max = sPartial(x,y2,0); max_val = max_val > partial_max ? max_val : partial_max; } - /*if(m == 3516 && y == 0){ - print("kernel max"); print(max_val); print("\n"); - }*/ - //max_val = reduce_max(max_val); - - //mOut(0,0,0) = max_val; return; ElementPartial sum_val = 0; for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ @@ -243,32 +194,16 @@ class SoftmaxFinalize { for(int y2 = 0; y2 < y_size; y2++){ ElementPartial partial_max = sPartial(x,y2,0); ElementPartial partial_sum = sPartial(x,y2,1); - sum_val = sum_val + partial_sum /** cutlass::fast_exp(partial_max - max_val)*/; - } - /*if(m == 3516 && y == 0){ - print("kernel sum"); print(sum_val); print("\n"); - }*/ - //sum_val = reduce_sum(sum_val); - - if(m==0 && batch_id==0){ - //print("max_val: "); print(max_val); print("\n"); - //print("sum_val: "); print(sum_val); print("\n"); + sum_val = sum_val + partial_sum; } ElementPartial norm = 1 / sum_val; - int unroll = 2; - //_Pragma("unroll 2") - //for(int n = y; n < params.args.IOSize[1]; n += y_size){ - for(int n = y * unroll; n < params.args.IOSize[1]; n += y_size * unroll){ + for(int n = y * 2; n < params.args.IOSize[1]; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); - //auto inVal3 = mIn(m, n+2, batch_id); - //auto inVal4 = mIn(m, n+3, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; - //mOut(m, n+2, batch_id) = cutlass::fast_exp(inVal3 - max_val) * norm; - //mOut(m, n+3, batch_id) = cutlass::fast_exp(inVal4 - max_val) * norm; } if(params.args.IOSize[1]%2==1){ int n = params.args.IOSize[1] - 1; From 84ce5c11d2e61c30b66b8ddf870e7ded038a71e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 3 Dec 2024 11:52:04 +0100 Subject: [PATCH 08/14] final cleanup --- examples/35_gemm_softmax/CMakeLists.txt | 2 +- examples/35_gemm_softmax/gemm_softmax.cu | 11 +- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 43 ++--- examples/35_gemm_softmax/softmax_epilogue.hpp | 164 ++---------------- examples/35_gemm_softmax/softmax_finalize.hpp | 82 ++++----- include/cutlass/fast_math.h | 1 + 6 files changed, 66 insertions(+), 237 deletions(-) diff --git a/examples/35_gemm_softmax/CMakeLists.txt b/examples/35_gemm_softmax/CMakeLists.txt index d7f2cd574b..824453a656 100644 --- a/examples/35_gemm_softmax/CMakeLists.txt +++ b/examples/35_gemm_softmax/CMakeLists.txt @@ -39,4 +39,4 @@ cutlass_example_add_executable( 35_gemm_online_softmax gemm_online_softmax.cpp ) -endif() \ No newline at end of file +endif() diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 47673501bb..27156ea02d 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -200,19 +200,20 @@ struct Testbed { // Type definitions // - using ElementA = float; - using ElementB = float; - using ElementC = float; + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; using ElementCompute = float; using ElementD = ElementC; using ElementSoftmax = ElementC; - using LayoutA = cutlass::layout::ColumnMajor; + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index c56e509af0..0d31e19457 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -49,12 +50,6 @@ #include "cutlass/trace.h" #endif // !defined(__CUDACC_RTC__) -// 2.x -//#include "cutlass/gemm/device/gemm_universal_base.h" -//#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -//#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -//#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" - // 3.x #include "cutlass/gemm/kernel/gemm_universal.hpp" @@ -171,8 +166,6 @@ class GemmSoftmaxAdapter /// Argument structure: User API using Arguments = typename GemmKernel::Arguments; - /// Argument structure: Kernel API - //using Params = typename GemmKernel::Params; struct Params{ typename GemmKernel::Params gemm_params; @@ -271,6 +264,20 @@ class GemmSoftmaxAdapter return max_active_blocks; } + void initialize_softmax_params(Arguments const& args, typename SoftmaxFinalizeKernel::Arguments& softmax_args){ + softmax_args.M = get<0>(args.problem_shape); + softmax_args.dataN = get<1>(args.problem_shape); + softmax_args.partialN = cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{})); + softmax_args.batch_count = get<3>(args.problem_shape); + softmax_args.dInput = args.epilogue.dD; + softmax_args.dPartial = args.epilogue.dTmp; + softmax_args.dOutput = args.epilogue.dD; + softmax_args.ptr_in = args.epilogue.ptr_D; + softmax_args.ptr_partial_max = args.epilogue.ptr_max; + softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; + softmax_args.ptr_out = args.epilogue.ptr_D; + } + /// Initializes GEMM state from arguments. Status initialize( @@ -289,19 +296,7 @@ class GemmSoftmaxAdapter } // Initialize the Params structure params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); - //TODO(Tadej) move to finalize kernel class? - auto& softmax_args = params_.softmax_params.args; - softmax_args.IOSize = {get<0>(args.problem_shape), get<1>(args.problem_shape)}; - softmax_args.partialSize = {get<0>(args.problem_shape), - cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{}))}; - softmax_args.batch_count = get<3>(args.problem_shape); - softmax_args.dInput = args.epilogue.dD; - softmax_args.dPartial = args.epilogue.dTmp; - softmax_args.dOutput = args.epilogue.dD; - softmax_args.ptr_in = args.epilogue.ptr_D; - softmax_args.ptr_partial_max = args.epilogue.ptr_max; - softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; - softmax_args.ptr_out = args.epilogue.ptr_D; + initialize_softmax_params(args, params_.softmax_params.args); // Don't set the function attributes - require the CudaHostAdapter to set it. if constexpr (kEnableCudaHostAdapter) { @@ -345,7 +340,7 @@ class GemmSoftmaxAdapter } params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); - //TODO(Tadej) update softmax args + initialize_softmax_params(args, params_.softmax_params.args); return Status::kSuccess; } @@ -446,8 +441,8 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.IOSize[0]), 1); - const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), + const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.M), 1); + const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x), params.softmax_params.args.batch_count, 1); auto event2 = launch>(launch_policy{ diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 2582784d29..093a6d70d1 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,9 +29,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ #pragma once @@ -90,7 +88,6 @@ class SoftmaxEpilogue { static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { - //cute::array_aligned(BlockShapeMNK{})>,C(BlockShapeMNK{})>>>>> smem_c; cute::array_aligned(BlockShapeMNK{}) * get<1>(BlockShapeMNK{})> smem_c; }; @@ -155,147 +152,6 @@ class SoftmaxEpilogue { return epilogue_op.is_source_needed(); } - template < - bool zero_init, - class FragSrc, - class FragDst, - class Op - > - CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { - // reduce across all the -N- M tiles in shape - CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < size<2>(src); z++) { - dst(z) = zero_init ? src(0, 0, z) : op(dst(z), src(0, 0, z)); - CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < size<0>(src); x++) { - CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < size<1>(src); y++) { - dst(z) = op(dst(z), src(x, y, z)); - } - } - } - - // reduce across the sub_group to get the final output - auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < size<2>(src); z++) { - CUTLASS_PRAGMA_UNROLL - for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { - dst(z) = op(dst(z), syclcompat::permute_sub_group_by_xor(sg, dst(z), laneMask, 16)); - } - } - } - - template < - class FragSrc, - class FragDst, - class SharedThreadTens, - class SharedTens, - class ResidueMap, - class Residue, - class Op - > - CUTLASS_DEVICE static ElementAccumulator reduceWg(FragSrc const &src, FragDst &dst, - SharedThreadTens& tCsC, SharedTens& sC, - ResidueMap tCcD, Residue residue_mnk, int thread_idx, - ElementAccumulator init, Op op) { - //TODO(Tadej): single loop over all dims - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(src); ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<1>(src); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(src); ++k) { - if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCsC(i,j,k) = src(i,j,k); - } else{ - tCsC(i,j,k) = init; - } - } - } - } - - syncthreads(); - - ElementAccumulator acc = sC(0, thread_idx); - for (int i = 1; i < size(src); ++i) { - acc = op(acc, sC(i, thread_idx)); - } - - syncthreads(); - - //broadcast it back to threads - //TODO(Tadej): optimize - for (int i = 0; i < size(src); ++i) { - sC(i, thread_idx) = acc; - } - - syncthreads(); - - CUTLASS_PRAGMA_UNROLL - for(int k = 1; k < size<2>(src); k++) { - dst(k) = tCsC(0,0,k); - } - - return acc; - - /*reduceSg(src, dst, op); - for(int i=ThreadIdxX() % NumThreadsPerWarp; i - CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { - reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); - } - - template < - class FragSrc, - class FragDst, - class SharedThreadTens, - class SharedTens, - class ResidueMap, - class Residue - > - CUTLASS_DEVICE static ElementAccumulator reduce_max_wg(FragSrc const &src, FragDst &dst, - SharedThreadTens& tCsC, SharedTens& sC, - ResidueMap tCcD, Residue residue_mnk, int thread_idx) { - - return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, - std::numeric_limits::min(), - [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); - } - - template < - bool zero_init, - class FragSrc, - class FragSum - > - CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { - reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); - } - - template < - class FragSrc, - class FragDst, - class SharedThreadTens, - class SharedTens, - class Residue, - class ResidueMap - > - CUTLASS_DEVICE static ElementAccumulator reduce_sum_wg(FragSrc const &src, FragDst &dst, - SharedThreadTens& tCsC, SharedTens& sC, - ResidueMap tCcD, Residue residue_mnk, int thread_idx) { - - return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, - 0, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x+y; }); - } - template< class ProblemShapeMNKL, class BlockCoordMNKL, @@ -333,7 +189,7 @@ class SoftmaxEpilogue { auto N_tmp = cute::ceil_div(N, N_tile); - cute::packed_tuple partial_block(M_tile, C<1>(), K_tile); + cute::packed_tuple partial_block(M_tile, K_tile); auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); @@ -341,19 +197,19 @@ class SoftmaxEpilogue { // Represent the full output tensors Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) - Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); + Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gMax = gMax_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gSum = gSum_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gMax = gMax_mnl(_,m_coord,n_coord,l_coord); + Tensor gSum = gSum_mnl(_,m_coord,n_coord,l_coord); //Represent the shared tensor Tensor sC = make_tensor(make_smem_ptr(reinterpret_cast(smem_buf)), make_layout(make_shape(M_tile, N_tile))); @@ -417,7 +273,7 @@ class SoftmaxEpilogue { max = cutlass::fast_max(max, accumulators(i)); } } - gMax(thread_idx,0) = max; + gMax(thread_idx) = max; ElementAccumulator sum = 0; CUTLASS_PRAGMA_UNROLL @@ -426,7 +282,7 @@ class SoftmaxEpilogue { sum += cutlass::fast_exp(accumulators(i) - max); } } - gSum(thread_idx,0) = sum; + gSum(thread_idx) = sum; } private: diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index ceb4de1813..a9e4b41dae 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -72,9 +72,9 @@ class SoftmaxFinalize { // struct Arguments { - //TODO(Tadej): duplicated part of sizes - cutlass::MatrixCoord IOSize; ///< Extent of input and output matrices - cutlass::MatrixCoord partialSize; ///< Extent of partial max and sum matrices + int M; + int dataN; + int partialN; int batch_count; ///< Batch count StrideInput dInput; StridePartial dPartial; @@ -106,8 +106,6 @@ class SoftmaxFinalize { Params(Arguments const &args_): args(args_) { } }; -private: - public: CUTLASS_DEVICE @@ -115,31 +113,11 @@ class SoftmaxFinalize { CUTLASS_DEVICE void operator()(Params const ¶ms, char* shared_storage) { - apply(params, shared_storage); } private: - template - CUTLASS_DEVICE static ElementPartial reduceSg(ElementPartial val, Op op) { - auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - CUTLASS_PRAGMA_UNROLL - for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { - val = op(val, syclcompat::permute_sub_group_by_xor(sg, val, laneMask, 16)); - } - return val; - } - - CUTLASS_DEVICE static ElementPartial reduce_max(ElementPartial val) { - return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x > y ? x : y; }); - } - - CUTLASS_DEVICE static ElementPartial reduce_sum(ElementPartial val) { - return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x + y; }); - } - - /// Full reduction CUTLASS_DEVICE void apply(Params const ¶ms, char* shared_storage) { using ConvertInput = cutlass::NumericConverter; @@ -151,62 +129,60 @@ class SoftmaxFinalize { int y_size = BlockDimY(); int batch_id = BlockIdxY(); - if(m>=params.args.IOSize[0]){ + if(m>=params.args.M){ return; } - // Represent the full tensors - auto IOTensorShape = make_shape(params.args.IOSize[0], params.args.IOSize[1], params.args.batch_count); - auto PartialTensorShape = make_shape(params.args.partialSize[0], params.args.partialSize[1], params.args.batch_count); - Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); // (m,n,l) - Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); // (m,n,l) - Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); // (m,n,l) - Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) + auto IOTensorShape = make_shape(params.args.M, params.args.dataN, params.args.batch_count); + auto PartialTensorShape = make_shape(params.args.M, params.args.partialN, params.args.batch_count); + Tensor mMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); + Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); + Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); + Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); //Represent the shared tensor - Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32, 2))); + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32))); ElementPartial max_val = std::numeric_limits::min(); - for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ - ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); - max_val = max_val > partial_max ? max_val : partial_max; + for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mMax(m, partial_n, batch_id); + max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(x,y,0) = max_val; + sPartial(x,y) = max_val; syncthreads(); - //TODO(Tadej): improve reduction + //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_max = sPartial(x,y2,0); - max_val = max_val > partial_max ? max_val : partial_max; + ElementPartial partial_max = sPartial(x,y2); + max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ - ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); - sum_val = sum_val + partial_sum * cutlass::fast_exp(partial_max - max_val); + sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(x,y,1) = sum_val; + sPartial(x,y) = sum_val; syncthreads(); sum_val = 0; - //TODO(Tadej): improve reduction + //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_max = sPartial(x,y2,0); - ElementPartial partial_sum = sPartial(x,y2,1); - sum_val = sum_val + partial_sum; + ElementPartial partial_sum = sPartial(x,y2); + sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = y * 2; n < params.args.IOSize[1]; n += y_size * 2){ + for(int n = y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; } - if(params.args.IOSize[1]%2==1){ - int n = params.args.IOSize[1] - 1; + if(params.args.dataN % 2 == 1){ + int n = params.args.dataN - 1; auto inVal = mIn(m, n, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; } diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index c856afa76e..59a6649d76 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without From fcc7d2f78a78922b6cb43ce7d52c7a65a40a43fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 4 Dec 2024 13:28:42 +0000 Subject: [PATCH 09/14] addressed most review comments --- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 5 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 4 +- examples/35_gemm_softmax/softmax_finalize.hpp | 67 ++++++++++--------- include/cutlass/gpu_generics.h | 11 +-- 4 files changed, 46 insertions(+), 41 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 0d31e19457..3b3e97fead 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -441,7 +441,10 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.M), 1); + const auto sycl_block2 = syclcompat::dim3(NumThreadsPerWarp, + std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, + params.softmax_params.args.M), + 1); const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x), params.softmax_params.args.batch_count, 1); diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 093a6d70d1..0c35f50458 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -168,7 +168,7 @@ class SoftmaxEpilogue { TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, - [[maybe_unused]] char* smem_buf) + char* smem_buf) { using namespace cute; using X = Underscore; @@ -265,7 +265,7 @@ class SoftmaxEpilogue { syncthreads(); // assumption: size<0>(sC) == wg size - ElementAccumulator max = std::numeric_limits::min(); + ElementAccumulator max = std::numeric_limits::lowest(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index a9e4b41dae..30ecd90fef 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -72,21 +72,21 @@ class SoftmaxFinalize { // struct Arguments { - int M; - int dataN; - int partialN; - int batch_count; ///< Batch count - StrideInput dInput; - StridePartial dPartial; - StrideOutput dOutput; - ElementInput* ptr_in; - ElementPartial* ptr_partial_max; - ElementPartial* ptr_partial_sum; - ElementOutput* ptr_out; + int M; // dimension M of input, output and partially reduced tensors + int dataN; // dimension N of the input and output + int partialN; // dimension N of the partially reduced tensors + int batch_count; // batch count + StrideInput dInput; // stride of the input + StridePartial dPartial; // stride of the partially reduced tensors + StrideOutput dOutput; // stride of the output + ElementInput* ptr_in; // pointer to start of input data + ElementPartial* ptr_partial_max; // pointer to start of partially reduced max data + ElementPartial* ptr_partial_sum; // pointer to start of partially reduced sum data + ElementOutput* ptr_out; // pointer to start of output data }; struct SharedStorage { - cute::array_aligned s_mem; + cute::array_aligned s_mem; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -123,11 +123,11 @@ class SoftmaxFinalize { using ConvertInput = cutlass::NumericConverter; using ConvertNormOutput = cutlass::NumericConverter; - int x = ThreadIdxX(); - int m = x + BlockDimX() * BlockIdxX(); - int y = ThreadIdxY(); - int y_size = BlockDimY(); - int batch_id = BlockIdxY(); + const int idx_x = ThreadIdxX(); + const int m = idx_x + BlockDimX() * BlockIdxX(); + const int idx_y = ThreadIdxY(); + const int y_size = BlockDimY(); + const int batch_id = BlockIdxY(); if(m>=params.args.M){ return; @@ -136,46 +136,47 @@ class SoftmaxFinalize { // Represent the full tensors auto IOTensorShape = make_shape(params.args.M, params.args.dataN, params.args.batch_count); auto PartialTensorShape = make_shape(params.args.M, params.args.partialN, params.args.batch_count); - Tensor mMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); + Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); //Represent the shared tensor - Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32))); + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), + make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp))); - ElementPartial max_val = std::numeric_limits::min(); - for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ - ElementPartial partial_max = mMax(m, partial_n, batch_id); + ElementPartial max_val = std::numeric_limits::lowest(); + for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(x,y) = max_val; + sPartial(idx_x,idx_y) = max_val; syncthreads(); - //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck - for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_max = sPartial(x,y2); + // tree-reduction could be better, although it does not seem to be a bottleneck + for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_max = sPartial(idx_x,idx_y2); max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ - ElementPartial partial_max = mMax(m, partial_n, batch_id); + for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(x,y) = sum_val; + sPartial(idx_x,idx_y) = sum_val; syncthreads(); sum_val = 0; - //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck - for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_sum = sPartial(x,y2); + // tree-reduction could be better, although it does not seem to be a bottleneck + for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_sum = sPartial(idx_x,idx_y2); sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = y * 2; n < params.args.dataN; n += y_size * 2){ + for(int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 5d46e4c370..037087e197 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -44,11 +44,12 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int NumThreadsPerWarp = 32; -static const int NumThreadsPerWarpGroup = 128; -static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; -static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -static const int NumThreadsPerQuad = 4; -static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static constexpr int NumThreadsPerWarpGroup = 128; +static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static constexpr int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static constexpr int NumThreadsPerQuad = 4; +static constexpr int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static constexpr int MaxNumThreadsPerBlock = 1024; //////////////////////////////////////////////////////////////////////////////////////////////////// From d5a77e302b84e8fefb130e50a1fe46be0f82ccbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 4 Dec 2024 13:58:18 +0000 Subject: [PATCH 10/14] addressed remaining review comments --- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 3b3e97fead..c367d8387f 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -354,9 +354,17 @@ class GemmSoftmaxAdapter CUTLASS_TRACE_HOST("GemmUniversal::run()"); dim3 const block = GemmKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); + dim3 const block_finalize = syclcompat::dim3(NumThreadsPerWarp, + std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, + params.softmax_params.args.M), + 1); + dim3 const grid_finalize = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, block_finalize.x), + params.softmax_params.args.batch_count, + 1); // configure smem size and carveout int smem_size = GemmKernel::SharedStorageSize; + int smem_size_finalize = SoftmaxFinalizeKernel::SharedStorageSize; Status launch_result{ Status::kSuccess }; // Use extended launch API only for mainloops that use it @@ -367,7 +375,9 @@ class GemmSoftmaxAdapter dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - void* kernel_params[] = {¶ms}; + dim3 cluster_finalize(1,1,1); + void* kernel_params[] = {¶ms.gemm_params}; + void* kernel_params_finalize[] = {¶ms.softmax_params}; if constexpr (kEnableCudaHostAdapter) { // @@ -388,6 +398,13 @@ class GemmSoftmaxAdapter stream, kernel_params, 0); + launch_result = cuda_adapter->launch(grid_finalize, + cluster_finalize, + block_finalize, + smem_size_finalize, + stream, + kernel_params_finalize, + 0); } else { return Status::kErrorInternal; @@ -396,13 +413,17 @@ class GemmSoftmaxAdapter else { CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; + void const* kernel_finalize = (void const*) device_kernel; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { if (is_static_1x1x1 && not launch_with_pdl) { - device_kernel<<>>(params); + device_kernel<<>>(params.gemm_params); + device_kernel<<>>(params.softmax_params); } else { launch_result = ClusterLauncher::launch( grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + launch_result = ClusterLauncher::launch( + grid_finalize, cluster_finalize, block_finalize, smem_size_finalize, stream, kernel_finalize, kernel_params, launch_with_pdl); } } } @@ -414,10 +435,14 @@ class GemmSoftmaxAdapter CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { void* kernel_params[] = {¶ms.gemm_params}; + void* kernel_params_finalize[] = {¶ms.softmax_params}; launch_result = cuda_adapter->launch( grid, block, smem_size, stream, kernel_params, 0 ); + launch_result = cuda_adapter->launch( + grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 0 + ); } else { @@ -441,19 +466,15 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - const auto sycl_block2 = syclcompat::dim3(NumThreadsPerWarp, - std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, - params.softmax_params.args.M), - 1); - const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x), - params.softmax_params.args.batch_count, - 1); + const auto sycl_block_finalize = syclcompat::dim3(block_finalize.x, block_finalize.y, block_finalize.z); + const auto sycl_grid_finalize = syclcompat::dim3(grid_finalize.x, grid_finalize.y, grid_finalize.z); auto event2 = launch>(launch_policy{ - sycl_grid2, sycl_block2, local_mem_size{SoftmaxFinalizeKernel::SharedStorageSize}}, + sycl_grid_finalize, sycl_block_finalize, local_mem_size{static_cast(smem_size_finalize)}}, params.softmax_params); EventManager::getInstance().addEvent(event2); #else device_kernel<<>>(params.gemm_params); + device_kernel<<>>(params.softmax_params); #endif } } From 1ea4f40966750eac1f14a628abfb1498010e02e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 5 Dec 2024 09:40:23 +0000 Subject: [PATCH 11/14] addressed second round of review comments. --- examples/35_gemm_softmax/gemm_softmax_adapter.hpp | 2 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index c367d8387f..3eb7e9d537 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -404,7 +404,7 @@ class GemmSoftmaxAdapter smem_size_finalize, stream, kernel_params_finalize, - 0); + 1); } else { return Status::kErrorInternal; diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 0c35f50458..755f33d5f8 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -264,10 +264,12 @@ class SoftmaxEpilogue { syncthreads(); - // assumption: size<0>(sC) == wg size + // assumption for reductions: size<0>(sC) == block size + assert(size<0>(sC) == BlockDimX() * BlockDimy() * BlockDimZ()); + ElementAccumulator max = std::numeric_limits::lowest(); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(sC); ++i) { + for (int i = 0; i < size<1>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i) = sC(thread_idx, i); max = cutlass::fast_max(max, accumulators(i)); @@ -277,7 +279,7 @@ class SoftmaxEpilogue { ElementAccumulator sum = 0; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(sC); ++i) { + for (int i = 0; i < size<1>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { sum += cutlass::fast_exp(accumulators(i) - max); } From dc123265709763230ac5159c4fd1c3c426d8a807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 5 Dec 2024 11:52:29 +0100 Subject: [PATCH 12/14] Apply formatting suggestions from code review Co-authored-by: Finlay --- .../35_gemm_softmax/gemm_online_softmax.cpp | 4 ++-- examples/35_gemm_softmax/softmax_finalize.hpp | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 67dff1460a..2a6e6f1d52 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -128,7 +128,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "14_ampere_tf32_tensorop_gemm_cute example\n\n" + out << "35_gemm_softmax example\n\n" << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" @@ -238,7 +238,7 @@ struct ExampleRunner { // Methods // template - bool verify_tensor(std::vector vector_Input, \ + bool verify_tensor(std::vector vector_Input, std::vector vector_Input_Ref, const Options& options) { auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 30ecd90fef..ca6e6ac93a 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -129,7 +129,7 @@ class SoftmaxFinalize { const int y_size = BlockDimY(); const int batch_id = BlockIdxY(); - if(m>=params.args.M){ + if (m >= params.args.M) { return; } @@ -146,43 +146,43 @@ class SoftmaxFinalize { make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp))); ElementPartial max_val = std::numeric_limits::lowest(); - for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(idx_x,idx_y) = max_val; + sPartial(idx_x, idx_y) = max_val; syncthreads(); // tree-reduction could be better, although it does not seem to be a bottleneck - for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ - ElementPartial partial_max = sPartial(idx_x,idx_y2); + for (int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_max = sPartial(idx_x, idx_y2); max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(idx_x,idx_y) = sum_val; + sPartial(idx_x, idx_y) = sum_val; syncthreads(); sum_val = 0; // tree-reduction could be better, although it does not seem to be a bottleneck for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ - ElementPartial partial_sum = sPartial(idx_x,idx_y2); + ElementPartial partial_sum = sPartial(idx_x, idx_y2); sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ + for (int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; } - if(params.args.dataN % 2 == 1){ + if (params.args.dataN % 2 == 1){ int n = params.args.dataN - 1; auto inVal = mIn(m, n, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; From 51cb0b5cb874c03224ee3bc45ff194ccb8aa3337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Fri, 6 Dec 2024 10:26:50 +0000 Subject: [PATCH 13/14] addressed third round of comments --- .../35_gemm_softmax/gemm_online_softmax.cpp | 26 +++++++++---------- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 7 +++-- examples/35_gemm_softmax/softmax_epilogue.hpp | 13 +++++----- include/cutlass/fast_math.h | 1 - include/cutlass/gpu_generics.h | 12 ++++----- 5 files changed, 28 insertions(+), 31 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 2a6e6f1d52..b8844286fb 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -128,7 +128,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "35_gemm_softmax example\n\n" + out << "35_gemm_online_softmax example\n\n" << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" @@ -193,13 +193,13 @@ struct ExampleRunner { using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; - using StrideTmp = typename Gemm::CollectiveEpilogue::StrideD; + using StridePartials = typename Gemm::CollectiveEpilogue::StrideD; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; - using LayoutTmp = typename Gemm::LayoutTmp; + using LayoutPartials = typename Gemm::LayoutPartials; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; @@ -223,7 +223,7 @@ struct ExampleRunner { StrideB stride_B; StrideC stride_C; StrideD stride_D; - StrideTmp stride_tmp; + StridePartials stride_partials; uint64_t seed = 0; cutlass::DeviceAllocation block_A; @@ -281,8 +281,8 @@ struct ExampleRunner { LayoutA layout_A(lda); LayoutB layout_B(ldb); LayoutC layout_C(ldc); - LayoutTmp Layout_N(ldn); - LayoutTmp Layout_S(lds); + LayoutPartials Layout_N(ldn); + LayoutPartials Layout_S(lds); cutlass::MatrixCoord extent_A{options.m, options.k}; cutlass::MatrixCoord extent_B{options.k, options.n}; @@ -365,21 +365,21 @@ struct ExampleRunner { auto [M, N, K, L] = problem_shape_MNKL; auto partials_N = cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})); - auto tmp_size = M * partials_N * L; + auto partials_size = M * partials_N * L; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(M, partials_N, L)); + stride_partials = cutlass::make_cute_packed_stride(StridePartials{}, cute::make_shape(M, partials_N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); - block_sum.reset(tmp_size); - block_max.reset(tmp_size); + block_sum.reset(partials_size); + block_max.reset(partials_size); initialize_block(block_A, seed + 2023); initialize_block(block_B, seed + 2022); @@ -399,7 +399,7 @@ struct ExampleRunner { options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, - block_max.get(), block_sum.get(), stride_tmp}, + block_max.get(), block_sum.get(), stride_partials}, hw_info }; @@ -513,7 +513,7 @@ int main(int argc, char const **args) { using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::ColumnMajor; - using LayoutTmp = cutlass::layout::ColumnMajor; + using LayoutPartials = cutlass::layout::ColumnMajor; // Tiling configuration selection using TileShape = Shape<_128,_128,_32>; @@ -580,7 +580,7 @@ int main(int argc, char const **args) { using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, - cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, TileShape, EpilogueOp, cutlass::gemm::EpilogueDefault>; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 3eb7e9d537..75a3feeeb0 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -1,5 +1,4 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * @@ -86,7 +85,7 @@ class GemmSoftmaxAdapter using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< ElementD, typename GemmKernel::StrideD, - ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StrideTmp, + ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StridePartials, ElementD, typename GemmKernel::StrideD>; // Map back to 2.x type as best as possible @@ -94,7 +93,7 @@ class GemmSoftmaxAdapter using LayoutB = gemm::detail::StrideToLayoutTagB_t; using LayoutC = gemm::detail::StrideToLayoutTagC_t; using LayoutD = gemm::detail::StrideToLayoutTagC_t; - using LayoutTmp = gemm::detail::StrideToLayoutTagC_t; + using LayoutPartials = gemm::detail::StrideToLayoutTagC_t; static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; @@ -270,7 +269,7 @@ class GemmSoftmaxAdapter softmax_args.partialN = cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{})); softmax_args.batch_count = get<3>(args.problem_shape); softmax_args.dInput = args.epilogue.dD; - softmax_args.dPartial = args.epilogue.dTmp; + softmax_args.dPartial = args.epilogue.dPartials; softmax_args.dOutput = args.epilogue.dD; softmax_args.ptr_in = args.epilogue.ptr_D; softmax_args.ptr_partial_max = args.epilogue.ptr_max; diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 755f33d5f8..438552a7f2 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -1,5 +1,4 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * @@ -53,7 +52,7 @@ namespace collective { template < class StrideC_, class StrideD_, - class StrideTmp_, + class StridePartials_, class BlockShapeMNK, class ThreadEpilogueOp_, class EpilogueSchedule_ @@ -76,7 +75,7 @@ class SoftmaxEpilogue { using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - using StrideTmp = StrideTmp_; + using StridePartials = StridePartials_; using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -102,7 +101,7 @@ class SoftmaxEpilogue { StrideD dD{}; ElementAccumulator* ptr_max; ElementAccumulator* ptr_sum; - StrideTmp dTmp{}; + StridePartials dPartials{}; }; // Device side epilogue params @@ -187,7 +186,7 @@ class SoftmaxEpilogue { auto N_tile = get<1>(blk_shape_MNK); auto K_tile = get<2>(blk_shape_MNK); - auto N_tmp = cute::ceil_div(N, N_tile); + auto N_partials = cute::ceil_div(N, N_tile); cute::packed_tuple partial_block(M_tile, K_tile); @@ -197,8 +196,8 @@ class SoftmaxEpilogue { // Represent the full output tensors Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); - Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_partials,L), params.dPartials); + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_partials,L), params.dPartials); Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 59a6649d76..c856afa76e 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -1,6 +1,5 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 037087e197..f6af850df7 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -43,12 +43,12 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumThreadsPerWarp = 32; -static constexpr int NumThreadsPerWarpGroup = 128; -static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; -static constexpr int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -static constexpr int NumThreadsPerQuad = 4; -static constexpr int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static const int NumThreadsPerWarp = 32; +static const int NumThreadsPerWarpGroup = 128; +static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static const int NumThreadsPerQuad = 4; +static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; static constexpr int MaxNumThreadsPerBlock = 1024; //////////////////////////////////////////////////////////////////////////////////////////////////// From bb925baa553dc2a0c0dbf24f283b56b883c29f44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 10 Dec 2024 12:41:11 +0000 Subject: [PATCH 14/14] fix from review --- examples/35_gemm_softmax/gemm_softmax_adapter.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 75a3feeeb0..13e0369e4e 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -440,7 +440,7 @@ class GemmSoftmaxAdapter grid, block, smem_size, stream, kernel_params, 0 ); launch_result = cuda_adapter->launch( - grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 0 + grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 1 ); }