diff --git a/CMakeLists.txt b/CMakeLists.txt index d442254..8faee83 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,12 +171,13 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_HEADERS_ONLY "ON" CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "main" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "9baca2cff3a28590fcd03e55515e2d91ff2cbc8b" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided FetchContent_Declare( cutlass-sycl GIT_REPOSITORY https://github.com/intel/cutlass-sycl + # Please keep this in sync with CUTLASS_REVISION line above. GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE @@ -184,7 +185,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) # cutlass compilation flags @@ -196,7 +197,6 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA") # list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " ) # list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " ) - FetchContent_MakeAvailable(cutlass-sycl) set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") @@ -269,11 +269,15 @@ endif () # # xpu only ops/kernels, implemented with cutlass/onednn/sycl. # +file(GLOB CUTLASS_BACKEND_SRCS + csrc/xpu/cutlass_kernels/*.cpp +) if(VLLM_GPU_LANG STREQUAL "SYCL") set(VLLM_EXT_XPU_SRC "csrc/xpu/torch_bindings.cpp" "csrc/xpu/lora/lora_shrink.cpp" "csrc/xpu/lora/lora_expand.cpp" + ${CUTLASS_BACKEND_SRCS} ) include_directories("/usr/include") set(CMPLR_ROOT $ENV{CMPLR_ROOT}) @@ -282,6 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "SYCL") list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" ) list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64") list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" ) + # CUTLASS FLAGS + list(APPEND VLLM_GPU_FLAGS "-O3" "-DNDEBUG") + list(APPEND VLLM_GPU_FLAGS "-gline-tables-only") + list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10") + list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen") + list(APPEND VLLM_GPU_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread") endif() if(ONEDNN_FOUND) @@ -305,6 +315,8 @@ define_gpu_extension_target( ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/csrc/core/registration.h b/csrc/core/registration.h index 9dbf34b..576b5e1 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -1,5 +1,4 @@ #pragma once - #include #define _CONCAT(A, B) A##B diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h new file mode 100644 index 0000000..f2743bf --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/default_gemm_universal.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix + multiply-add with the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major + outputs are accommodated by exchanging A and B operands and assuming + transposed layouts. Partial specializations here choose + 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "gemm_universal_k.h" +#include "cutlass/gemm/kernel/gemm_universal_streamk.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" + +#include "cutlass/layout/permute.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout_ = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout_ = layout::NoPermute, + /// + typename Enable = void> +struct DefaultGemmUniversal; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout> +struct DefaultGemmUniversal< + ElementA, LayoutA, + ComplexTransform::kNone, // transform A + kAlignmentA, ElementB, LayoutB, + ComplexTransform::kNone, // transform B + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, GatherA, GatherB, + ScatterD, PermuteDLayout, PermuteALayout, PermuteBLayout, + typename platform::enable_if< + !cutlass::is_complex::value>::type> { + using DefaultGemmKernel = typename kernel::DefaultGemm< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, + LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, + WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, + true, Operator, SharedMemoryClear, GatherA, GatherB, ScatterD, + PermuteDLayout, PermuteALayout, PermuteBLayout>::GemmKernel; + + /// Universal kernel without StreamkFeature member type + template + class SelectBase + : public kernel::GemmUniversal {}; + + /// Universal kernel with StreamkFeature member type + template + class SelectBase + : public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, SwizzleT> {}; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = SelectBase; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Complex-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DefaultGemmUniversal< + ElementA, LayoutA, TransformA, kAlignmentA, ElementB, LayoutB, TransformB, + kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, false, false, + false, layout::NoPermute, layout::NoPermute, layout::NoPermute, + typename platform::enable_if< + cutlass::is_complex::value>::type> { + using DefaultGemmKernel = typename kernel::DefaultGemmComplex< + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, + InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages, + TransformA, TransformB, Operator, false>::GemmKernel; + + /// Universal kernel without StreamkFeature member type + template + class SelectBase + : public kernel::GemmUniversal {}; + + /// Universal kernel with StreamkFeature member type + template + class SelectBase + : public kernel::GemmUniversalStreamk< + typename DefaultGemmKernel::Mma, + typename DefaultGemmKernel::Epilogue, SwizzleT> {}; + + /// Select kernel by ThreadblockSwizzle's support for StreamkFeature + using GemmKernel = SelectBase; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h new file mode 100644 index 0000000..411f673 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.h @@ -0,0 +1,366 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "gemm_universal_k.h" + +#include "default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "gemm_universal_base.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversal is a stateful, reusable GEMM handle. Once initialized for a + given GEMM computation (problem geometry and data references), it can be + reused across different GEMM problems having the geometry. (Once initialized, + details regarding problem geometry and references to workspace memory cannot + be updated.) + + The universal GEMM accommodates serial reductions, parallel reductions, + batched strided, and batched array variants. +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB = ComplexTransform::kNone, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout_ = layout::NoPermute, + /// Permute operand A + typename PermuteALayout_ = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout_ = layout::NoPermute> +class GemmUniversal + : public GemmUniversalBase::GemmKernel> { + public: + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using Base = GemmUniversalBase::GemmKernel>; + + using Arguments = typename Base::Arguments; + using GemmKernel = typename Base::GemmKernel; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for column-major output exchanges problem size and +/// operand. +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for. This is the minimum SM that + /// supports the intended feature. The device kernel can be built + /// targeting any SM larger than this number. + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Access granularity of A matrix in units of elements + int AlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB, + /// Operation performed by GEMM + typename Operator_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout_, + /// Permute operand A + typename PermuteALayout_, + /// Permute operand B + typename PermuteBLayout_> +class GemmUniversal< + ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, + layout::ColumnMajor, // partially specialized on LayoutC + ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, + WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, + Stages, AlignmentA, AlignmentB, Operator_, TransformA, TransformB, GatherA, + GatherB, ScatterD, PermuteDLayout_, PermuteALayout_, PermuteBLayout_> { + public: + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = layout::ColumnMajor; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + using PermuteDLayout = PermuteDLayout_; + using PermuteALayout = PermuteALayout_; + using PermuteBLayout = PermuteBLayout_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static ComplexTransform const kTransformA = TransformA; + static ComplexTransform const kTransformB = TransformB; + + using UnderlyingOperator = typename GemmUniversal< + ElementB, typename layout::LayoutTranspose::type, ElementA, + typename layout::LayoutTranspose::type, ElementC, + layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, Stages, kAlignmentB, kAlignmentA, Operator, + kTransformB, kTransformA, GatherB, GatherA, ScatterD, PermuteDLayout, + PermuteBLayout, PermuteALayout>::Base; + + using GemmKernel = typename UnderlyingOperator::GemmKernel; + static int const kAlignmentC = EpilogueOutputOp::kCount; + + /// Argument structure + using Arguments = typename UnderlyingOperator::Arguments; + + private: + UnderlyingOperator underlying_operator_; + + public: + /// Constructs the GEMM. + GemmUniversal() {} + + /// Helper to construct a transposed equivalent for the underlying GEMM + /// operator + static Arguments to_underlying_arguments(Arguments const& args) { + return args.transposed_problem(); + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + return UnderlyingOperator::can_implement(to_underlying_arguments(args)); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + return UnderlyingOperator::get_workspace_size( + to_underlying_arguments(args)); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + return underlying_operator_.initialize(to_underlying_arguments(args), + workspace, stream); + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + return underlying_operator_.update(to_underlying_arguments(args), + workspace); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + return underlying_operator_.run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { return run(stream); } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp new file mode 100644 index 0000000..3b59cc8 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal.hpp @@ -0,0 +1,57 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal_decl.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +// In cases where ProblemShape is not a tuple, this is used to check if the +// underlying problem shape type is aliased within or not. +// Used for dispatching GemmUniversal to 2.x API or 3.x API +template +struct IsCutlass3ArrayKernel : cute::false_type {}; + +template +struct IsCutlass3ArrayKernel< + ProblemShape, cute::void_t> + : cute::true_type {}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// +#include "xe_gemm_array_cooperative.hpp" diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h new file mode 100644 index 0000000..0c923e8 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_adapter.h @@ -0,0 +1,844 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, + batched strided, and batched array variants. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) + #include "cutlass/cluster_launch.hpp" + #include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +// 2.x +#include "gemm_universal_base.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" + +// 3.x +#include "gemm_universal.hpp" + +#if defined(CUTLASS_ENABLE_SYCL) + #include "cutlass/util/sycl_event_manager.hpp" +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + GemmUniversalAdapter is a stateful, reusable GEMM handle built around a kernel + of type cutlass::gemm::kernel::Gemm or cutlass::gemm::kernel::GemmUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes + APIs to create it from the host facing arguments. For power users, new static + methods are exposed in 3.x APIs that bypass the stateful methods or + args->params lowering. + + It supports kernel types that implement both the 2.x and 3.0 APIs, + however, this is done by specializing the implementation of + GemmUniversalAdapter on the two kernel API types, and thus, + GemmUniversalAdapter's behaviour might differ between the two specializations. +*/ +template +class GemmUniversalAdapter; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Work-around for some DispatchPolicy types not having a Stages member. +// In that case, the Stages value is 0. Most code should static_assert +// that the number of stages is valid. + +// Whether DispatchPolicy::Stages is valid. +// It should also be convertible to int, but if not, that will show up +// as a build error when GemmUniversalAdapter attempts to assign it to kStages. +template +struct has_Stages : cute::false_type {}; + +template +struct has_Stages> + : cute::true_type {}; + +template +constexpr int stages_member(DispatchPolicy) { + if constexpr (has_Stages::value) { + return DispatchPolicy::Stages; + } else { + return 0; + } +} + +} // namespace detail + +template +class GemmUniversalAdapter>::value>> { + public: + using GemmKernel = GetUnderlyingKernel_t; + using TileShape = typename GemmKernel::TileShape; + using ElementA = typename GemmKernel::ElementA; + using ElementB = typename GemmKernel::ElementB; + using ElementC = typename GemmKernel::ElementC; + using ElementD = typename GemmKernel::ElementD; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; + using DispatchPolicy = typename GemmKernel::DispatchPolicy; + using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; + using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; + + // Map back to 2.x type as best as possible + using LayoutA = + gemm::detail::StrideToLayoutTagA_t; + using LayoutB = + gemm::detail::StrideToLayoutTagB_t; + using LayoutC = + gemm::detail::StrideToLayoutTagC_t; + using LayoutD = + gemm::detail::StrideToLayoutTagC_t; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + static ComplexTransform const kTransformA = + cute::is_same_v + ? ComplexTransform::kConjugate + : ComplexTransform::kNone; + static ComplexTransform const kTransformB = + cute::is_same_v + ? ComplexTransform::kConjugate + : ComplexTransform::kNone; + + // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 + using MathOperator = cutlass::arch::OpMultiplyAdd; + + using OperatorClass = cutlass::detail::get_operator_class_t< + typename CollectiveMainloop::TiledMma>; + + using ArchTag = typename GemmKernel::ArchTag; + + // NOTE: Assume identity swizzle for now + using ThreadblockSwizzle = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our + // TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = GemmKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and + // rest along N, none along K We also always round up the warp count to 4 if + // the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max( + 4, CUTE_STATIC_V(cute::size(typename GemmKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = + cutlass::gemm::GemmShape( + typename CollectiveMainloop::TiledMma{})) / + WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>( + typename CollectiveMainloop::TiledMma{})) / + WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>( + typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = + detail::stages_member(typename CollectiveMainloop::DispatchPolicy{}); + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA, + typename CollectiveMainloop::TiledMma::ValTypeA>(); + static int constexpr kAlignmentB = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB, + typename CollectiveMainloop::TiledMma::ValTypeB>(); + static int constexpr kAlignmentC = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = + cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + // Split-K preserves splits that are 128b aligned + static int constexpr kSplitKAlignment = cute::max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + /// Argument structure: User API + using Arguments = typename GemmKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename GemmKernel::Params; + + private: + /// Kernel API parameters object + Params params_; + + public: + /// Access the Params structure + Params const& params() const { return params_; } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + if (GemmKernel::can_implement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + workspace_bytes += sizeof(int) * size_t(cute::size<0>(TileShape{})) * + size_t(cute::size<1>(TileShape{})); + } + + workspace_bytes += GemmKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = GemmKernel::to_underlying_arguments(args, workspace); + return GemmKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Params const& params) { + return GemmKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("GemmUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = GemmKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute(device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, device_kernel, + GemmKernel::MaxThreadsPerBlock, smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = + GemmKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = GemmKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set + // it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = GemmKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if !defined(CUTLASS_ENABLE_SYCL) + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } +#endif + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight + /// update of params. + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = GemmKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and + /// manage their own params. Supplied params struct must be construct by + /// calling GemmKernel::to_underlying_arguments() + static Status run(Params& params, sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("GemmUniversal::run()"); + dim3 const block = GemmKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + +#if defined(CUTLASS_ENABLE_SYCL) + const syclcompat::dim3 sycl_block(block.x, block.y, block.z); + const syclcompat::dim3 sycl_grid(grid.x, grid.y, grid.z); +#endif + + // configure smem size and carveout + int smem_size = GemmKernel::SharedStorageSize; + + Status launch_result{Status::kSuccess}; + // Use extended launch API only for mainloops that use it + if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("GemmUniversal::run: Use extended launch API"); +#endif +#if !defined(CUTLASS_ENABLE_SYCL) + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v< + typename GemmKernel::DispatchPolicy::ClusterShape> and + cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1; + [[maybe_unused]] dim3 cluster( + cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename GemmKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename GemmKernel::DispatchPolicy::ClusterShape{})); + + // Dynamic cluster support + [[maybe_unused]] dim3 fallback_cluster = dim3{0, 0, 0}; + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || + GemmKernel::ArchTag::kMinComputeCapability == 101) { + if constexpr (!cute::is_static_v< + typename GemmKernel::DispatchPolicy::ClusterShape>) { + fallback_cluster = params.hw_info.cluster_shape_fallback; + cluster = params.hw_info.cluster_shape; + } + } + + [[maybe_unused]] void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run() does not support launching with PDL and " + "a custom cuda adapter."); + return Status::kErrorInternal; + } + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with CUDA host adapter"); + #endif + if constexpr (is_static_1x1x1) { + launch_result = cuda_adapter->launch(grid, block, smem_size, stream, + kernel_params, 0); + } else { + launch_result = + cuda_adapter->launch(grid, cluster, fallback_cluster, block, + smem_size, stream, kernel_params, 0); + } + } else { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: kEnableCudaHostAdapter is true, but CUDA " + "host adapter is null"); + return Status::kErrorInternal; + } + } else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + [[maybe_unused]] void const* kernel = + (void const*)device_kernel; + static constexpr bool kClusterLaunch = + GemmKernel::ArchTag::kMinComputeCapability == 90; + if constexpr (kClusterLaunch) { + if constexpr (is_static_1x1x1) { + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching static 1x1x1 kernel"); + #endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports failure"); + } + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports success"); + } + #endif + } else { + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching dynamic cluster kernel"); + #endif + launch_result = + ClusterLauncher::launch(grid, cluster, block, smem_size, stream, + kernel, kernel_params, launch_with_pdl); + } + } + + else { + if constexpr (GemmKernel::ArchTag::kMinComputeCapability == 100 || + GemmKernel::ArchTag::kMinComputeCapability == 101 || + GemmKernel::ArchTag::kMinComputeCapability == 120) { + if constexpr (is_static_1x1x1) { + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching static 1x1x1 kernel"); + #endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports " + "failure"); + } + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports " + "success"); + } + #endif + } else { + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with fall-back " + "cluster"); + #endif + launch_result = ClusterLauncher::launch_with_fallback_cluster( + grid, cluster, fallback_cluster, block, smem_size, stream, + kernel, kernel_params, launch_with_pdl); + } + } + } + } +#endif + } else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with CUDA host adapter"); +#endif + launch_result = cuda_adapter->launch(grid, block, smem_size, stream, + kernel_params, 0); + + } else { + CUTLASS_TRACE_HOST("GemmUniversal::run: CUDA host adapter is null"); + return Status::kErrorInternal; + } + } else { + CUTLASS_ASSERT(cuda_adapter == nullptr); +#if defined(CUTLASS_ENABLE_SYCL) + // sycl::queue q = stream; // ? *stream : + // syclcompat::get_default_queue(); + #if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace syclcompat::experimental; + if constexpr (cute::is_same_v) { + auto event = launch>( + launch_policy{sycl_grid, sycl_block, + local_mem_size { + static_cast(smem_size) + }}, + q, params); + EventManager::getInstance().addEvent(event); + } else { + auto event = launch>( + launch_policy{ + sycl_grid, sycl_block, + local_mem_size{static_cast(smem_size)} + #if defined(SYCL_INTEL_TARGET) + , + kernel_properties { + sycl_exp::sub_group_size + } + #endif + }, + stream, params); + EventManager::getInstance().addEvent(event); + } + #else + #if defined(SYCL_INTEL_TARGET) + constexpr bool allow_subgroup_size_prop = true; + #else + constexpr bool allow_subgroup_size_prop = false; + #endif + auto kernel_props = [] { + constexpr bool is_device_agnostic = + cute::is_same_v; + if constexpr (!allow_subgroup_size_prop or is_device_agnostic) { + using EmptyProperties = + decltype(sycl::ext::oneapi::experimental::properties()); + return syclcompat::experimental::kernel_properties< + EmptyProperties>{}; + } else { + return syclcompat::experimental::kernel_properties{ + sycl::ext::oneapi::experimental::sub_group_size< + DispatchPolicy::SubgroupSize>}; + } + }(); + syclcompat::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + syclcompat::experimental::launch_policy policy{ + sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = + syclcompat::experimental::launch>( + policy, stream, params); + EventManager::getInstance().addEvent(event); + #endif // !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) +#else + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: Launching kernel with cutlass::kernel_launch"); + #endif + launch_result = cutlass::kernel_launch( + grid, block, smem_size, stream, params, launch_with_pdl); + if (launch_result != Status::kSuccess) { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports failure"); + } + #if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cutlass::kernel_launch reports success"); + } + #endif +#endif + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST( + "GemmUniversal::run: cudaGetLastError reports success"); +#endif + return Status::kSuccess; + } else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params + // struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from + /// supplied arguments. + Status run(Arguments const& args, void* workspace, sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from + /// supplied arguments. + Status operator()(Arguments const& args, void* workspace, sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating + /// internal params struct. + Status run(sycl::queue& stream, CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating + /// internal params struct. + Status operator()(sycl::queue& stream, + CudaHostAdapter* cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 2.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalAdapter< + GemmKernel_, cute::enable_if_t>::value>> { + public: + using GemmKernel = GetUnderlyingKernel_t; + + static bool const kInternalTranspose = + !cutlass::epilogue::threadblock::detail::is_2x_evt_v< + typename GemmKernel::Epilogue> && // 2.x EVT does not require + // internal transpose + cute::is_same::value; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + using WarpShape = typename GemmKernel::WarpShape; + using InstructionShape = typename GemmKernel::InstructionShape; + + // warp-level, arch-level (instruction), math operator + using WarpMmaOperator = typename GemmKernel::Mma::Policy::Operator; + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + + // Operator class and arch tag extract bottom-up + // set it for top-level gemm device-level template + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + // Type, layout, and complex transform deliberately exchanged with B + using MapArguments = kernel::detail::MapArguments< + typename GemmKernel::ElementA, typename GemmKernel::LayoutA, + GemmKernel::kTransformA, GemmKernel::kAlignmentA, + typename GemmKernel::ElementB, typename GemmKernel::LayoutB, + GemmKernel::kTransformB, GemmKernel::kAlignmentB, + typename GemmKernel::LayoutC, kInternalTranspose>; + + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static int const kAlignmentA = MapArguments::kAlignmentA; + + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + static int const kAlignmentB = MapArguments::kAlignmentB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename MapArguments::LayoutC; + static int const kAlignmentC = GemmKernel::kAlignmentC; + + // C and D same type for 2.x kernel + using ElementD = ElementC; + using LayoutD = LayoutC; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + static int const kStages = GemmKernel::Mma::kStages; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using UnderlyingOperator = GemmUniversalBase; + using Arguments = typename UnderlyingOperator::Arguments; + + private: + UnderlyingOperator underlying_operator_; + + public: + /// Constructs the GEMM. + GemmUniversalAdapter() {} + + /// Helper to construct a transposed equivalent for the underlying GEMM + /// operator + static Arguments to_underlying_arguments(Arguments const& args) { + if (kInternalTranspose) { + return args.transposed_problem(); + } else { + return args; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + return UnderlyingOperator::can_implement(to_underlying_arguments(args), + cuda_adapter); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args), + cuda_adapter); + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + return UnderlyingOperator::get_grid_shape(to_underlying_arguments(args)); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + return UnderlyingOperator::maximum_active_blocks(smem_capacity); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return underlying_operator_.initialize(to_underlying_arguments(args), + workspace, stream, cuda_adapter); + } + + /// Lightweight update given a subset of arguments. + Status update(Arguments const& args) { + return underlying_operator_.update(to_underlying_arguments(args)); + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return underlying_operator_.run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h new file mode 100644 index 0000000..b909318 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_base.h @@ -0,0 +1,524 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates streamk, batched strided, and batched + array variants. +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) + #include +#else + #include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/gemm/gemm.h" +#include "gemm_universal_k.h" + +#include "default_gemm_universal.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalBase { + public: + using GemmKernel = GemmKernel_; + + /// Boolean indicating whether the CudaHostAdapter is enabled + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + /// Numerical accumulation element type + using ElementAccumulator = typename GemmKernel::Mma::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + /// Index of the GEMM Kernel within the CudaHostAdapter + static int32_t const kGemmKernelIndex = 0; + + /// Kernel dynamic shared memory allocation requirement + /// Update the kernel function's shared memory configuration for the current + /// device + static constexpr size_t kSharedStorageSize = + sizeof(typename GemmKernel::SharedStorage); + + protected: + // + // Device properties (uniform across all instances of the current thread) + // + + // Device ordinal + CUTLASS_THREAD_LOCAL static int device_ordinal_; + + /// Device SM count + CUTLASS_THREAD_LOCAL static int device_sms_; + + /// Kernel SM occupancy (in thread blocks) + CUTLASS_THREAD_LOCAL static int sm_occupancy_; + + protected: + /// Initialize static thread-local members for the thread's current device, + /// if necessary. + static Status init_device_props() { + CUTLASS_TRACE_HOST("GemmUniversalBase::init_device_props()"); + + cudaError_t cudart_result; + + // Get current device ordinal + int current_ordinal; + cudart_result = cudaGetDevice(¤t_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " + << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // Done if matches the current static member + if (current_ordinal == device_ordinal_) { + // Already initialized + return Status::kSuccess; + } + + // Update SM count member + cudart_result = cudaDeviceGetAttribute( + &device_sms_, cudaDevAttrMultiProcessorCount, current_ordinal); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " + << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // If requires more than 48KB: configure for extended, dynamic shared memory + if constexpr (kSharedStorageSize >= (48 << 10)) { + cudart_result = cudaFuncSetAttribute( + Kernel2, cudaFuncAttributeMaxDynamicSharedMemorySize, + kSharedStorageSize); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " + << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + } + + // Update SM occupancy member + cudart_result = cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &sm_occupancy_, Kernel2, GemmKernel::kThreadCount, + kSharedStorageSize, cudaOccupancyDisableCachingOverride); + if (cudart_result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned " + "error " + << cudaGetErrorString(cudart_result)); + return Status::kErrorInternal; + } + + // Update device ordinal member on success + device_ordinal_ = current_ordinal; + + CUTLASS_TRACE_HOST( + " " + "device_ordinal: (" + << device_ordinal_ + << "), " + "device_sms: (" + << device_sms_ + << "), " + "sm_occupancy: (" + << sm_occupancy_ + << ") " + "smem_size: (" + << kSharedStorageSize + << ") " + "GemmKernel::kThreadCount: (" + << GemmKernel::kThreadCount << ")"); + + return Status::kSuccess; + } + + protected: + // + // Instance data members + // + + /// Kernel parameters + typename GemmKernel::Params params_; + + /// Initialize params member + Status init_params(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + // + // Occupancy query using CudaHostAdapter::query_occupancy(). + // + + if (cuda_adapter) { + Status status = cuda_adapter->query_occupancy( + &device_sms, &sm_occupancy, kGemmKernelIndex, + GemmKernel::kThreadCount, kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return status; + } + } else { + return Status::kErrorInternal; + } + } else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + + // Initialize static device properties, if necessary + Status result = init_device_props(); + + if (result != Status::kSuccess) { + return result; + } + + // + // Use thread-local static members for occupancy query initialized by call + // to `init_device_props()` + // + + device_sms = device_sms_; + sm_occupancy = sm_occupancy_; + } + + // Initialize params member + params_ = typename GemmKernel::Params(args, device_sms, sm_occupancy); + return Status::kSuccess; + } + + public: + //--------------------------------------------------------------------------------------------- + // Stateless API + //--------------------------------------------------------------------------------------------- + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::can_implement()"); + + if (!kEnableCudaHostAdapter || cuda_adapter) { + dim3 grid = get_grid_shape(args, cuda_adapter); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + return Status::kErrorInvalidProblem; + } + } else { + // + // With a null host adapter, a conservative grid shape is computed and + // required to conform to CUDA grid dimension limits. + // + + int64_t logicalGridM = + (int64_t(args.problem_size.m()) + ThreadblockShape::kM - 1) / + ThreadblockShape::kM; + int64_t logicalGridN = + (int64_t(args.problem_size.n()) + ThreadblockShape::kN - 1) / + ThreadblockShape::kN; + int32_t logicalGridL = args.batch_count; + + if ((int64_t(std::numeric_limits::max()) < logicalGridM) || + (int64_t(std::numeric_limits::max()) < logicalGridN) || + (int32_t(std::numeric_limits::max()) < logicalGridL)) { + return Status::kErrorInvalidProblem; + } + } + + return GemmKernel::can_implement(args); + } + + /// Returns the workspace size (in bytes) needed for the problem + /// geometry expressed by these arguments + static size_t get_workspace_size(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { + return 0; + } + + // Get size from parameters + size_t workspace_bytes = base.params_.get_workspace_size(); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + return workspace_bytes; + } + + /// Returns the grid extents in thread blocks to launch + static dim3 get_grid_shape(Arguments const& args, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); + + // Initialize parameters from args + GemmUniversalBase base; + if (base.init_params(args, cuda_adapter) != Status::kSuccess) { + return dim3(0, 0, 0); + } + + // Get dims from parameters + dim3 grid_dims = base.params_.get_grid_dims(); + + CUTLASS_TRACE_HOST(" tiled_shape: " + << base.params_.get_tiled_shape() << "\n" + << " grid_dims: {" << grid_dims << "}"); + + return grid_dims; + } + + /// Returns the maximum number of active thread blocks per multiprocessor + static int maximum_active_blocks(CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + + int32_t device_sms = 0; + int32_t sm_occupancy = 0; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + + if (cuda_adapter) { + Status status = cuda_adapter->query_occupancy( + &device_sms, &sm_occupancy, kGemmKernelIndex, + GemmKernel::kThreadCount, kSharedStorageSize); + + CUTLASS_ASSERT(status == Status::kSuccess); + + if (status != Status::kSuccess) { + return -1; + } + } else { + return -1; + } + } else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + // Initialize static device properties, if necessary + if (init_device_props() != Status::kSuccess) { + return -1; + } + + sm_occupancy = sm_occupancy_; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << sm_occupancy_); + return sm_occupancy; + } + + //--------------------------------------------------------------------------------------------- + // Stateful API + //--------------------------------------------------------------------------------------------- + + /// Initializes GEMM state from arguments and workspace memory + Status initialize(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + << workspace + << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize parameters from args + Status result = init_params(args, cuda_adapter); + if (result != Status::kSuccess) { + return result; + } + + // Assign and prepare workspace memory + if (args.mode == GemmUniversalMode::kGemm) { + return params_.init_workspace(workspace, stream); + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments. + Status update(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); + params_.update(args); + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); + + // Configure grid and block dimensions + dim3 block(GemmKernel::kThreadCount, 1, 1); + dim3 grid = params_.get_grid_dims(); + + // Launch kernel + CUTLASS_TRACE_HOST( + " " + "grid: (" + << grid + << "), " + "block: (" + << block + << "), " + "SMEM: (" + << kSharedStorageSize << ")"); + + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms_}; + return cuda_adapter->launch(grid, block, kSharedStorageSize, stream, + kernel_params, 0); + } else { + return Status::kErrorInternal; + } + } else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + +#if defined(CUTLASS_ENABLE_SYCL) + const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z); + const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z); + + sycl::queue q = stream ? *stream : syclcompat::get_default_queue(); + syclcompat::experimental::launch>( + syclcompat::experimental::launch_policy{ + sycl_grid, sycl_block, + #if defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + sycl::ext::oneapi::experimental::work_group_scratch_size( + kSharedStorageSize) + #else + syclcompat::experimental::local_mem_size{ + static_cast(kSharedStorageSize)} + #endif + }, + q, params_); +#else + Kernel2<<>>(params_); +#endif + + // Query for errors + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return run(stream, cuda_adapter); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (status == Status::kSuccess) { + status = run(stream, cuda_adapter); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Static initializers +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device ordinal +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_ordinal_ = -1; + +/// Device SM count +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_sms_ = -1; + +/// Kernel SM occupancy (in thread blocks) +template +CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h new file mode 100644 index 0000000..19871ee --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/gemm_universal_k.h @@ -0,0 +1,649 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" +#include "gemm_universal.hpp" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/params_universal_base.h" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversal< + Mma_, Epilogue_, ThreadblockSwizzle_, void, + // 3.x kernels use the first template argument to define the ProblemShape + // We use this invariant to SFINAE dispatch against either the 2.x API or + // the 3.x API + cute::enable_if_t::value || + IsCutlass3ArrayKernel::value)>> { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments : UniversalArgumentsBase { + // + // Data members + // + + typename EpilogueOutputOp::Params epilogue; + + void const* ptr_A; + void const* ptr_B; + void const* ptr_C; + void* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + + typename LayoutA::Stride stride_a; + typename LayoutB::Stride stride_b; + typename LayoutC::Stride stride_c; + typename LayoutC::Stride stride_d; + + typename LayoutA::Stride::LongIndex lda; + typename LayoutB::Stride::LongIndex ldb; + typename LayoutC::Stride::LongIndex ldc; + typename LayoutC::Stride::LongIndex ldd; + + int const* ptr_gather_A_indices; + int const* ptr_gather_B_indices; + int const* ptr_scatter_D_indices; + + // + // Methods + // + + Arguments() + : ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_gather_A_indices(nullptr), + ptr_gather_B_indices(nullptr), + ptr_scatter_D_indices(nullptr) {} + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode, GemmCoord problem_size, int batch_count, + typename EpilogueOutputOp::Params epilogue, void const* ptr_A, + void const* ptr_B, void const* ptr_C, void* ptr_D, + int64_t batch_stride_A, int64_t batch_stride_B, + int64_t batch_stride_C, int64_t batch_stride_D, + typename LayoutA::Stride stride_a, + typename LayoutB::Stride stride_b, + typename LayoutC::Stride stride_c, + typename LayoutC::Stride stride_d, + int const* ptr_gather_A_indices = nullptr, + int const* ptr_gather_B_indices = nullptr, + int const* ptr_scatter_D_indices = nullptr) + : UniversalArgumentsBase(mode, problem_size, batch_count, + batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + stride_a(stride_a), + stride_b(stride_b), + stride_c(stride_c), + stride_d(stride_d), + ptr_gather_A_indices(ptr_gather_A_indices), + ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { + lda = 0; + ldb = 0; + ldc = 0; + ldd = 0; + CUTLASS_TRACE_HOST( + "GemmUniversal::Arguments::Arguments() - problem_size: " + << problem_size); + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode, GemmCoord problem_size, int batch_count, + typename EpilogueOutputOp::Params epilogue, void const* ptr_A, + void const* ptr_B, void const* ptr_C, void* ptr_D, + int64_t batch_stride_A, int64_t batch_stride_B, + int64_t batch_stride_C, int64_t batch_stride_D, + typename LayoutA::Stride::LongIndex lda, + typename LayoutB::Stride::LongIndex ldb, + typename LayoutC::Stride::LongIndex ldc, + typename LayoutC::Stride::LongIndex ldd, + int const* ptr_gather_A_indices = nullptr, + int const* ptr_gather_B_indices = nullptr, + int const* ptr_scatter_D_indices = nullptr) + : UniversalArgumentsBase(mode, problem_size, batch_count, + batch_stride_D), + epilogue(epilogue), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + ptr_gather_A_indices(ptr_gather_A_indices), + ptr_gather_B_indices(ptr_gather_B_indices), + ptr_scatter_D_indices(ptr_scatter_D_indices) { + stride_a = make_Coord(lda); + stride_b = make_Coord(ldb); + stride_c = make_Coord(ldc); + stride_d = make_Coord(ldd); + CUTLASS_TRACE_HOST( + "GemmUniversal::Arguments::Arguments() - problem_size: " + << problem_size); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.stride_a, args.stride_b); + std::swap(args.batch_stride_A, args.batch_stride_B); + std::swap(args.ptr_gather_A_indices, args.ptr_gather_B_indices); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + : UniversalParamsBase { + using ParamsBase = + UniversalParamsBase; + + // + // Data members + // + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + + typename EpilogueOutputOp::Params output_op; + + void* ptr_A; + void* ptr_B; + void* ptr_C; + void* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + + int* ptr_gather_A_indices; + int* ptr_gather_B_indices; + int* ptr_scatter_D_indices; + + // + // Host dispatch API + // + + /// Default constructor + Params() = default; + + /// Constructor + Params(Arguments const& args, /// GEMM application arguments + int device_sms, /// Number of SMs on the device + int sm_occupancy) /// Kernel SM occupancy (in thread blocks) + : ParamsBase(args, device_sms, sm_occupancy), + params_A(args.lda + ? make_Coord_with_padding(args.lda) + : args.stride_a), + params_B(args.ldb + ? make_Coord_with_padding(args.ldb) + : args.stride_b), + params_C(args.ldc + ? make_Coord_with_padding(args.ldc) + : args.stride_c), + params_D(args.ldd + ? make_Coord_with_padding(args.ldd) + : args.stride_d), + output_op(args.epilogue), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + ptr_gather_A_indices(const_cast(args.ptr_gather_A_indices)), + ptr_gather_B_indices(const_cast(args.ptr_gather_B_indices)), + ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) {} + + /// Lightweight update given a subset of arguments. + void update(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); + + // Update input/output pointers + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); + ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); + ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); + + output_op = args.epilogue; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + public: + // + // Host dispatch API + // + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + + static int const kAlignmentA = + (cute::is_same>::value) ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (cute::is_same>::value) ? 32 + : (cute::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + (cute::is_same>::value) ? 32 + : (cute::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (cute::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (cute::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (cute::is_same>::value || + cute::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (cute::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (cute::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (cute::is_same>::value || + cute::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (cute::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (cute::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (cute::is_same>::value || + cute::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + public: + // + // Device-only API + // + + // Factory invocation + CUTLASS_DEVICE + static void invoke(Params const& params, SharedStorage& shared_storage) { + GemmUniversal op; + op(params, shared_storage); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const& params, SharedStorage& shared_storage, + ThreadblockSwizzle& threadblock_swizzle) { + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast( + params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast( + params.ptr_B)[threadblock_tile_offset.k()]; + } + + syncthreads(); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = ThreadIdxX(); + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, + thread_idx, tb_offset_A, params.ptr_gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, + thread_idx, tb_offset_B, params.ptr_gather_B_indices); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = ThreadIdxX() % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC* ptr_C = static_cast(params.ptr_C); + ElementC* ptr_D = static_cast(params.ptr_D); + + // + // Fetch pointers based on mode. + // + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial + // synchronization + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } + } else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast( + params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast( + params.ptr_D)[threadblock_tile_offset.k()]; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, ptr_C, params.problem_size.mn(), thread_idx, + threadblock_offset, params.ptr_scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, ptr_D, params.problem_size.mn(), thread_idx, + threadblock_offset, params.ptr_scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp new file mode 100644 index 0000000..bd49242 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_epilogue.hpp @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +// #include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor_softmax.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, + "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class CollectiveEpilogue { + public: + // + // Type Aliases + // + using DispatchPolicy = IntelXeXMX16Group; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = + typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2R; + using GmemTiledCopyD = cute::conditional_t && + not cute::is_void_v, + CopyOpR2G, XE_2D_U32x8x16_ST_N>; + using ElementOutput = ElementD; + using ElementCompute = ElementAccumulator; + using ElementSource = typename FusionCallbacks::ElementSource; + using ElementScalar = typename FusionCallbacks::ElementScalar; + static constexpr FloatRoundStyle RoundStyle = + FloatRoundStyle::round_to_nearest; + + static_assert( + cute::is_same_v< + typename FusionCallbacks::Operation, + fusion::LinearCombination>, + "Only Linear Combination Epilogue is supported for Grouped GEMM at the " + "moment."); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, + "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, + "Copy operation to shared memory is not supported"); + + using CopyThreadShape = Shape<_1, Int>; + using Trait_C = Copy_Traits; + using XE_Copy_C = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + make_layout( + shape_div(typename Trait_C::BlockShape{}, CopyThreadShape{})))); + using Trait_D = Copy_Traits; + using XE_Copy_D = decltype(make_tiled_copy( + Copy_Atom{}, Layout{}, + make_layout( + shape_div(typename Trait_D::BlockShape{}, CopyThreadShape{})))); + + private: + // constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_source_supported = false; + constexpr static bool is_destination_supported = + not cute::is_void_v && not cute::is_void_v; + + public: + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl : cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + using TensorC = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideC{})); //(m, n) + using TensorD = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideD{})); //(m, n) + using EpilogueTensors = cute::tuple; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only + // rank-3 (MNK) + auto problem_shape_MNL = repeat_like( + typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + auto [M, N, L] = problem_shape_MNL; + + XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + ElementC const* ptr_C_first_batch = + reinterpret_cast(args.ptr_C); + TensorC mC_mnl = + make_tensor(make_gmem_ptr(ptr_C_first_batch), + make_layout(make_shape(M, N, L), InternalStrideC{})); + xe_load_c = {xe_load_c.with(mC_mnl)}; + } + + XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + ElementD* ptr_D_first_batch = reinterpret_cast(args.ptr_D); + TensorD mD_mnl = + make_tensor(make_gmem_ptr(ptr_D_first_batch), + make_layout(make_shape(M, N, L), InternalStrideD{})); + xe_store_d = {xe_store_d.with(mD_mnl)}; + } + + return {FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, + workspace), + xe_load_c, + xe_store_d, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD}; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, + Arguments const& args) { + return 0; + } + + template + static cutlass::Status initialize_workspace( + ProblemShape const& problem_shape, Arguments const& args, void* workspace, + cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + static bool can_implement(ProblemShape problem_shape, Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + + bool implementable = true; + bool fusion_implementable = true; + + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = + append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int min_aligned_elements_D = + copy_alignment_bits / sizeof_bits::value; + implementable &= + cutlass::detail::check_alignment( + cute::make_shape(M, N, L), InternalStrideD{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_D = + batch_alignment_bits / sizeof_bits::value; + implementable &= + get<2>(InternalStrideD{}) % min_batch_aligned_elements_D == 0; + } + } + + if constexpr (is_source_supported) { + constexpr int min_aligned_elements_C = + copy_alignment_bits / sizeof_bits::value; + implementable &= + cutlass::detail::check_alignment( + cute::make_shape(M, N, L), InternalStrideC{}); + if (L > 1) { + constexpr int min_batch_aligned_elements_C = + batch_alignment_bits / sizeof_bits::value; + implementable &= + get<2>(InternalStrideC{}) % min_batch_aligned_elements_C == 0; + } + } + + fusion_implementable = + fusion_implementable && + FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for XE 2D copy.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements " + "for FusionCallbacks.\n"); + } + + return implementable && fusion_implementable; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, + TensorStorage const& shared_storage_) + : params(params_), + fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template + CUTLASS_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, TiledMma tiled_mma, + int thread_idx, + LoadStoreTensor const& load_store_tensors) { + (void)tiled_mma; + using namespace cute; + + static_assert(cute::rank(CtaTileMNK{}) == 3, + "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(InternalStrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]"); + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, + // "assertion fail"); + static constexpr auto ATOM_M = + get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = + get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = + get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert( + BLK_M % ATOM_M == 0 && BLK_N % ATOM_N == 0 && BLK_K % ATOM_K == 0, + "expected CTATileMNK to be evenly divided by TiledMma::ThrLayoutVMNK"); + static constexpr auto SG_M = BLK_M / ATOM_M; + static constexpr auto SG_N = BLK_N / ATOM_N; + static constexpr auto SG_K = BLK_K / ATOM_K; + using SubgroupTileShape = + Shape; + + static constexpr int FragsM = + get<0>(SubgroupTileShape{}) / + get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = + get<1>(SubgroupTileShape{}) / + get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + + // Get the layout and reconstruct the MN mapping equivalent to the old + // get_layoutS_MN() + auto layoutS_TV = params.xe_store_d.get_layoutS_TV(); + auto mn_shape = shape(typename decltype(params.xe_store_d)::Tiler_MN{}); + auto layoutS_MN = right_inverse(layoutS_TV).with_shape(mn_shape); + using EpilogueTile = decltype(layoutS_MN.shape()); + + auto sg_local_m_coord = get_sub_group_id() / ATOM_N; + auto sg_local_n_coord = get_sub_group_id() % ATOM_N; + + auto sg_m_coord = m_coord * ATOM_M + sg_local_m_coord; + auto sg_n_coord = n_coord * ATOM_N + sg_local_n_coord; + auto sg_coord = make_coord(sg_m_coord, sg_n_coord, k_coord, l_coord); + + bool is_C_load_needed = + is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Represent the full output tensor + Tensor mD_mnl = cute::get_xe_tensor(make_shape(M, N, L)); + + // Tile the output tensor per WG and select the tile for current WG + Tensor g_wg_D = + local_tile(mD_mnl, take<0, 2>(CtaTileMNK{}), + make_coord(m_coord, n_coord, l_coord)); // (BLK_M,BLK_N) + + // Tile the output tensor per SG and select tile for the current SG + Tensor gD = local_tile(g_wg_D, take<0, 2>(SubgroupTileShape{}), + make_coord(m_sg, n_sg)); // (SG_M,SG_N) + + auto thread_xe_store_d = params.xe_store_d.get_thread_slice(thread_idx); + Tensor tCgD = thread_xe_store_d.partition_D(gD); + + Tensor trC = + make_tensor(Shape>{}); + Tensor trD_compute = + make_tensor(Shape>{}); + + // Because Sm90 uses shared memory, they are not tied to using the same + // accumulator values for MMA and Epilogue. But because we are operating + // directly in the accumulators, we need to be sure that we are operating on + // the same values. + ThrCopy thread_g2r = params.xe_load_c.get_slice(thread_idx); + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M, N)); // (M,N) + Tensor cD = local_tile(mD_crd, take<0, 2>(SubgroupTileShape{}), + make_coord(sg_m_coord, sg_n_coord)); + Tensor cD_mn = local_tile(mD_crd, take<0, 2>(CtaTileMNK{}), + make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_g2r.partition_S( + flat_divide(cD_mn, EpilogueTile{})); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + Tensor tRS_cD = + make_coord_tensor(tRS_cD_mn.layout()); // (G2R,G2R_M,G2R_N,EPI_M,EPI_N) + + // Get the fusion callbacks + // Arguments passed here relate to sub-group tiles, rather than CTA + // (work-group) tiles + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); // TODO(Codeplay): this is not correct + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + SubgroupTileShape{}, + sg_coord, + tiled_mma, + EpilogueTile{}, + params.xe_store_d, + cD, + residue_mn, + tRS_cD, + residue_mn, + trC, + thread_idx, + }; + auto cst_callbacks = + fusion_callbacks.template get_consumer_store_callbacks( + cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_compute_frag = + recast>(trD_compute); + + Tensor trD = make_tensor(Shape>{}); + auto trD_frag = recast>(trD); + + constexpr int ValuesLoaded = FragsM * FragsN * FragmentSize * SubgroupSize * + ATOM_M * ATOM_N * ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert( + ValuesLoaded == MN, + "the total elements loaded by all threads should be the same as MxN"); + + auto synchronize = [&]() {}; + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + if (is_C_load_needed) { + // coordinates for C and D are the same + copy(params.xe_load_c.with(get<0>(load_store_tensors)), + tCgD(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size<0>(trD_compute_frag); ++epi_v) { + trD_compute_frag(epi_v) = + cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + cst_callbacks.reduce(nullptr, synchronize, epi_m, epi_n, + (epi_m == FragsM - 1 && epi_n == FragsN - 1), + trD_compute_frag); + + if constexpr (is_destination_supported) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(trD_compute_frag); ++i) { + trD_frag(i) = + cutlass::NumericArrayConverter{}( + trD_compute_frag(i)); + } + copy(params.xe_store_d.with(get<1>(load_store_tensors)), trD, + tCgD(_, epi_m, epi_n)); + } + } + } + + cst_callbacks.end(); + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + int32_t const& next_group, ProblemShape_MNKL const& problem_shape_mnkl) { + auto [M, N, K, L] = problem_shape_mnkl; + + TensorC mC_mnl; + TensorD mD_mnl; + if constexpr (is_source_supported) { + ElementC const* ptr_C_curr_batch = + reinterpret_cast(params.ptr_C[next_group]); + mC_mnl = + make_tensor(make_gmem_ptr(ptr_C_curr_batch), + make_layout(make_shape(M, N, L), params.dC[next_group])); + } + + if constexpr (is_destination_supported) { + ElementD* ptr_D_curr_batch = + reinterpret_cast(params.ptr_D[next_group]); + mD_mnl = + make_tensor(make_gmem_ptr(ptr_D_curr_batch), + make_layout(make_shape(M, N, L), params.dD[next_group])); + } + return cute::make_tuple(mC_mnl, mD_mnl); + } + + private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp new file mode 100644 index 0000000..a2abb4b --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_array_mma.hpp @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, + TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelXeXMX16Group; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using ElementB = ElementB_; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert( + platform::is_same::value, + "MainloopIntelXeXMX16Array requires that A and B have same type."); + + static_assert(std::is_same_v, + "Transformation for A is not currently supported on Intel PVC"); + static_assert(std::is_same_v, + "Transformation for B is not currently supported on Intel PVC"); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = + get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = + get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = + get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = + Shape; + + static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using Copy_A = typename Copy_Traits< + GmemTiledCopyA, InternalStrideA>::template DefaultTiledCopy; + using Copy_B = typename Copy_Traits< + GmemTiledCopyB, InternalStrideB>::template DefaultTiledCopy; + + using TensorMKL = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideA{})); //(m, k) + using TensorNKL = + decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_shape(0, 0, 0), InternalStrideB{})); //(n, k) + using MainloopTensors = cute::tuple; + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + struct Params { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, + void* workspace) { + (void)workspace; + + auto problem_shape_MNK = repeat_like( + typename ProblemShape::UnderlyingProblemShape{}, int32_t(1)); + ; + auto init_M = get<0>(problem_shape_MNK); + auto init_N = get<1>(problem_shape_MNK); + auto init_K = get<2>(problem_shape_MNK); + + return Params{args.ptr_A, args.dA, args.ptr_B, args.dB}; + } + + template + static bool can_implement(ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = + copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = + copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = + batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = + batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = + append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + + implementable &= cutlass::detail::check_alignment( + cute::make_shape(M, K, L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment( + cute::make_shape(N, K, L), InternalStrideB{}); + + if (L > 1) { + implementable &= + get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= + get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment " + "requirements for XE 2D copy.\n"); + } + + return implementable; + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void operator()(FrgTensorD& accum, TensorA gA, TensorB gB, + FrgTensorC const& src_accum, + KTileIterator k_tile_iter, + int const& k_tile_count, + BlkCoord const& blk_coord, int const& K_start, + int const& thread_idx, Params const& mainloop, + LoadTensors const& load_tensors) { + static_assert(is_rmem::value, + "D tensor must be rmem resident."); + static_assert(is_rmem::value, + "C tensor must be rmem resident."); + + (void)thread_idx; + + Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; + Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; + + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + // TODO(Codeplay): see if we can make this nicer + // To make all work items in a subgroup have the same global tensors pass in + // the index of work item 0 in each subgroup + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition global counting tensors for MMA + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor tCrA = make_tensor( + make_fragment_layout(tiled_copy_a, tCgA(_, _, _, 0).shape())); + Tensor tCrB = make_tensor( + make_fragment_layout(tiled_copy_b, tCgB(_, _, _, 0).shape())); + + // Retile registers for copies + Tensor tArA = thr_copy_A.retile_D(tCrA); + Tensor tBrB = thr_copy_B.retile_D(tCrB); + + // Retile global counting tensors for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = + cute::prefetch_selector, Int>, Num_SGs>( + tiled_copy_a); + auto tiled_prefetch_b = + cute::prefetch_selector, Int>, Num_SGs>( + tiled_copy_b); + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); + print(gA); + print("\n"); + print("tCgA : "); + print(tCgA); + print("\n"); + print("tAgA : "); + print(tAgA); + print("\n"); + + print("===================== B :\n"); + print(" gB : "); + print(gB); + print("\n"); + print("tCgB : "); + print(tCgB); + print("\n"); + print("tBgB : "); + print(tBgB); + print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); + print(MaxThreadsPerBlock); + print("\n"); + print(" SubgroupTileShape : "); + print(SubgroupTileShape{}); + print("\n"); + } +#endif + + // + // Mainloop + // + const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; + k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + // Copy gmem to rmem for the first k_tile + copy(tiled_copy_a, tAgA(_, _, _, k_tile), tArA); + copy(tiled_copy_b, tBgB(_, _, _, k_tile), tBrB); + + if (prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + cute::gemm(tiled_mma, tCrA, tCrB, accum); + barrier_wait(barrier_scope); + } + } + + template + CUTLASS_DEVICE auto update_tensor_shape_stride( + Params const& mainloop_params, int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + const int32_t M = get<0>(problem_shape_mnkl); + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const* ptr_A_curr_batch = + reinterpret_cast(mainloop_params.ptr_A[next_group]); + ElementB const* ptr_B_curr_batch = + reinterpret_cast(mainloop_params.ptr_B[next_group]); + + Tensor mA = make_tensor(make_gmem_ptr(ptr_A_curr_batch), + make_shape(M, K, (int32_t)1), + mainloop_params.dA[next_group]); + Tensor mB = make_tensor(make_gmem_ptr(ptr_B_curr_batch), + make_shape(N, K, (int32_t)1), + mainloop_params.dB[next_group]); + + return cute::make_tuple(mA, mB); + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp new file mode 100644 index 0000000..ca749c3 --- /dev/null +++ b/csrc/xpu/cutlass_kernels/collective/gemm/xe_builder.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include // cute::DefaultCopy +#include // cute::is_base_of_v +// #include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "xe_array_epilogue.hpp" +#include "xe_callbacks.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify epilogue subtile shape or dispatch to automatic computation +// of subtile shape +struct EpilogueTileAuto {}; + +// Used to let the builder pick the epilogue schedule automatically. +// Can be overridden with kernel schedule tags in +// cutlass/gemm/dispatch_policy.hpp +struct EpilogueScheduleAuto {}; + +template < + class ArchTag, class OpClass, class TileShape_MNK, class ClusterShape_MNK, + class EpilogueTileType, class ElementAccumulator, class ElementCompute, + class ElementC, class GmemLayoutTagC, int AlignmentC, class ElementD, + class GmemLayoutTagD, int AlignmentD, class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination< + ElementD, ElementCompute, ElementC, ElementCompute>, + class Enable = void> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// helper sub-builder for epilogue fusion callbacks (for internal use by +// CollectiveBuilder only) +namespace detail { + +// callbacks builder with operation tag +template +struct CallbacksBuilder { + using Callbacks = fusion::FusionCallbacks; +}; + +// callbacks builder with callbacks passthrough +template +struct CallbacksBuilder>> { + using Callbacks = FusionCallbacks; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +namespace detail { +template +struct FusionOpInfo { + static_assert(cutlass::detail::dependent_false, + "Could not find a builder specialization."); +}; + +template +struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + + template + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinearCombination, + TileShape_MNK, EpilogueTile>; +}; + +template