From 430f9ebc566b863c89d22d644aa4a41dfdd1aed1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 7 Nov 2024 10:41:05 +0100 Subject: [PATCH 01/19] pre-state - WIP --- examples/35_gemm_softmax/CMakeLists.txt | 8 +- .../35_gemm_softmax/gemm_online_softmax.cpp | 519 ++++++++++++++++++ examples/35_gemm_softmax/softmax_epilogue.hpp | 319 +++++++++++ examples/CMakeLists.txt | 1 + 4 files changed, 846 insertions(+), 1 deletion(-) create mode 100644 examples/35_gemm_softmax/gemm_online_softmax.cpp create mode 100644 examples/35_gemm_softmax/softmax_epilogue.hpp diff --git a/examples/35_gemm_softmax/CMakeLists.txt b/examples/35_gemm_softmax/CMakeLists.txt index b7ecd99fcc..d7f2cd574b 100644 --- a/examples/35_gemm_softmax/CMakeLists.txt +++ b/examples/35_gemm_softmax/CMakeLists.txt @@ -29,8 +29,14 @@ +if (NOT CUTLASS_ENABLE_SYCL) cutlass_example_add_executable( 35_gemm_softmax gemm_softmax.cu ) - +else() +cutlass_example_add_executable( + 35_gemm_online_softmax + gemm_online_softmax.cpp + ) +endif() \ No newline at end of file diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp new file mode 100644 index 0000000000..2a7d21fd33 --- /dev/null +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -0,0 +1,519 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple GEMM example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture + + This example demonstrate how to instantiate and run a TF32 GEMM using the Cute and + CUTLASS 3.x APIs on NVIDIA Ampere architecture. Please check example 07 and 08 for + the basics of tensor op gemm kernels. On NVIDIA Ampere architecture, most concept + still holds. The two main differences are: + + (1) NVIDIA Ampere architecture introduces a new series of tensor core instructions + (see include/cute/arch/mma_sm80.hpp) which are more efficient on Ampere. + (2) NVIDIA Ampere architecture uses CP_ASYNC (see include/cute/arch/copy_sm80.hpp) + to build a multistage software pipeline to better hide latency (see + include/cutlass/gemm/collective/sm80_mma_multistage.hpp). + + Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) + data types in tensor cores. One big advantage is that we can load in fp32 data and convert + them implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate + traditional fp32 data by using NVIDIA Ampere architecture. + + Examples: + + $ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm_cute + +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#else +#include "cutlass/util/reference/device/tensor_fill.h" +#endif +#include "helper.h" +#include "softmax_epilogue.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +/// Result structure +struct Result { + + double avg_runtime_ms; + double gflops; + bool passed; + + // + // Methods + // + + Result( + double avg_runtime_ms = 0, + double gflops = 0) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), passed(false) + {} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + int m, n, k, l; + float alpha, beta; + int iterations; + + Options(): + help(false), + m(5120), n(4096), k(4096), l(1), + alpha(1), beta(0), + iterations(0) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + //cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "14_ampere_tf32_tensorop_gemm_cute example\n\n" + << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Wrapper to run and verify a GEMM. +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementCompute(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementCompute(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + +#if defined(CUTLASS_ENABLE_SYCL) + syclcompat::wait_and_throw(); +#else + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } +#endif + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(problem_size, options.alpha, options.beta); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm_op.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' + << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. +#if !defined(CUTLASS_ENABLE_SYCL) + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + return 0; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + return 0; + } +#endif + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. + // This information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // Problem configuration + using ElementA = float; + using ElementB = float; + using ElementAcc = float; + using ElementOutput = float; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + // Tiling configuration selection + using TileShape = Shape<_128,_128,_32>; + + // + // Assembling the CollectiveMainloop type + // + + // Number of pipelines you want to use + constexpr int PipelineStages = 4; + + using DispatchPolicy = cutlass::gemm::MainloopSm80CpAsync; + + // This code section describes the MMA op and the tile size a warp will compute + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2,_1,_1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // Define the copy layout and atom for device memory copy. + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, float>{}, + Layout, Stride<_1,_16>>{}, + Layout>{})); + + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, float>{}, + Layout, Stride<_8,_1>>{}, + Layout>{})); + + // Define the copy layout and atom for shared memory copy. + using SmemLayoutAtomA = decltype(composition(Swizzle<2,3,2>{}, Layout, Stride< _1,_32>>{})); + using SmemCopyAtomA = Copy_Atom, float>; + + using SmemLayoutAtomB = decltype(composition(Swizzle<3,2,3>{}, Layout, Stride<_32, _1>>{})); + using SmemCopyAtomB = Copy_Atom; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape, + ElementA, + cutlass::detail::TagToStrideA_t, + ElementB, + cutlass::detail::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // + // Assembling the Collective Epilogue Type + // + + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAcc, // <- data type of accumulator + ElementOutput>; // <- data type for alpha/beta in linear combination function + + using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + EpilogueOp, + cutlass::gemm::EpilogueDefault>; + + // + // Assembling the GemmKernel + // + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + runner.run(options, hw_info); + + return 0; +} diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp new file mode 100644 index 0000000000..7f7a055f09 --- /dev/null +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class SoftmaxEpilogue { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + SoftmaxEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragDst, + class Op + > + CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { + // reduce across all the N tiles in shape + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + dst(x, y) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + CUTLASS_PRAGMA_UNROLL + for(int z = 1; z < SizeC; z++) { + dst(x, y) = op(dst(x, y), src(x, y, z)); + } + } + } + + // reduce across the sub_group to get the final output + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for(int x = 0; x < SizeA; x++) { + CUTLASS_PRAGMA_UNROLL + for(int y = 0; y < SizeB; y++) { + CUTLASS_PRAGMA_UNROLL + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); + } + } + } + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragDst, + class Op + > + CUTLASS_DEVICE static void reduceWg(FragSrc const &src, FragDst &dst, char* smem_buf, Op op, SharedStorage const& shared_storage) { + reduceSg(src, dst, op); + for(int i=ThreadIdxX() % NumThreadsPerWarp; i + CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { + reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + } + + template < + bool zero_init, + int SizeA, + int SizeB, + int SizeC, + class FragSrc, + class FragSum + > + CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { + reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + //auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + //Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + //Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + //Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + //Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + //CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + // "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + if(ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0){ + print("thr_mma: "); print(thr_mma); print("\n"); + print("tiled_mma: "); print(tiled_mma); print("\n"); + //print("tiled_mma L: "); print_latex(tiled_mma); print("\n"); + print("acc: "); print(accumulators); print("\n"); + print("tCgD: "); print(tCgD); print("\n"); + print("gD: "); print(gD); print("\n"); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + for (int j = 0; j < size<1>(accumulators); ++j) { + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i,j,k) = epilogue_op(accumulators(i,j,k)); + } + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 26c69b310a..c1081d7417 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -155,6 +155,7 @@ if (NOT CUTLASS_ENABLE_SYCL) else() foreach(EXAMPLE 14_ampere_tf32_tensorop_gemm + 35_gemm_softmax cute sycl ) From f191b9a2d01fc0b22ef5d3f65cd61f7eabd0c8a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 11 Nov 2024 13:29:14 +0100 Subject: [PATCH 02/19] WIP second kernel --- .../35_gemm_softmax/gemm_online_softmax.cpp | 21 ++- examples/35_gemm_softmax/softmax_epilogue.hpp | 141 +++++++++++++----- 2 files changed, 123 insertions(+), 39 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 2a7d21fd33..6ad3c95333 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -77,6 +77,7 @@ #endif #include "helper.h" #include "softmax_epilogue.hpp" +#include "gemm_softmax_adapter.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -116,7 +117,7 @@ struct Options { help(false), m(5120), n(4096), k(4096), l(1), alpha(1), beta(0), - iterations(0) + iterations(100) { } // Parses the command line @@ -205,6 +206,7 @@ struct ExampleRunner { using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideTmp = typename Gemm::CollectiveEpilogue::StrideD; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; @@ -232,11 +234,14 @@ struct ExampleRunner { StrideB stride_B; StrideC stride_C; StrideD stride_D; + StrideTmp stride_tmp; uint64_t seed = 0; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_max; + cutlass::DeviceAllocation block_sum; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; @@ -292,16 +297,24 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; + // 1 element per warp. + auto tmp_size = cute::ceil_div(M * K * L, cute::shape<0>(typename Gemm::TileShape{}) * cute::shape<1>(typename Gemm::TileShape{})) * NumWarpsPerWarpGroup; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(cute::ceil_div(M, cute::shape<0>(typename Gemm::TileShape{})), + cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})), + L)); block_A.reset(M * K * L); block_B.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); + block_sum.reset(tmp_size); + block_max.reset(tmp_size); initialize_block(block_A, seed + 2023); initialize_block(block_B, seed + 2022); @@ -317,7 +330,7 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, block_max.get(), block_sum.get(), stride_tmp}, hw_info }; @@ -431,6 +444,7 @@ int main(int argc, char const **args) { using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::ColumnMajor; + using LayoutTmp = cutlass::layout::ColumnMajor; // Tiling configuration selection using TileShape = Shape<_128,_128,_32>; @@ -497,6 +511,7 @@ int main(int argc, char const **args) { using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, EpilogueOp, cutlass::gemm::EpilogueDefault>; @@ -510,7 +525,7 @@ int main(int argc, char const **args) { CollectiveEpilogue >; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = cutlass::gemm::device::GemmSoftmaxAdapter; ExampleRunner runner; runner.run(options, hw_info); diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 7f7a055f09..e3e558da8d 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -55,6 +55,7 @@ namespace collective { template < class StrideC_, class StrideD_, + class StrideTmp_, class ThreadEpilogueOp_, class EpilogueSchedule_ > @@ -76,6 +77,7 @@ class SoftmaxEpilogue { using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; + using StrideTmp = StrideTmp_; using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -97,6 +99,9 @@ class SoftmaxEpilogue { StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; + ElementAccumulator* ptr_max; + ElementAccumulator* ptr_sum; + StrideTmp dTmp{}; }; // Device side epilogue params @@ -148,9 +153,6 @@ class SoftmaxEpilogue { template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragDst, class Op @@ -158,12 +160,12 @@ class SoftmaxEpilogue { CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { // reduce across all the N tiles in shape CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < SizeA; x++) { + for(int x = 0; x < size<0>(src); x++) { CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < SizeB; y++) { - dst(x, y) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + for(int y = 0; y < size<1>(src); y++) { + dst(0, 0) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < SizeC; z++) { + for(int z = 1; z < size<2>(src); z++) { dst(x, y) = op(dst(x, y), src(x, y, z)); } } @@ -172,9 +174,9 @@ class SoftmaxEpilogue { // reduce across the sub_group to get the final output auto sg = syclcompat::get_nd_item<1>().get_sub_group(); CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < SizeA; x++) { + for(int x = 0; x < size<0>(src); x++) { CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < SizeB; y++) { + for(int y = 0; y < size<1>(src); y++) { CUTLASS_PRAGMA_UNROLL for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); @@ -185,42 +187,45 @@ class SoftmaxEpilogue { template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragDst, class Op > CUTLASS_DEVICE static void reduceWg(FragSrc const &src, FragDst &dst, char* smem_buf, Op op, SharedStorage const& shared_storage) { - reduceSg(src, dst, op); + reduceSg(src, dst, op); for(int i=ThreadIdxX() % NumThreadsPerWarp; i y ? x : y; } + };*/ + template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragMax > CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { - reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + //reduceSg(src, max, MaxOp()); } + /*struct SumOp { + CUTLASS_DEVICE ElementAccumulator + operator()(ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; } + };*/ + template < bool zero_init, - int SizeA, - int SizeB, - int SizeC, class FragSrc, class FragSum > CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { - reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); + reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); + //reduceSg(src, sum, SumOp()); } template< @@ -236,7 +241,7 @@ class SoftmaxEpilogue { ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, BlockCoordMNKL blk_coord_mnkl, - cute::Tensor const& accumulators, + cute::Tensor & accumulators, TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, @@ -255,28 +260,28 @@ class SoftmaxEpilogue { auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); - //auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); // Represent the full output tensor - //Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - //Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - //Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) - //Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) static_assert(is_static::value, "Accumulator layout must be static"); - //CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), - // "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); @@ -284,25 +289,89 @@ class SoftmaxEpilogue { auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); + //Tensor acc_max = make_tensor(Shape(accumulators)>, Int(accumulators)>>{}); + //Tensor acc_max = make_tensor(size<0>(accumulators)); + Tensor acc_max = make_tensor_like(take<0,2>(accumulators)); + Tensor acc_sum = make_tensor_like(take<0,2>(accumulators)); //TODO can reuse prev? + if(ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0){ - print("thr_mma: "); print(thr_mma); print("\n"); - print("tiled_mma: "); print(tiled_mma); print("\n"); - //print("tiled_mma L: "); print_latex(tiled_mma); print("\n"); - print("acc: "); print(accumulators); print("\n"); - print("tCgD: "); print(tCgD); print("\n"); - print("gD: "); print(gD); print("\n"); + //print("thr_mma: "); print(thr_mma); print("\n"); + //print("tiled_mma: "); print(tiled_mma); print("\n"); + //print("acc: "); print(accumulators); print("\n"); + //print("tCgD: "); print(tCgD); print("\n"); + //print("acc_max: "); print(acc_max); print("\n"); + //print("take<0,2>(accumulators): "); print(take<0,2>(accumulators)); print("\n"); + //print("gD: "); print(gD); print("\n"); + } + + if(is_source_needed()){ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); + tCgD(i,j,k) = accumulators(i,j,k); + } + } + } + } + } else{ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); + tCgD(i,j,k) = accumulators(i,j,k); + } + } + } + } } + reduce_max(accumulators, acc_max); + //reduceSg(accumulators, acc_max, MaxOp()); + CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(accumulators); ++k) { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCgD(i,j,k) = epilogue_op(accumulators(i,j,k)); + accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(i,j)); } } } } + + reduce_sum(accumulators, acc_sum); + + //TODO write out reductions + + //second kernel: + // - finalize max reduction: mN = sum(mj) + // - finalize sum reduction: sN = sum(sj * exp(mj-mN)) + // - finalize softmax: yi = exp(xi-mN)/sN + + /*CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(accumulators); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(accumulators); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(accumulators); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i,j,k) = accumulators(i,j,k); + } + } + } + }*/ + } private: From e0fe666b48faa0c21bb8943d104cd266ff0dbf3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 18 Nov 2024 14:36:46 +0100 Subject: [PATCH 03/19] working softmax --- .../35_gemm_softmax/gemm_online_softmax.cpp | 198 ++++++- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 527 ++++++++++++++++++ examples/35_gemm_softmax/softmax_epilogue.hpp | 254 +++++++-- examples/35_gemm_softmax/softmax_finalize.hpp | 305 ++++++++++ include/cutlass/gpu_generics.h | 2 +- 5 files changed, 1232 insertions(+), 54 deletions(-) create mode 100644 examples/35_gemm_softmax/gemm_softmax_adapter.hpp create mode 100644 examples/35_gemm_softmax/softmax_finalize.hpp diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 6ad3c95333..eb2533947b 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -112,12 +112,14 @@ struct Options { int m, n, k, l; float alpha, beta; int iterations; + float tolerance; Options(): help(false), m(5120), n(4096), k(4096), l(1), alpha(1), beta(0), - iterations(100) + iterations(100), + tolerance(1e-5f) { } // Parses the command line @@ -134,8 +136,9 @@ struct Options { cmd.get_cmd_line_argument("k", k, 4096); cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); - //cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tolerance", tolerance); } @@ -152,7 +155,8 @@ struct Options { << " --l= Sets the L extent (batch count) of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Number of profiling iterations to perform.\n\n"; + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --tolerance Error tolerance\n"; return out; } @@ -212,13 +216,15 @@ struct ExampleRunner { using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; + using LayoutTmp = typename Gemm::LayoutTmp; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; using ElementAcc = typename Gemm::ElementAccumulator; using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; @@ -249,7 +255,7 @@ struct ExampleRunner { // Methods // - bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { + /*bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { auto [M, N, K, L] = problem_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); @@ -290,6 +296,177 @@ struct ExampleRunner { bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); return passed; + }*/ + + template + bool verify_tensor(std::vector vector_Input, \ + std::vector vector_Input_Ref, const Options& options) { + + auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); + float abs_tol = options.tolerance; + float rel_tol = options.tolerance; + + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i)); + float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { + printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + return false; + } + + } + + return true; + } + + /// Verifies the reference matches + bool verify(const Options& options) { + using ElementSoftmax = ElementD; + + cutlass::gemm::GemmCoord problem_size = cutlass::gemm::GemmCoord{options.m, options.n, options.k}; + + int64_t total_elements_A_per_batch = options.m * options.k; + int64_t total_elements_B_per_batch = options.k * options.n; + int64_t total_elements_C_per_batch = options.m * options.n; + int64_t total_elements_D_per_batch = total_elements_C_per_batch; + + int64_t lda = LayoutA::packed({options.m, options.k}).stride(0); + int64_t ldb = LayoutB::packed({options.k, options.n}).stride(0); + int64_t ldc = LayoutC::packed({options.m, options.n}).stride(0); + + int64_t ldn = options.m; + int64_t lds = ldn; + + LayoutA layout_A(lda); + LayoutB layout_B(ldb); + LayoutC layout_C(ldc); + LayoutTmp Layout_N(ldn); + LayoutTmp Layout_S(lds); + + cutlass::MatrixCoord extent_A{options.m, options.k}; + cutlass::MatrixCoord extent_B{options.k, options.n}; + cutlass::MatrixCoord extent_C{options.m, options.n}; + + cutlass::HostTensor reference_N; + reference_N.reset({options.m, 1}, false); + + for (int batch_idx = 0; batch_idx < options.l; batch_idx++) { + cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); + cutlass::TensorView view_Ref_device(block_ref_D.get(), layout_C, extent_C); + + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementCompute + >( + problem_size, + options.alpha, + view_A, + cutlass::ComplexTransform::kNone, + view_B, + cutlass::ComplexTransform::kNone, + options.beta, + view_C, + view_Ref_device, + ElementCompute(0) + ); + + // Copy reference results to host memory for verification + std::vector matrix_D_Ref(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_ref_D.get(), matrix_D_Ref.size()); + cutlass::TensorView view_D_Ref(matrix_D_Ref.data(), layout_C, extent_C); + + std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); + cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); + + // Copy computed results to host memory + std::vector matrix_D(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); + + auto& matrix_Softmax = matrix_D; + //std::vector matrix_Softmax(layout_C.capacity(extent_C)); + //cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); + + // Compute the norm + for (int m = 0; m < options.m; ++m) { + reference_N.at({m, 0}) = view_D_Ref.ref().at({m, 0}); + if(batch_idx == 0 && m < 3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ + std::cout << "ref tmp " << m << " " << 0 << ": " << view_D_Ref.ref().at({m, 0}) << std::endl; + } + for (int n = 1; n < options.n; ++n) { + //std::cout << "val: " << view_D_Ref.ref().at({m, n}) << std::endl; + reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); + + if(batch_idx == 0 && m < 3 && n<3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ + std::cout << "ref tmp " << m << " " << n << ": " << view_D_Ref.ref().at({m, n}) << std::endl; + } + if(batch_idx == 0 && m==0 && n==127 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ + std::cout << "ref max tmp " << m << " " << n << ": " << reference_N.at({m, 0}) << std::endl; + } + } + if(batch_idx == 0 && m == 0){ + std::cout << "ref max: " << reference_N.at({m, 0}) << std::endl; + } + } + + // Compute softmax + for (int m = 0; m < options.m; ++m) { + float sum = float(); + + for (int n = 0; n < options.n; ++n) { + sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); + } + if(batch_idx == 0 && m == 0){ + std::cout << "ref sum: " << sum << std::endl; + } + + float inv_sum = float(1.0f / sum); + + for (int n = 0; n < options.n; ++n) { + view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( + std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum + ); + } + } + + // Verification checks - set any of these to 'true' to override the verification checks. + bool verified_D = false; + bool verified_Softmax = false; + + // Verify softmax output + if (!verified_D) { + verified_D = verify_tensor(matrix_D, matrix_D_Ref, options); + } + + if (!verified_Softmax) { + verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref, options); + } + //TODO(Tadej): just softmax + if (!verified_D && !verified_Softmax) { + std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; + + // Summarize which checks failed + if (!verified_D) { + std::cerr << "Verification of D tensor failed\n"; + } else{ + std::cerr << "Verification of D tensor passed\n"; + } + + if (!verified_Softmax) { + std::cerr << "Verification of Softmax tensor failed\n"; + } else{ + std::cerr << "Verification of Softmax tensor passed\n"; + } + + return false; + } + } + return true; } /// Initialize operands to be used in the GEMM and reference GEMM @@ -297,16 +474,14 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - // 1 element per warp. - auto tmp_size = cute::ceil_div(M * K * L, cute::shape<0>(typename Gemm::TileShape{}) * cute::shape<1>(typename Gemm::TileShape{})) * NumWarpsPerWarpGroup; + auto partials_N = cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})); + auto tmp_size = M * partials_N * L; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(cute::ceil_div(M, cute::shape<0>(typename Gemm::TileShape{})), - cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})), - L)); + stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(M, partials_N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); @@ -348,7 +523,7 @@ struct ExampleRunner { // Check if output from CUTLASS kernel and reference kernel are equal or not Result result; - result.passed = verify(problem_size, options.alpha, options.beta); + result.passed = verify(options); std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; @@ -512,6 +687,7 @@ int main(int argc, char const **args) { cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, + TileShape, EpilogueOp, cutlass::gemm::EpilogueDefault>; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp new file mode 100644 index 0000000000..08fa798684 --- /dev/null +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -0,0 +1,527 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +//#include "cutlass/gemm/device/gemm_universal_base.h" +//#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +//#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +//#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/sycl_event_manager.hpp" +#endif + +#include "softmax_finalize.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmSoftmaxAdapter +{ +public: + using GemmKernel = GemmKernel_; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< + ElementD, typename GemmKernel::StrideD, + ElementD, typename GemmKernel::StrideD, + ElementD, typename GemmKernel::StrideD>; + + // Map back to 2.x type as best as possible + using LayoutA = gemm::detail::StrideToLayoutTagA_t; + using LayoutB = gemm::detail::StrideToLayoutTagB_t; + using LayoutC = gemm::detail::StrideToLayoutTagC_t; + using LayoutD = gemm::detail::StrideToLayoutTagC_t; + using LayoutTmp = gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + static ComplexTransform const kTransformB = cute::is_same_v ? + ComplexTransform::kConjugate : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + //using Params = typename GemmKernel::Params; + + struct Params{ + typename GemmKernel::Params gemm_params; + typename SoftmaxFinalizeKernel::Params softmax_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * size_t(cute::size<1>(TileShape{})); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params.gemm_params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + GemmKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); + //TODO(Tadej) move to finalize kernel class? + auto& softmax_args = params_.softmax_params.args; + softmax_args.IOSize = {get<0>(args.problem_shape), get<1>(args.problem_shape)}; + softmax_args.partialSize = {get<0>(args.problem_shape), + cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{}))}; + softmax_args.batch_count = get<3>(args.problem_shape); + softmax_args.dInput = args.epilogue.dD; + softmax_args.dPartial = args.epilogue.dTmp; + softmax_args.dOutput = args.epilogue.dD; + softmax_args.ptr_in = args.epilogue.ptr_D; + softmax_args.ptr_partial_max = args.epilogue.ptr_max; + softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; + softmax_args.ptr_out = args.epilogue.ptr_D; + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if !defined(CUTLASS_ENABLE_SYCL) + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } +#endif + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); + //TODO(Tadej) update softmax args + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling GemmKernel::to_underling_arguments() + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if !defined(CUTLASS_ENABLE_SYCL) + constexpr bool is_static_1x1x1 = cute::is_static_v and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { + if (is_static_1x1x1 && not launch_with_pdl) { + device_kernel<<>>(params); + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + } +#endif + } + else { + launch_result = Status::kSuccess; + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms.gemm_params}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if defined(CUTLASS_ENABLE_SYCL) + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + using namespace syclcompat::experimental; +#if defined (SYCL_INTEL_TARGET) + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size} + }, params.gemm_params); +#else + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, + params.gemm_params); +#endif + //EventManager::getInstance().addEvent(event); + + const auto sycl_block2 = syclcompat::dim3(128, 1, 1); + const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), + params.softmax_params.args.batch_count, + 1); + auto event2 = launch>(launch_policy{ + sycl_grid2, sycl_block2, local_mem_size{0}}, + params.softmax_params); + EventManager::getInstance().addEvent(event2); +#else + device_kernel<<>>(params.gemm_params); +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index e3e558da8d..96c902e373 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -56,6 +56,7 @@ template < class StrideC_, class StrideD_, class StrideTmp_, + class BlockShapeMNK, class ThreadEpilogueOp_, class EpilogueSchedule_ > @@ -88,7 +89,10 @@ class SoftmaxEpilogue { static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - struct SharedStorage { }; + struct SharedStorage { + //cute::array_aligned(BlockShapeMNK{})>,C(BlockShapeMNK{})>>>>> smem_c; + cute::array_aligned(BlockShapeMNK{}) * get<1>(BlockShapeMNK{})> smem_c; + }; using TensorStorage = SharedStorage; @@ -158,15 +162,15 @@ class SoftmaxEpilogue { class Op > CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { - // reduce across all the N tiles in shape + // reduce across all the -N- M tiles in shape CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < size<0>(src); x++) { + for(int z = 1; z < size<2>(src); z++) { + dst(z) = zero_init ? src(0, 0, z) : op(dst(z), src(0, 0, z)); CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < size<1>(src); y++) { - dst(0, 0) = zero_init ? src(x, y, 0) : op(dst(x, y), src(x, y, 0)); + for(int x = 0; x < size<0>(src); x++) { CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < size<2>(src); z++) { - dst(x, y) = op(dst(x, y), src(x, y, z)); + for(int y = 0; y < size<1>(src); y++) { + dst(z) = op(dst(z), src(x, y, z)); } } } @@ -174,35 +178,73 @@ class SoftmaxEpilogue { // reduce across the sub_group to get the final output auto sg = syclcompat::get_nd_item<1>().get_sub_group(); CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < size<0>(src); x++) { + for(int z = 1; z < size<2>(src); z++) { CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < size<1>(src); y++) { - CUTLASS_PRAGMA_UNROLL - for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { - dst(x,y) = op(dst(x, y), syclcompat::permute_sub_group_by_xor(sg, dst(x, y), laneMask, 16)); - } + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + dst(z) = op(dst(z), syclcompat::permute_sub_group_by_xor(sg, dst(z), laneMask, 16)); } } } template < - bool zero_init, class FragSrc, class FragDst, + class SharedThreadTens, + class SharedTens, + class ResidueMap, + class Residue, class Op > - CUTLASS_DEVICE static void reduceWg(FragSrc const &src, FragDst &dst, char* smem_buf, Op op, SharedStorage const& shared_storage) { - reduceSg(src, dst, op); + CUTLASS_DEVICE static ElementAccumulator reduceWg(FragSrc const &src, FragDst &dst, + SharedThreadTens& tCsC, SharedTens& sC, + ResidueMap tCcD, Residue residue_mnk, int thread_idx, + ElementAccumulator init, Op op) { + //TODO(Tadej): single loop over all dims + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(src); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(src); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(src); ++k) { + if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCsC(i,j,k) = src(i,j,k); + } else{ + tCsC(i,j,k) = init; + } + } + } + } + + syncthreads(); + + ElementAccumulator acc = sC(0, thread_idx); + for (int i = 1; i < size(src); ++i) { + acc = op(acc, sC(i, thread_idx)); + } + + syncthreads(); + + //broadcast it back to threads + //TODO(Tadej): optimize + for (int i = 0; i < size(src); ++i) { + sC(i, thread_idx) = acc; + } + + syncthreads(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < size<2>(src); k++) { + dst(k) = tCsC(0,0,k); + } + + return acc; + + /*reduceSg(src, dst, op); for(int i=ThreadIdxX() % NumThreadsPerWarp; i y ? x : y; } - };*/ - template < bool zero_init, class FragSrc, @@ -210,13 +252,24 @@ class SoftmaxEpilogue { > CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); - //reduceSg(src, max, MaxOp()); } - /*struct SumOp { - CUTLASS_DEVICE ElementAccumulator - operator()(ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; } - };*/ + template < + class FragSrc, + class FragDst, + class SharedThreadTens, + class SharedTens, + class ResidueMap, + class Residue + > + CUTLASS_DEVICE static ElementAccumulator reduce_max_wg(FragSrc const &src, FragDst &dst, + SharedThreadTens& tCsC, SharedTens& sC, + ResidueMap tCcD, Residue residue_mnk, int thread_idx) { + + return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, + std::numeric_limits::min(), + [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); + } template < bool zero_init, @@ -225,12 +278,27 @@ class SoftmaxEpilogue { > CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); - //reduceSg(src, sum, SumOp()); + } + + template < + class FragSrc, + class FragDst, + class SharedThreadTens, + class SharedTens, + class Residue, + class ResidueMap + > + CUTLASS_DEVICE static ElementAccumulator reduce_sum_wg(FragSrc const &src, FragDst &dst, + SharedThreadTens& tCsC, SharedTens& sC, + ResidueMap tCcD, Residue residue_mnk, int thread_idx) { + + return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, + 0, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x+y; }); } template< class ProblemShapeMNKL, - class BlockShapeMNK, + //class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, class TiledMma, @@ -255,29 +323,53 @@ class SoftmaxEpilogue { static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - // Separate out problem shape for convenience + //auto wlid = thread_idx % NumWarpsPerWarpGroup; // warp local id + //auto wid = thread_idx / NumWarpsPerWarpGroup; // warp id in tile + + // Separate out problem and tile shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); + auto M_tile = get<0>(blk_shape_MNK); + auto N_tile = get<1>(blk_shape_MNK); + auto K_tile = get<2>(blk_shape_MNK); + + auto N_tmp = cute::ceil_div(N, N_tile); + + cute::packed_tuple partial_block(M_tile, C<1>(), K_tile); + auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); - // Represent the full output tensor + // Represent the full output tensors Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gMax = gMax_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gSum = gSum_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - // Partition source and destination tiles to match the accumulator partitioning + //Represent the shared tensor + Tensor sC = make_tensor(make_smem_ptr(reinterpret_cast(smem_buf)), make_layout(make_shape(M_tile, N_tile))); + + // Partition the tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + //Tensor tCgMax = thr_mma.partition_C(gMax); // (VEC,THR_M,THR_N) + //Tensor tCgSum = thr_mma.partition_C(gSum); // (VEC,THR_M,THR_N) + Tensor tCsC = thr_mma.partition_C(sC); // (VEC,THR_M,THR_N) + static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), @@ -291,17 +383,35 @@ class SoftmaxEpilogue { //Tensor acc_max = make_tensor(Shape(accumulators)>, Int(accumulators)>>{}); //Tensor acc_max = make_tensor(size<0>(accumulators)); - Tensor acc_max = make_tensor_like(take<0,2>(accumulators)); - Tensor acc_sum = make_tensor_like(take<0,2>(accumulators)); //TODO can reuse prev? - - if(ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0){ + //Tensor acc_max = make_tensor_like(take<2,3>(accumulators)); + //Tensor acc_sum = make_tensor_like(acc_max); //TODO can reuse prev? + + //Tensor acc_max = make_tensor(shape<2>(accumulators), LayoutLeft{}); + //Tensor acc_sum = make_tensor(shape<2>(accumulators), LayoutLeft{}); //TODO can reuse prev? + + bool is_first = ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0; + if(is_first){ + print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); + //print("blk_shape_MNK: "); print(blk_shape_MNK); print("\n"); + //print("partial_block: "); print(partial_block); print("\n"); //print("thr_mma: "); print(thr_mma); print("\n"); //print("tiled_mma: "); print(tiled_mma); print("\n"); //print("acc: "); print(accumulators); print("\n"); + //print("mD_mnl: "); print(mD_mnl); print("\n"); + print("mMax_mnl: "); print(mMax_mnl); print("\n"); + //print("gD_mnl: "); print(gD_mnl); print("\n"); + print("gMax_mnl: "); print(gMax_mnl); print("\n"); + //print("gD: "); print(gD); print("\n"); + print("gMax: "); print(gMax); print("\n"); //print("tCgD: "); print(tCgD); print("\n"); + //print("sC: "); print(sC); print("\n"); + //print("tCsC: "); print(tCsC); print("\n"); + //print("sC.data: "); print(&sC(0)); print("\n"); + //print("tCsC.data: "); print(&tCsC(0)); print("\n"); + //decltype(tCsC(0)) a = "asd"; + //print("tCgMax: "); print(tCgMax); print("\n"); //print("acc_max: "); print(acc_max); print("\n"); - //print("take<0,2>(accumulators): "); print(take<0,2>(accumulators)); print("\n"); - //print("gD: "); print(gD); print("\n"); + //print("accumulators: "); print(accumulators); print("\n"); } if(is_source_needed()){ @@ -314,6 +424,10 @@ class SoftmaxEpilogue { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); + tCsC(i,j,k) = accumulators(i,j,k); + /*if(is_first){ + print("acc1.1:"); print(tCsC(i,j,k)); print("\n"); + }*/ } } } @@ -328,14 +442,65 @@ class SoftmaxEpilogue { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); - } + tCsC(i,j,k) = accumulators(i,j,k); + /*if(is_first){ + print("acc1.2:"); print(accumulators(i,j,k)); print(".\n"); + print("idx:"); print(tCsC.layout()(i,j,k)); print(".\n"); + }*/ + } } } } } - reduce_max(accumulators, acc_max); - //reduceSg(accumulators, acc_max, MaxOp()); + syncthreads(); + + // assumption size<0>(sC) == wg size + ElementAccumulator max = std::numeric_limits::min(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(sC); ++i) { + if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + accumulators(i) = sC(thread_idx, i); + max = cutlass::fast_max(max, accumulators(i)); + /*if(is_first && i < 3){ + print("acc2 :"); print(accumulators(i)); print("\n"); + //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); + for (int j = 0; j < 3; ++j) { + print("shared :"); print(j); print(" "); print(i); print(": "); print(sC(j, i)); print("\n"); + } + }*/ + } + } + /*if(m_coord == 0 && n_coord == 1 && ThreadIdxX()==0){ + print("max epilogue val:"); print(max); print("\n"); + print("idx:"); print(n_coord); print("\n"); + }*/ + + gMax(thread_idx,0) = max; + + ElementAccumulator sum = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(sC); ++i) { + if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + sum += cutlass::fast_exp(accumulators(i) - max); + if(is_first){ + //print("acc3 :"); print(accumulators(i)); print("\n"); + //print("diff :"); print(accumulators(i) - max); print("\n"); + //print("add :"); print(cutlass::fast_exp(accumulators(i) - max)); print("\n"); + //print("sum :"); print(sum); print("\n"); + //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); + } + } + } + if(is_first){ + print("sum epilogue val:"); print(sum); print("\n"); + } + + gSum(thread_idx,0) = sum; + + + /*//reduce_max(accumulators, acc_max); + reduce_max_wg(accumulators, acc_max, tCsC, sC, tCcD, residue_mnk, thread_idx); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(accumulators); ++i) { @@ -344,13 +509,18 @@ class SoftmaxEpilogue { CUTLASS_PRAGMA_UNROLL for (int k = 0; k < size<2>(accumulators); ++k) { if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(i,j)); + accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(k)); } } } } reduce_sum(accumulators, acc_sum); + if(wlid == 0){ + for (int k = 0; k < size<2>(accumulators); ++k) { + gSum(wid,k) = acc_sum(k); + } + }*/ //TODO write out reductions diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp new file mode 100644 index 0000000000..bc91f85dd9 --- /dev/null +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a final calculation of softmax +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +template < + typename ElementInput_, + typename StrideInput_, + typename ElementPartial_, + typename StridePartial_, + typename ElementOutput_, + typename StrideOutput_ +> +class SoftmaxFinalize { +public: + + using ElementInput = ElementInput_; + using StrideInput = StrideInput_; + using ElementPartial = ElementPartial_; + using StridePartial = StridePartial_; + using ElementOutput = ElementOutput_; + using StrideOutput = StrideOutput_; + + // + // Arguments + // + + struct Arguments { + //TODO(Tadej): duplicated part of sizes + cutlass::MatrixCoord IOSize; ///< Extent of input and output matrices + cutlass::MatrixCoord partialSize; ///< Extent of partial max and sum matrices + int batch_count; ///< Batch count + StrideInput dInput; + StridePartial dPartial; + StrideOutput dOutput; + ElementInput* ptr_in; + ElementPartial* ptr_partial_max; + ElementPartial* ptr_partial_sum; + ElementOutput* ptr_out; + +/* + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size, + ElementNorm* block_Norm, + ElementSum* block_Sum + ): + problem_size(problem_size), + block_Norm(block_Norm), + block_Sum(block_Sum), + problem_sizes(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr), + batch_stride_Max(0), + batch_stride_Sum(0) + { + + }*/ + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + SoftmaxFinalize() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char* shared_storage) { + + apply(params, shared_storage); + } + +private: + + template + CUTLASS_DEVICE static ElementPartial reduceSg(ElementPartial val, Op op) { + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { + val = op(val, syclcompat::permute_sub_group_by_xor(sg, val, laneMask, 16)); + } + return val; + } + + CUTLASS_DEVICE static ElementPartial reduce_max(ElementPartial val) { + return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x > y ? x : y; }); + } + + CUTLASS_DEVICE static ElementPartial reduce_sum(ElementPartial val) { + return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x + y; }); + } + + /// Full reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, char* shared_storage) { + using ConvertInput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + /*int tid = ThreadIdxX(); + //int bid = BlockIdxX(); + int bdim = BlockDimX(); + + int warps_in_block = bdim / NumThreadsPerWarp; + + int wlid = tid % warps_in_block; // local id of thread in warp + int gid = tid;// + bid * GridDimX(); + int bsize = GridDimX(); + int m_batch_id = BlockIdxY(); + int batch_id = m_batch_id / params.args.IOSize[1]; + int m = m_batch_id % params.args.IOSize[1];*/ + + int m = ThreadIdxX() + BlockDimX() * BlockIdxX(); + int batch_id = BlockIdxY(); + + if(m>=params.args.IOSize[0]){ + return; + } + + + // Represent the full tensors + auto IOTensorShape = make_shape(params.args.IOSize[0], params.args.IOSize[1], params.args.batch_count); + auto PartialTensorShape = make_shape(params.args.partialSize[0], params.args.partialSize[1], params.args.batch_count); + Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); // (m,n,l) + Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); // (m,n,l) + Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); // (m,n,l) + Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) + + if(m==0 && batch_id==0){ + print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); + } + + ElementPartial max_val = std::numeric_limits::min(); + for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + /*if(m==0 && batch_id==0){ + print("partial_max: "); print(partial_max); print("\n"); + }*/ + max_val = max_val > partial_max ? max_val : partial_max; + } + //max_val = reduce_max(max_val); + + //mOut(0,0,0) = max_val; return; + + ElementPartial sum_val = 0; + for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); + sum_val = sum_val + partial_sum * cutlass::fast_exp(partial_max - max_val); + } + //sum_val = reduce_sum(sum_val); + + if(m==0 && batch_id==0){ + print("max_val: "); print(max_val); print("\n"); + print("sum_val: "); print(sum_val); print("\n"); + } + + ElementPartial norm = 1 / sum_val; + + for(int n = 0; n < params.args.IOSize[1]; n += 1){ + mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; + } + } + /* + // defining three vars for a general reduction module + cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; + int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; + int access_offset = isGroupedProblem ? 0 : bid * bdim; + + if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; + + ElementNorm *curr_ptr_Max = isGroupedProblem ? \ + params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ + params.args.block_Norm + block_batch * params.args.batch_stride_Max; + ElementSum *curr_ptr_Sum = isGroupedProblem ? \ + params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ + params.args.block_Sum + block_batch * params.args.batch_stride_Sum; + + int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + uint32_t float_max_bits = 0xff7fffff; + float min_float = reinterpret_cast(float_max_bits); + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { + ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; + ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; + ElementNorm *access_n_bak = access_n; + ElementSum *access_s_bak = access_s; + ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); + ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); + ElementNorm fetch_n; + ElementSum fetch_s; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); + access_n += problem_size.m(); + } + + access_n = access_n_bak; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + cutlass::arch::global_load(fetch_s, access_s, true); + sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); + access_n += problem_size.m(); + access_s += problem_size.m(); + } + + ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; + + access_n = access_n_bak; + access_s = access_s_bak; + + access_n[0] = convert_norm_output(max_val); + access_s[0] = convert_sum_output(inv_sum); + } + + }*/ +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 22d82e9d5d..7dab25b1dc 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -43,7 +43,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -static const int NumThreadsPerWarp = 32; +static constexpr int NumThreadsPerWarp = 32; static const int NumThreadsPerWarpGroup = 128; static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; From abafbabd2722f8dd1415cf3efe41c2bf67ce8d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 19 Nov 2024 13:06:27 +0100 Subject: [PATCH 04/19] partial cleanup --- .../35_gemm_softmax/gemm_online_softmax.cpp | 139 +++--------------- examples/35_gemm_softmax/gemm_softmax.cu | 10 +- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 2 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 57 +------ examples/35_gemm_softmax/softmax_finalize.hpp | 71 +-------- 5 files changed, 35 insertions(+), 244 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index eb2533947b..720de5d824 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -31,28 +31,11 @@ **************************************************************************************************/ /*! \file - \brief Simple GEMM example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture + \brief GEMM + Softmax example using Cute and CUTLASS 3.x APIs for NVIDIA Ampere architecture This example demonstrate how to instantiate and run a TF32 GEMM using the Cute and CUTLASS 3.x APIs on NVIDIA Ampere architecture. Please check example 07 and 08 for - the basics of tensor op gemm kernels. On NVIDIA Ampere architecture, most concept - still holds. The two main differences are: - - (1) NVIDIA Ampere architecture introduces a new series of tensor core instructions - (see include/cute/arch/mma_sm80.hpp) which are more efficient on Ampere. - (2) NVIDIA Ampere architecture uses CP_ASYNC (see include/cute/arch/copy_sm80.hpp) - to build a multistage software pipeline to better hide latency (see - include/cutlass/gemm/collective/sm80_mma_multistage.hpp). - - Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) - data types in tensor cores. One big advantage is that we can load in fp32 data and convert - them implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate - traditional fp32 data by using NVIDIA Ampere architecture. - - Examples: - - $ ./examples/14_ampere_tf32_tensorop_gemm/14_ampere_tf32_tensorop_gemm_cute - + the basics of tensor op gemm kernels. */ #include @@ -246,59 +229,15 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_max; - cutlass::DeviceAllocation block_sum; + cutlass::DeviceAllocation block_max; + cutlass::DeviceAllocation block_sum; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; // // Methods // - - /*bool verify(const ProblemShapeType& problem_size, ElementOutput alpha, ElementOutput beta) { - auto [M, N, K, L] = problem_size; - - cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); - cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); - cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); - - cutlass::reference::device::GemmComplex( - {M, N, K}, - ElementCompute(alpha), - ref_A, - cutlass::ComplexTransform::kNone, - ref_B, - cutlass::ComplexTransform::kNone, - ElementCompute(beta), - ref_C, - ref_D, - ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); - -#if defined(CUTLASS_ENABLE_SYCL) - syclcompat::wait_and_throw(); -#else - cudaError_t result = cudaDeviceSynchronize(); - if (result != cudaSuccess) { - std::cerr << "Reference kernel failed. Last CUDA error: " - << cudaGetErrorString(result) << std::endl; - return false; - } -#endif - - // Check if output from CUTLASS kernel and reference kernel are equal or not - bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); - - return passed; - }*/ - - template + template bool verify_tensor(std::vector vector_Input, \ std::vector vector_Input_Ref, const Options& options) { @@ -362,7 +301,7 @@ struct ExampleRunner { ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, - ElementCompute, ElementCompute + ElementCompute, ElementCompute, ElementD >( problem_size, options.alpha, @@ -388,42 +327,20 @@ struct ExampleRunner { std::vector matrix_D(layout_C.capacity(extent_C)); cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); - auto& matrix_Softmax = matrix_D; - //std::vector matrix_Softmax(layout_C.capacity(extent_C)); - //cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); - // Compute the norm for (int m = 0; m < options.m; ++m) { reference_N.at({m, 0}) = view_D_Ref.ref().at({m, 0}); - if(batch_idx == 0 && m < 3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ - std::cout << "ref tmp " << m << " " << 0 << ": " << view_D_Ref.ref().at({m, 0}) << std::endl; - } for (int n = 1; n < options.n; ++n) { - //std::cout << "val: " << view_D_Ref.ref().at({m, n}) << std::endl; reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); - - if(batch_idx == 0 && m < 3 && n<3 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ - std::cout << "ref tmp " << m << " " << n << ": " << view_D_Ref.ref().at({m, n}) << std::endl; - } - if(batch_idx == 0 && m==0 && n==127 /*abs(view_D_Ref.ref().at({m, n}) - 240395) < 0.1*/){ - std::cout << "ref max tmp " << m << " " << n << ": " << reference_N.at({m, 0}) << std::endl; - } - } - if(batch_idx == 0 && m == 0){ - std::cout << "ref max: " << reference_N.at({m, 0}) << std::endl; } } // Compute softmax for (int m = 0; m < options.m; ++m) { - float sum = float(); - + float sum = 0; for (int n = 0; n < options.n; ++n) { sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } - if(batch_idx == 0 && m == 0){ - std::cout << "ref sum: " << sum << std::endl; - } float inv_sum = float(1.0f / sum); @@ -434,35 +351,9 @@ struct ExampleRunner { } } - // Verification checks - set any of these to 'true' to override the verification checks. - bool verified_D = false; - bool verified_Softmax = false; - - // Verify softmax output - if (!verified_D) { - verified_D = verify_tensor(matrix_D, matrix_D_Ref, options); - } - + bool verified_Softmax = verify_tensor(matrix_D, matrix_Softmax_Ref, options); if (!verified_Softmax) { - verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref, options); - } - //TODO(Tadej): just softmax - if (!verified_D && !verified_Softmax) { - std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; - - // Summarize which checks failed - if (!verified_D) { - std::cerr << "Verification of D tensor failed\n"; - } else{ - std::cerr << "Verification of D tensor passed\n"; - } - - if (!verified_Softmax) { - std::cerr << "Verification of Softmax tensor failed\n"; - } else{ - std::cerr << "Verification of Softmax tensor passed\n"; - } - + std::cerr << "Verification of Softmax tensor failed\n"; return false; } } @@ -505,7 +396,11 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, block_max.get(), block_sum.get(), stride_tmp}, + {{options.alpha,//static_cast(options.alpha), + options.beta},//static_cast(options.beta)}, + block_C.get(), stride_C, + block_D.get(), stride_D, + block_max.get(), block_sum.get(), stride_tmp}, hw_info }; @@ -610,6 +505,10 @@ int main(int argc, char const **args) { hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); // Problem configuration + /*using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementAcc = float; + using ElementOutput = cutlass::half_t;*/ using ElementA = float; using ElementB = float; using ElementAcc = float; @@ -681,7 +580,7 @@ int main(int argc, char const **args) { // elements. This becomes the vector width of // math instructions in the epilogue too ElementAcc, // <- data type of accumulator - ElementOutput>; // <- data type for alpha/beta in linear combination function + ElementAcc>; // <- data type for alpha/beta in linear combination function using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< cutlass::detail::TagToStrideC_t, diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 27156ea02d..7e663679b6 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -201,19 +201,23 @@ struct Testbed { // - using ElementA = cutlass::half_t; + /*using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; + using ElementCompute = float;*/ + using ElementA = float; + using ElementB = float; + using ElementC = float; using ElementCompute = float; using ElementD = ElementC; using ElementSoftmax = ElementC; - using LayoutA = cutlass::layout::RowMajor; + using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 08fa798684..bdc4f6052a 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -91,7 +91,7 @@ class GemmSoftmaxAdapter using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< ElementD, typename GemmKernel::StrideD, - ElementD, typename GemmKernel::StrideD, + ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StrideTmp, ElementD, typename GemmKernel::StrideD>; // Map back to 2.x type as best as possible diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 96c902e373..f5f807a84d 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -298,7 +298,6 @@ class SoftmaxEpilogue { template< class ProblemShapeMNKL, - //class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, class TiledMma, @@ -391,18 +390,18 @@ class SoftmaxEpilogue { bool is_first = ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0; if(is_first){ - print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); + //print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); //print("blk_shape_MNK: "); print(blk_shape_MNK); print("\n"); //print("partial_block: "); print(partial_block); print("\n"); //print("thr_mma: "); print(thr_mma); print("\n"); //print("tiled_mma: "); print(tiled_mma); print("\n"); //print("acc: "); print(accumulators); print("\n"); //print("mD_mnl: "); print(mD_mnl); print("\n"); - print("mMax_mnl: "); print(mMax_mnl); print("\n"); + //print("mMax_mnl: "); print(mMax_mnl); print("\n"); //print("gD_mnl: "); print(gD_mnl); print("\n"); - print("gMax_mnl: "); print(gMax_mnl); print("\n"); + //print("gMax_mnl: "); print(gMax_mnl); print("\n"); //print("gD: "); print(gD); print("\n"); - print("gMax: "); print(gMax); print("\n"); + //print("gMax: "); print(gMax); print("\n"); //print("tCgD: "); print(tCgD); print("\n"); //print("sC: "); print(sC); print("\n"); //print("tCsC: "); print(tCsC); print("\n"); @@ -493,55 +492,9 @@ class SoftmaxEpilogue { } } if(is_first){ - print("sum epilogue val:"); print(sum); print("\n"); + //print("sum epilogue val:"); print(sum); print("\n"); } - gSum(thread_idx,0) = sum; - - - /*//reduce_max(accumulators, acc_max); - reduce_max_wg(accumulators, acc_max, tCsC, sC, tCcD, residue_mnk, thread_idx); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(accumulators); ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<1>(accumulators); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(accumulators); ++k) { - if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - accumulators(i,j,k) = expf(accumulators(i,j,k) - acc_max(k)); - } - } - } - } - - reduce_sum(accumulators, acc_sum); - if(wlid == 0){ - for (int k = 0; k < size<2>(accumulators); ++k) { - gSum(wid,k) = acc_sum(k); - } - }*/ - - //TODO write out reductions - - //second kernel: - // - finalize max reduction: mN = sum(mj) - // - finalize sum reduction: sN = sum(sj * exp(mj-mN)) - // - finalize softmax: yi = exp(xi-mN)/sN - - /*CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(accumulators); ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<1>(accumulators); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(accumulators); ++k) { - if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCgD(i,j,k) = accumulators(i,j,k); - } - } - } - }*/ - } private: diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index bc91f85dd9..72a7c93f8c 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -197,7 +197,7 @@ class SoftmaxFinalize { Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) if(m==0 && batch_id==0){ - print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); + //print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); } ElementPartial max_val = std::numeric_limits::min(); @@ -221,8 +221,8 @@ class SoftmaxFinalize { //sum_val = reduce_sum(sum_val); if(m==0 && batch_id==0){ - print("max_val: "); print(max_val); print("\n"); - print("sum_val: "); print(sum_val); print("\n"); + //print("max_val: "); print(max_val); print("\n"); + //print("sum_val: "); print(sum_val); print("\n"); } ElementPartial norm = 1 / sum_val; @@ -231,71 +231,6 @@ class SoftmaxFinalize { mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; } } - /* - // defining three vars for a general reduction module - cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; - int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; - int access_offset = isGroupedProblem ? 0 : bid * bdim; - - if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; - - ElementNorm *curr_ptr_Max = isGroupedProblem ? \ - params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ - params.args.block_Norm + block_batch * params.args.batch_stride_Max; - ElementSum *curr_ptr_Sum = isGroupedProblem ? \ - params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ - params.args.block_Sum + block_batch * params.args.batch_stride_Sum; - - int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; - - ConvertSum convert_sum; - ConvertNorm convert_norm; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - uint32_t float_max_bits = 0xff7fffff; - float min_float = reinterpret_cast(float_max_bits); - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { - ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; - ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; - ElementNorm *access_n_bak = access_n; - ElementSum *access_s_bak = access_s; - ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); - ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); - ElementNorm fetch_n; - ElementSum fetch_s; - - CUTLASS_PRAGMA_UNROLL - for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { - cutlass::arch::global_load(fetch_n, access_n, true); - max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); - access_n += problem_size.m(); - } - - access_n = access_n_bak; - - CUTLASS_PRAGMA_UNROLL - for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { - cutlass::arch::global_load(fetch_n, access_n, true); - cutlass::arch::global_load(fetch_s, access_s, true); - sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); - access_n += problem_size.m(); - access_s += problem_size.m(); - } - - ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; - - access_n = access_n_bak; - access_s = access_s_bak; - - access_n[0] = convert_norm_output(max_val); - access_s[0] = convert_sum_output(inv_sum); - } - - }*/ }; ///////////////////////////////////////////////////////////////////////////////////////////////// From 8d8c1bcbd7626e0e6521d634c1912e927486bcf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 25 Nov 2024 14:20:32 +0100 Subject: [PATCH 05/19] 2nd kernel work distribution fix --- .../35_gemm_softmax/gemm_online_softmax.cpp | 11 ++++- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 4 +- examples/35_gemm_softmax/softmax_finalize.hpp | 42 ++++++++++++++++--- 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 720de5d824..f8ebdc3221 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -250,8 +250,8 @@ struct ExampleRunner { float abs_diff = fabs(diff); float abs_ref = fabs((float)vector_Input_Ref.at(i)); float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; - if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { - printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + printf("i = %d diff = %f, {%f, %f}.\n", i, abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } @@ -333,6 +333,9 @@ struct ExampleRunner { for (int n = 1; n < options.n; ++n) { reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); } + /*if(m == 3516 && batch_idx == 0){ + std:: cout << "max0: " << reference_N.at({m, 0}) << std::endl; + }*/ } // Compute softmax @@ -341,6 +344,10 @@ struct ExampleRunner { for (int n = 0; n < options.n; ++n) { sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } + + /*if(m == 3516 && batch_idx == 0){ + std:: cout << "sum0: " << sum << std::endl; + }*/ float inv_sum = float(1.0f / sum); diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index bdc4f6052a..18c5b4ec15 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -448,12 +448,12 @@ class GemmSoftmaxAdapter #endif //EventManager::getInstance().addEvent(event); - const auto sycl_block2 = syclcompat::dim3(128, 1, 1); + const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.IOSize[0]), 1); const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), params.softmax_params.args.batch_count, 1); auto event2 = launch>(launch_policy{ - sycl_grid2, sycl_block2, local_mem_size{0}}, + sycl_grid2, sycl_block2, local_mem_size{SoftmaxFinalizeKernel::SharedStorageSize}}, params.softmax_params); EventManager::getInstance().addEvent(event2); #else diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 72a7c93f8c..ca6c3c458c 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -109,10 +109,11 @@ class SoftmaxFinalize { }; struct SharedStorage { - - + cute::array_aligned s_mem; }; + static constexpr int SharedStorageSize = sizeof(SharedStorage); + // // Params struct // @@ -180,7 +181,10 @@ class SoftmaxFinalize { int batch_id = m_batch_id / params.args.IOSize[1]; int m = m_batch_id % params.args.IOSize[1];*/ - int m = ThreadIdxX() + BlockDimX() * BlockIdxX(); + int x = ThreadIdxX(); + int m = x + BlockDimX() * BlockIdxX(); + int y = ThreadIdxY(); + int y_size = BlockDimY(); int batch_id = BlockIdxY(); if(m>=params.args.IOSize[0]){ @@ -196,28 +200,54 @@ class SoftmaxFinalize { Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); // (m,n,l) Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) + //Represent the shared tensor + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32, 2))); + if(m==0 && batch_id==0){ //print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); } ElementPartial max_val = std::numeric_limits::min(); - for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); /*if(m==0 && batch_id==0){ print("partial_max: "); print(partial_max); print("\n"); }*/ max_val = max_val > partial_max ? max_val : partial_max; } + sPartial(x,y,0) = max_val; + syncthreads(); + //TODO(Tadej): improve reduction + for(int y2 = 0; y2 < y_size; y2++){ + ElementPartial partial_max = sPartial(x,y2,0); + max_val = max_val > partial_max ? max_val : partial_max; + } + /*if(m == 3516 && y == 0){ + print("kernel max"); print(max_val); print("\n"); + }*/ //max_val = reduce_max(max_val); //mOut(0,0,0) = max_val; return; ElementPartial sum_val = 0; - for(int partial_n = 0; partial_n < params.args.partialSize[1]; partial_n += 1){ + for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val = sum_val + partial_sum * cutlass::fast_exp(partial_max - max_val); } + syncthreads(); + sPartial(x,y,1) = sum_val; + syncthreads(); + sum_val = 0; + //TODO(Tadej): improve reduction + for(int y2 = 0; y2 < y_size; y2++){ + ElementPartial partial_max = sPartial(x,y2,0); + ElementPartial partial_sum = sPartial(x,y2,1); + sum_val = sum_val + partial_sum /** cutlass::fast_exp(partial_max - max_val)*/; + } + /*if(m == 3516 && y == 0){ + print("kernel sum"); print(sum_val); print("\n"); + }*/ //sum_val = reduce_sum(sum_val); if(m==0 && batch_id==0){ @@ -227,7 +257,7 @@ class SoftmaxFinalize { ElementPartial norm = 1 / sum_val; - for(int n = 0; n < params.args.IOSize[1]; n += 1){ + for(int n = y; n < params.args.IOSize[1]; n += y_size){ mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; } } From 1ad0f8ec9f0c5f8e16ba7d0dfc1348953e7d0816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 2 Dec 2024 10:29:14 +0100 Subject: [PATCH 06/19] native exp + optimizations --- .../35_gemm_softmax/gemm_online_softmax.cpp | 2 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 19 +------------------ examples/35_gemm_softmax/softmax_finalize.hpp | 19 +++++++++++++++++-- include/cutlass/fast_math.h | 2 ++ 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index f8ebdc3221..cdafc65a50 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -430,7 +430,7 @@ struct ExampleRunner { std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; if (!result.passed) { - exit(-1); + //exit(-1); } // Run profiling loop diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index f5f807a84d..0e0505bf67 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -424,9 +424,6 @@ class SoftmaxEpilogue { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k), tCgC(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); tCsC(i,j,k) = accumulators(i,j,k); - /*if(is_first){ - print("acc1.1:"); print(tCsC(i,j,k)); print("\n"); - }*/ } } } @@ -442,10 +439,6 @@ class SoftmaxEpilogue { accumulators(i,j,k) = epilogue_op(accumulators(i,j,k)); tCgD(i,j,k) = accumulators(i,j,k); tCsC(i,j,k) = accumulators(i,j,k); - /*if(is_first){ - print("acc1.2:"); print(accumulators(i,j,k)); print(".\n"); - print("idx:"); print(tCsC.layout()(i,j,k)); print(".\n"); - }*/ } } } @@ -461,19 +454,8 @@ class SoftmaxEpilogue { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i) = sC(thread_idx, i); max = cutlass::fast_max(max, accumulators(i)); - /*if(is_first && i < 3){ - print("acc2 :"); print(accumulators(i)); print("\n"); - //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); - for (int j = 0; j < 3; ++j) { - print("shared :"); print(j); print(" "); print(i); print(": "); print(sC(j, i)); print("\n"); - } - }*/ } } - /*if(m_coord == 0 && n_coord == 1 && ThreadIdxX()==0){ - print("max epilogue val:"); print(max); print("\n"); - print("idx:"); print(n_coord); print("\n"); - }*/ gMax(thread_idx,0) = max; @@ -482,6 +464,7 @@ class SoftmaxEpilogue { for (int i = 0; i < size<0>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { sum += cutlass::fast_exp(accumulators(i) - max); + //sum += sycl::native::exp(accumulators(i) - max); if(is_first){ //print("acc3 :"); print(accumulators(i)); print("\n"); //print("diff :"); print(accumulators(i) - max); print("\n"); diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index ca6c3c458c..56a88e8acd 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -257,8 +257,23 @@ class SoftmaxFinalize { ElementPartial norm = 1 / sum_val; - for(int n = y; n < params.args.IOSize[1]; n += y_size){ - mOut(m, n, batch_id) = cutlass::fast_exp(mIn(m, n, batch_id) - max_val) * norm; + int unroll = 2; + //_Pragma("unroll 2") + //for(int n = y; n < params.args.IOSize[1]; n += y_size){ + for(int n = y * unroll; n < params.args.IOSize[1]; n += y_size * unroll){ + auto inVal = mIn(m, n, batch_id); + auto inVal2 = mIn(m, n+1, batch_id); + //auto inVal3 = mIn(m, n+2, batch_id); + //auto inVal4 = mIn(m, n+3, batch_id); + mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; + mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; + //mOut(m, n+2, batch_id) = cutlass::fast_exp(inVal3 - max_val) * norm; + //mOut(m, n+3, batch_id) = cutlass::fast_exp(inVal4 - max_val) * norm; + } + if(params.args.IOSize[1]%2==1){ + int n = params.args.IOSize[1] - 1; + auto inVal = mIn(m, n, batch_id); + mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; } } }; diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index fa3873c5e7..c856afa76e 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -859,6 +859,8 @@ CUTLASS_HOST_DEVICE float fast_exp(float x) { #if defined(__CUDA_ARCH__) return ::expf(x); + #elif defined(__SYCL_CUDA_ARCH__) + return ::sycl::native::exp(x); #else return std::exp(x); #endif From cf34fdad589ec19f670709704f8a80ead58e5a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Mon, 2 Dec 2024 12:51:58 +0100 Subject: [PATCH 07/19] basic cleanup --- .../35_gemm_softmax/gemm_online_softmax.cpp | 18 +---- examples/35_gemm_softmax/gemm_softmax.cu | 5 -- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 2 - examples/35_gemm_softmax/softmax_epilogue.hpp | 53 +------------- examples/35_gemm_softmax/softmax_finalize.hpp | 69 +------------------ 5 files changed, 6 insertions(+), 141 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index cdafc65a50..67dff1460a 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -333,9 +333,6 @@ struct ExampleRunner { for (int n = 1; n < options.n; ++n) { reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementSoftmax(view_D_Ref.ref().at({m, n}))); } - /*if(m == 3516 && batch_idx == 0){ - std:: cout << "max0: " << reference_N.at({m, 0}) << std::endl; - }*/ } // Compute softmax @@ -344,11 +341,6 @@ struct ExampleRunner { for (int n = 0; n < options.n; ++n) { sum += std::exp( float(view_D_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } - - /*if(m == 3516 && batch_idx == 0){ - std:: cout << "sum0: " << sum << std::endl; - }*/ - float inv_sum = float(1.0f / sum); for (int n = 0; n < options.n; ++n) { @@ -403,8 +395,8 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha,//static_cast(options.alpha), - options.beta},//static_cast(options.beta)}, + {{options.alpha, + options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, block_max.get(), block_sum.get(), stride_tmp}, @@ -430,7 +422,7 @@ struct ExampleRunner { std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; if (!result.passed) { - //exit(-1); + exit(-1); } // Run profiling loop @@ -512,10 +504,6 @@ int main(int argc, char const **args) { hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); // Problem configuration - /*using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementAcc = float; - using ElementOutput = cutlass::half_t;*/ using ElementA = float; using ElementB = float; using ElementAcc = float; diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 7e663679b6..47673501bb 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -200,11 +200,6 @@ struct Testbed { // Type definitions // - - /*using ElementA = cutlass::half_t; - using ElementB = cutlass::half_t; - using ElementC = cutlass::half_t; - using ElementCompute = float;*/ using ElementA = float; using ElementB = float; using ElementC = float; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 18c5b4ec15..c56e509af0 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -446,8 +446,6 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - //EventManager::getInstance().addEvent(event); - const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.IOSize[0]), 1); const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), params.softmax_params.args.batch_count, diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 0e0505bf67..2582784d29 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -322,9 +322,6 @@ class SoftmaxEpilogue { static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - //auto wlid = thread_idx % NumWarpsPerWarpGroup; // warp local id - //auto wid = thread_idx / NumWarpsPerWarpGroup; // warp id in tile - // Separate out problem and tile shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); @@ -365,11 +362,8 @@ class SoftmaxEpilogue { auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) - //Tensor tCgMax = thr_mma.partition_C(gMax); // (VEC,THR_M,THR_N) - //Tensor tCgSum = thr_mma.partition_C(gSum); // (VEC,THR_M,THR_N) Tensor tCsC = thr_mma.partition_C(sC); // (VEC,THR_M,THR_N) - static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), "Source and destination must have the same number of elements."); @@ -380,39 +374,6 @@ class SoftmaxEpilogue { auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); - //Tensor acc_max = make_tensor(Shape(accumulators)>, Int(accumulators)>>{}); - //Tensor acc_max = make_tensor(size<0>(accumulators)); - //Tensor acc_max = make_tensor_like(take<2,3>(accumulators)); - //Tensor acc_sum = make_tensor_like(acc_max); //TODO can reuse prev? - - //Tensor acc_max = make_tensor(shape<2>(accumulators), LayoutLeft{}); - //Tensor acc_sum = make_tensor(shape<2>(accumulators), LayoutLeft{}); //TODO can reuse prev? - - bool is_first = ThreadIdxX()==0 && BlockIdxX()==0 && BlockIdxY()==0 && BlockIdxZ()==0; - if(is_first){ - //print("blk_coord_mnkl: "); print(blk_coord_mnkl); print("\n"); - //print("blk_shape_MNK: "); print(blk_shape_MNK); print("\n"); - //print("partial_block: "); print(partial_block); print("\n"); - //print("thr_mma: "); print(thr_mma); print("\n"); - //print("tiled_mma: "); print(tiled_mma); print("\n"); - //print("acc: "); print(accumulators); print("\n"); - //print("mD_mnl: "); print(mD_mnl); print("\n"); - //print("mMax_mnl: "); print(mMax_mnl); print("\n"); - //print("gD_mnl: "); print(gD_mnl); print("\n"); - //print("gMax_mnl: "); print(gMax_mnl); print("\n"); - //print("gD: "); print(gD); print("\n"); - //print("gMax: "); print(gMax); print("\n"); - //print("tCgD: "); print(tCgD); print("\n"); - //print("sC: "); print(sC); print("\n"); - //print("tCsC: "); print(tCsC); print("\n"); - //print("sC.data: "); print(&sC(0)); print("\n"); - //print("tCsC.data: "); print(&tCsC(0)); print("\n"); - //decltype(tCsC(0)) a = "asd"; - //print("tCgMax: "); print(tCgMax); print("\n"); - //print("acc_max: "); print(acc_max); print("\n"); - //print("accumulators: "); print(accumulators); print("\n"); - } - if(is_source_needed()){ CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(accumulators); ++i) { @@ -447,7 +408,7 @@ class SoftmaxEpilogue { syncthreads(); - // assumption size<0>(sC) == wg size + // assumption: size<0>(sC) == wg size ElementAccumulator max = std::numeric_limits::min(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(sC); ++i) { @@ -456,7 +417,6 @@ class SoftmaxEpilogue { max = cutlass::fast_max(max, accumulators(i)); } } - gMax(thread_idx,0) = max; ElementAccumulator sum = 0; @@ -464,19 +424,8 @@ class SoftmaxEpilogue { for (int i = 0; i < size<0>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { sum += cutlass::fast_exp(accumulators(i) - max); - //sum += sycl::native::exp(accumulators(i) - max); - if(is_first){ - //print("acc3 :"); print(accumulators(i)); print("\n"); - //print("diff :"); print(accumulators(i) - max); print("\n"); - //print("add :"); print(cutlass::fast_exp(accumulators(i) - max)); print("\n"); - //print("sum :"); print(sum); print("\n"); - //print("idx:"); print(sC.layout()(thread_idx, i)); print(".\n"); - } } } - if(is_first){ - //print("sum epilogue val:"); print(sum); print("\n"); - } gSum(thread_idx,0) = sum; } diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 56a88e8acd..ceb4de1813 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -83,29 +83,6 @@ class SoftmaxFinalize { ElementPartial* ptr_partial_max; ElementPartial* ptr_partial_sum; ElementOutput* ptr_out; - -/* - // - // Methods - // - Arguments() { } - - Arguments( - cutlass::gemm::GemmCoord problem_size, - ElementNorm* block_Norm, - ElementSum* block_Sum - ): - problem_size(problem_size), - block_Norm(block_Norm), - block_Sum(block_Sum), - problem_sizes(nullptr), - offset_Norm_Device(nullptr), - offset_Sum_Device(nullptr), - batch_stride_Max(0), - batch_stride_Sum(0) - { - - }*/ }; struct SharedStorage { @@ -168,19 +145,6 @@ class SoftmaxFinalize { using ConvertInput = cutlass::NumericConverter; using ConvertNormOutput = cutlass::NumericConverter; - /*int tid = ThreadIdxX(); - //int bid = BlockIdxX(); - int bdim = BlockDimX(); - - int warps_in_block = bdim / NumThreadsPerWarp; - - int wlid = tid % warps_in_block; // local id of thread in warp - int gid = tid;// + bid * GridDimX(); - int bsize = GridDimX(); - int m_batch_id = BlockIdxY(); - int batch_id = m_batch_id / params.args.IOSize[1]; - int m = m_batch_id % params.args.IOSize[1];*/ - int x = ThreadIdxX(); int m = x + BlockDimX() * BlockIdxX(); int y = ThreadIdxY(); @@ -203,16 +167,9 @@ class SoftmaxFinalize { //Represent the shared tensor Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32, 2))); - if(m==0 && batch_id==0){ - //print("PartialTensorShape: "); print(PartialTensorShape); print("\n"); - } - ElementPartial max_val = std::numeric_limits::min(); for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); - /*if(m==0 && batch_id==0){ - print("partial_max: "); print(partial_max); print("\n"); - }*/ max_val = max_val > partial_max ? max_val : partial_max; } sPartial(x,y,0) = max_val; @@ -222,12 +179,6 @@ class SoftmaxFinalize { ElementPartial partial_max = sPartial(x,y2,0); max_val = max_val > partial_max ? max_val : partial_max; } - /*if(m == 3516 && y == 0){ - print("kernel max"); print(max_val); print("\n"); - }*/ - //max_val = reduce_max(max_val); - - //mOut(0,0,0) = max_val; return; ElementPartial sum_val = 0; for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ @@ -243,32 +194,16 @@ class SoftmaxFinalize { for(int y2 = 0; y2 < y_size; y2++){ ElementPartial partial_max = sPartial(x,y2,0); ElementPartial partial_sum = sPartial(x,y2,1); - sum_val = sum_val + partial_sum /** cutlass::fast_exp(partial_max - max_val)*/; - } - /*if(m == 3516 && y == 0){ - print("kernel sum"); print(sum_val); print("\n"); - }*/ - //sum_val = reduce_sum(sum_val); - - if(m==0 && batch_id==0){ - //print("max_val: "); print(max_val); print("\n"); - //print("sum_val: "); print(sum_val); print("\n"); + sum_val = sum_val + partial_sum; } ElementPartial norm = 1 / sum_val; - int unroll = 2; - //_Pragma("unroll 2") - //for(int n = y; n < params.args.IOSize[1]; n += y_size){ - for(int n = y * unroll; n < params.args.IOSize[1]; n += y_size * unroll){ + for(int n = y * 2; n < params.args.IOSize[1]; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); - //auto inVal3 = mIn(m, n+2, batch_id); - //auto inVal4 = mIn(m, n+3, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; - //mOut(m, n+2, batch_id) = cutlass::fast_exp(inVal3 - max_val) * norm; - //mOut(m, n+3, batch_id) = cutlass::fast_exp(inVal4 - max_val) * norm; } if(params.args.IOSize[1]%2==1){ int n = params.args.IOSize[1] - 1; From 84ce5c11d2e61c30b66b8ddf870e7ded038a71e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 3 Dec 2024 11:52:04 +0100 Subject: [PATCH 08/19] final cleanup --- examples/35_gemm_softmax/CMakeLists.txt | 2 +- examples/35_gemm_softmax/gemm_softmax.cu | 11 +- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 43 ++--- examples/35_gemm_softmax/softmax_epilogue.hpp | 164 ++---------------- examples/35_gemm_softmax/softmax_finalize.hpp | 82 ++++----- include/cutlass/fast_math.h | 1 + 6 files changed, 66 insertions(+), 237 deletions(-) diff --git a/examples/35_gemm_softmax/CMakeLists.txt b/examples/35_gemm_softmax/CMakeLists.txt index d7f2cd574b..824453a656 100644 --- a/examples/35_gemm_softmax/CMakeLists.txt +++ b/examples/35_gemm_softmax/CMakeLists.txt @@ -39,4 +39,4 @@ cutlass_example_add_executable( 35_gemm_online_softmax gemm_online_softmax.cpp ) -endif() \ No newline at end of file +endif() diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 47673501bb..27156ea02d 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -200,19 +200,20 @@ struct Testbed { // Type definitions // - using ElementA = float; - using ElementB = float; - using ElementC = float; + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; using ElementCompute = float; using ElementD = ElementC; using ElementSoftmax = ElementC; - using LayoutA = cutlass::layout::ColumnMajor; + using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index c56e509af0..0d31e19457 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -49,12 +50,6 @@ #include "cutlass/trace.h" #endif // !defined(__CUDACC_RTC__) -// 2.x -//#include "cutlass/gemm/device/gemm_universal_base.h" -//#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -//#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -//#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" - // 3.x #include "cutlass/gemm/kernel/gemm_universal.hpp" @@ -171,8 +166,6 @@ class GemmSoftmaxAdapter /// Argument structure: User API using Arguments = typename GemmKernel::Arguments; - /// Argument structure: Kernel API - //using Params = typename GemmKernel::Params; struct Params{ typename GemmKernel::Params gemm_params; @@ -271,6 +264,20 @@ class GemmSoftmaxAdapter return max_active_blocks; } + void initialize_softmax_params(Arguments const& args, typename SoftmaxFinalizeKernel::Arguments& softmax_args){ + softmax_args.M = get<0>(args.problem_shape); + softmax_args.dataN = get<1>(args.problem_shape); + softmax_args.partialN = cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{})); + softmax_args.batch_count = get<3>(args.problem_shape); + softmax_args.dInput = args.epilogue.dD; + softmax_args.dPartial = args.epilogue.dTmp; + softmax_args.dOutput = args.epilogue.dD; + softmax_args.ptr_in = args.epilogue.ptr_D; + softmax_args.ptr_partial_max = args.epilogue.ptr_max; + softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; + softmax_args.ptr_out = args.epilogue.ptr_D; + } + /// Initializes GEMM state from arguments. Status initialize( @@ -289,19 +296,7 @@ class GemmSoftmaxAdapter } // Initialize the Params structure params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); - //TODO(Tadej) move to finalize kernel class? - auto& softmax_args = params_.softmax_params.args; - softmax_args.IOSize = {get<0>(args.problem_shape), get<1>(args.problem_shape)}; - softmax_args.partialSize = {get<0>(args.problem_shape), - cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{}))}; - softmax_args.batch_count = get<3>(args.problem_shape); - softmax_args.dInput = args.epilogue.dD; - softmax_args.dPartial = args.epilogue.dTmp; - softmax_args.dOutput = args.epilogue.dD; - softmax_args.ptr_in = args.epilogue.ptr_D; - softmax_args.ptr_partial_max = args.epilogue.ptr_max; - softmax_args.ptr_partial_sum = args.epilogue.ptr_sum; - softmax_args.ptr_out = args.epilogue.ptr_D; + initialize_softmax_params(args, params_.softmax_params.args); // Don't set the function attributes - require the CudaHostAdapter to set it. if constexpr (kEnableCudaHostAdapter) { @@ -345,7 +340,7 @@ class GemmSoftmaxAdapter } params_.gemm_params = GemmKernel::to_underlying_arguments(args, workspace); - //TODO(Tadej) update softmax args + initialize_softmax_params(args, params_.softmax_params.args); return Status::kSuccess; } @@ -446,8 +441,8 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.IOSize[0]), 1); - const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.IOSize[0], sycl_block2.x), + const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.M), 1); + const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x), params.softmax_params.args.batch_count, 1); auto event2 = launch>(launch_policy{ diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 2582784d29..093a6d70d1 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,9 +29,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ #pragma once @@ -90,7 +88,6 @@ class SoftmaxEpilogue { static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { - //cute::array_aligned(BlockShapeMNK{})>,C(BlockShapeMNK{})>>>>> smem_c; cute::array_aligned(BlockShapeMNK{}) * get<1>(BlockShapeMNK{})> smem_c; }; @@ -155,147 +152,6 @@ class SoftmaxEpilogue { return epilogue_op.is_source_needed(); } - template < - bool zero_init, - class FragSrc, - class FragDst, - class Op - > - CUTLASS_DEVICE static void reduceSg(FragSrc const &src, FragDst &dst, Op op) { - // reduce across all the -N- M tiles in shape - CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < size<2>(src); z++) { - dst(z) = zero_init ? src(0, 0, z) : op(dst(z), src(0, 0, z)); - CUTLASS_PRAGMA_UNROLL - for(int x = 0; x < size<0>(src); x++) { - CUTLASS_PRAGMA_UNROLL - for(int y = 0; y < size<1>(src); y++) { - dst(z) = op(dst(z), src(x, y, z)); - } - } - } - - // reduce across the sub_group to get the final output - auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - CUTLASS_PRAGMA_UNROLL - for(int z = 1; z < size<2>(src); z++) { - CUTLASS_PRAGMA_UNROLL - for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { - dst(z) = op(dst(z), syclcompat::permute_sub_group_by_xor(sg, dst(z), laneMask, 16)); - } - } - } - - template < - class FragSrc, - class FragDst, - class SharedThreadTens, - class SharedTens, - class ResidueMap, - class Residue, - class Op - > - CUTLASS_DEVICE static ElementAccumulator reduceWg(FragSrc const &src, FragDst &dst, - SharedThreadTens& tCsC, SharedTens& sC, - ResidueMap tCcD, Residue residue_mnk, int thread_idx, - ElementAccumulator init, Op op) { - //TODO(Tadej): single loop over all dims - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(src); ++i) { - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<1>(src); ++j) { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(src); ++k) { - if (elem_less(tCcD(i,j,k), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { - tCsC(i,j,k) = src(i,j,k); - } else{ - tCsC(i,j,k) = init; - } - } - } - } - - syncthreads(); - - ElementAccumulator acc = sC(0, thread_idx); - for (int i = 1; i < size(src); ++i) { - acc = op(acc, sC(i, thread_idx)); - } - - syncthreads(); - - //broadcast it back to threads - //TODO(Tadej): optimize - for (int i = 0; i < size(src); ++i) { - sC(i, thread_idx) = acc; - } - - syncthreads(); - - CUTLASS_PRAGMA_UNROLL - for(int k = 1; k < size<2>(src); k++) { - dst(k) = tCsC(0,0,k); - } - - return acc; - - /*reduceSg(src, dst, op); - for(int i=ThreadIdxX() % NumThreadsPerWarp; i - CUTLASS_DEVICE static void reduce_max(FragSrc const &src, FragMax& max) { - reduceSg(src, max, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); - } - - template < - class FragSrc, - class FragDst, - class SharedThreadTens, - class SharedTens, - class ResidueMap, - class Residue - > - CUTLASS_DEVICE static ElementAccumulator reduce_max_wg(FragSrc const &src, FragDst &dst, - SharedThreadTens& tCsC, SharedTens& sC, - ResidueMap tCcD, Residue residue_mnk, int thread_idx) { - - return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, - std::numeric_limits::min(), - [](ElementAccumulator const & x, ElementAccumulator const & y) { return x > y ? x : y; }); - } - - template < - bool zero_init, - class FragSrc, - class FragSum - > - CUTLASS_DEVICE static void reduce_sum(FragSrc const &src, FragSum& sum) { - reduceSg(src, sum, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x + y; }); - } - - template < - class FragSrc, - class FragDst, - class SharedThreadTens, - class SharedTens, - class Residue, - class ResidueMap - > - CUTLASS_DEVICE static ElementAccumulator reduce_sum_wg(FragSrc const &src, FragDst &dst, - SharedThreadTens& tCsC, SharedTens& sC, - ResidueMap tCcD, Residue residue_mnk, int thread_idx) { - - return reduceWg(src, dst, tCsC, sC, tCcD, residue_mnk, thread_idx, - 0, [](ElementAccumulator const & x, ElementAccumulator const & y) { return x+y; }); - } - template< class ProblemShapeMNKL, class BlockCoordMNKL, @@ -333,7 +189,7 @@ class SoftmaxEpilogue { auto N_tmp = cute::ceil_div(N, N_tile); - cute::packed_tuple partial_block(M_tile, C<1>(), K_tile); + cute::packed_tuple partial_block(M_tile, K_tile); auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); @@ -341,19 +197,19 @@ class SoftmaxEpilogue { // Represent the full output tensors Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) - Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); // (m,n,l) + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); + Tensor gSum_mnl = local_tile(mSum_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gMax = gMax_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) - Tensor gSum = gSum_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gMax = gMax_mnl(_,m_coord,n_coord,l_coord); + Tensor gSum = gSum_mnl(_,m_coord,n_coord,l_coord); //Represent the shared tensor Tensor sC = make_tensor(make_smem_ptr(reinterpret_cast(smem_buf)), make_layout(make_shape(M_tile, N_tile))); @@ -417,7 +273,7 @@ class SoftmaxEpilogue { max = cutlass::fast_max(max, accumulators(i)); } } - gMax(thread_idx,0) = max; + gMax(thread_idx) = max; ElementAccumulator sum = 0; CUTLASS_PRAGMA_UNROLL @@ -426,7 +282,7 @@ class SoftmaxEpilogue { sum += cutlass::fast_exp(accumulators(i) - max); } } - gSum(thread_idx,0) = sum; + gSum(thread_idx) = sum; } private: diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index ceb4de1813..a9e4b41dae 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -72,9 +72,9 @@ class SoftmaxFinalize { // struct Arguments { - //TODO(Tadej): duplicated part of sizes - cutlass::MatrixCoord IOSize; ///< Extent of input and output matrices - cutlass::MatrixCoord partialSize; ///< Extent of partial max and sum matrices + int M; + int dataN; + int partialN; int batch_count; ///< Batch count StrideInput dInput; StridePartial dPartial; @@ -106,8 +106,6 @@ class SoftmaxFinalize { Params(Arguments const &args_): args(args_) { } }; -private: - public: CUTLASS_DEVICE @@ -115,31 +113,11 @@ class SoftmaxFinalize { CUTLASS_DEVICE void operator()(Params const ¶ms, char* shared_storage) { - apply(params, shared_storage); } private: - template - CUTLASS_DEVICE static ElementPartial reduceSg(ElementPartial val, Op op) { - auto sg = syclcompat::get_nd_item<1>().get_sub_group(); - CUTLASS_PRAGMA_UNROLL - for(uint laneMask = 8; laneMask >= 1; laneMask /= 2) { - val = op(val, syclcompat::permute_sub_group_by_xor(sg, val, laneMask, 16)); - } - return val; - } - - CUTLASS_DEVICE static ElementPartial reduce_max(ElementPartial val) { - return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x > y ? x : y; }); - } - - CUTLASS_DEVICE static ElementPartial reduce_sum(ElementPartial val) { - return reduceSg(val, [](ElementPartial const & x, ElementPartial const & y) { return x + y; }); - } - - /// Full reduction CUTLASS_DEVICE void apply(Params const ¶ms, char* shared_storage) { using ConvertInput = cutlass::NumericConverter; @@ -151,62 +129,60 @@ class SoftmaxFinalize { int y_size = BlockDimY(); int batch_id = BlockIdxY(); - if(m>=params.args.IOSize[0]){ + if(m>=params.args.M){ return; } - // Represent the full tensors - auto IOTensorShape = make_shape(params.args.IOSize[0], params.args.IOSize[1], params.args.batch_count); - auto PartialTensorShape = make_shape(params.args.partialSize[0], params.args.partialSize[1], params.args.batch_count); - Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); // (m,n,l) - Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); // (m,n,l) - Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); // (m,n,l) - Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); // (m,n,l) + auto IOTensorShape = make_shape(params.args.M, params.args.dataN, params.args.batch_count); + auto PartialTensorShape = make_shape(params.args.M, params.args.partialN, params.args.batch_count); + Tensor mMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); + Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); + Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); + Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); //Represent the shared tensor - Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32, 2))); + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32))); ElementPartial max_val = std::numeric_limits::min(); - for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ - ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); - max_val = max_val > partial_max ? max_val : partial_max; + for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mMax(m, partial_n, batch_id); + max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(x,y,0) = max_val; + sPartial(x,y) = max_val; syncthreads(); - //TODO(Tadej): improve reduction + //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_max = sPartial(x,y2,0); - max_val = max_val > partial_max ? max_val : partial_max; + ElementPartial partial_max = sPartial(x,y2); + max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = y; partial_n < params.args.partialSize[1]; partial_n += y_size){ - ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); + for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); - sum_val = sum_val + partial_sum * cutlass::fast_exp(partial_max - max_val); + sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(x,y,1) = sum_val; + sPartial(x,y) = sum_val; syncthreads(); sum_val = 0; - //TODO(Tadej): improve reduction + //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_max = sPartial(x,y2,0); - ElementPartial partial_sum = sPartial(x,y2,1); - sum_val = sum_val + partial_sum; + ElementPartial partial_sum = sPartial(x,y2); + sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = y * 2; n < params.args.IOSize[1]; n += y_size * 2){ + for(int n = y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; } - if(params.args.IOSize[1]%2==1){ - int n = params.args.IOSize[1] - 1; + if(params.args.dataN % 2 == 1){ + int n = params.args.dataN - 1; auto inVal = mIn(m, n, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; } diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index c856afa76e..59a6649d76 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -1,5 +1,6 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without From fcc7d2f78a78922b6cb43ce7d52c7a65a40a43fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 4 Dec 2024 13:28:42 +0000 Subject: [PATCH 09/19] addressed most review comments --- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 5 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 4 +- examples/35_gemm_softmax/softmax_finalize.hpp | 67 ++++++++++--------- include/cutlass/gpu_generics.h | 11 +-- 4 files changed, 46 insertions(+), 41 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 0d31e19457..3b3e97fead 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -441,7 +441,10 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - const auto sycl_block2 = syclcompat::dim3(32, std::min(32, params.softmax_params.args.M), 1); + const auto sycl_block2 = syclcompat::dim3(NumThreadsPerWarp, + std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, + params.softmax_params.args.M), + 1); const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x), params.softmax_params.args.batch_count, 1); diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 093a6d70d1..0c35f50458 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -168,7 +168,7 @@ class SoftmaxEpilogue { TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, - [[maybe_unused]] char* smem_buf) + char* smem_buf) { using namespace cute; using X = Underscore; @@ -265,7 +265,7 @@ class SoftmaxEpilogue { syncthreads(); // assumption: size<0>(sC) == wg size - ElementAccumulator max = std::numeric_limits::min(); + ElementAccumulator max = std::numeric_limits::lowest(); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<0>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index a9e4b41dae..30ecd90fef 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -72,21 +72,21 @@ class SoftmaxFinalize { // struct Arguments { - int M; - int dataN; - int partialN; - int batch_count; ///< Batch count - StrideInput dInput; - StridePartial dPartial; - StrideOutput dOutput; - ElementInput* ptr_in; - ElementPartial* ptr_partial_max; - ElementPartial* ptr_partial_sum; - ElementOutput* ptr_out; + int M; // dimension M of input, output and partially reduced tensors + int dataN; // dimension N of the input and output + int partialN; // dimension N of the partially reduced tensors + int batch_count; // batch count + StrideInput dInput; // stride of the input + StridePartial dPartial; // stride of the partially reduced tensors + StrideOutput dOutput; // stride of the output + ElementInput* ptr_in; // pointer to start of input data + ElementPartial* ptr_partial_max; // pointer to start of partially reduced max data + ElementPartial* ptr_partial_sum; // pointer to start of partially reduced sum data + ElementOutput* ptr_out; // pointer to start of output data }; struct SharedStorage { - cute::array_aligned s_mem; + cute::array_aligned s_mem; }; static constexpr int SharedStorageSize = sizeof(SharedStorage); @@ -123,11 +123,11 @@ class SoftmaxFinalize { using ConvertInput = cutlass::NumericConverter; using ConvertNormOutput = cutlass::NumericConverter; - int x = ThreadIdxX(); - int m = x + BlockDimX() * BlockIdxX(); - int y = ThreadIdxY(); - int y_size = BlockDimY(); - int batch_id = BlockIdxY(); + const int idx_x = ThreadIdxX(); + const int m = idx_x + BlockDimX() * BlockIdxX(); + const int idx_y = ThreadIdxY(); + const int y_size = BlockDimY(); + const int batch_id = BlockIdxY(); if(m>=params.args.M){ return; @@ -136,46 +136,47 @@ class SoftmaxFinalize { // Represent the full tensors auto IOTensorShape = make_shape(params.args.M, params.args.dataN, params.args.batch_count); auto PartialTensorShape = make_shape(params.args.M, params.args.partialN, params.args.batch_count); - Tensor mMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); + Tensor mPartialMax = make_tensor(make_gmem_ptr(params.args.ptr_partial_max), PartialTensorShape, params.args.dPartial); Tensor mPartialSum = make_tensor(make_gmem_ptr(params.args.ptr_partial_sum), PartialTensorShape, params.args.dPartial); Tensor mOut = make_tensor(make_gmem_ptr(params.args.ptr_out), IOTensorShape, params.args.dOutput); Tensor mIn = make_tensor(make_gmem_ptr(params.args.ptr_in), IOTensorShape, params.args.dInput); //Represent the shared tensor - Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), make_layout(make_shape(32, 32))); + Tensor sPartial = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage)), + make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp))); - ElementPartial max_val = std::numeric_limits::min(); - for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ - ElementPartial partial_max = mMax(m, partial_n, batch_id); + ElementPartial max_val = std::numeric_limits::lowest(); + for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(x,y) = max_val; + sPartial(idx_x,idx_y) = max_val; syncthreads(); - //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck - for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_max = sPartial(x,y2); + // tree-reduction could be better, although it does not seem to be a bottleneck + for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_max = sPartial(idx_x,idx_y2); max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = y; partial_n < params.args.partialN; partial_n += y_size){ - ElementPartial partial_max = mMax(m, partial_n, batch_id); + for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(x,y) = sum_val; + sPartial(idx_x,idx_y) = sum_val; syncthreads(); sum_val = 0; - //TODO(Tadej): tree-reduction could be better, although it does not seem to be a bottleneck - for(int y2 = 0; y2 < y_size; y2++){ - ElementPartial partial_sum = sPartial(x,y2); + // tree-reduction could be better, although it does not seem to be a bottleneck + for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_sum = sPartial(idx_x,idx_y2); sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = y * 2; n < params.args.dataN; n += y_size * 2){ + for(int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 5d46e4c370..037087e197 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -44,11 +44,12 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int NumThreadsPerWarp = 32; -static const int NumThreadsPerWarpGroup = 128; -static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; -static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -static const int NumThreadsPerQuad = 4; -static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static constexpr int NumThreadsPerWarpGroup = 128; +static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static constexpr int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static constexpr int NumThreadsPerQuad = 4; +static constexpr int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static constexpr int MaxNumThreadsPerBlock = 1024; //////////////////////////////////////////////////////////////////////////////////////////////////// From d5a77e302b84e8fefb130e50a1fe46be0f82ccbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Wed, 4 Dec 2024 13:58:18 +0000 Subject: [PATCH 10/19] addressed remaining review comments --- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 3b3e97fead..c367d8387f 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -354,9 +354,17 @@ class GemmSoftmaxAdapter CUTLASS_TRACE_HOST("GemmUniversal::run()"); dim3 const block = GemmKernel::get_block_shape(); dim3 const grid = get_grid_shape(params); + dim3 const block_finalize = syclcompat::dim3(NumThreadsPerWarp, + std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, + params.softmax_params.args.M), + 1); + dim3 const grid_finalize = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, block_finalize.x), + params.softmax_params.args.batch_count, + 1); // configure smem size and carveout int smem_size = GemmKernel::SharedStorageSize; + int smem_size_finalize = SoftmaxFinalizeKernel::SharedStorageSize; Status launch_result{ Status::kSuccess }; // Use extended launch API only for mainloops that use it @@ -367,7 +375,9 @@ class GemmSoftmaxAdapter dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); - void* kernel_params[] = {¶ms}; + dim3 cluster_finalize(1,1,1); + void* kernel_params[] = {¶ms.gemm_params}; + void* kernel_params_finalize[] = {¶ms.softmax_params}; if constexpr (kEnableCudaHostAdapter) { // @@ -388,6 +398,13 @@ class GemmSoftmaxAdapter stream, kernel_params, 0); + launch_result = cuda_adapter->launch(grid_finalize, + cluster_finalize, + block_finalize, + smem_size_finalize, + stream, + kernel_params_finalize, + 0); } else { return Status::kErrorInternal; @@ -396,13 +413,17 @@ class GemmSoftmaxAdapter else { CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; + void const* kernel_finalize = (void const*) device_kernel; if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 90) { if (is_static_1x1x1 && not launch_with_pdl) { - device_kernel<<>>(params); + device_kernel<<>>(params.gemm_params); + device_kernel<<>>(params.softmax_params); } else { launch_result = ClusterLauncher::launch( grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + launch_result = ClusterLauncher::launch( + grid_finalize, cluster_finalize, block_finalize, smem_size_finalize, stream, kernel_finalize, kernel_params, launch_with_pdl); } } } @@ -414,10 +435,14 @@ class GemmSoftmaxAdapter CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { void* kernel_params[] = {¶ms.gemm_params}; + void* kernel_params_finalize[] = {¶ms.softmax_params}; launch_result = cuda_adapter->launch( grid, block, smem_size, stream, kernel_params, 0 ); + launch_result = cuda_adapter->launch( + grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 0 + ); } else { @@ -441,19 +466,15 @@ class GemmSoftmaxAdapter sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params.gemm_params); #endif - const auto sycl_block2 = syclcompat::dim3(NumThreadsPerWarp, - std::min(MaxNumThreadsPerBlock / NumThreadsPerWarp, - params.softmax_params.args.M), - 1); - const auto sycl_grid2 = syclcompat::dim3(cute::ceil_div(params.softmax_params.args.M, sycl_block2.x), - params.softmax_params.args.batch_count, - 1); + const auto sycl_block_finalize = syclcompat::dim3(block_finalize.x, block_finalize.y, block_finalize.z); + const auto sycl_grid_finalize = syclcompat::dim3(grid_finalize.x, grid_finalize.y, grid_finalize.z); auto event2 = launch>(launch_policy{ - sycl_grid2, sycl_block2, local_mem_size{SoftmaxFinalizeKernel::SharedStorageSize}}, + sycl_grid_finalize, sycl_block_finalize, local_mem_size{static_cast(smem_size_finalize)}}, params.softmax_params); EventManager::getInstance().addEvent(event2); #else device_kernel<<>>(params.gemm_params); + device_kernel<<>>(params.softmax_params); #endif } } From 1ea4f40966750eac1f14a628abfb1498010e02e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 5 Dec 2024 09:40:23 +0000 Subject: [PATCH 11/19] addressed second round of review comments. --- examples/35_gemm_softmax/gemm_softmax_adapter.hpp | 2 +- examples/35_gemm_softmax/softmax_epilogue.hpp | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index c367d8387f..3eb7e9d537 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -404,7 +404,7 @@ class GemmSoftmaxAdapter smem_size_finalize, stream, kernel_params_finalize, - 0); + 1); } else { return Status::kErrorInternal; diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 0c35f50458..755f33d5f8 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -264,10 +264,12 @@ class SoftmaxEpilogue { syncthreads(); - // assumption: size<0>(sC) == wg size + // assumption for reductions: size<0>(sC) == block size + assert(size<0>(sC) == BlockDimX() * BlockDimy() * BlockDimZ()); + ElementAccumulator max = std::numeric_limits::lowest(); CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(sC); ++i) { + for (int i = 0; i < size<1>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { accumulators(i) = sC(thread_idx, i); max = cutlass::fast_max(max, accumulators(i)); @@ -277,7 +279,7 @@ class SoftmaxEpilogue { ElementAccumulator sum = 0; CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(sC); ++i) { + for (int i = 0; i < size<1>(sC); ++i) { if (elem_less(cD(thread_idx, i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { sum += cutlass::fast_exp(accumulators(i) - max); } From dc123265709763230ac5159c4fd1c3c426d8a807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Thu, 5 Dec 2024 11:52:29 +0100 Subject: [PATCH 12/19] Apply formatting suggestions from code review Co-authored-by: Finlay --- .../35_gemm_softmax/gemm_online_softmax.cpp | 4 ++-- examples/35_gemm_softmax/softmax_finalize.hpp | 20 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 67dff1460a..2a6e6f1d52 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -128,7 +128,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "14_ampere_tf32_tensorop_gemm_cute example\n\n" + out << "35_gemm_softmax example\n\n" << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" @@ -238,7 +238,7 @@ struct ExampleRunner { // Methods // template - bool verify_tensor(std::vector vector_Input, \ + bool verify_tensor(std::vector vector_Input, std::vector vector_Input_Ref, const Options& options) { auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); diff --git a/examples/35_gemm_softmax/softmax_finalize.hpp b/examples/35_gemm_softmax/softmax_finalize.hpp index 30ecd90fef..ca6e6ac93a 100644 --- a/examples/35_gemm_softmax/softmax_finalize.hpp +++ b/examples/35_gemm_softmax/softmax_finalize.hpp @@ -129,7 +129,7 @@ class SoftmaxFinalize { const int y_size = BlockDimY(); const int batch_id = BlockIdxY(); - if(m>=params.args.M){ + if (m >= params.args.M) { return; } @@ -146,43 +146,43 @@ class SoftmaxFinalize { make_layout(make_shape(NumThreadsPerWarp, MaxNumThreadsPerBlock / NumThreadsPerWarp))); ElementPartial max_val = std::numeric_limits::lowest(); - for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); max_val = cutlass::fast_max(max_val, partial_max); } - sPartial(idx_x,idx_y) = max_val; + sPartial(idx_x, idx_y) = max_val; syncthreads(); // tree-reduction could be better, although it does not seem to be a bottleneck - for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ - ElementPartial partial_max = sPartial(idx_x,idx_y2); + for (int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ + ElementPartial partial_max = sPartial(idx_x, idx_y2); max_val = cutlass::fast_max(max_val, partial_max); } ElementPartial sum_val = 0; - for(int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ + for (int partial_n = idx_y; partial_n < params.args.partialN; partial_n += y_size){ ElementPartial partial_max = mPartialMax(m, partial_n, batch_id); ElementPartial partial_sum = mPartialSum(m, partial_n, batch_id); sum_val += partial_sum * cutlass::fast_exp(partial_max - max_val); } syncthreads(); - sPartial(idx_x,idx_y) = sum_val; + sPartial(idx_x, idx_y) = sum_val; syncthreads(); sum_val = 0; // tree-reduction could be better, although it does not seem to be a bottleneck for(int idx_y2 = 0; idx_y2 < y_size; idx_y2++){ - ElementPartial partial_sum = sPartial(idx_x,idx_y2); + ElementPartial partial_sum = sPartial(idx_x, idx_y2); sum_val += partial_sum; } ElementPartial norm = 1 / sum_val; - for(int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ + for (int n = idx_y * 2; n < params.args.dataN; n += y_size * 2){ auto inVal = mIn(m, n, batch_id); auto inVal2 = mIn(m, n+1, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; mOut(m, n+1, batch_id) = cutlass::fast_exp(inVal2 - max_val) * norm; } - if(params.args.dataN % 2 == 1){ + if (params.args.dataN % 2 == 1){ int n = params.args.dataN - 1; auto inVal = mIn(m, n, batch_id); mOut(m, n, batch_id) = cutlass::fast_exp(inVal - max_val) * norm; From 51cb0b5cb874c03224ee3bc45ff194ccb8aa3337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Fri, 6 Dec 2024 10:26:50 +0000 Subject: [PATCH 13/19] addressed third round of comments --- .../35_gemm_softmax/gemm_online_softmax.cpp | 26 +++++++++---------- .../35_gemm_softmax/gemm_softmax_adapter.hpp | 7 +++-- examples/35_gemm_softmax/softmax_epilogue.hpp | 13 +++++----- include/cutlass/fast_math.h | 1 - include/cutlass/gpu_generics.h | 12 ++++----- 5 files changed, 28 insertions(+), 31 deletions(-) diff --git a/examples/35_gemm_softmax/gemm_online_softmax.cpp b/examples/35_gemm_softmax/gemm_online_softmax.cpp index 2a6e6f1d52..b8844286fb 100644 --- a/examples/35_gemm_softmax/gemm_online_softmax.cpp +++ b/examples/35_gemm_softmax/gemm_online_softmax.cpp @@ -128,7 +128,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "35_gemm_softmax example\n\n" + out << "35_gemm_online_softmax example\n\n" << " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" @@ -193,13 +193,13 @@ struct ExampleRunner { using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; - using StrideTmp = typename Gemm::CollectiveEpilogue::StrideD; + using StridePartials = typename Gemm::CollectiveEpilogue::StrideD; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; - using LayoutTmp = typename Gemm::LayoutTmp; + using LayoutPartials = typename Gemm::LayoutPartials; using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; @@ -223,7 +223,7 @@ struct ExampleRunner { StrideB stride_B; StrideC stride_C; StrideD stride_D; - StrideTmp stride_tmp; + StridePartials stride_partials; uint64_t seed = 0; cutlass::DeviceAllocation block_A; @@ -281,8 +281,8 @@ struct ExampleRunner { LayoutA layout_A(lda); LayoutB layout_B(ldb); LayoutC layout_C(ldc); - LayoutTmp Layout_N(ldn); - LayoutTmp Layout_S(lds); + LayoutPartials Layout_N(ldn); + LayoutPartials Layout_S(lds); cutlass::MatrixCoord extent_A{options.m, options.k}; cutlass::MatrixCoord extent_B{options.k, options.n}; @@ -365,21 +365,21 @@ struct ExampleRunner { auto [M, N, K, L] = problem_shape_MNKL; auto partials_N = cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{})); - auto tmp_size = M * partials_N * L; + auto partials_size = M * partials_N * L; stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(M, partials_N, L)); + stride_partials = cutlass::make_cute_packed_stride(StridePartials{}, cute::make_shape(M, partials_N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); - block_sum.reset(tmp_size); - block_max.reset(tmp_size); + block_sum.reset(partials_size); + block_max.reset(partials_size); initialize_block(block_A, seed + 2023); initialize_block(block_B, seed + 2022); @@ -399,7 +399,7 @@ struct ExampleRunner { options.beta}, block_C.get(), stride_C, block_D.get(), stride_D, - block_max.get(), block_sum.get(), stride_tmp}, + block_max.get(), block_sum.get(), stride_partials}, hw_info }; @@ -513,7 +513,7 @@ int main(int argc, char const **args) { using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::ColumnMajor; - using LayoutTmp = cutlass::layout::ColumnMajor; + using LayoutPartials = cutlass::layout::ColumnMajor; // Tiling configuration selection using TileShape = Shape<_128,_128,_32>; @@ -580,7 +580,7 @@ int main(int argc, char const **args) { using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue< cutlass::detail::TagToStrideC_t, cutlass::detail::TagToStrideC_t, - cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, TileShape, EpilogueOp, cutlass::gemm::EpilogueDefault>; diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 3eb7e9d537..75a3feeeb0 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -1,5 +1,4 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * @@ -86,7 +85,7 @@ class GemmSoftmaxAdapter using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize< ElementD, typename GemmKernel::StrideD, - ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StrideTmp, + ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StridePartials, ElementD, typename GemmKernel::StrideD>; // Map back to 2.x type as best as possible @@ -94,7 +93,7 @@ class GemmSoftmaxAdapter using LayoutB = gemm::detail::StrideToLayoutTagB_t; using LayoutC = gemm::detail::StrideToLayoutTagC_t; using LayoutD = gemm::detail::StrideToLayoutTagC_t; - using LayoutTmp = gemm::detail::StrideToLayoutTagC_t; + using LayoutPartials = gemm::detail::StrideToLayoutTagC_t; static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; @@ -270,7 +269,7 @@ class GemmSoftmaxAdapter softmax_args.partialN = cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{})); softmax_args.batch_count = get<3>(args.problem_shape); softmax_args.dInput = args.epilogue.dD; - softmax_args.dPartial = args.epilogue.dTmp; + softmax_args.dPartial = args.epilogue.dPartials; softmax_args.dOutput = args.epilogue.dD; softmax_args.ptr_in = args.epilogue.ptr_D; softmax_args.ptr_partial_max = args.epilogue.ptr_max; diff --git a/examples/35_gemm_softmax/softmax_epilogue.hpp b/examples/35_gemm_softmax/softmax_epilogue.hpp index 755f33d5f8..438552a7f2 100644 --- a/examples/35_gemm_softmax/softmax_epilogue.hpp +++ b/examples/35_gemm_softmax/softmax_epilogue.hpp @@ -1,5 +1,4 @@ /*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * @@ -53,7 +52,7 @@ namespace collective { template < class StrideC_, class StrideD_, - class StrideTmp_, + class StridePartials_, class BlockShapeMNK, class ThreadEpilogueOp_, class EpilogueSchedule_ @@ -76,7 +75,7 @@ class SoftmaxEpilogue { using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; - using StrideTmp = StrideTmp_; + using StridePartials = StridePartials_; using GmemTiledCopyC = void; using GmemTiledCopyD = void; @@ -102,7 +101,7 @@ class SoftmaxEpilogue { StrideD dD{}; ElementAccumulator* ptr_max; ElementAccumulator* ptr_sum; - StrideTmp dTmp{}; + StridePartials dPartials{}; }; // Device side epilogue params @@ -187,7 +186,7 @@ class SoftmaxEpilogue { auto N_tile = get<1>(blk_shape_MNK); auto K_tile = get<2>(blk_shape_MNK); - auto N_tmp = cute::ceil_div(N, N_tile); + auto N_partials = cute::ceil_div(N, N_tile); cute::packed_tuple partial_block(M_tile, K_tile); @@ -197,8 +196,8 @@ class SoftmaxEpilogue { // Represent the full output tensors Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) - Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp); - Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp); + Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_partials,L), params.dPartials); + Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_partials,L), params.dPartials); Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_), Step<_1, X>{}); diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 59a6649d76..c856afa76e 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -1,6 +1,5 @@ /*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 037087e197..f6af850df7 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -43,12 +43,12 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumThreadsPerWarp = 32; -static constexpr int NumThreadsPerWarpGroup = 128; -static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; -static constexpr int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; -static constexpr int NumThreadsPerQuad = 4; -static constexpr int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; +static const int NumThreadsPerWarp = 32; +static const int NumThreadsPerWarpGroup = 128; +static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static const int NumThreadsPerQuad = 4; +static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; static constexpr int MaxNumThreadsPerBlock = 1024; //////////////////////////////////////////////////////////////////////////////////////////////////// From cb07f6fd8b42508138aa6d8f0f669c85bf1719a7 Mon Sep 17 00:00:00 2001 From: Finlay Marno Date: Tue, 10 Dec 2024 10:56:21 +0000 Subject: [PATCH 14/19] modify the examples run in github workflow --- .github/workflows/test.yml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 33b255c942..19bfa5f35c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -71,13 +71,4 @@ jobs: shell: bash run: | export LD_LIBRARY_PATH=~/dpcpp/lib/:$LD_LIBRARY_PATH - echo Run sgemm_1 - ./examples/cute/tutorial/sgemm_1 - echo Run sgemm_2 - ./examples/cute/tutorial/sgemm_2 - echo Run sgemm_sm70 - ./examples/cute/tutorial/sgemm_sm70 - echo Run sgemm_sm80 - ./examples/cute/tutorial/sgemm_sm80 - echo Run tiled_copy - ./examples/cute/tutorial/tiled_copy + cmake --build . --target test_examples -j 24 From bb925baa553dc2a0c0dbf24f283b56b883c29f44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tadej=20Ciglari=C4=8D?= Date: Tue, 10 Dec 2024 12:41:11 +0000 Subject: [PATCH 15/19] fix from review --- examples/35_gemm_softmax/gemm_softmax_adapter.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp index 75a3feeeb0..13e0369e4e 100644 --- a/examples/35_gemm_softmax/gemm_softmax_adapter.hpp +++ b/examples/35_gemm_softmax/gemm_softmax_adapter.hpp @@ -440,7 +440,7 @@ class GemmSoftmaxAdapter grid, block, smem_size, stream, kernel_params, 0 ); launch_result = cuda_adapter->launch( - grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 0 + grid_finalize, block_finalize, smem_size_finalize, stream, kernel_params_finalize, 1 ); } From d5e2836341f41f78dfa387a2cb6434b6de369fd1 Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Tue, 10 Dec 2024 18:25:14 +0000 Subject: [PATCH 16/19] Cancel Previous runs on CI (#165) Cancel all the previous actions if the CI is updated. --- .github/workflows/cuda_test.yml | 4 ++++ .github/workflows/test.yml | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/.github/workflows/cuda_test.yml b/.github/workflows/cuda_test.yml index 2b839dcfa7..c4908b943f 100644 --- a/.github/workflows/cuda_test.yml +++ b/.github/workflows/cuda_test.yml @@ -9,6 +9,10 @@ on: permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: run-tests: name: Run cuda tests diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 33b255c942..6315031acf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,6 +13,10 @@ on: permissions: {} +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: run-tests: name: Run tests From a6573aba40fd976d113c2650440e10247b2d3fae Mon Sep 17 00:00:00 2001 From: jiyang1011 Date: Thu, 28 Nov 2024 18:38:53 -0800 Subject: [PATCH 17/19] gelu example && TensorRefGeLu --- examples/sycl/pvc/CMakeLists.txt | 5 + .../sycl/pvc/pvc_gemm_with_epilogue_gelu.cpp | 377 ++++++++++++++++++ .../util/reference/device/tensor_gelu.h | 148 +++++++ 3 files changed, 530 insertions(+) create mode 100644 examples/sycl/pvc/pvc_gemm_with_epilogue_gelu.cpp create mode 100644 tools/util/include/cutlass/util/reference/device/tensor_gelu.h diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index f9c5fca18c..5736847e88 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -37,6 +37,11 @@ cutlass_example_add_executable( pvc_gemm_with_epilogue_relu.cpp ) +cutlass_example_add_executable( + pvc_gemm_with_epilogue_gelu + pvc_gemm_with_epilogue_gelu.cpp +) + cutlass_example_add_executable( pvc_collective_builder pvc_collective_builder.cpp diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_gelu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_gelu.cpp new file mode 100644 index 0000000000..9a0b814309 --- /dev/null +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_gelu.cpp @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_gelu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "common.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + using TensorView = cutlass::TensorView; + cutlass::reference::device::TensorGeLu(TensorView(block_ref_D.get(), LayoutD::packed({M, N}), + cutlass::make_Coord(M, N))); + + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _128, _16>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_64,_32,_16>>; // Subgroup level-tile + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinCombEltAct; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/tools/util/include/cutlass/util/reference/device/tensor_gelu.h b/tools/util/include/cutlass/util/reference/device/tensor_gelu.h new file mode 100644 index 0000000000..30f452a5cd --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/tensor_gelu.h @@ -0,0 +1,148 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorGeLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + + // + // Methods + // + + Params( + TensorView view_ = TensorView() + ): + view(view_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorGeLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + + params.view.at(coord) = Element(cutlass::constants::half() * value * + (cutlass::constants::one() + (Element)erff((float)(value * cutlass::constants::half_root_two())))); + } +}; +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply GeLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorGeLu( + TensorView view) { ///< destination tensor + + using Func = detail::TensorGeLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +#if (CUTLASS_ENABLE_SYCL) +namespace sycl { + template <> + struct is_device_copyable < + cutlass::reference::device::detail::TensorGeLuFunc::Params> : std::true_type {}; +} +#endif From f14d683769d74a7311c560ce3a1dbbf3e58a0b31 Mon Sep 17 00:00:00 2001 From: Atharva Dubey Date: Wed, 11 Dec 2024 19:29:50 +0000 Subject: [PATCH 18/19] Device Agnostic Pipeline (#140) * initial changes for device_agnostic pipeline * add a TODO before raising draft PR * fix compilation issues part 1 * fix bugs in device_agnostic_gemm * fix issues with collective builder API for device agnostic * refactor * remove instances of PVC from device agnostic example * fix build * Address comments --------- Co-authored-by: Alejandro Acosta --- examples/sycl/CMakeLists.txt | 4 + examples/sycl/device_agnostic/CMakeLists.txt | 37 ++ .../device_agnostic_collective_builder.cpp | 372 +++++++++++++++++ .../device_agnostic/device_agnostic_gemm.cpp | 382 ++++++++++++++++++ examples/sycl/pvc/{common.h => common.hpp} | 0 examples/sycl/pvc/pvc_collective_builder.cpp | 2 +- examples/sycl/pvc/pvc_gemm.cpp | 2 +- .../sycl/pvc/pvc_gemm_with_epilogue_relu.cpp | 2 +- include/cutlass/arch/arch.h | 4 + .../builders/device_agnostic_builder.inl | 91 +++++ .../collective/collective_builder.hpp | 4 + .../builders/device_agnostic_mma_builder.inl | 128 ++++++ .../gemm/collective/collective_builder.hpp | 4 + .../gemm/collective/collective_mma.hpp | 5 + .../gemm/collective/device_agnostic_mma.hpp | 195 +++++++++ .../gemm/device/gemm_universal_adapter.h | 18 +- include/cutlass/gemm/dispatch_policy.hpp | 8 + 17 files changed, 1250 insertions(+), 8 deletions(-) create mode 100644 examples/sycl/device_agnostic/CMakeLists.txt create mode 100644 examples/sycl/device_agnostic/device_agnostic_collective_builder.cpp create mode 100644 examples/sycl/device_agnostic/device_agnostic_gemm.cpp rename examples/sycl/pvc/{common.h => common.hpp} (100%) create mode 100644 include/cutlass/epilogue/collective/builders/device_agnostic_builder.inl create mode 100644 include/cutlass/gemm/collective/builders/device_agnostic_mma_builder.inl create mode 100644 include/cutlass/gemm/collective/device_agnostic_mma.hpp diff --git a/examples/sycl/CMakeLists.txt b/examples/sycl/CMakeLists.txt index b736ce35e8..197ce58b2d 100644 --- a/examples/sycl/CMakeLists.txt +++ b/examples/sycl/CMakeLists.txt @@ -30,3 +30,7 @@ if(SYCL_INTEL_TARGET) add_subdirectory(pvc) endif() + +if (CUTLASS_ENABLE_SYCL) + add_subdirectory(device_agnostic) +endif() diff --git a/examples/sycl/device_agnostic/CMakeLists.txt b/examples/sycl/device_agnostic/CMakeLists.txt new file mode 100644 index 0000000000..4bd75ca951 --- /dev/null +++ b/examples/sycl/device_agnostic/CMakeLists.txt @@ -0,0 +1,37 @@ +# Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + device_agnostic_gemm + device_agnostic_gemm.cpp +) + +cutlass_example_add_executable( + device_agnostic_collective_builder + device_agnostic_collective_builder.cpp +) diff --git a/examples/sycl/device_agnostic/device_agnostic_collective_builder.cpp b/examples/sycl/device_agnostic/device_agnostic_collective_builder.cpp new file mode 100644 index 0000000000..44cb129465 --- /dev/null +++ b/examples/sycl/device_agnostic/device_agnostic_collective_builder.cpp @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/kernel_hardware_info.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/GPU_Clock.hpp" + +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(128), n(128), k(128), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 128); + cmd.get_cmd_line_argument("n", n, 128); + cmd.get_cmd_line_argument("k", k, 128); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Device Agnostic GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + using TensorView = cutlass::TensorView; + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + template + void initialize_block(cutlass::DeviceAllocation block_device, uint64_t seed) { + std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution<> dist(0.0f, 1.0f); + rng.seed(seed); + + auto block_host = std::vector(block_device.size()); + for (auto& element : block_host) { + element = static_cast(dist(rng)); + } + + block_device.copy_from_host(block_host.data()); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of CUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = float; // <- data type of elements in input matrix A + using ElementInputB = float; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + constexpr int AlignmentA = sizeof(ElementInputA); + constexpr int AlignmentB = sizeof(ElementInputB); + constexpr int AlignmentC = sizeof(ElementAccumulator); + constexpr int AlignmentD = sizeof(ElementOutput); + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + // Workgroup-level tile + using TileShape = Shape<_16, _16, _8>; + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Agnostic, cutlass::arch::OpMultiplyAdd, + ElementInputA, LayoutA, AlignmentA, + ElementInputB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, Shape<_1, _1, _1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination< + ElementOutput, ElementComputeEpilogue, ElementAccumulator, + ElementAccumulator>; + + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Agnostic, cutlass::arch::OpMultiplyAdd, + TileShape, Shape<_1, _1, _1>, + cutlass::epilogue::collective::EpilogueTileAuto, ElementComputeEpilogue, + ElementAccumulator, + ElementAccumulator, LayoutC, AlignmentC, + ElementOutput, LayoutD, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto, + EpilogueOp + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/examples/sycl/device_agnostic/device_agnostic_gemm.cpp b/examples/sycl/device_agnostic/device_agnostic_gemm.cpp new file mode 100644 index 0000000000..142a2f2966 --- /dev/null +++ b/examples/sycl/device_agnostic/device_agnostic_gemm.cpp @@ -0,0 +1,382 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(128), n(128), k(128), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 128); + cmd.get_cmd_line_argument("n", n, 128); + cmd.get_cmd_line_argument("k", k, 128); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Device Agnostic GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + template + void initialize_block(cutlass::DeviceAllocation block_device, uint64_t seed) { + std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution<> dist(0.0f, 1.0f); + rng.seed(seed); + + auto block_host = std::vector(block_device.size()); + for (auto& element : block_host) { + element = static_cast(dist(rng)); + } + + block_device.copy_from_host(block_host.data()); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of CUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = float; // <- data type of elements in input matrix A + using ElementInputB = float; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using TileShape = Shape<_4, _4, _8>; + + using TiledMma = TiledMMA>, + Layout>>; + + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementInputA>{}, + Layout, Stride<_4, _1>>{}, + Layout>{} + )); + + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementInputB>{}, + Layout, Stride <_1, _4>>{}, + Layout>{} + )); + + using SmemLayoutAtomA = Layout, Stride<_1, _4>>; + using SmemLayoutAtomB = Layout, Stride<_1, _4>>; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopDeviceAgnostic; + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementAccumulator, + 1, + ElementComputeEpilogue, + ElementOutput>; + + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + EpilogueOp, + cutlass::gemm::EpilogueDefault>; + + using SmemCopyAtomA = Copy_Atom, ElementInputA>; + using SmemCopyAtomB = Copy_Atom, ElementInputB>; + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/examples/sycl/pvc/common.h b/examples/sycl/pvc/common.hpp similarity index 100% rename from examples/sycl/pvc/common.h rename to examples/sycl/pvc/common.hpp diff --git a/examples/sycl/pvc/pvc_collective_builder.cpp b/examples/sycl/pvc/pvc_collective_builder.cpp index 420fa4a8bc..6fbe7884a0 100644 --- a/examples/sycl/pvc/pvc_collective_builder.cpp +++ b/examples/sycl/pvc/pvc_collective_builder.cpp @@ -48,7 +48,7 @@ #include "cutlass/tensor_view.h" #include "cutlass/coord.h" -#include "common.h" +#include "common.hpp" using namespace cute; diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index c96fdff2a4..ee20a51eec 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -45,7 +45,7 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -#include "common.h" +#include "common.hpp" using namespace cute; diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp index 47d538141d..459d9dccbc 100644 --- a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -49,7 +49,7 @@ #include "cutlass/tensor_view.h" #include "cutlass/coord.h" -#include "common.h" +#include "common.hpp" using namespace cute; diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index b87d899a32..cceef6cda1 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -102,6 +102,10 @@ struct IntelPVC { static int const kMinComputeCapability = 0; }; +struct Agnostic { + static int const kMinComputeCapability = 1; +}; + #endif /// Triggers a breakpoint on the device diff --git a/include/cutlass/epilogue/collective/builders/device_agnostic_builder.inl b/include/cutlass/epilogue/collective/builders/device_agnostic_builder.inl new file mode 100644 index 0000000000..da33c33186 --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/device_agnostic_builder.inl @@ -0,0 +1,91 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace cutlass::epilogue::collective { + +template< + class TileShape_MNK, + class ElementAccumulator_, + class ElementCompute_, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC_, + class ElementD_, + class GmemLayoutTagD_, + int AlignmentD_, + class FusionOpOrCallbacks +> +struct CollectiveBuilder< + arch::Agnostic, + arch::OpMultiplyAdd, + TileShape_MNK, + Shape<_1, _1, _1>, + EpilogueTileAuto, + ElementAccumulator_, + ElementCompute_, + ElementC_, + GmemLayoutTagC_, + AlignmentC_, + ElementD_, + GmemLayoutTagD_, + AlignmentD_, + EpilogueScheduleAuto, + FusionOpOrCallbacks, + cute::enable_if_t< + (cute::is_same_v>) + > +> { + using ElementD = ElementD_; + using ElementOutput = ElementD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute>; + + using CollectiveOp = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault>; +}; + +} // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index 8ee169024a..c920295469 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -121,4 +121,8 @@ struct CallbacksBuilder< #if defined(SYCL_INTEL_TARGET) #include "builders/xe_builder.inl" #endif + +#if defined(CUTLASS_ENABLE_SYCL) +#include "builders/device_agnostic_builder.inl" +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/builders/device_agnostic_mma_builder.inl b/include/cutlass/gemm/collective/builders/device_agnostic_mma_builder.inl new file mode 100644 index 0000000000..013f994973 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/device_agnostic_mma_builder.inl @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include + +#include "cutlass/gemm/collective/device_agnostic_mma.hpp" + + +namespace cutlass::gemm::collective { + +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class KernelScheduleType + > +struct CollectiveBuilder< + arch::Agnostic, + arch::OpMultiplyAdd, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + Shape<_1, _1, _1>, // Cluster Shape + cutlass::gemm::collective::StageCountAuto, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +>{ +#ifndef CUTLASS_ENABLE_SYCL + static_assert(cutlass::detail::dependent_false, + "Trying to use device Agnostic pipeline without SYCL enabled"); +#endif + + using TiledMMA = TiledMMA>, + Layout>>; + + using DispatchPolicy = MainloopDeviceAgnostic; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, Stride<_4, _1>>{}, + Layout>{} + )); + + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, Stride<_1, _4>>{}, + Layout>{} + )); + + using SmemCopyAtomA = Copy_Atom, ElementA>; + using SmemCopyAtomB = Copy_Atom, ElementB>; + + // + using SmemLayoutAtomA = decltype( + make_layout(make_shape(get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{})), + make_stride(_1{}, get<0>(TileShape_MNK{}))) + ); + + using SmemLayoutAtomB = decltype( + make_layout(make_shape(get<1>(TileShape_MNK{}), get<2>(TileShape_MNK{})), + make_stride(_1{}, get<1>(TileShape_MNK{}))) + ); + + using TransformA = cute::identity; + using TransformB = cute::identity; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + MainloopDeviceAgnostic, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMMA, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + TransformA, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + TransformB + >; +}; + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index fa31aebaa1..f7fab76e1e 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -43,4 +43,8 @@ #if defined(SYCL_INTEL_TARGET) #include "cutlass/gemm/collective/builders/xe_mma_builder.inl" #endif + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/gemm/collective/builders/device_agnostic_mma_builder.inl" +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 9a1fc2c3d4..38f51d2c8f 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -50,4 +50,9 @@ #if defined(SYCL_INTEL_TARGET) #include "cutlass/gemm/collective/xe_mma.hpp" #endif + +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/gemm/collective/device_agnostic_mma.hpp" +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/device_agnostic_mma.hpp b/include/cutlass/gemm/collective/device_agnostic_mma.hpp new file mode 100644 index 0000000000..90157b24cf --- /dev/null +++ b/include/cutlass/gemm/collective/device_agnostic_mma.hpp @@ -0,0 +1,195 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/sm70_mma_twostage.hpp" + +namespace cutlass::gemm::collective { + using namespace cute; + +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma < + MainloopDeviceAgnostic, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ +> : + CollectiveMma< + MainloopSm70TwoStage, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + > + + { + using DispatchPolicy = MainloopDeviceAgnostic; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})))); + + struct SharedStorage + { + cute::array_aligned> smem_a; + cute::array_aligned> smem_b; + }; + + + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { + (void) workspace; + return args; + } + + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf) + { + // We can reuse the 2 stage blocking gemm in SM_70 predicated pipeline, giving a somewhat performant + // device agnostic pipeline + + CollectiveMma< + MainloopSm70TwoStage, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_ + >::operator()( + accum, + gA, + gB, + src_accum, + k_tile_iter, k_tile_count, + residue_mnk, thread_idx, + smem_buf + ); + } + }; +} diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 8c9d37e573..093eb20df6 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -456,16 +456,24 @@ class GemmUniversalAdapter< using namespace syclcompat::experimental; #if defined (SYCL_INTEL_TARGET) - auto event = launch>(launch_policy{ - sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, - kernel_properties{sycl_exp::sub_group_size} - }, params); + if constexpr (cute::is_same_v) { + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)} + }, params); + EventManager::getInstance().addEvent(event); + } else { + auto event = launch>(launch_policy{ + sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size} + }, params); + EventManager::getInstance().addEvent(event); + } #else auto event = launch>(launch_policy{ sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}}, params); -#endif EventManager::getInstance().addEvent(event); +#endif #else #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) CUTLASS_TRACE_HOST("GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index acc0961d64..4d9cee37b2 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -332,6 +332,14 @@ struct MainloopIntelPVC { }; #endif +#if defined(CUTLASS_ENABLE_SYCL) +struct MainloopDeviceAgnostic { + using ArchTag = arch::Agnostic; + using ClusterShape = Shape<_1,_1,_1>; + using Schedule = KernelMultistage; +}; +#endif + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm From 8398ba889fa44ac09e5ee9316d933fdaddbab304 Mon Sep 17 00:00:00 2001 From: jiyang1011 <110882834+jiyang1011@users.noreply.github.com> Date: Thu, 12 Dec 2024 07:30:02 +0800 Subject: [PATCH 19/19] add Habana UTs to benchmark (#163) Co-authored-by: Alejandro Acosta --- CMakeLists.txt | 4 ++ benchmarks/pvc/benchmarks.hpp | 55 ++++++++++++++++++++-- benchmarks/pvc/gemm_configuration.hpp | 45 ++++-------------- benchmarks/pvc/input.in | 38 ++++++++------- include/cutlass/gemm/collective/xe_mma.hpp | 15 +++--- include/cutlass/gemm/kernel/xe_gemm.hpp | 10 ++++ 6 files changed, 104 insertions(+), 63 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9187927b13..4f67d0b123 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,6 +94,10 @@ endif() set(CUTLASS_ENABLE_SYCL OFF CACHE BOOL "Enable SYCL") set(CUTLASS_SYCL_PROFILING_ENABLED OFF CACHE BOOL "Use SYCL events to calculate device execution time") +set(CUTLASS_SYCL_SWITCH_WG OFF CACHE BOOL "Enable SWITCH WG and for GEMM on Intel PVC during benchmarking") +if(CUTLASS_SYCL_SWITCH_WG) + add_compile_definitions(CUTLASS_SYCL_SWITCH_WG) +endif() if (CUTLASS_ENABLE_SYCL) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) diff --git a/benchmarks/pvc/benchmarks.hpp b/benchmarks/pvc/benchmarks.hpp index a8ebc0b678..3745a01085 100644 --- a/benchmarks/pvc/benchmarks.hpp +++ b/benchmarks/pvc/benchmarks.hpp @@ -34,15 +34,62 @@ #include "../benchmark_runner.hpp" #include "gemm_configuration.hpp" -using PvcGemmBF16BF16FP32_RRR = cutlass::gemm::device::GemmConfiguration< +using MMAAtom = MMA_Atom; +using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration< cutlass::arch::IntelPVC, cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, - float>; + float, Shape<_256, _256, _32>, + TiledMMA>>, + XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>; -CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR); +using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_128, _512, _32>, + TiledMMA>>, + XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>; + +using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_256, _128, _32>, + TiledMMA>>, + XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V>; + +using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_128, _256, _16>, + TiledMMA>>, + XE_2D_U16x32x16_LD_N, XE_2D_U16x16x32_LD_V>; + +using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration< + cutlass::arch::IntelPVC, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + cutlass::bfloat16_t, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, Shape<_8, _128, _32>, + TiledMMA>>, + XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V>; + +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); +CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); static void register_benchmarks() { - CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4); + CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5); } diff --git a/benchmarks/pvc/gemm_configuration.hpp b/benchmarks/pvc/gemm_configuration.hpp index 3a07857f5e..c687c76afe 100644 --- a/benchmarks/pvc/gemm_configuration.hpp +++ b/benchmarks/pvc/gemm_configuration.hpp @@ -57,7 +57,9 @@ template< class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, class LayoutC, - class ElementAccumulator> + class ElementAccumulator, + class TileShape, class TiledMma, + class GmemTiledCopyA, class GmemTiledCopyB> struct GemmConfiguration { static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists."); }; @@ -66,47 +68,16 @@ struct GemmConfiguration { // bfloat16 -namespace detail { - -template -struct Gemm_OperandA; - -template -struct Gemm_OperandB; - -template<> -struct Gemm_OperandA { - using GmemTiledCopy = XE_2D_U16x32x32_LD_N; -}; - -template<> -struct Gemm_OperandB { - using GmemTiledCopy = XE_2D_U16x32x32_LD_V; -}; - -} // namespace details - -template +template struct GemmConfiguration< arch::IntelPVC, bfloat16_t, LayoutA, bfloat16_t, LayoutB, float, LayoutC, - float> { - using TileShape = Shape<_256, _256, _32>; - using DispatchPolicy = MainloopIntelPVC<3>;; - using TiledMma = TiledMMA< - MMA_Atom, - Layout>, - Tile<_64,_64,_32>>; - - // A - using OperandA = detail::Gemm_OperandA; - using GmemTiledCopyA = typename OperandA::GmemTiledCopy; - - // B - using OperandB = detail::Gemm_OperandB; - using GmemTiledCopyB = typename OperandB::GmemTiledCopy; + float, TileShape, TiledMma, + GmemTiledCopyA, GmemTiledCopyB> { + using DispatchPolicy = MainloopIntelPVC<3>; // Mainloop using CollectiveMainloop = collective::CollectiveMma< diff --git a/benchmarks/pvc/input.in b/benchmarks/pvc/input.in index 8e68fcd566..4f5d47648c 100644 --- a/benchmarks/pvc/input.in +++ b/benchmarks/pvc/input.in @@ -1,17 +1,23 @@ # BFloat16 benchmarks -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 -PvcGemmBF16BF16FP32_RRR --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=4096 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1 --k=5120 --n=13824 +PvcGemmBF16BF16FP32_RRR_2 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=28672 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=3072 --k=4096 --n=3072 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4 --k=4096 --n=12288 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024 +PvcGemmBF16BF16FP32_RRR_4 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384 +PvcGemmBF16BF16FP32_RRR_5 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096 +PvcGemmBF16BF16FP32_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128 +PvcGemmBF16BF16FP32_RRR_3 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128 diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index a340833989..77cddf7dfb 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -263,11 +263,14 @@ struct CollectiveMma< // Mainloop // auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; + #ifdef CUTLASS_SYCL_SWITCH_WG + const int m_coord = n_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = m_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + #else const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + #endif const int l_coord = l_idx; - - int sub_group_id = get_sub_group_id(); Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{}); @@ -279,13 +282,13 @@ struct CollectiveMma< int prefetch_k = 0; Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( - make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), - (k_start_idx + (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), + make_coord(m_coord + (get_sub_group_id() % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), + (k_start_idx + (get_sub_group_id() % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( - make_coord(((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, - n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), + make_coord(((get_sub_group_id() / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, + n_coord + (get_sub_group_id() / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 2b41f3d31f..d7b49666fc 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -186,8 +186,13 @@ class GemmUniversal< batch_count = cute::size<3>(params.problem_shape); } return dim3( + #ifdef CUTLASS_SYCL_SWITCH_WG + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), + #else cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + #endif batch_count ); } @@ -221,8 +226,13 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); auto blk_shape = TileShape{}; + #ifdef CUTLASS_SYCL_SWITCH_WG + auto m_coord = BlockIdxX(); + auto n_coord = BlockIdxY(); + #else auto m_coord = BlockIdxY(); auto n_coord = BlockIdxX(); + #endif auto l_coord = BlockIdxZ(); auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); int sub_group_id = thread_idx / SubgroupSize;