diff --git a/cmake/FindDPCPP.cmake b/cmake/FindDPCPP.cmake index 6ed98d3e4b..3574ea1604 100644 --- a/cmake/FindDPCPP.cmake +++ b/cmake/FindDPCPP.cmake @@ -57,6 +57,10 @@ if(NOT "${DPCPP_SYCL_ARCH}" STREQUAL "") endif() endif() +if("${DPCPP_SYCL_TARGET}" STREQUAL "intel_gpu_pvc") + list(APPEND DPCPP_FLAGS "-Xspirv-translator;--spirv-ext=+SPV_INTEL_split_barrier") +endif() + if(UNIX) set_target_properties(DPCPP::DPCPP PROPERTIES INTERFACE_COMPILE_OPTIONS "${DPCPP_FLAGS};${DPCPP_COMPILE_ONLY_FLAGS}" diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index 322896e20e..f9c5fca18c 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -41,3 +41,8 @@ cutlass_example_add_executable( pvc_collective_builder pvc_collective_builder.cpp ) + +cutlass_example_add_executable( + pvc_gemm_streamk + pvc_gemm_streamk.cpp +) diff --git a/examples/sycl/pvc/pvc_gemm_streamk.cpp b/examples/sycl/pvc/pvc_gemm_streamk.cpp new file mode 100644 index 0000000000..5eb48e15f7 --- /dev/null +++ b/examples/sycl/pvc/pvc_gemm_streamk.cpp @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "common.h" + +#include "cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp" +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_SYCL_PROFILING_ENABLED + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool splitk; + bool dp; + + int m, n, k, l, iterations, splits; + float alpha, beta; + + Options(): + help(false), + error(false), + splitk(false), + dp(false), + m(5120), n(4096), k(4096), l(1), iterations(20), splits(1), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("splitk")) { + splitk = true; + } + + if (cmd.check_cmd_line_flag("dp")) { + dp = true; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("splits", splits, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --dp If specified, uses Data Parallel decomposition\n" + << " --splitk If specified, uses SplitK decomposition\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --splits= Sets the splitting factor for GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int32_t count; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info, + {options.splits, + options.dp ? cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::DataParallel : + options.splitk ? cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::SplitK : + cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::StreamK} + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + float elapsed_time_seconds = 0.f; + for (int i = 0; i < options.iterations; ++i) { + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info, + {options.splits, + options.dp ? cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::DataParallel : + options.splitk ? cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::SplitK : + cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode::StreamK} + }; + gemm_op.initialize(arguments, workspace.get()); + timer.start(); + gemm_op.run(); + syclcompat::wait(); + elapsed_time_seconds += timer.seconds(); + } + + float cute_time = elapsed_time_seconds / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_64,_64,_32>>; // Subgroup level-tile + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index cd2d7be3cb..0e1f344f27 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -36,7 +36,16 @@ #include #include -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) + +#if defined(SYCL_INTEL_TARGET) +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierWaitINTEL(int execution_scope, int memory_scope, int memory_semantics); +SYCL_EXTERNAL __attribute__((convergent)) void __spirv_ControlBarrierArriveINTEL(int execution_scope, int memory_scope, int memory_semantics); + +#define EXECUTION_SCOPE_WORK_GROUP 2 +#define MEMORY_SCOPE_WORK_GROUP 2 +#define MEMORY_SEMANTICS_RELAXED 0 + +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) #define CUDA_BARRIER_ENABLED 1 #else #define CUDA_BARRIER_ENABLED 0 @@ -151,7 +160,10 @@ class NamedBarrier { private: CUTLASS_DEVICE static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { -#if CUDA_BARRIER_ENABLED +#if defined(SYCL_INTEL_TARGET) + __spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED); + __spirv_ControlBarrierWaitINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED); +#elif CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); @@ -160,7 +172,9 @@ class NamedBarrier { CUTLASS_DEVICE static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { -#if CUDA_BARRIER_ENABLED +#if defined(SYCL_INTEL_TARGET) + __spirv_ControlBarrierArriveINTEL(EXECUTION_SCOPE_WORK_GROUP, MEMORY_SCOPE_WORK_GROUP, MEMORY_SEMANTICS_RELAXED); +#elif CUDA_BARRIER_ENABLED asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); #elif defined(__CUDA_ARCH__) asm volatile ("brkpt;\n" ::); diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 9b2362a9c5..1cfc73c1cb 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -97,7 +97,12 @@ struct GenericBarrier { { int state = 0; -#if (__CUDA_ARCH__ >= 700) +#if defined (SYCL_INTEL_TARGET) + auto atm = sycl::atomic_ref(*ptr); + return atm.load(sycl::memory_order::acquire); +#elif (__CUDA_ARCH__ >= 700) /// SM70 and newer use memory consistency qualifiers // Acquire pattern using acquire modifier diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 9fa7bf7ba1..57835f8d35 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -782,7 +782,7 @@ struct atomic_add CUTLASS_DEVICE void operator()(T *ptr, const T &data) { -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__SYCL_DEVICE_ONLY__) atomicAdd(ptr, data); #endif } diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index b08b79c2f8..a340833989 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -187,12 +187,15 @@ struct CollectiveMma< /// Perform a subgroup-scoped matrix multiply-accumulate template < + int PrefetchStrideA, + int PrefetchStrideB, class FrgTensorD, class TensorA, class TensorB, class FrgTensorC, class KTileIterator, - class ResidueMNK + class ResidueMNK, + class BlkCoord > CUTLASS_DEVICE void operator() ( @@ -202,6 +205,8 @@ struct CollectiveMma< FrgTensorC const &src_accum, KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, + BlkCoord const &blk_coord, + int const &K, int thread_idx, char *smem_buf, Params const& mainloop) @@ -257,10 +262,12 @@ struct CollectiveMma< // // Mainloop // + auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; + const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + const int l_coord = l_idx; + int sub_group_id = get_sub_group_id(); - const int m_coord = BlockIdxY() * BLK_M + (sub_group_id / ATOM_N) * SG_M; - const int n_coord = BlockIdxX() * BLK_N + (sub_group_id % ATOM_N) * SG_N; - const int l_coord = BlockIdxZ(); Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), append<4>(tCrA_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{}); @@ -268,44 +275,49 @@ struct CollectiveMma< make_coord(0, n_coord, l_coord), append<4>(tCrB_copy_view.shape(), k_tile_count), append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{}); + const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K)); + int prefetch_k = 0; + Tensor prefetch_iter_a = mainloop.gmem_prefetch_a.get_pvc_tensor( make_coord(m_coord + (sub_group_id % ATOM_N) / get<1>(PrefetchAThrShape{}) * get<0>(PrefetchATileSize{}), - (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{}) * get<1>(PrefetchATileSize{}), l_coord), + (k_start_idx + (sub_group_id % ATOM_N) % get<1>(PrefetchAThrShape{})) * PrefetchStrideA, l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_M, SG_K), BLK_K), seq<0, 1, 1>{}); Tensor prefetch_iter_b = mainloop.gmem_prefetch_b.get_pvc_tensor( - make_coord((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) * get<0>(PrefetchBTileSize{}), + make_coord(((sub_group_id / ATOM_N) / get<1>(PrefetchBThrShape{}) + k_start_idx) * PrefetchStrideB, n_coord + (sub_group_id / ATOM_N) % get<1>(PrefetchBThrShape{}) * get<1>(PrefetchBTileSize{}), l_coord), append<4>(make_shape(_1{}, _1{}, _1{}), k_tile_count), append<3>(make_shape(SG_K, SG_N), BLK_K), seq<0,1,0>{}); -#pragma unroll - for (int i = 0; i < DispatchPolicy::Stages; i++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,i)); + prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); } if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,i)); + prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); } } -#pragma unroll - for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + + CUTLASS_PRAGMA_UNROLL + for (int k_tile = 0, k = k_start_idx; k_tile < k_tile_count; ++k_tile, ++k, ++prefetch_k) { // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile), tCrA_copy_view); - copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile), tCrB_copy_view); + copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k), tCrA_copy_view); + copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k), tCrB_copy_view); - if(k_tile + DispatchPolicy::Stages < k_tile_count) { - if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,k_tile + DispatchPolicy::Stages)); + if(prefetch_k < k_tile_count) { + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_a, prefetch_iter_a(_,_,_,prefetch_k)); + } + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,prefetch_k)); + } } - if constexpr(cute::detail::has_prefetch) { - prefetch(mainloop.gmem_tiled_copy_b, prefetch_iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); + + for (int i = 0; i < SG_K / SubgroupSize; i++) { + cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); } } - for (int i = 0; i < SG_K / SubgroupSize; i++) { - cute::gemm(tiled_mma, accum, tCrA(_, _, i), tCrB(_, i, _), src_accum); - } - } } }; diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 81327ef7b8..839ebb23cd 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -65,5 +65,6 @@ struct IsCutlass3ArrayKernel; }; +#if defined (SYCL_INTEL_TARGET) +template < + class TileShape, + class ClusterShape +> +struct TileSchedulerSelector< + StreamKScheduler, + arch::IntelPVC, + TileShape, + ClusterShape + > { + using Scheduler = PersistentTileSchedulerXeStreamK; +}; +#endif + template < class TileShape, class ClusterShape diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 9e4968acad..1a014ba16f 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -52,7 +52,8 @@ class GemmUniversal< CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, - cute::enable_if_t>> + cute::enable_if_t + && !cute::is_same_v>> { public: // @@ -104,6 +105,10 @@ class GemmUniversal< static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + using PrefetchATileSize = typename CollectiveMainloop::PrefetchATileSize; + using PrefetchBTileSize = typename CollectiveMainloop::PrefetchBTileSize; + static constexpr int PrefetchStrideA = static_cast(get<1>(PrefetchATileSize{})); + static constexpr int PrefetchStrideB = static_cast(get<0>(PrefetchBTileSize{})); // Kernel level shared memory storage struct SharedStorage { @@ -235,18 +240,20 @@ class GemmUniversal< Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); clear(accumulators); - auto k_tile_iter = cute::make_coord_iterator(make_shape(K / get<2>(workgroup_shape))); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(0, make_shape(K)), make_shape(K)); int k_tile_count = K / get<2>(workgroup_shape); // Perform the collective scoped MMA CollectiveMainloop collective_mma; - collective_mma( + collective_mma.template operator()( accumulators, gA, gB, accumulators, k_tile_iter, k_tile_count, residue_mnk, + blk_coord_mnkl, + K, thread_idx, smem_buf, params.mainloop diff --git a/include/cutlass/gemm/kernel/xe_gemm_cooperative.hpp b/include/cutlass/gemm/kernel/xe_gemm_cooperative.hpp new file mode 100644 index 0000000000..7817b1396d --- /dev/null +++ b/include/cutlass/gemm/kernel/xe_gemm_cooperative.hpp @@ -0,0 +1,322 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t + && cute::is_same_v>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + + using PrefetchATileSize = typename CollectiveMainloop::PrefetchATileSize; + static constexpr int PrefetchStrideA = static_cast(get<1>(PrefetchATileSize{})); + static constexpr int PrefetchStrideB = static_cast(CollectiveMainloop::SG_K); + + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, hw_info, args.scheduler, workspace_ptr); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace_ptr), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace_ptr), + hw_info, + scheduler, + workspace + }; + } + + static bool + can_implement(Arguments const& args) { + bool mode_implementable = args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + return mode_implementable && TileScheduler::can_implement(args.scheduler); + } + + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info); + return workspace_size; + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr, args.problem_shape, args.hw_info); + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, params.hw_info); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.initial_work_tile_info(); + + int thread_idx = int(ThreadIdxX()); + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (BLK_M,BLK_N,BLK_K) + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + + while (work_tile_info.is_valid()) { + const int m_coord = work_tile_info.M_idx; + const int n_coord = work_tile_info.N_idx; + const int l_coord = work_tile_info.L_idx; + const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); + + Tensor mA_mkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(M,K,L), StrideA{}); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(N,K,L), StrideB{}); //(n,k,l) + Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) + Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) + + auto gA = local_tile(mA_mk, workgroup_shape, take<0, 3>(tile_coord), Step<_1, X, _1>{}); + auto gB = local_tile(mB_nk, workgroup_shape, take<0, 3>(tile_coord), Step< X, _1, _1>{}); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + const int work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, workgroup_shape); + const int work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, make_shape(K)), make_shape(K)); + + auto k_residue = K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + + // Compute tile residues for predication + auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord + auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(workgroup_shape)); + + CollectiveMainloop collective_mma; + + // Perform the collective scoped MMA + collective_mma.template operator()( + accumulators, + gA, + gB, + accumulators, + k_tile_iter, work_k_tile_count, + residue_mnk, + tile_coord, + K, + thread_idx, + smem_buf, + params.mainloop + ); + + // Perform reduction across splits, if needed + TileScheduler::template fixup( + params.scheduler, work_tile_info, accumulators); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + + epilogue( + problem_shape_MNKL, + subgroup_shape, + tile_coord, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + } + + // Get next work tile + work_tile_info = scheduler.fetch_next_work(work_tile_info); + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp b/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp new file mode 100644 index 0000000000..fc84b99ff7 --- /dev/null +++ b/include/cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp @@ -0,0 +1,704 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief Parameters structures for persistent tile schedulers +*/ + +#include "cutlass/coord.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/workspace.h" +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm_coord.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { +namespace detail { + +//////////////////////////////////////////////////////////////////////////////// +// Parameters for Xe persistent stream-K scheduler +struct PersistentTileSchedulerXeStreamKParams { + + // Strategies for computing reductions between work-groups computing portions of a given output tile + enum class ReductionMode { + // Participating work-groups perform reduction in a turnstile fashion in order of the K extent + // covered by each work-group. This requires a lock to be held exclusively be the work-group that is + // currently accumulating. + // + // Turnstile accumulation ensures deterministic numeric behavior when using this mode. + Deterministic, + + // Participating work-groups perform reduction atomically to the same workspace (mostly) without locking. + // Locks are used only to wait for the first work-group to write its partial values (to initialize the + // workspace), and for all but the final work-group to have accumulated (so that the final work-group can load + // the accumulated value and accumulate it into registers on top of which the epilogue will + // be performed). + // + // Due to the nondeterminsitic ordering of accumulation, deterministic numeric behavior cannot + // be guaranteed with this mode (e.g., floating-point rounding error will depend on the order + // of accumulation) + Nondeterministic + }; + + // Strategies for decomposing the problem + enum class DecompositionMode { + // Use a heuristic to determine whether data-parallel, split-K, or stream-K decomposition should be performed + Heuristic, + // Force a data-parallel decomposition + DataParallel, + // Force a split-K decomposition. This should be paired with setting the `splits` parameter + SplitK, + // Force a stream-K decomposition + StreamK + }; + + FastDivmodU64 divmod_batch_{}; + FastDivmodU64 divmod_blk_major_{}; + + // Divide up the number of stream-K tiles amongst G groups of stream-K units. + // Currently defaults to 1 since we don't create groups for Xe. + FastDivmodU64 divmod_sk_groups_{}; + + // Number of stream-K units in each group + FastDivmodU64 divmod_sk_units_per_group_{}; + + uint64_t units_per_problem_ = 0; + FastDivmod divmod_tiles_per_output_tile_{}; + + // The splitting factor to be used in a split-K decomposition of the problem. + // If this is set to a value greater than 1, stream-K decomposition logic + // is bypassed in favor of a split-K decomposition. + FastDivmod divmod_splits_{}; + + // Number of stream-K or split-K work units that compute an extra k iteration. + // This is done to handle residuals in dividing up the k iteration space. + uint32_t big_units_ = 0; + + // The number of groups of stream-K units that will process an extra stream-K tile. + uint32_t big_groups_ = 0; + + // Workspace for holding partial accumulators to be reduced across stream-K/split-K units + void* reduction_workspace_ = nullptr; + + // Number of tiles covered by stream-K work units + uint32_t sk_tiles_ = 0; + + // Number of work units computing stream-K tiles + uint32_t sk_units_ = 0; + + // Number of tiled k iterations computed by each stream-K work unit. This + // can potentially cover more than one output tile. + FastDivmod divmod_k_tiles_per_sk_unit_{}; + // Number of tiled k iterations computed by each "big" stream-K units, which + // processes one more K chunk than a "normal" stream-K unit. + FastDivmod divmod_k_tiles_per_sk_big_unit_{}; + + // Strategy to use when reducing between collaborating work-groups + ReductionMode reduction_mode_ = ReductionMode::Deterministic; + + // Minimum number of k tiles that can be assigned to a stream-K unit + static constexpr uint32_t min_iters_per_sk_unit_ = 8u; + + // Maximum number of groups of stream-K units + static constexpr uint32_t max_sk_groups_ = 1u; + + // ktile start from even for each cta + uint32_t ktile_start_alignment_count { 1u }; + + // Initializes members. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + void + initialize( + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + KernelHardwareInfo hw_info, + int splits, + ReductionMode reduction_mode, + DecompositionMode decomposition_mode, + void* workspace + ) { + + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape, tile_shape); + // Number of k tiles in each output tile + uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); + + initialize( + problem_blocks, + k_tiles_per_output_tile, + hw_info, + splits, + reduction_mode, + decomposition_mode, + workspace + ); + } + + // Version of initialize that takes in as input the number of work-groups in the M and N and L dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or work-group shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + void + initialize( + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + KernelHardwareInfo hw_info, + int splits, + ReductionMode reduction_mode, + DecompositionMode decomposition_mode, + void* workspace + ) { + + auto problem_blocks_l = problem_blocks.z; + + auto problem_blocks_m = problem_blocks.x; + auto problem_blocks_n = problem_blocks.y; + uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; + + // Reduction workspace is at the beginning of the workspace. Lock workspace follows. + void* reduction_workspace = workspace; + + if (decomposition_mode == DecompositionMode::SplitK || + (decomposition_mode == DecompositionMode::Heuristic && splits > 1)) { + // Short circuit to basic split-K decomposition + + // Don't split by more than the available number of SMs + if (splits > hw_info.sm_count) { + splits = hw_info.sm_count; + } + + // Don't split by more than the K tile iterations + // + // splits is almost certainly nonnegative here (e.g., hw_info.sm_count, + // despite being an int, is a count), so it can safely be converted to unsigned + // in the comparison to avoid a signed-unsigned comparison warning-as-error. + if (static_cast(splits) > k_tiles_per_output_tile) { + splits = k_tiles_per_output_tile; + } + + // If splits == k_tiles_per_output_tiles, there will be one k_tile per cta + // and this violate k_tile start from even requirements. Thus we need to + // reduce the number of splits. + if (ktile_start_alignment_count > 1u && + static_cast(splits) == k_tiles_per_output_tile) { + splits = k_tiles_per_output_tile / ktile_start_alignment_count; + } + + set_params_basic( + problem_blocks_m, + problem_blocks_n, + problem_blocks_l, + splits, + k_tiles_per_output_tile, + reduction_workspace, + reduction_mode + ); + return; + } + + // Calculate the maximum number of blocks that we can fit within sm_count SMs. + dim3 grid = get_grid_shape( + problem_blocks, + hw_info + ); + + uint64_t wgs_per_wave = grid.x * grid.y; + // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. + uint32_t sk_tiles = get_num_sk_tiles( + output_tiles, + wgs_per_wave, + k_tiles_per_output_tile, + decomposition_mode + ); + uint64_t dp_tiles = output_tiles - sk_tiles; + + // Calculate the number of work units covering the data-parallel and stream-K tiles. + // A "work unit" is a single index in the linearized ID space used by the scheduler. + // A work unit can encompass multiple output tiles worth of work (as will be the + // case for stream-K blocks). + // Since splitting is not required for data-parallel tiles, only one data-parallel unit + // is needed per data-parallel tile. + uint64_t dp_units = dp_tiles; + + uint64_t wgs_per_sk_wave = wgs_per_wave; + uint64_t sk_units = get_num_sk_units(wgs_per_sk_wave, sk_tiles, k_tiles_per_output_tile); + + if (decomposition_mode == DecompositionMode::DataParallel || + (decomposition_mode == DecompositionMode::Heuristic && sk_tiles == 0) || + sk_units == 0) { + // Short circuit to basic data-parallel decomposition + set_params_basic( + problem_blocks_m, + problem_blocks_n, + problem_blocks_l, + /* splits = */ 1, + k_tiles_per_output_tile, + reduction_workspace, + reduction_mode + ); + return; + } + + uint32_t groups = max_sk_groups_; + + auto sk_units_per_group = sk_units / groups; + + uint64_t sk_tiles_per_group = sk_tiles / groups; + + // Groups that will process an extra stream-K tile. These differ from "big_units," which + // are stream-K units within a group that process an extra K chunk. + uint64_t sk_big_groups = sk_tiles % groups; + + uint64_t k_tiles_per_group = k_tiles_per_output_tile * sk_tiles_per_group; + + // Number of k tiles computed per stream-K unit + uint64_t k_tiles_per_sk_unit = k_tiles_per_group / sk_units_per_group; + + uint32_t reduction_units = 0; + + // Use separate reduction when we have less than one wave of output tiles (dp_tiles == 0) + // and when each tile will be operated on by at least two stream-K units (sk_units > 2 * sk_tiles) + if (decomposition_mode == DecompositionMode::Heuristic && sk_tiles < sk_units && sk_units % sk_tiles == 0) { + // If the number of stream-K units is a multiple of the number of stream-K tiles, then + // the problem can leverage a basic split-K decomposition for the stream-K tiles. + // This case happens when separate reduction is disable. + uint32_t sk_splits = static_cast(sk_units / sk_tiles); + set_params_basic( + problem_blocks_m, + problem_blocks_n, + problem_blocks_l, + sk_splits, + k_tiles_per_output_tile, + reduction_workspace, + reduction_mode + ); + return; + } + + divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n); + divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); + divmod_sk_groups_ = FastDivmodU64(static_cast(groups)); + divmod_sk_units_per_group_ = FastDivmodU64(static_cast(sk_units / groups)); + + divmod_blk_major_ = FastDivmodU64(problem_blocks_n); + + divmod_splits_ = FastDivmod(splits); + units_per_problem_ = static_cast(dp_units + sk_units); + + // Assign big_units_ assuming that group count == 1. This is unused by stream-K + // when group count > 1. + big_units_ = static_cast(k_tiles_per_group % k_tiles_per_sk_unit); + + big_groups_ = static_cast(sk_big_groups); + reduction_workspace_ = reduction_workspace; + sk_tiles_ = sk_tiles; + sk_units_ = static_cast(sk_units); + divmod_k_tiles_per_sk_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit)); + divmod_k_tiles_per_sk_big_unit_ = FastDivmod(static_cast(k_tiles_per_sk_unit + 1)); + reduction_mode_ = reduction_mode; + } + + static CUTLASS_DEVICE + cute::tuple + get_work_idx_m_and_n( + uint64_t blk_per_grid_dim, + FastDivmodU64 const& divmod_blk_major) { + + uint64_t m_idx, n_idx; + divmod_blk_major(m_idx, n_idx, blk_per_grid_dim); + auto i = static_cast(m_idx); + auto j = static_cast(n_idx); + + return {i, j}; + } + + // Computes the linear index within a batch given M and N tile offsets within the batch. + // This essentially inverts the mapping performed in get_work_idx_m_and_n + static CUTLASS_DEVICE + uint64_t + get_linear_idx_from_m_and_n( + int32_t tile_m, + int32_t tile_n, + FastDivmodU64 const& divmod_blk_major) { + return static_cast(tile_m * divmod_blk_major.divisor + tile_n); + } + + // Get the number of work-group tiles in this problem. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + CUTLASS_HOST_DEVICE + static dim3 + get_tiled_wg_shape_mnl(BatchedGemmCoord problem_shape, GemmCoord cta_shape) { + auto cta_m = (problem_shape.m() + cta_shape.m() - 1) / cta_shape.m(); + auto cta_n = (problem_shape.n() + cta_shape.n() - 1) / cta_shape.n(); + + return { + static_cast(cta_m), + static_cast(cta_n), + static_cast(problem_shape.batch()) + }; + } + + CUTLASS_HOST_DEVICE + static dim3 + get_grid_shape( + dim3 problem_blocks, + KernelHardwareInfo hw_info, + bool truncate_range = true + ) { + uint32_t available_sms = hw_info.sm_count / 8; + auto possibly_truncate = [&](int x, int y) { + if(truncate_range) + return static_cast(platform::min(x, y)); + else + return static_cast(x); + }; + return dim3{possibly_truncate(available_sms, problem_blocks.x * problem_blocks.y * problem_blocks.z), 1, 1}; + } + + // Returns the number of stream-K tiles that will be computed amongst `output_tiles` total + // output tiles on a device with `wgs_per_wave` work-groups in each wave. + static uint32_t + get_num_sk_tiles( + uint64_t output_tiles, + uint64_t wgs_per_wave, + uint32_t k_tiles_per_output_tile, + DecompositionMode decomposition_mode + ) { + uint32_t full_waves = static_cast(output_tiles / wgs_per_wave); + uint32_t total_waves = static_cast((output_tiles + wgs_per_wave - 1) / wgs_per_wave); + + if (decomposition_mode == DecompositionMode::DataParallel || + decomposition_mode == DecompositionMode::SplitK) { + return 0; + } + + // If there is wave quantization, assign the first two waves worth of tiles to be + // covered by stream-K work and the remainder to be data-parallel. Since we know + // that full_waves == total_waves - 1 in this case, the number of data-parallel + // waves is simply full_waves-1 (unless full_waves == 0). + uint32_t dp_waves = full_waves > 1 ? full_waves - 1 : 0; + uint64_t dp_tiles = dp_waves * wgs_per_wave; + uint64_t sk_tiles = output_tiles - dp_tiles; + + if (decomposition_mode == DecompositionMode::Heuristic) { + if (full_waves == total_waves || k_tiles_per_output_tile <= min_iters_per_sk_unit_) { + // All tiles will be data-parallel tiles if there is either no quantization + // or if there is no work to be split. + return 0; + } + + // + // The final wave is not full. Perform some stream-K work. + // + + // Rudimentary heuristic: prefer data-parallel decomposition if we have more than + // one wave and the tail wave is more than half full. This is subject to change. + uint64_t tail_tiles = output_tiles - (full_waves * wgs_per_wave); + if (2 * tail_tiles >= wgs_per_wave) { + return 0; + } + } + + return static_cast(sk_tiles); + } + + CUTLASS_HOST_DEVICE + static uint64_t + get_num_sk_units(uint64_t wgs_per_sk_wave, uint32_t sk_tiles, uint32_t k_tiles_per_output_tile) { + // If there are stream-K tiles to compute and a sufficiently large number of k iterations + // across them, they will be covered by a single wave of persistent work_groups. Thus, there + // will be as many work units as there are work_groups in a single wave. + // + // When the total k iterations across stream-K tiles is too small to justify distributing + // across an entire wave of work_groups, we instead distribute the iterations over a smaller + // set of work_groups. + + // Calculate the number of stream-K units that would be needed if each stream-K unit + // computed the minimum allowable k iterations. + + // Number of k iterations computed by the stream-K units as a whole + uint64_t k_tiles_sk_total = k_tiles_per_output_tile * sk_tiles; + + // Calculate the number of stream-K units that would be needed if each stream-K unit + // computed the minimum allowable k iterations. + uint64_t min_sized_sk_units = (k_tiles_sk_total / min_iters_per_sk_unit_); + + uint64_t sk_units = platform::min(wgs_per_sk_wave, min_sized_sk_units); + return sk_units; + } + + // Calculates the size of the workspace needed for holding reduction barriers + CUTLASS_HOST_DEVICE + static size_t + get_barrier_workspace_size(uint64_t num_tiles, uint32_t barrier_bits) { + size_t workspace_bits = num_tiles * static_cast(barrier_bits); + return bits_to_bytes(workspace_bits); + } + + // Calculates the size of the workspace needed for holding partial outputs from splits + CUTLASS_HOST_DEVICE + static size_t + get_reduction_workspace_size(uint64_t num_tiles, GemmCoord tile_shape, uint32_t accumulator_bits, uint32_t num_accumulator_mtxs = 1) { + size_t output_tile_size = tile_shape.m() * tile_shape.n(); + size_t workspace_bits = accumulator_bits * output_tile_size * num_tiles * num_accumulator_mtxs; + return bits_to_bytes(workspace_bits); + } + + static void + get_workspace_component_sizes( + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord tile_shape, + size_t& barrier_workspace_size, + size_t& reduction_workspace_size, + KernelHardwareInfo const& hw_info, + int splits, + DecompositionMode decomposition_mode, + uint32_t barrier_bits, + uint32_t accumulator_bits) { + + // Workspace is needed only for output tiles that will be split. Thus, we first determine the number + // of output tiles that will be split, and then calculate the workspace needed to cover these. + uint64_t output_tiles = problem_blocks.x * problem_blocks.y * problem_blocks.z; + + if (decomposition_mode == DecompositionMode::DataParallel) { + barrier_workspace_size = 0; + reduction_workspace_size = 0; + } + else if (splits > 1 && + (decomposition_mode == DecompositionMode::SplitK || decomposition_mode == DecompositionMode::Heuristic)) { + // Basic split-K variant requires workspace for all output tiles + barrier_workspace_size = get_barrier_workspace_size(output_tiles, barrier_bits); + reduction_workspace_size = get_reduction_workspace_size(output_tiles, tile_shape, accumulator_bits); + } + else { + KernelHardwareInfo new_hw_info; + new_hw_info.device_id = hw_info.device_id; + new_hw_info.sm_count = hw_info.sm_count; + if (new_hw_info.sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + new_hw_info.sm_count = KernelHardwareInfo::query_device_multiprocessor_count(new_hw_info.device_id); + } + + dim3 grid = get_grid_shape( + problem_blocks, + new_hw_info + ); + uint64_t wgs_per_wave = grid.x * grid.y; + uint32_t sk_tiles = get_num_sk_tiles( + output_tiles, + wgs_per_wave, + static_cast(k_tiles_per_output_tile), + decomposition_mode + ); + uint64_t wgs_per_sk_wave = wgs_per_wave; + uint64_t sk_units = get_num_sk_units(wgs_per_sk_wave, sk_tiles, k_tiles_per_output_tile); + uint64_t dp_tiles = output_tiles - sk_tiles; + + uint64_t reduction_tiles = sk_tiles; + + barrier_workspace_size = get_barrier_workspace_size(sk_tiles, barrier_bits); + reduction_workspace_size = get_reduction_workspace_size(reduction_tiles, tile_shape, accumulator_bits); + } + } + + // Get the amount of scratch workspace needed for the kernel. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + static size_t + get_workspace_size( + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + KernelHardwareInfo const& hw_info, + int splits, + DecompositionMode decomposition_mode, + uint32_t barrier_bits, + uint32_t element_accumulator_bits) { + + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape, tile_shape); + uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); + + return get_workspace_size( + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + hw_info, + splits, + decomposition_mode, + barrier_bits, + element_accumulator_bits + ); + } + + // Version of get_workspace_size that takes in as input the number of work-groups in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or work-group shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + static size_t + get_workspace_size( + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord tile_shape, + KernelHardwareInfo const& hw_info, + int splits, + DecompositionMode decomposition_mode, + uint32_t barrier_bits, + uint32_t element_accumulator_bits) { + + size_t barrier_workspace_size = 0; + size_t reduction_workspace_size = 0; + + get_workspace_component_sizes( + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + barrier_workspace_size, + reduction_workspace_size, + hw_info, + splits, + decomposition_mode, + barrier_bits, + element_accumulator_bits + ); + + return barrier_workspace_size + reduction_workspace_size; + } + + // Initialize the workspace to be used for the kernel. This variant of the method should only be used when + // problem_shape and tile_shape contain modes of only rank 1. + static cutlass::Status + initialize_workspace( + void* workspace, + BatchedGemmCoord problem_shape, + GemmCoord tile_shape, + KernelHardwareInfo const& hw_info, + int splits, + DecompositionMode decomposition_mode, + uint32_t barrier_bits, + uint32_t element_accumulator_bits) { + + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape, tile_shape); + uint32_t k_tiles_per_output_tile = (problem_shape.k() + tile_shape.k() - 1) / tile_shape.k(); + + return initialize_workspace( + workspace, + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + hw_info, + splits, + decomposition_mode, + barrier_bits, + element_accumulator_bits + ); + } + + // Version of initialize_workspace that takes in as input the number of work-groups in the M and N dimensions. + // This is useful for calculating the tiled shape when a mode of problem and/or work-group shape has rank > 1, + // for which using CuTe algebra for calculating tile shapes is easiest. + static cutlass::Status + initialize_workspace( + void* workspace, + dim3 problem_blocks, + uint32_t k_tiles_per_output_tile, + GemmCoord tile_shape, + KernelHardwareInfo const& hw_info, + int splits, + DecompositionMode decomposition_mode, + uint32_t barrier_bits, + uint32_t element_accumulator_bits) { + + uint64_t barrier_workspace_size = 0; + uint64_t reduction_workspace_size = 0; + + get_workspace_component_sizes( + problem_blocks, + k_tiles_per_output_tile, + tile_shape, + barrier_workspace_size, + reduction_workspace_size, + hw_info, + splits, + decomposition_mode, + barrier_bits, + element_accumulator_bits + ); + + if (barrier_workspace_size > 0) { + if (workspace == nullptr) { + return Status::kErrorWorkspaceNull; + } + + // Only the barrier workspace needs to be cleared for stream-K. + // Barrier workspace follows reduction workspace. + uint8_t* barrier_workspace = reinterpret_cast(workspace) + reduction_workspace_size; + return zero_workspace(static_cast(barrier_workspace), barrier_workspace_size); + } + + return Status::kSuccess; + } + + void + set_params_basic( + uint32_t blocks_m, + uint32_t blocks_n, + uint32_t blocks_l, + uint32_t splits, + uint32_t k_tiles_per_output_tile, + void* reduction_workspace, + ReductionMode reduction_mode) { + + divmod_batch_ = FastDivmodU64(blocks_m * blocks_n); + divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); + divmod_sk_groups_ = FastDivmodU64(1u); + divmod_splits_ = FastDivmod(splits); + divmod_blk_major_ = FastDivmodU64(blocks_n); + units_per_problem_ = blocks_m * blocks_n * blocks_l; + big_units_ = k_tiles_per_output_tile % splits; + reduction_workspace_ = reduction_workspace; + reduction_mode_ = reduction_mode; + divmod_k_tiles_per_sk_unit_ = FastDivmod(k_tiles_per_output_tile / splits); + + // No stream-K work is performed for "basic" data-parallel and split-K decompositions + sk_tiles_ = 0; + sk_units_ = 0; + divmod_sk_units_per_group_ = FastDivmodU64(blocks_m * blocks_n * blocks_l); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace detail +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp b/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp new file mode 100644 index 0000000000..4b405b39cc --- /dev/null +++ b/include/cutlass/gemm/kernel/xe_tile_scheduler_streamk.hpp @@ -0,0 +1,691 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/barrier.h" +#include "cutlass/block_striped.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/gemm/kernel/xe_persistent_tile_scheduler_params_streamk.hpp" + +namespace cutlass::gemm::kernel::detail { + +// Persistent Thread Block (TB) scheduler leveraging stream-K decomposition +template < + class TileShape +> +class PersistentTileSchedulerXeStreamK { + // + // Data members + // + +private: + uint64_t current_work_linear_idx_ = 0; + +public: + + // Use a dummy barrier manager to simply get the type used to store the barrier + using BarrierType = typename NamedBarrierManager<1>::T; + + using Params = PersistentTileSchedulerXeStreamKParams; + using ReductionMode = Params::ReductionMode; + using DecompositionMode = Params::DecompositionMode; + + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t K_idx = 0; + int32_t L_idx = 0; + + // Number of k tiles to compute for this unit of work. For stream-K, this + // can indicate the number of K tiles across multiple output tiles. + uint32_t k_tile_count = 0; + + // Number of k tiles remaining for the work unit as a whole + uint32_t k_tile_remaining = 0; + + CUTLASS_HOST_DEVICE + bool + is_valid() const { + // A work tile that computes no K tiles is invalid + return k_tile_count > 0; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, -1, 0}; + } + + CUTLASS_HOST_DEVICE + bool + is_final_split(uint32_t k_tiles_per_output_tile) const { + return (K_idx + k_tile_count) == k_tiles_per_output_tile; + } + }; + + struct Arguments { + + Arguments() = default; + Arguments(Arguments const&) = default; + Arguments(Arguments&&) = default; + + CUTLASS_HOST_DEVICE + Arguments& + operator=(Arguments const& args) { + splits = args.splits; + reduction_mode = args.reduction_mode; + decomposition_mode = args.decomposition_mode; + return *this; + } + + CUTLASS_HOST_DEVICE + Arguments& + operator=(Arguments&& args) noexcept { + splits = args.splits; + reduction_mode = args.reduction_mode; + decomposition_mode = args.decomposition_mode; + return *this; + } + + CUTLASS_HOST_DEVICE + Arguments(int splits_) : splits(splits_) {} + + CUTLASS_HOST_DEVICE + Arguments(int splits_, DecompositionMode decomposition_mode_) : + splits(splits_), + decomposition_mode(decomposition_mode_) {} + + // The splitting factor to be used in a split-K decomposition of the problem. + // If this is set to a value greater than 1, stream-K decomposition logic + // is bypassed in favor of a split-K decomposition. + int splits = 1; + ReductionMode reduction_mode = ReductionMode::Deterministic; + DecompositionMode decomposition_mode = DecompositionMode::Heuristic; + }; + + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // + + template + static Params + to_underlying_arguments( + ProblemShape problem_shape, + TileShape tile_shape, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace) { + + static_assert(cute::is_static::value); + + auto problem_shape_mnkl = cute::append<4>(problem_shape, cute::Int<1>{}); + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape_mnkl, tile_shape); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + Params params; + params.initialize( + problem_blocks, + k_tile_per_output_tile, + hw_info, + args.splits, + args.reduction_mode, + args.decomposition_mode, + workspace + ); + return params; + } + + static bool + can_implement(Arguments const& args) { + // Split count > 1 is only valid for heuristic and split-K decomposition modes + return (args.splits == 1 || + args.decomposition_mode == DecompositionMode::Heuristic || + args.decomposition_mode == DecompositionMode::SplitK); + } + + CUTLASS_HOST_DEVICE + PersistentTileSchedulerXeStreamK() { }; + + CUTLASS_HOST_DEVICE + PersistentTileSchedulerXeStreamK(Params const& params_) : scheduler_params(params_) { + current_work_linear_idx_ = uint64_t(BlockIdxX()); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() const { + return get_current_work_for_linear_idx(current_work_linear_idx_, scheduler_params); + } + + CUTLASS_DEVICE + static WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) { + // The maximum number of work units is units_per_problem_ * splits_. + // The multiplication by splits_ is used for handling split-K, in which + // units_per_problem_ is equal to the total number of output tiles. To account + // for the fact that we have splits_ peers per output tile, we multiply this + // value by splits_. For stream-K, this multiplication ends up being a no-op + // because splits_ is set to 1 for stream-K. + if(linear_idx >= (params.units_per_problem_ * params.divmod_splits_.divisor)) { + // Invalid work. Return an empty result. + return WorkTileInfo::invalid_work_tile(); + } + + WorkTileInfo work_tile_info; + assign_work(params, linear_idx, work_tile_info); + return work_tile_info; + } + + // Returns whether the current work_tile_info passed in should continue to be used. This + // occurs only in the stream-K decomposition with stream-K work units, which encompass + // work over multiple output tiles. If the current work_tile_info should continue to be + // used, it is updated to advance to the next output tile it should cover. + CUTLASS_DEVICE + bool + continue_current_work(WorkTileInfo& work_tile_info) const { + return continue_current_work_for_linear_idx( + current_work_linear_idx_, work_tile_info, scheduler_params); + } + + CUTLASS_DEVICE + static bool + continue_current_work_for_linear_idx( + uint64_t linear_idx, + WorkTileInfo& work_tile_info, + Params const& params) { + + work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count; + + if (work_tile_info.k_tile_remaining == 0) { + return false; + } + assign_work(params, linear_idx, work_tile_info); + return work_tile_info.is_valid(); + } + + CUTLASS_DEVICE + void + advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += uint64_t(GridDimX()) * uint64_t(GridDimY()) * uint64_t(GridDimZ()) * uint64_t(advance_count); + } + + // Given the inputs, computes the total number of output work-groups this problem will compute over. + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_wg_shape_mnl(ProblemShape problem_shape_mnkl, TileShape cta_shape) { + return Params::get_tiled_wg_shape_mnl(to_gemm_coord(problem_shape_mnkl), to_gemm_coord(cta_shape)); + } + + // Computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + ProblemShape problem_shape, + TileShape tile_shape, + KernelHardwareInfo hw_info) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, cute::Int<1>{}); + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape_mnkl, tile_shape); + + return Params::get_grid_shape( + problem_blocks, + hw_info + ); + } + + // Returns whether fixup is needed for `work_tile_info`. + CUTLASS_HOST_DEVICE + static bool + requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) { + // Fixup is not needed for invalid or data-parallel tiles + return work_tile_info.is_valid() && work_tile_info.k_tile_count != params.divmod_tiles_per_output_tile_.divisor; + } + + // Performs the reduction across splits for a given output tile. +template + CUTLASS_DEVICE + static void + fixup( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers = 1, + uint32_t barrier_idx = 0) { + static constexpr uint32_t Offset = static_cast(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); + static constexpr uint32_t MaxNumNamedBarriers = 1; + using BarrierManager = NamedBarrierManager; + return fixup_helper( + params, work_tile_info, accumulators, num_barriers, barrier_idx); + } + + // Helper for performing the reduction across splits for a given output tile. + template + CUTLASS_DEVICE + static void + fixup_helper( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx, + uint32_t num_accumulator_mtxs = 1) { + + using ElementAccumulator = typename FrgTensorC::value_type; + + if (!requires_fixup(params, work_tile_info)) { + return; + } + auto tile_idx = output_tile_index(params, work_tile_info); + + // Index of the lock on which to wait + auto lock_idx = (tile_idx * num_barriers) + barrier_idx; + + auto reduction_tile_idx = tile_idx; + auto reduction_peer_offset = 0; + int barrier_group_thread_idx = ThreadIdxX(); + + // Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood. + // Thus, the start of the reduction space is the same across all threads in a work group. + int reduction_offset = + (cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * reduction_tile_idx * num_accumulator_mtxs) + + reduction_peer_offset; + + ElementAccumulator* group_reduction_workspace = reinterpret_cast(params.reduction_workspace_) + reduction_offset; + + using AccumulatorArrayT = Array; + using BlockStripedReduceT = BlockStripedReduce; + + AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); + AccumulatorArrayT* accumulator_array = reinterpret_cast(accumulators.data()); + + // The number of tiles for which reduction is required is either: + // (a) the total number of output tiles (in the case of split-K) + // (b) the number of stream-K tiles + // To calculate the total number of output tiles in the split-K case, we + // note that, in the split-K case, the units_per_problem_ member of Params will be + // the total number of output tiles. + uint32_t reduction_tiles = 0; + if (params.divmod_splits_.divisor > 1) { + reduction_tiles = params.units_per_problem_; + } + else { + reduction_tiles = params.sk_tiles_; + } + + auto reduction_workspace_size = Params::get_reduction_workspace_size( + reduction_tiles, to_gemm_coord(TileShape{}), sizeof_bits::value, num_accumulator_mtxs); + BarrierType* lock_workspace = reinterpret_cast( + reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); + + if (!compute_epilogue(work_tile_info, params)) { + if (work_tile_info.K_idx == 0) { + // The first peer initializes the workspace partials + BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); + } + else { + // Wait until the preceding split added its accumulators + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + + // Perform reduction in workspace + BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); + } + + // Each participating stream-K unit increments the barrier by the K tile count that this unit has + // processed. + int32_t increment = work_tile_info.k_tile_count; + + // Signal our arrival + BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment); + } + else { + if (params.reduction_mode_ == ReductionMode::Deterministic) { + // Wait until the preceding split added its accumulators + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + } + else { + // Wait until the first split has stored its accumulators + BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); + } + + // The block computing the final split for the tile adds previously-reduced partials + // to its accumulators and computes the epilogue. + BlockStripedReduceT::load_add(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx); + } + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const& work_tile_info, Params const& params) { + // `is_final_split` will be set to `true` for the following scenarios, all of which must compute the epilogue: + // 1. The tile is computed in data-parallel mode + // 2. The tile is computed in split-/stream-K mode and this work unit represents the final split of the tile + return work_tile_info.is_valid() && + work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor); + } + + // Returns the linearized index of the output tile corresponding to the tile with offset [L, M, K] + CUTLASS_DEVICE + static int + output_tile_index(Params const& params, WorkTileInfo const& work_tile_info) { + uint64_t linear_idx_in_batch = Params::get_linear_idx_from_m_and_n( + work_tile_info.M_idx, work_tile_info.N_idx, + params.divmod_blk_major_ + ); + + uint64_t tiles_mn = params.divmod_batch_.divisor; + return tiles_mn * work_tile_info.L_idx + linear_idx_in_batch; + } + + template + static size_t + get_workspace_size( + Arguments const& args, + ProblemShape problem_shape, + KernelHardwareInfo const& hw_info) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + TileShape tile_shape; + + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape_mnkl, tile_shape); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + return Params::get_workspace_size( + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(tile_shape), + hw_info, + args.splits, + args.decomposition_mode, + sizeof_bits::value, + sizeof_bits::value + ); + } + + template + static cutlass::Status + initialize_workspace( + Arguments const& args, + void* workspace, + ProblemShape const& problem_shape, + KernelHardwareInfo const& hw_info) { + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + TileShape tile_shape; + + dim3 problem_blocks = get_tiled_wg_shape_mnl(problem_shape_mnkl, tile_shape); + uint32_t k_tile_per_output_tile = cute::size(cute::ceil_div(cute::shape<2>(problem_shape_mnkl), cute::shape<2>(TileShape{}))); + + return Params::initialize_workspace( + workspace, + problem_blocks, + k_tile_per_output_tile, + to_gemm_coord(tile_shape), + hw_info, + args.splits, + args.decomposition_mode, + sizeof_bits::value, + sizeof_bits::value + ); + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape, TileShape) { + return work_tile_info.k_tile_count; + } + + CUTLASS_HOST_DEVICE + static uint32_t + get_work_k_tile_start(WorkTileInfo const& work_tile_info) { + return work_tile_info.K_idx; + } + + // Kernel helper function to get next work tile + CUTLASS_DEVICE + auto + fetch_next_work(WorkTileInfo work_tile_info) { + if (continue_current_work(work_tile_info)) { + return work_tile_info; + } + + advance_to_next_work(); + return get_current_work(); + } + + // Returns the initial work tile info that will be computed over + CUTLASS_DEVICE + WorkTileInfo + initial_work_tile_info() { + return get_current_work(); + } + +private: + // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info + // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining + // iterations) is used to find the next tile in the current work unit. +CUTLASS_DEVICE + static void + assign_work( + Params const& params, + uint64_t linear_idx, + WorkTileInfo& work_tile_info) { + + uint64_t output_tile_id = linear_idx; + if (linear_idx >= params.sk_units_ && params.divmod_splits_.divisor == 1) { + // Data-parallel work + output_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; + work_tile_info.K_idx = 0; + work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; + work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; + } + else { + + // Determine whether we are in a "big unit" within the group, that will process + // an additional K chunk in the group. + auto sk_tiles_in_group = params.sk_tiles_; + auto k_tiles_in_group = sk_tiles_in_group * params.divmod_tiles_per_output_tile_.divisor; + auto k_tiles_per_unit_in_group = params.divmod_sk_units_per_group_.divide(k_tiles_in_group); + auto big_units_in_group = k_tiles_in_group - (k_tiles_per_unit_in_group * params.divmod_sk_units_per_group_.divisor); + + uint64_t split; + params.divmod_sk_units_per_group_(split, output_tile_id, output_tile_id); + + bool is_split_k = params.divmod_splits_.divisor > 1; + auto big_unit_cmp_lhs = is_split_k ? split : output_tile_id; + auto big_unit_cmp_rhs = is_split_k ? params.big_units_ : big_units_in_group; + auto linear_idx_mult = is_split_k ? params.divmod_tiles_per_output_tile_.divisor : k_tiles_per_unit_in_group; + auto k_tiles_per_split = is_split_k ? params.divmod_k_tiles_per_sk_unit_.divisor : k_tiles_per_unit_in_group; + + // Determine the starting k iteration computed by this stream-K work unit + uint32_t unit_iter_start = (linear_idx_mult * linear_idx) + (k_tiles_per_split * split); + + // Adjust the starting position and number of k iterations for "big units," which + // compute one extra iteration. If there are any big units, they will be the first + // in the linearized ID space. + auto k_tiles_in_my_split = k_tiles_per_split; + if (big_unit_cmp_lhs < big_unit_cmp_rhs) { + // Since the "big units" are the first units in the linearized ID space, each + // of the units preceding this big unit computed one extra iteration. Thus, + // we must offset our start iteration by the number of units that precede + // the current unit in the linearized ID space. + unit_iter_start += big_unit_cmp_lhs; + ++k_tiles_in_my_split; + } + else { + // Increment by one for each of the big clusters (since all big units precede this unit) + unit_iter_start += big_unit_cmp_rhs; + } + + if (!is_split_k) { + // Adjust the unit starting position and number of tiles to avoid + // computing splits of size less than min_iters_per_sk_unit_ + int unused, start_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, start_tile_k_tile, unit_iter_start); + if (start_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Starting K tile is in range [0, Params::min_iters_per_sk_unit_), which means that another + // stream-K unit will be computing a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take over these K tiles. + unit_iter_start -= start_tile_k_tile; + k_tiles_in_my_split += start_tile_k_tile; + } + else if (start_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Starting K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. + auto adjustment_tiles = (params.divmod_tiles_per_output_tile_.divisor - start_tile_k_tile); + unit_iter_start += adjustment_tiles; + k_tiles_in_my_split -= adjustment_tiles; + } + else if (params.ktile_start_alignment_count == 2 && start_tile_k_tile % 2 != 0) { + // ktile for each SM start from even number + // If start from odd number ktile within the output tile + // now start at the ktile one before my initial ktile start (take one ktile from prev sm) + // if end on odd number ktile within the output tile + // now end at ktile that one before my ktile end (give one ktile to next sm) + unit_iter_start -= 1; + k_tiles_in_my_split += 1; + } + } + + if (work_tile_info.k_tile_count == 0) { + // This is a new unit + + if (!is_split_k) { + // + // Adjust the unit ending position and number of tiles to avoid + // computing splits of size less than min_iters_per_sk_unit_ + // + + // Begin by assuming that no adjustment is needed + auto initial_unit_iter_end = unit_iter_start + k_tiles_in_my_split; + + int unused, end_tile_k_tile; + params.divmod_tiles_per_output_tile_(unused, end_tile_k_tile, initial_unit_iter_end); + + if (end_tile_k_tile < Params::min_iters_per_sk_unit_) { + // Ending K tile is within the first Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that this unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to shed these K tiles to a neighboring stream-K unit that will compute more consecutive K tiles. + k_tiles_in_my_split -= end_tile_k_tile; + } + else if (end_tile_k_tile > (params.divmod_tiles_per_output_tile_.divisor - Params::min_iters_per_sk_unit_)) { + // Ending K tile is within the final Params::min_iters_per_sk_unit_ K tiles of some output tile, + // which means that some other unit will compute a split with fewer than Params::min_iters_per_sk_unit_ K tiles. + // Adjust our work to take on these K tiles. + k_tiles_in_my_split += (params.divmod_tiles_per_output_tile_.divisor - end_tile_k_tile); + } + else if (params.ktile_start_alignment_count == 2 && end_tile_k_tile % 2 != 0) { + // ktile for each SM start from even number + // If start from odd number ktile within the output tile + // now start at the ktile one before my initial ktile start (take one ktile from prev sm) + // If end on odd number ktile within the output tile, + // now end at ktile that one before my ktile end (give one ktile to next sm) + k_tiles_in_my_split -= 1; + } + } + + work_tile_info.k_tile_remaining = k_tiles_in_my_split; + } + + uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; + + // Find the output tile corresponding to the final k tile covered by this + // work unit. Stream-K work units will work backwards in terms of the tiles they + // are responsible computing. This is beneficial because the final (partial) + // tile computed by a stream-K block is typically the beginning of the output + // tile, while the beginning (partial) tile is typically the ending of another + // output tile. Since ending portions of an output tile must reduce across + // other work units computing portions of that output tile, it is preferable + // for them to be computed later, so as to reduce the likelihood of blocking + // on other work. + + auto output_tile_id_in_group = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); + uint32_t output_tile_iter_start = output_tile_id_in_group * params.divmod_tiles_per_output_tile_.divisor; + uint32_t output_tile_iter_end = output_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; + + // Convert the output tile from the linearized space within each group to the + // overall linearized space. + output_tile_id = output_tile_id_in_group * params.divmod_sk_groups_.divisor; + + // The unit's starting k iteration in the current tile is either the starting + // iteration for the tile as a whole, or the starting k iteration for the unit + // as a whole (if the latter is greater than the former). + uint32_t tile_iter_start = max(output_tile_iter_start, unit_iter_start); + + // Similarly, the unit's ending k iteration (exclusive) is either the end of + // the current tile it is assigned, or the ending iteration of the unit as a whole + // (if the latter is less than the former). + uint32_t tile_iter_end = min(output_tile_iter_end, unit_iter_end + 1); + + // Set the k offset to be the starting k tile for this output tile + work_tile_info.K_idx = static_cast(tile_iter_start - output_tile_iter_start); + work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; + } + + uint64_t work_idx_l, remainder; + + if(params.divmod_splits_.divisor > 1) { + output_tile_id %= params.units_per_problem_; + } + + params.divmod_batch_(work_idx_l, remainder, output_tile_id); + + uint64_t cta_per_grid_dim = remainder; + + auto [work_idx_m, work_idx_n] = Params::get_work_idx_m_and_n( + cta_per_grid_dim, + params.divmod_blk_major_ + ); + + // Set the M, N, and L block offsets + work_tile_info.M_idx = work_idx_m; + work_tile_info.N_idx = work_idx_n; + work_tile_info.L_idx = work_idx_l; + } + +}; + +} // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gpu_generics.h b/include/cutlass/gpu_generics.h index 22d82e9d5d..44b5a92acb 100644 --- a/include/cutlass/gpu_generics.h +++ b/include/cutlass/gpu_generics.h @@ -314,19 +314,20 @@ using cudaStream_t = void *; using dim3 = syclcompat::dim3; // Atomic - -CUTLASS_DEVICE int atomicAdd(int *address, int val) { +template +CUTLASS_DEVICE T atomicAdd(T *address, T val) { #if defined(__SYCL_DEVICE_ONLY__) - return syclcompat::atomic_fetch_add(address, val); + return syclcompat::atomic_fetch_add(address, val); #endif return 0; } CUTLASS_DEVICE int atomicCAS(int *address, int compare, int val) { + int result = 0; #if defined(__SYCL_DEVICE_ONLY__) - syclcompat::atomic_compare_exchange_strong(address, compare, val); + result = syclcompat::atomic_compare_exchange_strong(address, compare, val); #endif - return 0; + return result; } // Error diff --git a/include/cutlass/workspace.h b/include/cutlass/workspace.h index 31c48435b1..79f3aa3d0c 100644 --- a/include/cutlass/workspace.h +++ b/include/cutlass/workspace.h @@ -61,7 +61,9 @@ zero_workspace(void* workspace, size_t workspace_size, cudaStream_t stream = nul CUTLASS_TRACE_HOST(" clearing workspace"); -#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER +#if defined (CUTLASS_ENABLE_SYCL) + syclcompat::memset_async(workspace, 0, workspace_size); +#elif defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER // // Use the cuda host adapter //