diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index 4c2267227d..006809faa3 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -97,6 +97,12 @@ cutlass_example_add_executable( TEST_BATCHES ) +cutlass_example_add_executable( + pvc_gemm_fp8 + pvc_gemm_fp8.cpp + TEST_COMMAND_OPTIONS + TEST_BATCHES +) cutlass_example_add_executable( pvc_gemm_group pvc_gemm_group.cpp diff --git a/examples/sycl/pvc/common.hpp b/examples/sycl/pvc/common.hpp index ea5763dbf3..d6dd7a5401 100644 --- a/examples/sycl/pvc/common.hpp +++ b/examples/sycl/pvc/common.hpp @@ -61,6 +61,27 @@ bool initialize_block( return true; } +template +void intialize_block(cutlass::DeviceAllocation& block_device, cutlass::DeviceAllocation& block_device_ref, + uint64_t seed, int M, int N) { + static_assert(cute::sizeof_bits_v > 8); + std::ranlux24_base rng(std::random_device{}()); + rng.seed(seed); + + using Limits = cutlass::platform::numeric_limits; + std::uniform_int_distribution<> dist(Limits::lowest(), Limits::max()); + + auto block_host = std::vector(block_device.size()); + auto block_host_ref = std::vector(block_device_ref.size()); + for (int i = 0; i < block_host.size(); i++) { + block_host[i] = static_cast(dist(rng)); + block_host_ref[i]= static_cast(block_host[i]); + } + + block_device.copy_from_host(block_host.data()); + block_device_ref.copy_from_host(block_host_ref.data()); +} + template void initialize_mixed_dtype_block(cutlass::DeviceAllocation& block_device, cutlass::DeviceAllocation& block_device_dq, diff --git a/examples/sycl/pvc/pvc_gemm_fp8.cpp b/examples/sycl/pvc/pvc_gemm_fp8.cpp new file mode 100644 index 0000000000..96483f1438 --- /dev/null +++ b/examples/sycl/pvc/pvc_gemm_fp8.cpp @@ -0,0 +1,571 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/collective/xe_epilogue.hpp" + #include "cutlass/epilogue/fusion/xe_callbacks.hpp" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/collective/collective_mma.hpp" + #include "cutlass/util/GPU_Clock.hpp" + + #include + #include + + #include "cutlass/util/command_line.h" + #include "cutlass/util/device_memory.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/reference/device/gemm_complex.h" + #include "cutlass/util/reference/device/tensor_compare.h" + #include "common.hpp" + #include "helper.h" + + using namespace cute; + + /////////////////////////////////////////////////////////////////////////////////////////////////// + + // Command line options parsing + struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(20), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 4096); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Mixed Type Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } + }; + + /////////////////////////////////////////////////////////////////////////////////////////////////// +#define A_ROW +#define B_COL + template < + class Gemm + > + struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using TiledMma = typename Gemm::CollectiveMainloop::TiledMma; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_A_ref; // Dequantized copy of A for validation + cutlass::DeviceAllocation block_B_ref; // Dequantized copy of B for validation + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const Options &options) { + #if defined(A_ROW) && defined(B_COL) + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; + #endif + #if defined(A_ROW) && defined(B_ROW) + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_N; + #endif + + #if defined(A_COL) && defined(B_ROW) + using GmemTiledCopyA = XE_2D_U16x16x16_LD_T; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + #endif + + #if defined(A_COL) && defined(B_COL) + using GmemTiledCopyA = XE_2D_U16x16x16_LD_T; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_T; + #endif + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using MMAAtom = MMA_Atom; + using TiledMma = TiledMMA, Stride<_4,_1,_0>>, + Tile, Stride<_1, _32, _8>>, + Layout, Stride<_1, _64, _16>>, + _32>>; + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + + using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + // Mainloop + using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + typename TiledMma::ValTypeA, + cutlass::gemm::TagToStrideA_t, + typename TiledMma::ValTypeB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A_ref.get(), stride_A, block_B_ref.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_ref_D.get(), stride_D} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementOutput const epsilon(1e-2f); + ElementOutput const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + return passed; + // return true; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_A_ref.reset(M * K * L); + block_B.reset(K * N * L); + block_B_ref.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + intialize_block(block_A, block_A_ref, seed + 2023, 256, 32); + intialize_block(block_B, block_B_ref, seed + 2022, 32, 256); + initialize_block(block_C, seed + 2021); + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return cutlass::Status::kSuccess; + } + + }; + + struct TransformA { + template + CUTE_HOST_DEVICE auto operator()(RTensor const& in, Trait trait, TransTensor& out) { + #if defined(A_ROW) + // auto mma_A = make_fragment_like(in); + Layout A_selector = make_layout(make_shape(_8{}, _4{}, _2{}), make_stride(_2{},_16{},_1{})); + // Layout A_selector = make_layout(make_shape(_8{}, _1{}, _2{}), make_stride(_2{},_16{}, _1{})); + // Layout A_selector = make_layout(make_shape(_8{}, _2{}, _2{}), make_stride(_2{}, _16{}, _1{})); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size<1>(out); i++) { + CUTLASS_PRAGMA_UNROLL + for(int j =0; j < size<2>(out); j++) { + CUTLASS_PRAGMA_UNROLL + for(int v = 0; v < size<0>(out); v++) { + out(v, i, j) = static_cast(in.data()[A_selector(v, i, j)]); + // out(v, i, j) = (bfloat16_t)(1.0f); + } + } + } + #endif + #if defined(A_COL) + Layout A_selector = make_layout(make_shape(_8{},_4{},_2{}), make_stride(_1{},_8{},_32{})); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size<1>(out); i++) { + CUTLASS_PRAGMA_UNROLL + for(int j =0; j < size<2>(out); j++) { + CUTLASS_PRAGMA_UNROLL + for(int v = 0; v < size<0>(out); v++) { + out(v, i, j) = static_cast(in.data()[A_selector(v, i, j)]); + // out(v, i, j) = (bfloat16_t)(1.0f); + } + } + } + #endif + } + }; + + struct TransformB { + template + CUTE_HOST_DEVICE auto operator()(RTensor const& in, Trait trait, TransTensor& out) { + #if defined(B_ROW) && defined(A_ROW) + // auto mma_B = make_fragment_like(in); + Layout B_selector = make_layout(make_shape(_16{}, make_shape(_2{}, _2{}), _2{}), make_stride(_4{}, make_stride(_1{}, _64{}) ,_2{})); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size<1>(out); i++) { + CUTLASS_PRAGMA_UNROLL + for(int j =0; j < size<2>(out); j++) { + CUTLASS_PRAGMA_UNROLL + for(int v = 0; v < size<0>(out); v++) { + out(v, i, j) = static_cast(in.data()[B_selector(v, i, j)]); + // out(v, i, j) = (bfloat16_t)(1.0f); + } + } + } + #endif + #if defined(B_ROW) && defined(A_COL) + Layout B_selector = make_layout(make_shape(_16{}, make_shape(_2{}, _2{}), _2{}), make_stride(_2{}, make_stride(_1{}, _64{}) ,_32{})); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size<1>(out); i++) { + CUTLASS_PRAGMA_UNROLL + for(int j =0; j < size<2>(out); j++) { + CUTLASS_PRAGMA_UNROLL + for(int v = 0; v < size<0>(out); v++) { + out(v, i, j) = static_cast(in.data()[B_selector(v, i, j)]); + // out(v, i, j) = (bfloat16_t)(1.0f); + } + } + } + #endif + #if defined(B_COL) && defined(A_COL) + Layout B_selector = make_layout(make_shape(_16{}, _4{},_2{}), make_stride(_1{}, _32{},_16{})); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size<1>(out); i++) { + CUTLASS_PRAGMA_UNROLL + for(int j =0; j < size<2>(out); j++) { + CUTLASS_PRAGMA_UNROLL + for(int v = 0; v < size<0>(out); v++) { + out(v, i, j) = static_cast(in.data()[B_selector(v, i, j)]); + // out(v, i, j) = (bfloat16_t)(1.0f); + } + } + } + #endif + #if defined(B_COL) && defined(A_ROW) + Layout B_selector = make_layout(make_shape(_16{}, _4{}, _2{}), make_stride(_2{}, _32{},_1{})); + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < size<1>(out); i++) { + CUTLASS_PRAGMA_UNROLL + for(int j =0; j < size<2>(out); j++) { + CUTLASS_PRAGMA_UNROLL + for(int v = 0; v < size<0>(out); v++) { + out(v, i, j) = static_cast(in.data()[B_selector(v, i, j)]); + // out(v, i, j) = (bfloat16_t)(1.0f); + } + } + } + #endif + } + }; + + int main(int argc, const char** argv) + { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = cutlass::float_e4m3_t; // <- data type of elements in input matrix A + using ElementInputB = cutlass::float_e4m3_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + // Note: XE_2D_U8x32x32_LD_V is incompatible with our bf16 MMA atoms + // 2.8tflops U8x32x32NLD_N + // 1.4tflops U8x16x32NLD_N + // 0.7tflops U8x 8x32NLD_N + #if defined(A_COL) && defined(B_ROW) + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using GmemTiledCopyA = XE_2D_U8x16x32_LD_T; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_N; + #endif + #if defined(A_ROW) && defined(B_ROW) + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x32x32_LD_N; + #endif + + #if defined(A_COL) & defined(B_COL) + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using GmemTiledCopyA = XE_2D_U8x16x32_LD_T; + using GmemTiledCopyB = XE_2D_U8x16x32_LD_T; + #endif + + #if defined(A_ROW) && defined(B_COL) + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U8x16x32_LD_T; + #endif + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCLowPrecision; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + #if defined(B_COL) + XE_2D_U32x8x16_ST_N, + #else + void, + #endif + void, void>; + + // Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, TransformA, // A + GmemTiledCopyB, void, void, TransformB // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + + return 0; + } + \ No newline at end of file diff --git a/include/cute/arch/xe_copy_2B.hpp b/include/cute/arch/xe_copy_2B.hpp index c502f19c30..261e213e96 100644 --- a/include/cute/arch/xe_copy_2B.hpp +++ b/include/cute/arch/xe_copy_2B.hpp @@ -778,6 +778,27 @@ struct XE_2D_U16x16x16_LD_T { } }; +struct XE_2D_U8x16x32_LD_T { + using BlockShape = Shape<_32, _16>; + using inst_dtype = uint32_t; + + static constexpr bool is_transpose = true; + + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { + #if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k8( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif + } + }; + struct XE_2D_U16x1x16_ST_N { using BlockShape = Shape<_1, _16>; diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 4dc460df42..e39992c3f8 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -1667,6 +1667,26 @@ struct Copy_Traits_ : XE_2D_LD_Unpack(args...) {} }; +template +struct Copy_Traits_ + : XE_2D_LD_Unpack { + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + // TODO(joe): Not convinced that changing from <_16, _256> should be required here + // but get_logical_layout assumes get<1,0>(layout.shape) is the type size + using SrcLayout = Layout>, + Stride< _0,Stride<_1,_64>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_256,Stride<_1, _8>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits_(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + // template // struct Copy_Traits // : XE_2D_LD_Unpack { @@ -2251,6 +2271,7 @@ COPY_TRAIT_LD_DEF(XE_2D_U16x32x16_LD_V) COPY_TRAIT_LD_DEF(XE_2D_U16x32x32_LD_V) COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_V) COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_T) +COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_T) COPY_TRAIT_LD_DEF(XE_2D_TF32x16x16_LD_N) COPY_TRAIT_LD_DEF(XE_2D_TF32x32x16_LD_N) COPY_TRAIT_LD_DEF(XE_2D_U4x32x64_LD_N) diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index a9ee6d2e13..603bb98a79 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -376,6 +376,22 @@ class CollectiveEpilogue< static_assert(ValuesLoaded == MN, "the total elements loaded by all threads should be the same as MxN" ); auto synchronize = [&] () {}; + +// 32 x 64 +// if(cute::thread0()) { +// print("accumulators: ");print(accumulators);print("\n"); +// } +if constexpr(!is_same_v) { +auto D = make_tensor(make_gmem_ptr(params.ptr_D), make_layout(make_shape(4096, 4096), make_stride(4096, 1))); +for(int i = 0; i < size<1>(accumulators); i++) { + for(int j = 0; j < size<2>(accumulators); j++) { + for(int v = 0; v < size<0>(accumulators); v++) { + D(v + i * 8 + m_sg * 32 + BlockIdxY() * 256 , BlockIdxX() * 256 + n_sg * 64 + (thread_idx % 16) * 2 + (j % 2) + (j / 2) * 32) = accumulators(v, i, j); + // D(v + i * 8 + m_sg * 16 + BlockIdxY() * 128 , BlockIdxX() * 256 + n_sg * 64 + (thread_idx % 16) * 2 + (j % 2) + (j / 2) * 32) = accumulators(v, i, j); + } + } +} +} else{ CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { CUTLASS_PRAGMA_UNROLL @@ -402,7 +418,7 @@ class CollectiveEpilogue< } } - cst_callbacks.end(); + cst_callbacks.end();} } private: diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 6bf9eaac42..33136d66d8 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -71,6 +71,7 @@ #include "cutlass/gemm/collective/xe_mma.hpp" #include "cutlass/gemm/collective/xe_array_mma.hpp" #include "cutlass/gemm/collective/xe_mma_mixed_input.hpp" +#include "cutlass/gemm/collective/xe_mma_fp8.hpp" #endif #if defined(CUTLASS_ENABLE_SYCL) diff --git a/include/cutlass/gemm/collective/xe_mma_fp8.hpp b/include/cutlass/gemm/collective/xe_mma_fp8.hpp new file mode 100644 index 0000000000..7c5249e230 --- /dev/null +++ b/include/cutlass/gemm/collective/xe_mma_fp8.hpp @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMma, TileShape_, ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, + GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, + SmemCopyAtomB_, TransformB_> { + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelPVCLowPrecision; + using WorkgroupTileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(platform::is_same::value, "MainloopIntelPVC requires that A and B have same type."); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static_assert(BLK_M % TiledMma{}.template tile_size_mnk<0>() == 0, "TiledMma permutation size must match block size."); + static_assert(BLK_N % TiledMma{}.template tile_size_mnk<1>() == 0, "TiledMma permutation size must match block size."); + static_assert(BLK_K % TiledMma{}.template tile_size_mnk<2>() == 0, "TiledMma permutation size must match block size."); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + // 32 + static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using CopyThreadShape = Shape<_1, Int>; + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + + using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k) + using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k) + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + }; + + struct Params { + TensorMKL mA; + TensorNKL mB; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + auto [M,N,K,L] = problem_shape; + + auto mA_mkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_A)), + make_layout(make_shape(M, K, L), args.dA)); + + auto mB_nkl = make_tensor(make_gmem_ptr(static_cast(args.ptr_B)), + make_layout(make_shape(N, K, L), args.dB)); + + return Params{mA_mkl, mB_nkl}; + } + + /// Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, + BlkCoord const &blk_coord, int const &K_start, int thread_idx, char *smem_buf, + Params const &mainloop) { + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + (void)residue_mnk; + (void)thread_idx; + (void)smem_buf; + + auto tiled_copy_a = make_tiled_copy(atom_load_A{}.with(mainloop.mA), + Layout{}, + make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); + auto tiled_copy_b = make_tiled_copy(atom_load_B{}.with(mainloop.mB), + Layout{}, + make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + // TODO(Codeplay): see if we can make this nicer + // To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition global counting tensors for MMA + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + Tensor tCrA = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor tCrB = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); + auto mma_A = make_fragment_like(tCrA); + auto mma_B = make_fragment_like(tCrB); + + // Retile registers for copies + Tensor tArA = thr_copy_A.retile_D(tCrA); + Tensor tBrB = thr_copy_B.retile_D(tCrB); + + // Retile global counting tensors for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = tiled_copy_a.template prefetch_selector,Int>, Num_SGs>(mainloop.mA); + auto tiled_prefetch_b = tiled_copy_b.template prefetch_selector,Int>, Num_SGs>(mainloop.mB); + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + TransformA transformA{}; + TransformB transformB{}; + +#if CUTLASS_ENABLE_DEBUG_PRINTS +#define PRINT(x) print(#x ": "); print(x); print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + PRINT(tCgA); + PRINT(tAgA); + + PRINT(tCrA); + PRINT(tArA); + PRINT(mainloop.copy_A); + + print("======================= B: \n"); + PRINT(tCgB); + PRINT(tBgB); + + PRINT(tCrB); + PRINT(tBrB); + PRINT(mainloop.copy_B); + } +#undef PRINT +#endif + + // + // Mainloop + // + const auto k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = 0; + + CUTLASS_PRAGMA_UNROLL + for (; prefetch_k < DispatchPolicy::Stages; prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + CUTLASS_PRAGMA_UNROLL + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + // Copy gmem to rmem for the first k_tile + copy(tiled_copy_a, tAgA(_,_,_,k_tile), tArA); + copy(tiled_copy_b, tBgB(_,_,_,k_tile), tBrB); + + transformA(tCrA, GmemTiledCopyA{}, mma_A); + transformB(tCrB, GmemTiledCopyB{}, mma_B); + if (prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_, _, _, prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_, _, _, prefetch_k)); + } + + cute::gemm(tiled_mma, mma_A, mma_B, accum); + barrier_wait(barrier_scope); + } + } +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index fb6254e423..897980305b 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -990,6 +990,15 @@ struct MainloopIntelPVCMixedPrecision { using Schedule = KernelPVC; using ClusterShape = Shape<_1,_1,_1>; }; + +template +struct MainloopIntelPVCLowPrecision { + constexpr static int Stages = Stages_; + constexpr static int SubgroupSize = 16; + using ArchTag = arch::IntelPVC; + using Schedule = KernelPVC; + using ClusterShape = Shape<_1,_1,_1>; +}; #endif #if defined(CUTLASS_ENABLE_SYCL)