diff --git a/examples/sycl/08_bmg_gemm_f8/08_bmg_gemm_f8_deepseek_scaling.cpp b/examples/sycl/08_bmg_gemm_f8/08_bmg_gemm_f8_deepseek_scaling.cpp new file mode 100644 index 0000000000..cc2760731c --- /dev/null +++ b/examples/sycl/08_bmg_gemm_f8/08_bmg_gemm_f8_deepseek_scaling.cpp @@ -0,0 +1,581 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file +\brief CUTLASS Intel BMG Gemm FP8 with optinal quantization. + The GemmMode enum describes the 3 modes of operation: + + - ConvertOnly: Narrower type is simply converted to the wider type before MMA + - ConvertAndScale: Narrower type is converted to wider type, then scaled + + - Requirements: + - dequantization group size (options.g) must be multiple of k-block size + - scales & zeros must be MN-major + + To build & run this example (from your build dir): + + $ ninja 08_bmg_gemm_f8_scaling + $ ./examples/sycl/08_bmg_gemm_f8/08_bmg_gemm_f8_scaling + + Call with `--help` for information about available options +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" +#include "cutlass/util/mixed_dtype_utils.hpp" + +using namespace cute; + +enum GemmMode { + ConvertOnly, + ConvertAndScale +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int mode; + int m, n, k, l, iterations; + int g; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + g(128), mode(0), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("g", g, 128); + cmd.get_cmd_line_argument("mode", mode, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --g= The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --mode= The mode to run the gemm. 0 is Convert Only, 1 is Convert and Scale\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using CollectiveMainloop = typename Gemm::CollectiveMainloop; + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + using ElementMMA = typename CollectiveMainloop::ElementMMA; + using ElementQuant = ElementA; + + using ElementScaleA = typename CollectiveMainloop::NonVoidElementScaleA; + using ElementScaleB = half_t; + + using StrideScaleA = typename CollectiveMainloop::NonVoidStrideScaleA; + using StrideScaleB = cute::Stride<_1, int64_t, int64_t>; + + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + StrideScaleA stride_SA; + StrideScaleB stride_SB; + + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_A_dq; // Dequantized copy of A for validation + cutlass::DeviceAllocation block_B_dq; // Dequantized copy of B for validation + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_scaleA; + cutlass::DeviceAllocation block_scaleB; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + template + void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) { + syclcompat::get_default_queue().parallel_for(size, [=](auto indx) { + d_dst[indx] = static_cast(d_src[indx]); + }).wait(); + } + + bool verify(const Options &options) { + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x16x16_LD_T; + + // Workgroup-level tile + using TileShape = Shape<_128, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + + using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + //using realStrideOfB = decltype(cute::reverse(cutlass::gemm::TagToStrideB_t{})); + + // Mainloop + using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementMMA, + cutlass::gemm::TagToStrideA_t, + ElementMMA, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + /* + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A_dq.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementOutput const epsilon(1e-2f); + ElementOutput const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + */ + return true; + } + + template + bool initialize_scale( + cutlass::DeviceAllocation& block, + Options const& options) { + + if (options.mode == GemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); + } + else { + const float elt_max_f = float(cutlass::platform::numeric_limits::max()); + // Need to fix max_dequant_val and min_dequant_val? + const float max_dequant_val = elt_max_f * 0.25f; + const float min_dequant_val = 0.5f; + const float scale_max = max_dequant_val / elt_max_f; + const float scale_min = min_dequant_val / elt_max_f; + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scale_max), Element(scale_min)); + } + return true; + } + + template + bool initialize_zero( + cutlass::DeviceAllocation& block, + Options const& options) { + + // No bias, so just initialize with 0 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + + return true; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(Options const& options) { + auto [M, N, K, L] = ProblemShapeType{options.m, options.n, options.k, options.l}; + + const int scale_k = cute::ceil_div(options.k, options.g); + const int scale_n = cute::ceil_div(options.n, options.g); + auto shape_A = cute::make_shape(M, K, L); + auto shape_B = cute::make_shape(N, K, L); + auto shape_CD = cute::make_shape(M, N, L); + auto shape_scale_zeroA = cute::make_shape(M, scale_k, L); + auto shape_scale_zeroB = cute::make_shape(scale_n, scale_k, L); + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, shape_A); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, shape_CD); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, shape_CD); + stride_SA = cutlass::make_cute_packed_stride(StrideScaleA{}, shape_scale_zeroA); + stride_SB = cutlass::make_cute_packed_stride(StrideScaleB{}, shape_scale_zeroB); + + block_A.reset(static_cast(M) * K * L); + block_A_dq.reset(static_cast(M) * K * L); + block_B.reset(static_cast(K) * N * L); + block_B_dq.reset(static_cast(K) * N * L); + block_C.reset(static_cast(M) * N * L); + block_D.reset(static_cast(M) * N * L); + block_ref_D.reset(static_cast(M) * N * L); + block_scaleA.reset(static_cast(scale_k) * L * M); + block_scaleB.reset(static_cast(scale_k) * L * scale_n); + + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + + convert_fp8_to_fp16( + block_A.get(), + block_A_dq.get(), + block_A.size() + ); + convert_fp8_to_fp16( + block_B.get(), + block_B_dq.get(), + block_B.size() + ); + + initialize_scale(block_scaleA, options); + initialize_scale(block_scaleB, options); + + auto layout_A = make_layout(shape_A, stride_A); + auto layout_B = make_layout(shape_B, stride_B); + auto layout_scale_zeroA = make_layout(shape_scale_zeroA, stride_SA); + auto layout_scale_zeroB = make_layout(shape_scale_zeroB, stride_SB); + + /* + // Note that we are overwriting the relevant `block_X_dq` here, both were + // filled by initialize_mixed_dtype_block above + cutlass::dequantize(block_A_dq.get(), block_A.get(), layout_A, + block_scaleA.get(), block_zeroA.get(), layout_scale_zeroA, layout_scale_zeroA, + options.g); + cutlass::dequantize(block_B_dq.get(), block_B.get(), layout_B, + block_scaleB.get(), block_zeroB.get(), layout_scale_zeroB, layout_scale_zeroB, + options.g); + */ + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(options); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B, + block_scaleA.get(), stride_SA, block_scaleB.get(), + options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e4m3_t"<< std::endl; + } else if constexpr (std::is_same_v) { + std::cout << "Datatype: float_e5m2_t"<< std::endl; + } + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + +}; + +template +int launcher(Options& options) +{ + // + // Run examples + // + + cutlass::KernelHardwareInfo hw_info; + + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + using QuantType = ElementType; + + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = QuantType; + using ElementInputB = QuantType; + using ElementOutput = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementScale = half_t; + + using StrideScale = cute::Stride<_1, int64_t, int64_t>; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x16x16_LD_T; + + using TileShape = Shape<_128, _256, _32>; + + // TODO: Consider smaller tile size to reduce register pressure + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 2; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16FP8DeepSeekScaling; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + using ConvertOnlyCollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + cute::tuple, + cutlass::gemm::TagToStrideA_t, + cute::tuple, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, + GmemTiledCopyB, void, void, cute::identity + >; + + using ConvertAndScaleCollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + cute::tuple, + cutlass::gemm::TagToStrideA_t, + cute::tuple, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, + GmemTiledCopyB, void, void, cute::identity + >; + + if(options.mode == GemmMode::ConvertOnly) { + std::cout << "Running in ConvertOnly mode." << std::endl; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + ConvertOnlyCollectiveMainloop, + CollectiveEpilogue + >; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } else if(options.mode == GemmMode::ConvertAndScale){ + std::cout << "Running in ConvertAndScale mode." << std::endl; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + ConvertAndScaleCollectiveMainloop, + CollectiveEpilogue + >; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } + return 0; +} + +int main(int argc, const char** argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + launcher(options); + launcher(options); + return 0; +} diff --git a/examples/sycl/08_bmg_gemm_f8/CMakeLists.txt b/examples/sycl/08_bmg_gemm_f8/CMakeLists.txt index 2184aaf1c3..efcecb166f 100644 --- a/examples/sycl/08_bmg_gemm_f8/CMakeLists.txt +++ b/examples/sycl/08_bmg_gemm_f8/CMakeLists.txt @@ -44,7 +44,16 @@ cutlass_example_add_executable( TEST_MODE_0 ) +cutlass_example_add_executable( + 08_bmg_gemm_f8_deepseek_scaling + 08_bmg_gemm_f8_deepseek_scaling.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES + TEST_MODE_0 +) + if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") # TODO(codeplay): Remove these once IGC block load loop hoisting bug is fixed target_link_options( 08_bmg_gemm_f8_scaling PRIVATE -Xs "-options \"-igc_opts 'allowDecompose2DBlockFuncs=0'\"" ) + target_link_options( 08_bmg_gemm_f8_deepseek_scaling PRIVATE -Xs "-options \"-igc_opts 'allowDecompose2DBlockFuncs=0'\"" ) endif() diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index df97578080..525ffe4f9e 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -76,6 +76,7 @@ #include "cutlass/gemm/collective/xe_mma_mixed_input.hpp" #include "cutlass/gemm/collective/xe_mma_w8a8.hpp" #include "cutlass/gemm/collective/xe_mma_fp8_scaling.hpp" +#include "cutlass/gemm/collective/xe_mma_fp8_deepseek_scaling.hpp" #endif #if defined(CUTLASS_ENABLE_SYCL) diff --git a/include/cutlass/gemm/collective/xe_mma_fp8_deepseek_scaling.hpp b/include/cutlass/gemm/collective/xe_mma_fp8_deepseek_scaling.hpp new file mode 100644 index 0000000000..43f4e2a9f6 --- /dev/null +++ b/include/cutlass/gemm/collective/xe_mma_fp8_deepseek_scaling.hpp @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/fp8_to_fp16.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/tensor_ref.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopIntelXeXMX16FP8DeepSeekScaling, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ +private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + +public: + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelXeXMX16FP8DeepSeekScaling; + using WorkgroupTileShape = TileShape_; + + + static_assert((cute::is_tuple::value & cute::is_tuple::value & + (cute::is_any_of_v, float_e4m3_t, float_e5m2_t>)), + "Either A and B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + static constexpr bool IsBTransformed = cute::is_tuple::value; + + using ElementMMA = typename TiledMma_::ValTypeA; + using ElementQuant = cute::conditional_t; + + // TODO(Codeplay): Create a ScaledTensor class to encapsulate scale logic + using ElementScaleA = half_t; + using StrideScaleA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ElementZeroA = detail::deduce_mixed_width_dtype_t<3, ElementAOptionalTuple>; + using StrideZeroA = detail::deduce_mixed_width_dtype_t<4, ElementAOptionalTuple>; + + using ElementScaleB = half_t; + + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScaleA = cute::conditional_t, ElementMMA, ElementScaleA>; + using NonVoidStrideScaleA = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScaleA>; + using NonVoidElementZeroA = cute::conditional_t, ElementMMA, ElementZeroA>; + using NonVoidStrideZeroA = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideZeroA>; + + using StrideA = StrideA_; + using StrideB = StrideB_; + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaType = typename TiledMma::ValTypeA; // ValTypeA and ValTypeB are always same and reflects MMA type on intel Xe + + static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); + static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + +private: + static constexpr ConversionMode + get_conversion_modeA() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode + get_conversion_modeB() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionModeA = get_conversion_modeA(); + static constexpr ConversionMode KernelConversionModeB = get_conversion_modeB(); + static constexpr bool ModeHasScalesA = KernelConversionModeA == ConversionMode::ConvertAndScale || + KernelConversionModeA == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool ModeHasScalesZeroA = KernelConversionModeA == ConversionMode::ConvertAndScaleWithZero; + + static constexpr bool ModeHasScalesB = KernelConversionModeB == ConversionMode::ConvertAndScale || + KernelConversionModeB == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool ModeHasScalesZeroB = KernelConversionModeB == ConversionMode::ConvertAndScaleWithZero; + +public: + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + using GmemTiledCopyScaleA = typename scale_zero_copy_traits::type; + + static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using CopyThreadShape = Shape<_1, Int>; + using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{})); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); + using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout{}, val_layout_load_A{})); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); + using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout{}, val_layout_load_B{})); + + using traits_load_scaleA = Copy_Traits; + using atom_load_scaleA = Copy_Atom; + using val_layout_load_scaleA = decltype(make_layout(shape_div(typename traits_load_scaleA::BlockShape{}, CopyThreadShapeRev{}))); + using Copy_ScaleA = decltype(make_tiled_copy(atom_load_scaleA{}, Layout{}, val_layout_load_scaleA{})); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + NonVoidElementScaleA const* ptr_SA = nullptr; + NonVoidStrideScaleA dSA{}; + half_t const* ptr_SB = nullptr; + int group_size = 1; + }; + + struct Params { + Copy_A tiled_copy_a; + Copy_B tiled_copy_b; + Copy_ScaleA tiled_copy_scaleA; + half_t* scales_B; + int group_size; + int original_N; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; + + auto [M, N, K, L] = problem_shape; + + auto mA_mkl = + make_tensor(make_gmem_ptr(args.ptr_A), make_layout(make_shape(M, K, L), args.dA)); + + auto ptr_B = [&]() { + if constexpr (sizeof_bits_v < 8) { + return cute::subbyte_iterator(args.ptr_B); + } else { + return make_gmem_ptr(static_cast(args.ptr_B)); + } + }(); + + + auto mB_nkl = + make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + + Copy_A tiled_copy_a{Copy_A{}.with(mA_mkl)}; + Copy_B tiled_copy_b{Copy_B{}.with(mB_nkl)}; + + Copy_ScaleA tiled_copy_scaleA; + + auto scale_k = cute::ceil_div(K, args.group_size); + auto scale_n = cute::ceil_div(N, args.group_size); + if constexpr(ModeHasScalesA) { + auto mScale = make_tensor( + make_gmem_ptr(static_cast(args.ptr_SA)), + make_layout(make_shape(M, scale_k, L), args.dSA)); + tiled_copy_scaleA = {Copy_ScaleA{}.with(mScale)}; + } else { + tiled_copy_scaleA = {}; + } + return Params{tiled_copy_a, tiled_copy_b, tiled_copy_scaleA, const_cast(args.ptr_SB), args.group_size, N}; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), args.dA); + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), args.dB); + + if (L > 1) { + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dA) % min_batch_aligned_elements_A == 0; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + implementable &= get<2>(args.dB) % min_batch_aligned_elements_B == 0; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + + /// Utilities to transform A. + template + CUTLASS_DEVICE typename std::enable_if_t >= 8> + transform_A( + Tensor const& tCrA_load, + Tensor& tCrA_mma, + Tensor& tCrS_input + ) { + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(std::is_same_v); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + convert_FP8_to_FP16(tCrA_load, tCrA_mma); + + if constexpr (IsATransformed && ModeHasScalesA) { + half_t s0 = tCrS_input(0); + half_t s1 = tCrS_input(1); + auto ptr_tcrA_mm = tCrA_mma.data(); + // The current scale load atom (1x32) gives 2 scale values to + // each thread. All threads need access to all other threads + // scale values, and each scale value is reused twice (unrolled) + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < 16; ++i) { + auto scale0 = group_broadcast(sg, s0, i); + ptr_tcrA_mm[i] *= scale0; + ptr_tcrA_mm[32 + i] *= scale0; + auto scale1 = group_broadcast(sg, s1, i); + ptr_tcrA_mm[16 + i] *= scale1; + ptr_tcrA_mm[32 + 16 + i] *= scale1; + } + } + } + + /// Utilities to transform B. + template + CUTLASS_DEVICE typename std::enable_if_t >= 8> + transform_B( + Tensor const& tCrA_load, + Tensor& tCrA_mma, + half_t B_scale + ) { + static_assert(is_rmem::value, "Input tensor for B conversion must come from registers"); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + + using DstType = typename EngineOut::value_type; + + convert_FP8_to_FP16(tCrA_load, tCrA_mma); + if constexpr (IsBTransformed && ModeHasScalesB) { + static constexpr auto N = decltype(size<1>(tCrA_load))::value; + DstType *pDst = const_cast(tCrA_mma.data()); + using DstArray = cutlass::Array; + DstArray *pDstArr = reinterpret_cast(pDst); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; ++n) { + pDstArr[n] = pDstArr[n] * B_scale; + } + } + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + BlkCoord const &blk_coord, + int const &K_start, + int thread_idx, + Params const& mainloop) + { + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + // Partition the copying of A and B tiles across the threads + auto thr_copy_A = mainloop.tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = mainloop.tiled_copy_b.get_slice(thread_idx); + auto thr_copy_scaleA = mainloop.tiled_copy_scaleA.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + // Create fragments + Tensor mma_A = make_tensor(make_fragment_layout(mainloop.tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor mma_B = make_tensor(make_fragment_layout(mainloop.tiled_copy_b, tCgB(_,_,_,0).shape())); + + // If IsATransformed, we need modes M_atom, and M_iter from fragment_A + // layout else we need mode N_iter from fragment_B layout. + using FragScaleALayout = Layout>; + Tensor fragment_scaleA_input = make_tensor(FragScaleALayout{}); + + + // narrow input fragment + Tensor quantA_frag = make_tensor(mma_A.layout()); + Tensor quantB_frag = make_tensor(mma_B.layout()); + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + // Retile for copy + auto frag_copy_A = thr_copy_A.retile_D(quantA_frag); + auto frag_copy_B = thr_copy_B.retile_D(quantB_frag); + + Tensor copy_tCrSA = thr_copy_scaleA.retile_D(fragment_scaleA_input); + + // Retile global tile for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(mainloop.tiled_copy_a);; + auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(mainloop.tiled_copy_b);; + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + // + // Mainloop + // + // TODO(Codeplay): Define these coord tensors using proper cute logic + auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; + const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + const int l_coord = l_idx; + + Tensor copy_iter_sA = make_tensor(make_inttuple_iter(make_coord(m_coord, 0, l_coord)), + make_layout(make_shape(_2{}, _1{}, _1{}, k_tile_count), + make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); + + #define LOG_GROUP 0 + #define LOG_THREAD 0 + #define CUTLASS_ENABLE_DEBUG_PRINTS 0 + #if CUTLASS_ENABLE_DEBUG_PRINTS + #define PRINT(x) print(#x ": "); print(x); print("\n"); + if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); print(gA); print("\n"); + print(" tCgA : "); print(tCgA); print("\n"); + print(" tAgA : "); print(tAgA); print("\n"); + print(" mma_A : "); print(mma_A); print("\n"); + print(" frag_copy_A : "); print(frag_copy_A); print("\n"); + + print("===================== B :\n"); + print(" gB : "); print(gB); print("\n"); + print(" tCgB : "); print(tCgB); print("\n"); + print(" tBgB : "); print(tBgB); print("\n"); + print(" mma_B : "); print(mma_B); print("\n"); + print(" frag_copy_B : "); print(frag_copy_B); print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); + print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + + print(" tiled_prefetch_a : "); print(tiled_prefetch_a); print("\n"); + print(" tiled_prefetch_b : "); print(tiled_prefetch_b); print("\n"); + print(" pAgA : "); print(pAgA); print("\n"); + print(" pBgB : "); print(pBgB); print("\n"); + } + #undef PRINT + #endif + + const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } + + const int k_reload_factor = mainloop.group_size / BLK_K; + + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + + // Copy gmem to rmem for the first k_tile + copy(mainloop.tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A); + copy(mainloop.tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B); + + if constexpr(ModeHasScalesA){ + copy(mainloop.tiled_copy_scaleA, copy_iter_sA(_, _, _, k_tile / k_reload_factor), copy_tCrSA); + } + + if(prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } + + int group_row = k_tile / k_reload_factor; + int group_col = n_coord / mainloop.group_size ; + int groups_per_row = (mainloop.original_N + mainloop.group_size - 1) / mainloop.group_size; + int group_index = group_row * groups_per_row + group_col; + + transform_A(quantA_frag, mma_A, fragment_scaleA_input); + transform_B(quantB_frag, mma_B, mainloop.scales_B[group_index]); + + cute::gemm(tiled_mma, mma_A, mma_B, accum); + barrier_wait(barrier_scope); + } + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 42fc5e8d6d..d2a7316096 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -1052,6 +1052,10 @@ template struct MainloopIntelXeXMX16FP8Scaling : MainloopIntelXeXMX16 { }; +template +struct MainloopIntelXeXMX16FP8DeepSeekScaling : MainloopIntelXeXMX16 { +}; + #endif #if defined(CUTLASS_ENABLE_SYCL)