From 07f36e5e73419b65932ef191c01e4c72187570d3 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 3 Jul 2024 00:58:48 -0700 Subject: [PATCH 01/36] apply patch of gemm pipeline --- build.sh | 40 ++ examples/sycl/pvc/pvc_gemm.cpp | 645 ++++++++++++------ include/cute/arch/copy_xe.hpp | 141 +++- include/cute/arch/mma_xe.hpp | 9 +- include/cute/atom/copy_atom.hpp | 4 + include/cute/atom/copy_traits_xe.hpp | 601 +++++++++------- include/cute/atom/mma_traits_xe.hpp | 2 +- .../epilogue/collective/default_epilogue.hpp | 35 + .../intel_pvc_epilogue_tensor_softmax.hpp | 156 +++++ .../epilogue/thread/linear_combination_relu.h | 23 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 245 ++++--- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 302 ++++---- include/cutlass/relatively_equal.h | 2 +- .../util/reference/device/tensor_compare.h | 8 +- .../util/reference/device/tensor_foreach.h | 10 +- 15 files changed, 1525 insertions(+), 698 deletions(-) create mode 100644 build.sh create mode 100644 include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp diff --git a/build.sh b/build.sh new file mode 100644 index 0000000000..c483462146 --- /dev/null +++ b/build.sh @@ -0,0 +1,40 @@ +script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cp ${script_dir}/tools/clang-format/clang-format.hook ${script_dir}/.git/hooks/pre-commit +chmod +x ${script_dir}/.git/hooks/pre-commit + +# https://github.com/intel/llvm/releases/tag/nightly-2024-05-16 +sycl_compiler_path=/opt/cutlass/compiler/0516/ + +# https://ubit-gfx.intel.com/build/19168301/artifacts +gpu_driver_path=/opt/cutlass/gpu_driver/gfx-driver-ci-comp_igc-25012/extract/ + +# AOT compile +output=intel_gpu_pvc +# jit compile +#output=spir64 + +unset epilogue + +# epilogue relu +#epilogue+=" -DEPILOGUE_RELU " + +# epilogue softmax +#epilogue+=" -DEPILOGUE_SOFTMAX " + +export ZE_AFFINITY_MASK=0 +export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/ +export LIBRARY_PATH=$gpu_driver_path/usr/lib/x86_64-linux-gnu/:$sycl_compiler_path/lib/ +export LD_LIBRARY_PATH=$LIBRARY_PATH +export IGC_EnableVISANoSchedule=1 +export IGC_ShaderDumpEnable=1 +export IGC_DumpToCustomDir=./mm_dumps +export IGC_VATemp=1 +export ONEAPI_DEVICE_SELECTOR=level_zero:gpu + +target=./examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute +rm -rf * + +cmake .. -G Ninja -DCMAKE_CUDA_HOST_COMPILER=${sycl_compiler_path}/bin/clang++ \ +-DCUTLASS_ENABLE_SYCL=ON -DDPCPP_SYCL_TARGET=$output -DCMAKE_CXX_COMPILER=${sycl_compiler_path}/bin/clang++ \ +-DCMAKE_CXX_FLAGS=" -DPREFETCH_DEFAULT -DSYCL_INTEL_TARGET ${epilogue} " \ +&& ninja -v $target && $target diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 5141a084cd..9ceaed637b 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,63 +18,78 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #define CUTLASS_SYCLCOMPAT_PROFILING_ENABLED #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" -#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/util/GPU_Clock.hpp" #include #include +#include "cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -template -static void fill_matrix(std::vector &vector) -{ - std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast( (rand() / double(RAND_MAX)) ); - }); +// 0 - None +// 1 - FLUSH by memset +// 2 - FLUSH by input offset with pingpong +#define CACHE_FLUSH 2 + +template static void fill_matrix(std::vector &M) { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist((T)0.0, +#ifdef EPILOGUE_SOFTMAX + (T)0.1); +#else + (T)1.0); +#endif + std::generate(std::begin(M), std::end(M), + [&] { return static_cast(dist(rng)); }); } template -static void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) -{ - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } +static void vnni_matrix(T *dst, const T *src, int batch, int numRows, + int numCols, int factor) { + for (int b = 0; b < batch; b++) { + for (int r = 0; r < numRows / factor; r++) { + for (int c = 0; c < numCols; c++) { + for (int k = 0; k < factor; k++) { + dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = + src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; + } } } + } } using namespace cute; +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = float; // <- data type of epilogue operations +using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A +using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D + /////////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing @@ -86,12 +101,9 @@ struct Options { int m, n, k, l, iterations; float alpha, beta; - Options(): - help(false), - error(false), - m(4096), n(4096), k(4096), l(1), iterations(100), - alpha(1.f), beta(0.f) - { } + Options() + : help(false), error(false), m(4096), n(4096), k(4096), l(1), + iterations(100), alpha(1.f), beta(0.f) {} // Parses the command line void parse(int argc, char const **args) { @@ -112,18 +124,20 @@ struct Options { } /// Prints the usage statement. - std::ostream & print_usage(std::ostream &out) const { + std::ostream &print_usage(std::ostream &out) const { out << "PVC GEMM Example\n\n" - << "Options:\n\n" - << " --help If specified, displays this usage statement\n\n" - << " --m= Sets the M extent of the GEMM\n" - << " --n= Sets the N extent of the GEMM\n" - << " --k= Sets the K extent of the GEMM\n" - << " --l= Sets the L extent (batch count) of the GEMM\n" - << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Iterations\n\n"; + << "Options:\n\n" + << " --help If specified, displays this " + "usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) " + "of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; return out; } @@ -131,10 +145,7 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template < - class Gemm -> -struct ExampleRunner { +template struct ExampleRunner { using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -170,95 +181,204 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_B_vnni; - cutlass::DeviceAllocation block_C; + // cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; + static constexpr auto l3_cache_size = 256 * 1024 * 1024; + +#if CACHE_FLUSH == 2 + size_t PINGPONG_ITER = 3; + size_t pingpong_size_a; + size_t pingpong_size_b; + size_t pingpong_size_d; +#endif + + std::vector a; + std::vector b; + std::vector d; // // Methods // - bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + bool verify(const ProblemShapeType &problem_size, ElementCompute alpha, + ElementCompute beta) { auto [M, N, K, L] = problem_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_C((ElementC *)nullptr /*block_C.get()*/, + LayoutC::packed({M, N})); cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); cutlass::reference::device::GemmComplex( - {M, N, K}, - alpha, - ref_A, - cutlass::ComplexTransform::kNone, - ref_B, - cutlass::ComplexTransform::kNone, - beta, - ref_C, - ref_D, - ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); + {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, + cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + +#ifdef EPILOGUE_SOFTMAX +#define IDX (l * M * N + i * N + j) + + ElementOutput *ptr = + (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr, block_ref_D.get(), + M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + for (int l = 0; l < L; l++) { + for (int i = 0; i < M; i++) { + + auto row_max = ptr[l * M * N + i * N]; + for (int j = 0; j < N; j++) { + row_max = max(row_max, ptr[IDX]); + } + + ElementOutput exp_sum = (ElementOutput)0; + for (int j = 0; j < N; j++) { + ptr[IDX] = ptr[IDX] - row_max; + ptr[IDX] = exp(ptr[IDX]); + exp_sum += ptr[IDX]; + } + + for (int j = 0; j < N; j++) { + ptr[IDX] = ptr[IDX] / exp_sum; + } + } + } + syclcompat::memcpy(block_ref_D.get(), ptr, + M * N * L * sizeof(ElementOutput)); syclcompat::wait(); - // Check if output from CUTLASS kernel and reference kernel are relatively equal or not - // need to set a larger error margin for comparison to succeed - auto epsilon = static_cast(0.1f); - auto nonzero_floor = static_cast(0.1f); + std::free(ptr); + +#undef IDX + +#endif + +#if 0 + ElementOutput *ptr = + (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr, block_D.get(), M * N * L * sizeof(ElementOutput)); + + ElementOutput *ptr_refD = + (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr_refD, block_ref_D.get(), + (size_t)M * N * L * sizeof(ElementOutput)); + syclcompat::wait(); + for (int b = 0; b < L; b++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int idx = b * M * N + i * N + j; + if (abs(ptr[idx] - ptr_refD[idx]) / ptr_refD[idx] >= 0.01f) + std::cout << "(" << b << ", " << i << ", " << j << "): " << "host: " << ptr[idx] + << " and device: " << ptr_refD[idx] << std::endl; + } + } + } + std::free(ptr); + std::free(ptr_refD); +#endif + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are relatively + // equal or not need to set a larger error margin for comparison to succeed bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( - block_ref_D.get(), block_D.get(), block_D.size(), - epsilon, nonzero_floor); + block_ref_D.get(), block_D.get(), M * N * L, 0.5f, 0.5f); return passed; } - /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(const ProblemShapeType& problem_size) { + void init_cache_flush(const ProblemShapeType &problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); +#if CACHE_FLUSH == 1 + auto ref_d_element = max(l3_cache_size / sizeof(ElementOutput), M * N * L); + block_ref_D.reset(ref_d_element); + syclcompat::memset(block_ref_D.get(), 0, + ref_d_element * sizeof(ElementOutput)); + +#elif CACHE_FLUSH == 2 + + pingpong_size_a = max((size_t)M * K * L, l3_cache_size / sizeof(ElementA)); + pingpong_size_b = max((size_t)K * N * L, l3_cache_size / sizeof(ElementB)); + pingpong_size_d = + max((size_t)M * N * L, l3_cache_size / sizeof(ElementOutput)); + auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); + PINGPONG_ITER = + std::min((size_t)3, + std::max((size_t)1, + (size_t)gmem_size / + ((pingpong_size_a * sizeof(ElementA) + + pingpong_size_b * sizeof(ElementB) + + pingpong_size_d * sizeof(ElementOutput))) - + 1)); + block_A.reset(pingpong_size_a * PINGPONG_ITER); + block_B.reset(pingpong_size_b * PINGPONG_ITER); + // block_C.reset(M * N * L * ITER); + block_D.reset(pingpong_size_d * PINGPONG_ITER); + + for (int i = 0; i < PINGPONG_ITER; i++) { + syclcompat::memcpy(block_A.get() + i * pingpong_size_a, a.data(), + a.size() * sizeof(ElementA)); + syclcompat::memcpy(block_B.get() + i * pingpong_size_b, b.data(), + b.size() * sizeof(ElementB)); + syclcompat::memcpy(block_D.get() + i * pingpong_size_d, d.data(), + d.size() * sizeof(ElementC)); + } +#endif + + // syclcompat::wait(); + } - block_A.reset(M * K * L); - block_B.reset(K * N * L); - block_B_vnni.reset(K * N * L); - block_C.reset(M * N * L); - block_D.reset(M * N * L); - block_ref_D.reset(M * N * L); + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType &problem_size) { + auto [M, N, K, L] = problem_size; + + stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L)); + stride_C = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + block_A.reset((size_t)M * K * L); + block_B.reset((size_t)K * N * L); + // block_C.reset(M * N * L); + block_D.reset((size_t)M * N * L); + block_ref_D.reset( + (size_t)max(l3_cache_size / sizeof(ElementOutput), (size_t)M * N * L)); // TODO: Enable initialization on device directly once RNG is // available through SYCL. - std::vector a(K * M * L); - std::vector b(K * N * L); - std::vector b_vnni(b.size()); - std::vector c(M * N * L); - std::vector d(M * N * L, ElementC{0}); - + a = std::vector((size_t)M * K * L); + b = std::vector((size_t)K * N * L); + d = std::vector((size_t)M * N * L, ElementC{0}); + std::cout << "random generating..." << std::endl; fill_matrix(a); fill_matrix(b); - fill_matrix(c); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); - syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); + // syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } - void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { - ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; - + template + void run(int M, int K, int N, int L, + const cutlass::KernelHardwareInfo &hw_info) { + static constexpr auto warmup = 10; + static constexpr auto testIterations = 10; + static constexpr auto total_iterations = warmup + testIterations; + ProblemShapeType problem_size = ProblemShapeType{M, N, K, L}; initialize(problem_size); sycl::property_list prop = { @@ -270,91 +390,175 @@ struct ExampleRunner { syclcompat::set_default_queue(q); typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, - hw_info - }; - - Gemm gemm_op; + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{1, 0.f}, + nullptr /*block_C.get()*/, + stride_C, + block_D.get(), + stride_D}, + hw_info}; + Gemm gemm_op_verify; size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - gemm_op.can_implement(arguments); + gemm_op_verify.can_implement(arguments); - gemm_op.initialize(arguments, workspace.get()); + gemm_op_verify.initialize(arguments, workspace.get()); // Run the GEMM - gemm_op.run(); - + gemm_op_verify.run(); syclcompat::wait(); // Verify that the result is correct - bool passed = verify(problem_size, options.alpha, options.beta); - std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + bool passed = verify(problem_size, 1, 0.f); + if (!passed) { + printf("PVC GEMM%s%s Example %s, MKNL(%d, %d,%d,%d), Config(%d, " + "%d,%d,%d,%d) !!!!!!!!!!!!!\n\n", +#ifdef EPILOGUE_RELU + "-relu" +#else + "" +#endif + , +#ifdef EPILOGUE_SOFTMAX + "-softmax" +#else + "" +#endif + , + (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, + sg_tile_m, sg_tile_n, sg_tile_k); + // return; + } - if (passed && options.iterations > 0) { - GPU_Clock timer; - timer.start(); - for (int i = 0; i < options.iterations; ++i) { + // ================ init cache flush ================ + init_cache_flush(problem_size); + + // ================ run and collect performance data ================ + if (total_iterations > 0) { + auto total_time = 0.f; + auto best = 999.f; + auto worst = 0.f; + + for (int i = 0; i < testIterations + warmup; ++i) { +#if CACHE_FLUSH == 1 + init_cache_flush(problem_size); + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{1, 0.f}, + nullptr /*block_C.get() + i * M * N * L*/, + stride_C, + block_D.get(), + stride_D}, + hw_info}; +#elif CACHE_FLUSH == 2 + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get() + (i % PINGPONG_ITER) * pingpong_size_a, stride_A, + block_B.get() + (i % PINGPONG_ITER) * pingpong_size_b, stride_B}, + {{1, 0.f}, + nullptr /*block_C.get() + i * M * N * L*/, + stride_C, + block_D.get() + (i % PINGPONG_ITER) * pingpong_size_d, + stride_D}, + hw_info}; +#endif + + Gemm gemm_op; + gemm_op.can_implement(arguments); + gemm_op.initialize(arguments, workspace.get()); + + GPU_Clock timer; + timer.start(); gemm_op.run(); + syclcompat::wait(); + + auto current_time = timer.seconds(); + if (i >= warmup) { + total_time += current_time; + + best = min(best, current_time); + + worst = max(worst, current_time); + } } - syclcompat::wait(); - float cute_time = timer.seconds() / options.iterations; - double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; - std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; - printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + float average = total_time / testIterations; + double tflops = (2.0 * M * N * K * L) * 1e-12; + double gflops = (2.0 * M * N * K * L) * 1e-9; + + double hbm = + L * + (M * K * sizeof(ElementInputA) + K * N * sizeof(ElementInputB) + + M * N * sizeof(ElementOutput)) * + 1e-9; + + printf("Collective pvc gemm%s, MKNL(%d, %d, %d, %d), Config(%d, %d, " + "%d, %d, %d):\n max: (%6.4f)ms, (%4.2f)TFlop/s, " + "(%4.2f)GB/s\n min: (%6.4f)ms, (%4.2f)TFlop/s, " + "(%4.2f)GB/s\n average: (%6.4f)ms, (%4.2f)TFlop/s, " + "(%4.2f)GB/s\n\n\n", +#if defined(EPILOGUE_RELU) + "-relu" +#elif defined(EPILOGUE_SOFTMAX) + "softmax" +#else + "" +#endif + , + M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, + best * 1000, tflops / best, hbm / best, worst * 1000, + tflops / worst, hbm / worst, average * 1000, tflops / average, + hbm / average); } - - return; } - }; -int main(int argc, const char** argv) -{ +template +void collective_gemm(int M, int K, int N, int L = 1) { // // Parse options // Options options; - options.parse(argc, argv); + // options.parse(argc, argv); if (options.help) { options.print_usage(std::cout) << std::endl; - return 0; + return; } if (options.error) { std::cerr << "Aborting execution." << std::endl; - return -1; + return; } // // Run examples // - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This - // information is used by the underlying kernel. + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a + // given device ID. This information is used by the underlying kernel. cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with multiple GPUs and wish - // to use a GPU other than that with device ID 0. - hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + // Change device_id to another value if you are running on a machine with + // multiple GPUs and wish to use a GPU other than that with device ID 0. + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); bool passed; - // The code section below describes datatype for input, output matrices and computation between - // elements in input matrices. - using ElementAccumulator = float; // <- data type of accumulator - using ElementComputeEpilogue = float; // <- data type of epilogue operations - using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A - using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B - using ElementOutput = float; // <- data type of elements in output matrix D + // The code section below describes datatype for input, output matrices and + // computation between elements in input matrices. using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; @@ -362,60 +566,115 @@ int main(int argc, const char** argv) using LayoutD = cutlass::layout::RowMajor; using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; - - // Workgroup-level tile - using TileShape = Shape<_32, _256, _32>; - - using TiledMma = TiledMMA, - Layout>, - Tile<_32,_64,_32>>; // Subgroup level-tile - - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; - - using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16x1x1_LD_N, - void, void, - XE_2D_U32x8x16x1x1_ST_N, - void, void>; - -// Mainloop + using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; + + using TileShape = Shape, Int, Int, + Int, Int>; + + using TiledMma = TiledMMA, + Layout>>; + + using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; +#ifdef EPILOGUE_RELU + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of + // elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear + +#else + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of + // elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear + // combination function +#endif + // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - GEMMDispatchPolicy, - TileShape, - ElementInputA, - cutlass::gemm::TagToStrideA_t, - ElementInputB, - cutlass::gemm::TagToStrideB_t, - TiledMma, - GmemTiledCopyA, void, void, cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; + DispatchPolicy, TileShape, ElementInputA, + cutlass::gemm::TagToStrideA_t, ElementInputB, + cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, + void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + +#ifdef EPILOGUE_SOFTMAX + using CollectiveEpilogue = + cutlass::epilogue::collective::PvcEpilogueTensorSoftmax< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, EpilogueOp, + cutlass::gemm::EpilogueDefault, CollectiveMainloop::sg_tile_m, + CollectiveMainloop::sg_tile_n / CollectiveMainloop::SubgroupSize>; +#else + using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, EpilogueOp, + cutlass::gemm::EpilogueDefault>; +#endif using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; + Shape, CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; ExampleRunner runner; - runner.run(options, hw_info); + runner.template run( + M, K, N, L, hw_info); +} - return 0; +int main() { + auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); +#if !defined(EPILOGUE_RELU) && !defined(EPILOGUE_SOFTMAX) + collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); + collective_gemm<256, 256, 32, 64, 32>(8192, 8192, 8192); + collective_gemm<256, 256, 32, 64, 32>(1, 5120, 13824); + collective_gemm<256, 256, 32, 64, 32>(1024, 28672, 8192); + collective_gemm<256, 256, 32, 64, 32>(3072, 4096, 3072); + collective_gemm<256, 256, 32, 64, 32>(4, 4096, 12288); + + // collective shape from habana + collective_gemm<256, 256, 32, 64, 32>(512, 8192, 8192); + collective_gemm<256, 256, 32, 64, 32>(512, 8192, 32768); + collective_gemm<256, 256, 32, 64, 32>(512, 32768, 8192); + collective_gemm<256, 256, 32, 64, 32>(16384, 8192, 1024); + collective_gemm<256, 256, 32, 64, 32>(16384, 1024, 8192); + collective_gemm<256, 256, 32, 64, 32>(16384, 8192, 4096); + collective_gemm<256, 256, 32, 64, 32>(16384, 4096, 8192); + collective_gemm<256, 256, 32, 64, 32>(4096, 16384, 8192); + collective_gemm<256, 256, 32, 64, 32>(8192, 16384, 4096); + collective_gemm<256, 256, 32, 64, 32>(1024, 16384, 8192); + collective_gemm<256, 256, 32, 64, 32>(8192, 16384, 1024); + + collective_gemm<256, 256, 32, 64, 32>(8, 128, 16384, 4096); + collective_gemm<16, 512, 16, 16, 32>(8, 16384, 128, 4096); + + collective_gemm<256, 256, 32, 64, 32>(32768, 128, 4096, 4); + collective_gemm<256, 256, 32, 64, 32>(32768, 4096, 128, 4); + collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 128, 32); +#endif + +#if defined(EPILOGUE_SOFTMAX) + // gemm + softmax + collective_gemm<64, 1024, 16, 64, 32>(1024, 64, 1024, 4); + collective_gemm<128, 512, 16, 64, 32>(512, 64, 512, 32); + collective_gemm<64, 1024, 16, 64, 32>(1024, 64, 1024, 16); + collective_gemm<32, 2048, 16, 64, 16>(2048, 64, 2048, 8); + collective_gemm<16, 4096, 16, 64, 32>(4096, 64, 4096, 4); + collective_gemm<8, 8192, 8, 128, 16>(8192, 64, 8192, 2); +#endif + +#if defined(EPILOGUE_RELU) + // gemm + softmax + collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); +#endif } diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 3bfc5c8535..db0f3a6334 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -44,6 +44,18 @@ namespace cute inline x { assert(false); } #endif +enum LSC_LDCC { + LSC_LDCC_DEFAULT = 0, + LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + LSC_LDCC_L1IAR_L3C + = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord, intel::uint8 data)); @@ -80,7 +92,24 @@ SYCL_DEVICE_BUILTIN(intel::int16 __builtin_IB_subgroup_block_read_flat_transform SYCL_DEVICE_BUILTIN(intel::int8 intel_subgroup_block_read_transform_u16_k16( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord)); - +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); #undef SYCL_DEVICE_BUILTIN @@ -98,6 +127,22 @@ struct XE_2D_U16x8x16x1x1_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + LSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; struct XE_2D_U32x8x16x1x1_LD_N @@ -130,6 +175,22 @@ struct XE_2D_U16x16x16x1x1_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + LSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; struct XE_2D_U16x8x16x4x2_LD_N @@ -146,6 +207,23 @@ struct XE_2D_U16x8x16x4x2_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + // __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + LSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; struct XE_2D_U16x8x16x2x2_LD_N @@ -162,6 +240,22 @@ struct XE_2D_U16x8x16x2x2_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + LSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; struct XE_2D_U16x8x16x1x2_LD_N @@ -179,6 +273,22 @@ struct XE_2D_U16x8x16x1x2_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + LSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; struct XE_2D_U16x8x16x4x1_LD_N @@ -195,6 +305,22 @@ struct XE_2D_U16x8x16x4x1_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + LSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; struct XE_2D_U32x8x16x2x1_LD_N @@ -229,6 +355,8 @@ struct XE_2D_U16x16x16x2x1_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; }; struct XE_2D_U16x16x16x2x2_V @@ -242,6 +370,9 @@ struct XE_2D_U16x16x16x2x2_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + // using PREFETCH = typename XE_2D_U16x8x16x4x2_LD_N::PREFETCH; + using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; struct XE_2D_U16x16x16x1x2_V @@ -255,6 +386,8 @@ struct XE_2D_U16x16x16x1x2_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; struct XE_2D_U16x16x16x2x1_V @@ -268,6 +401,8 @@ struct XE_2D_U16x16x16x2x1_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; }; struct XE_2D_U16x16x16x1x1_V @@ -282,6 +417,8 @@ struct XE_2D_U16x16x16x1x1_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + using PREFETCH = typename XE_2D_U16x16x16x1x1_LD_N::PREFETCH; }; struct XE_2D_U32x8x16x1x1_ST_N @@ -300,4 +437,4 @@ struct XE_2D_U32x8x16x1x1_ST_N } }; -} // end namespace cute +} // end namespace cute \ No newline at end of file diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 3d1bfb8f68..878c587fdc 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -45,7 +45,11 @@ SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::i #undef SYCL_DEVICE_OCL namespace cute { -struct XE_8x16x16_F32BF16BF16F32_TN +//MxNxK_A,B,C,D +//# of vector component of a x subgroup-size x function name +//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); +//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 +struct XE_8x16x16_BF16BF16F32F32_NN { using DRegisters = intel::float8[1]; using ARegisters = intel::short8[1]; @@ -65,7 +69,8 @@ struct XE_8x16x16_F32BF16BF16F32_TN #endif } }; -struct XE_1x16x16_F32BF16BF16F32_TN +//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +struct XE_1x16x16_BF16BF16F32F32_NN { using DRegisters = float[1]; using ARegisters = short[1]; diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index e20bace705..a619b725d8 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -111,6 +111,10 @@ struct Copy_Atom, CopyInternalType> // recurse this rank-1 layout by peeling off the mode // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); + } else if constexpr (is_tuple::engine_type::iterator:: + value_type>::value) { + return copy_unpack(*this, src, dst); } else { static_assert(dependent_false, "No instruction match and no recursion possible."); } diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 63eae73dd4..7562dbc144 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -35,318 +35,423 @@ #include -namespace cute -{ +namespace cute { +template +struct XE_2D_LD_Unpack { + GTensor tensor; -template -CUTE_HOST_DEVICE constexpr -auto get_shape_WHD(cute::Stride, IntT, IntT> , cute::Shape shape_MKL) { - return shape_MKL; -} + using Copy_Traits = Copy_Traits; + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + int W = size<1>(traits.tensor) + * sizeof(typename Copy_Traits::CopyInternalType); + auto [y, x, z] = src.data().coord_; + CopyOp::copy(traits.tensor.data() + z * W * H / sizeof(typename Copy_Traits::CopyInternalType), W, H, W, intel::coord_t {x, y}, + &*dst.data()); + } -template -CUTE_HOST_DEVICE constexpr -auto get_shape_WHD(cute::Stride, IntT> , cute::Shape shape_MKL) { - return Shape(get<1>(shape_MKL), get<0>(shape_MKL), get<2>(shape_MKL)); -} - -template -CUTE_HOST_DEVICE constexpr -auto get_coordinates(cute::Stride, IntT, IntT> , - Tensor>, SLayout> const &src) { - auto [x, y, z] = src.data().coord_; - return make_coord(x, y, z); -} + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + GShape const &shape, GStride const &stride_mul) const { + return make_tensor(make_inttuple_iter(coord), + make_layout(make_shape(_1 {}, get<0>(shape), get<1>(shape), + get<2>(shape)), + make_stride(_1 {}, E<0> {} * get<0>(stride_mul), + E<1> {} * get<1>(stride_mul), + E<2> {} * get<2>(stride(tensor))))); + } +}; -template -CUTE_HOST_DEVICE constexpr -auto get_coordinates(cute::Stride, IntT> , - Tensor>, SLayout> const &src) { - auto [x, y, z] = src.data().coord_; - return make_coord(y, x, z); -} +template +struct Copy_Traits + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; template -struct XE_2D_LD_Unpack -{ - GTensor tensor; - - using Copy_Traits = Copy_Traits; - - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const &traits, - Tensor>, SLayout> const &src, - Tensor &dst) - { - static_assert(is_rmem::value); - auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); - int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType); - int H = size<1>(shape_whd); - auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data()); - } - - template - CUTE_HOST_DEVICE constexpr auto - get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const - { - return make_tensor(make_inttuple_iter(coord), - make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)), - make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor))))); - } +struct XE_2D_PF_Unpack { + GTensor tensor; + + template + CUTE_HOST_DEVICE XE_2D_PF_Unpack(G_Tensor const &t) : tensor(t) {} + + using Copy_Traits = Copy_Traits; + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + using T = typename Copy_Traits::CopyInternalType; + int H = size<0>(traits.tensor); + int W = size<1>(traits.tensor) * sizeof(T); + auto [y, x, z] = src.data().coord_; + CopyOp::template copy(traits.tensor.data() + z * W * H / sizeof(T), W, H, W, + intel::coord_t {static_cast(x), static_cast(y)}); + } }; template -struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +struct Copy_Traits + : XE_2D_PF_Unpack { + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_256, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = uint; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = uint; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>>; // one coordinate - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>, - Stride<_16, Stride<_512, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_512, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - // 32 bits register file - using CopyInternalType = uint; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + // 32 bits register file + using CopyInternalType = uint; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>>; // expected 4 coordinates - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - // 32 bits register file - using CopyInternalType = uint; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + // 32 bits register file + using CopyInternalType = uint; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape< Shape< _8, _2>, Shape<_16, _2, _2>>>, - Stride<_16, Stride, Stride< _1, _512, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout< + Shape<_16, Shape, Shape<_16, _2, _2>>>, + Stride<_16, Stride, Stride<_1, _512, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape< _8, Shape<_16, _2, _2>>>, - Stride<_16, Stride<_1024, Stride< _1, _512, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_1024, Stride<_1, _512, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape<_8, Shape<_16, _2>>>, - Stride<_16, Stride<_512, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_512, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template -struct XE_2D_ST_Unpack -{ - GTensor tensor; - using Copy_Traits = Copy_Traits; - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const &traits, - Tensor const &src, - Tensor>, DLayout> &dst) - { - static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x, z] = dst.data().coord_; - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data()); - } - - template - CUTE_HOST_DEVICE constexpr auto - get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const - { - return make_tensor(make_inttuple_iter(coord), - make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)), - make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor))))); - } +struct XE_2D_ST_Unpack { + GTensor tensor; + using Copy_Traits = Copy_Traits; + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, Tensor const &src, + Tensor>, DLayout> &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + int W = size<1>(traits.tensor) + * sizeof(typename Copy_Traits::CopyInternalType); + auto [y, x, z] = dst.data().coord_; + + CopyOp::copy(traits.tensor.data() + z * W * H / sizeof(typename Copy_Traits::CopyInternalType), W, H, W, intel::coord_t {x, y}, + &*src.data()); + } + + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + GShape const &shape, GStride const &stride_mul) const { + return make_tensor(make_inttuple_iter(coord), + make_layout(make_shape(_1 {}, get<0>(shape), get<1>(shape), + get<2>(shape)), + make_stride(_1 {}, E<0> {} * get<0>(stride_mul), + E<1> {} * get<1>(stride_mul), + E<2> {} * get<2>(stride(tensor))))); + } }; template struct Copy_Traits - : XE_2D_ST_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - using CopyInternalType = uint; + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + using CopyInternalType = uint; }; template -auto make_xe_2d_copy(Tensor gtensor) -{ +auto make_xe_2d_copy(Tensor gtensor) { using GTensor = Tensor; using Traits = Copy_Traits; - Traits traits{gtensor}; - return Copy_Atom{traits}; + Traits traits {gtensor}; + return Copy_Atom {traits}; } } // end namespace cute diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index 1cbefc872d..a5ef6dbec2 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -38,7 +38,7 @@ namespace cute { template <> -struct MMA_Traits +struct MMA_Traits { using ValTypeD = float; using ValTypeA = bfloat16_t; diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index bbeeacacd3..71ba713ba3 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -147,6 +147,41 @@ class DefaultEpilogue { return epilogue_op.is_source_needed(); } +#ifdef EPILOGUE_RELU + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout> + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor & accumulators){ + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + if (epilogue_op.is_source_needed()) { + auto source = make_fragment_like(accumulators); + auto gmem_tiled_copy_c = + make_xe_2d_copy(make_tensor( + params.ptr_C, make_shape(M, N, L), params.dC)); + + Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), + make_shape(size<1>(accumulators), size<2>(accumulators), L), + make_stride(size<0>(blk_shape_MNK), size<1>(blk_shape_MNK))); + copy(gmem_tiled_copy_c, tCi(_, _, _, l_coord), source); + epilogue_op(accumulators, source); + } else { + epilogue_op(accumulators); + } + } +#endif + template< class ProblemShapeMNKL, class BlockShapeMNK, diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp new file mode 100644 index 0000000000..00700c5372 --- /dev/null +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class PvcEpilogueTensorSoftmax { +public: + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const *ptr_C = nullptr; + StrideC dC{}; + ElementD *ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments([[maybe_unused]] ProblemShape const &_, + Arguments const &args, + [[maybe_unused]] void *workspace) { + return args; + } + + template CUTLASS_DEVICE void operator()(T &t) { + static_assert(cute::is_same_v && m <= 32); + + auto const &group = + sycl::ext::oneapi::experimental::this_nd_item<3>().get_group(); + + static constexpr auto vec_size = 4; + + static_assert((m % vec_size) == 0 && vec_size <= 16); + static constexpr auto loop_cnt = m / vec_size; + + sycl::vec local_max; + sycl::vec local_plus; + + for (int loop = 0; loop < loop_cnt; loop++) { + + auto base_row = loop * vec_size; + // init local max + for (int i = 0; i < vec_size; i++) { + local_max[i] = t[(base_row + i) * n]; + } + + for (int i = 0; i < vec_size; i++) { + for (int j = 0; j < n; j++) { + local_max[i] = max(local_max[i], t((base_row + i) * n + j)); + } + } + + // get group max + auto group_max = reduce_over_group(group, local_max, sycl::maximum<>()); + + // -max, exp, and get local plus + for (int i = 0; i < vec_size; i++) { + for (int j = 0; j < n; j++) { + auto offset = (base_row + i) * n + j; + t[offset] -= group_max[i]; + t[offset] = sycl::exp(t[offset]); + + local_plus[i] += t[offset]; + } + } + + // get group plus + auto group_plus = reduce_over_group(group, local_plus, sycl::plus<>()); + + // last div + for (int i = 0; i < vec_size; i++) { + for (int j = 0; j < n; j++) { + auto offset = (base_row + i) * n + j; + t[offset] = t[offset] / group_plus[i]; + // local_sum += t[i * n + j]; + } + } + } + + // printf("verify softmax, local_sum: %f, group_sum: %f\n", local_sum, + // reduce_over_group(group, local_sum, sycl::plus<>())); + // } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 07ebdec93d..7ecba15006 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -183,7 +183,28 @@ class LinearCombinationRelu { threshold_ = reinterpret_cast(allones); } } - + +#ifdef EPILOGUE_RELU + using ElementC = ElementOutput_; + using ElementD = ElementOutput_; + template + CUTLASS_HOST_DEVICE + void operator()(cute::Tensor &accumulators) const { + for (int i = 0; i < size(accumulators); i++) { + accumulators(i) = accumulators(i) < 0 ? 0 : accumulators(i); + } + } + + template + CUTLASS_HOST_DEVICE + void operator()(cute::Tensor &accumulators, + cute::Tensor const &source) const { + for (int i = 0; i < size(accumulators); i++) { + accumulators(i) = accumulators(i) < 0 ? source(i) : accumulators(i) + source(i); + } + } +#endif + /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index f69ae7bdf0..2cd540d58d 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -34,54 +35,35 @@ #include "cutlass/gemm/dispatch_policy.hpp" #include "cute/algorithm/functional.hpp" -#include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/tensor_predicate.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// - + namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopIntelPVCUnpredicated, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ +#define get_sub_group_id() \ + (sycl::ext::oneapi::experimental::this_nd_item<3>() \ + .get_sub_group() \ + .get_group_id()[0]) + +template +struct CollectiveMma { // // Type Aliases // using DispatchPolicy = MainloopIntelPVCUnpredicated; - using WorkgroupTileShape = TileShape_; + using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; using ElementB = ElementB_; @@ -98,38 +80,55 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; + TileShape tile_shape; + static constexpr auto wg_tile_m = decltype(get<0>(tile_shape))::value; + static constexpr auto wg_tile_n = decltype(get<1>(tile_shape))::value; + static constexpr auto sg_tile_m = decltype(get<2>(tile_shape))::value; + static constexpr auto sg_tile_n = decltype(get<3>(tile_shape))::value; + static constexpr auto sg_tile_k = decltype(get<4>(tile_shape))::value; + static constexpr auto sg_per_wg_m = wg_tile_m / sg_tile_m; + static constexpr auto sg_per_wg_n = wg_tile_n / sg_tile_n; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - using MmaAtomShape = typename TiledMma::AtomShape_MNK; - using SubgroupTileShape = decltype(tile_shape(TiledMma())); - - static constexpr uint32_t MaxThreadsPerBlock = - cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; - - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group - static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape()); - - // Calculate the vector width based on the amount of registers - // required per work item by dividing the total fragment size by + static constexpr int DpasM = get<0>( + shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per + // sub_group for Matrix A + static constexpr int DpasN = get<1>( + shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per + // sub_group for Matrix B + static constexpr int DpasK = get<1>( + shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per + // sub_group for Matrix A + + static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + static constexpr int FragsM = sg_tile_m / DpasM; // A frags per sub_group + static constexpr int FragsN = sg_tile_n / DpasN; // B frags per sub_group + static constexpr int FragsK = sg_tile_k / DpasK; + + // Calculate the vector width based on the amount of registers + // required per work item by dividing the total fragment size by // the sub_group size. - static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize; - static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; - static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecC = (DpasN * DpasM) / SubgroupSize; + static constexpr int VecA = (DpasM * DpasK) / SubgroupSize; + static constexpr int VecB = (DpasN * DpasK) / SubgroupSize; // Host side kernel arguments struct Arguments { - ElementA const* ptr_A; + ElementA const *ptr_A; StrideA dA; - ElementB const* ptr_B; + ElementB const *ptr_B; StrideB dB; }; struct Params { - using XE_Copy_A = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}))); - using XE_Copy_B = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}))); + using XE_Copy_A = decltype(make_xe_2d_copy( + make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), StrideA{}))); + using XE_Copy_B = decltype(make_xe_2d_copy( + make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), StrideB{}))); XE_Copy_A gmem_tiled_copy_a; XE_Copy_B gmem_tiled_copy_b; }; @@ -142,14 +141,17 @@ struct CollectiveMma< template static constexpr Params - to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - (void) workspace; + to_underlying_arguments(ProblemShape const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M,N,K,L] = problem_shape_MNKL; + auto [M, N, K, L] = problem_shape_MNKL; - Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); + Tensor tensorA = + make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensorB = + make_tensor(args.ptr_B, make_layout(make_shape(K, N, L), args.dB)); typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); @@ -157,59 +159,84 @@ struct CollectiveMma< } /// Perform a subgroup-scoped matrix multiply-accumulate - template < - class FrgTensorD, - class TensorA, - class TensorB, - class FrgTensorC, - class KTileIterator, - class ResidueMNK - > - CUTLASS_DEVICE void - operator() ( - FrgTensorD &accum, - TensorA gA, - TensorB gB, - FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int k_tile_count, - ResidueMNK residue_mnk, - int thread_idx, - char *smem_buf, - Params const& mainloop) - { + template + CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, int thread_idx, + char *smem_buf, Params const &mainloop) { (void)residue_mnk; (void)thread_idx; (void)smem_buf; - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_tuple::value, "A tensor must be a tuple iterator."); - static_assert(is_tuple::value, "B tensor must be a tuple iterator."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(is_rmem::value, + "D tensor must be rmem resident."); + static_assert( + is_tuple::value, + "A tensor must be a tuple iterator."); + static_assert( + is_tuple::value, + "B tensor must be a tuple iterator."); + static_assert(is_rmem::value, + "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, _1>{}); - Tensor tBr = make_tensor(Shape(SubgroupTileShape{}) / 2>, Int>{}); + Tensor tAr = make_tensor( + Shape, Int<1>>{}); + + constexpr int version = + is_same_v ? 1 : 2; + Tensor tBr = make_tensor( + Shape, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), - Shape, Int, Int>{}); + Shape, Int, Int>{}); Tensor tBr_view = make_tensor(static_cast(tBr).data(), - Shape, Int, Int>{}, - Stride<_1, Int(SubgroupTileShape{}) / 2>, Int>{}); + Shape, Int, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; + int K = size<1>(mainloop.gmem_tiled_copy_a.tensor); + Tensor tAi = make_tensor( + make_inttuple_iter( + *gA.data() + + make_coord((get_sub_group_id() % sg_per_wg_n % 4) * DpasM, 0)), + make_layout(make_shape(_1{}, _1{}, K), + make_stride(_1{}, E<0>{}, E<1>{}))); + Tensor tBi = make_tensor( + make_inttuple_iter( + *gB.data() + + make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * DpasK, + (get_sub_group_id() / sg_per_wg_n % 2 * 2) * DpasN)), + make_layout(make_shape(_1{}, K, _1{}), + make_stride(_1{}, E<0>{}, E<1>{}))); // // Mainloop // - for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(MmaAtomShape()) * FragsK) - { - // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); - copy(mainloop.gmem_tiled_copy_b, gB(_,_,k/2), tBr); - - cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum); - } + int prefetch_k = 0; + for (int i = 0; i < 3; i++) { + prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); + prefetch(mainloop.gmem_tiled_copy_b, tBi(_, prefetch_k, _)); + prefetch_k += sg_tile_k; + } + + for (int k_tile = 0, k = 0; k_tile < k_tile_count; + ++k_tile, k += DpasK * FragsK) { + // Copy gmem to rmem for the first k_tile + copy(mainloop.gmem_tiled_copy_a, gA(_, _, k), tAr); + copy(mainloop.gmem_tiled_copy_b, gB(_, k, _), tBr); + + prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); + prefetch(mainloop.gmem_tiled_copy_b, tBi(_, prefetch_k, _)); + prefetch_k += sg_tile_k; + + for (int kl = 0; kl < FragsK; kl++) { + cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), + src_accum); + } + } } }; diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 6e7aee895b..5db3fe49a4 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,22 +18,23 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" #include "cute/tensor.hpp" @@ -41,19 +42,13 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template < - class ProblemShape_, - class CollectiveMainloop_, - class CollectiveEpilogue_, - class TileScheduler_ -> +template class GemmUniversal< - ProblemShape_, - CollectiveMainloop_, - CollectiveEpilogue_, - TileScheduler_, - cute::enable_if_t>> -{ + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + cute::enable_if_t>> { public: // // Type Aliases @@ -61,50 +56,60 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); + "ProblemShape{} should be or "); // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::WorkgroupTileShape; - using WorkgroupTileShape = TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; - static_assert(cute::is_void_v or cute::is_same_v, - "Intel PVC does not support specializing the tile scheduler."); + static_assert(cute::is_void_v or + cute::is_same_v, + "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, WorkgroupTileShape, - cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + TileScheduler_, ArchTag, TileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; + using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; + using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); + static_assert( + cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; - static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size - static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + static constexpr uint32_t MinBlocksPerMultiprocessor = + CollectiveMainloop::MinBlocksPerMultiprocessor; - using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; - using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + static constexpr int num_sg = + MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group + + static constexpr int DpasM = CollectiveMainloop::DpasM; + static constexpr int DpasN = CollectiveMainloop::DpasN; + static constexpr int DpasK = CollectiveMainloop::DpasK; static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -139,68 +144,70 @@ class GemmUniversal< // Methods // - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static - Params - to_underlying_arguments(Arguments const& args, void* workspace) { - (void) workspace; - return { - args.mode, - args.problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) - }; + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const &args, + void *workspace) { + (void)workspace; + return {args.mode, args.problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace)}; } - static bool - can_implement(Arguments const& args) { - bool mode_implementable = args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + static bool can_implement(Arguments const &args) { + bool mode_implementable = + args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); return mode_implementable && TileScheduler::can_implement(args.scheduler); } - static int - get_workspace_size(Arguments const& args) { - return 0; - } + static int get_workspace_size(Arguments const &args) { return 0; } - static - cutlass::Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { + static cutlass::Status + initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { return Status::kSuccess; } - static dim3 - get_grid_shape(Params const& params) { - int batch_count = 1; - if constexpr (cute::rank(ProblemShape{}) == 4) { - batch_count = cute::size<3>(params.problem_shape); - } - - return dim3( - cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), - cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), - batch_count - ); + static dim3 get_grid_shape(Params const ¶ms) { + auto M = get<0>(params.problem_shape); + auto N = get<1>(params.problem_shape); + auto L = get<3>(params.problem_shape); + + const int sg_m = + cute::ceil_div(M, + CollectiveMainloop::wg_tile_m); // sub_groups required to + // process A fragments + const int sg_n = + cute::ceil_div(N, + CollectiveMainloop::wg_tile_n); // sub_groups required to + // process B fragments + + return dim3(sg_n, sg_m, L); } - static dim3 - get_block_shape() { - return dim3(MaxThreadsPerBlock, 1, 1); + static dim3 get_block_shape() { + return dim3(cute::ceil_div(CollectiveMainloop::wg_tile_n, + CollectiveMainloop::sg_tile_n / SubgroupSize), + cute::ceil_div(CollectiveMainloop::wg_tile_m, + CollectiveMainloop::sg_tile_m), + 1); } CUTLASS_DEVICE - void - operator()(Params const& params, char* smem_buf) { + void operator()(Params const ¶ms, char *smem_buf) { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions - CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only + // rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -208,70 +215,101 @@ class GemmUniversal< auto L = get<3>(problem_shape_MNKL); // Preconditions - static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - // Get the appropriate blocks for this sub_group -- potential for sub_group locality + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this sub_group -- potential for sub_group + // locality int thread_idx = int(ThreadIdxX()); - constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) - constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) - const int m_coord = BlockIdxX() * get<0>(subgroup_shape); - const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); + int thread_idy = int(ThreadIdxY()); + + static constexpr auto sg_per_wg_n = + CollectiveMainloop::wg_tile_n / CollectiveMainloop::sg_tile_n; + + auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) + const int m_coord = + BlockIdxY() * CollectiveMainloop::wg_tile_m + + (get_sub_group_id() / sg_per_wg_n) * CollectiveMainloop::sg_tile_m; + const int n_coord = + BlockIdxX() * CollectiveMainloop::wg_tile_n + + (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; const int l_coord = BlockIdxZ(); const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, 0), - make_shape(_1{}, K, L), - make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); + make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}), + make_stride(Int{}, _1{})); + constexpr int version = + is_same_v + ? 1 + : 2; Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(n_coord, 0, 0), - make_shape(Int{}, K / 2, L), - make_stride(get<1>(MmaAtomShape()), _1{})); + make_coord(0, n_coord, l_coord), + make_shape(K, Int{}, _1{}), + make_stride(_1{}, Int{})); // Compute tile residues for predication - auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord - auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord - auto k_residue = K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + auto m_max_coord = + M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord + auto n_max_coord = + N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord + auto k_residue = + K - get<2>(subgroup_shape) * + (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape TiledMma tiled_mma; - Tensor accumulators = make_tensor(Shape, Int, Int>{}); + Tensor accumulators = + make_tensor(Shape, Int, Int>{}); clear(accumulators); - auto k_tile_iter = cute::make_coord_iterator(make_shape(K / get<2>(subgroup_shape))); - int k_tile_count = K / get<2>(subgroup_shape); + int k_tile_count = cute::ceil_div(K, CollectiveMainloop::sg_tile_k); + auto k_tile_iter = cute::make_coord_iterator(make_shape(k_tile_count)); // Perform the collective scoped MMA CollectiveMainloop collective_mma; - collective_mma( - accumulators, - tAi(_,_,_,l_coord), - tBi(_,_,_,l_coord), - accumulators, - k_tile_iter, k_tile_count, - residue_mnk, - thread_idx, - smem_buf, - params.mainloop - ); - - CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; - epilogue( - problem_shape_MNKL, - subgroup_shape, - tile_coord, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - smem_buf - ); + collective_mma(accumulators, tAi(_, _, _, 0), tBi(_, _, _, 0), accumulators, + k_tile_iter, k_tile_count, residue_mnk, thread_idx, smem_buf, + params.mainloop); + +#ifdef EPILOGUE_RELU + // relu + CollectiveEpilogue collective_relu(params.epilogue); + collective_relu(problem_shape_MNKL, + make_shape(Int{}, Int{}, Int{}), + make_coord(m_coord, n_coord, 0, l_coord), accumulators); +#endif + +#ifdef EPILOGUE_SOFTMAX + // softmax + CollectiveEpilogue collective_softmax; + collective_softmax(accumulators); +#endif + + auto gmem_tiled_copy_c = + make_xe_2d_copy(make_tensor( + params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); + + Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), + make_shape(Int{}, Int{}, _1{}), + make_stride(Int{}, Int{})); + + copy(gmem_tiled_copy_c, accumulators, tCi(_, _, _, 0)); } }; diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index fd900b6605..a3eee0405c 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -71,7 +71,7 @@ bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { return true; } else if (a == zero || b == zero || diff < nonzero_floor) { - return diff < epsilon * nonzero_floor; + return diff < (epsilon * nonzero_floor) || (diff / abs_B) < (T)0.001f; } return diff < epsilon * (abs_A + abs_B); diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 3c312f5ff8..de96d53122 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -95,16 +95,18 @@ __global__ void size_t idx = ThreadIdxX() + BlockDimX() * BlockIdxX(); - for (; idx < capacity; idx += GridDimX() * BlockDimX()) { - + //for (; idx < capacity; idx += GridDimX() * BlockDimX()) { + if (idx < capacity ){ Element a = cutlass::ReferenceFactory::get(ptr_A, idx); Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { *equal = 0; + //printf("error, idx at: %lu, capacity: %lu, a: %f, b: %f\n", idx, capacity, a, b); return; } } + // } } } // namespace kernel @@ -239,7 +241,7 @@ bool BlockCompareRelativelyEqual( #if defined (CUTLASS_ENABLE_SYCL) block_size = 128; grid_size = (capacity + block_size - 1) / block_size; - grid_size = (grid_size < 64 ? grid_size : 64); // limit grid size to avoid out_of_resources runtime error. + //grid_size = (grid_size < 64 ? grid_size : 64); // limit grid size to avoid out_of_resources runtime error. #else // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API cudaError_t result = cudaOccupancyMaxPotentialBlockSize( diff --git a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h index 37e238e86e..728c0a02f0 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -54,9 +54,7 @@ struct TensorForEach { #if defined (CUTLASS_ENABLE_SYCL) // TODO: query the queue for block size block_size = 128; - grid_size = (size.product() + block_size - 1) / block_size; - int sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); - grid_size = grid_size > sm_count / 2 ? sm_count / 2 : grid_size; + grid_size = (size(size) + block_size - 1) / block_size; #else // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API cudaError_t result = cudaOccupancyMaxPotentialBlockSize( @@ -77,7 +75,7 @@ struct TensorForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, size, params); + syclcompat::launch>(sycl_grid, sycl_block, 0, size, params); #else dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); @@ -105,7 +103,7 @@ struct TensorDiagonalForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3((end - start + block_size - 1) / block_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, size, params, start, end); + syclcompat::launch>(sycl_grid, sycl_block, 0, size, params, start, end); #else dim3 block(block_size, 1, 1); dim3 grid((end - start + block_size - 1) / block_size, 1, 1); @@ -155,7 +153,7 @@ struct BlockForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, ptr, capacity, params); + syclcompat::launch>(sycl_grid, sycl_block, 0, ptr, capacity, params); #else dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); From d4cf3eb9fabb0b51488d7c5d74ad39f495217407 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Tue, 9 Jul 2024 18:33:14 -0700 Subject: [PATCH 02/36] fix format of copyright --- examples/sycl/pvc/pvc_gemm.cpp | 21 +++++++++---------- .../intel_pvc_epilogue_tensor_softmax.hpp | 21 +++++++++---------- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 21 +++++++++---------- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 21 +++++++++---------- 4 files changed, 40 insertions(+), 44 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 9ceaed637b..2ec4f72484 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp index 00700c5372..88f787f764 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 2cd540d58d..bbe7a73c57 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 5db3fe49a4..0d87c14241 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once From 665f9be32e947b9509c270aa1d2fe7a086309774 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 10 Jul 2024 00:07:19 -0700 Subject: [PATCH 03/36] replace the macro of cache flush and idx --- examples/sycl/pvc/pvc_gemm.cpp | 72 ++++++++++------------------------ 1 file changed, 21 insertions(+), 51 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 2ec4f72484..62cc1cdf4b 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -48,11 +48,6 @@ #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -// 0 - None -// 1 - FLUSH by memset -// 2 - FLUSH by input offset with pingpong -#define CACHE_FLUSH 2 - template static void fill_matrix(std::vector &M) { std::random_device dev; std::mt19937 rng(dev()); @@ -144,7 +139,7 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template struct ExampleRunner { +template struct ExampleRunner { using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -186,12 +181,10 @@ template struct ExampleRunner { static constexpr auto l3_cache_size = 256 * 1024 * 1024; -#if CACHE_FLUSH == 2 - size_t PINGPONG_ITER = 3; + size_t PINGPONG_ITER = 1; size_t pingpong_size_a; size_t pingpong_size_b; size_t pingpong_size_d; -#endif std::vector a; std::vector b; @@ -222,7 +215,6 @@ template struct ExampleRunner { ); #ifdef EPILOGUE_SOFTMAX -#define IDX (l * M * N + i * N + j) ElementOutput *ptr = (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); @@ -231,21 +223,24 @@ template struct ExampleRunner { syclcompat::wait(); for (int l = 0; l < L; l++) { for (int i = 0; i < M; i++) { - + auto row_idx = l * M * N + i * N; auto row_max = ptr[l * M * N + i * N]; - for (int j = 0; j < N; j++) { - row_max = max(row_max, ptr[IDX]); - } ElementOutput exp_sum = (ElementOutput)0; for (int j = 0; j < N; j++) { - ptr[IDX] = ptr[IDX] - row_max; - ptr[IDX] = exp(ptr[IDX]); - exp_sum += ptr[IDX]; + auto idx = row_idx + j; + row_max = max(row_max, ptr[idx]); + } + for (int j = 0; j < N; j++) { + auto idx = row_idx + j; + ptr[idx] = ptr[idx] - row_max; + ptr[idx] = exp(ptr[idx]); + exp_sum += ptr[idx]; } for (int j = 0; j < N; j++) { - ptr[IDX] = ptr[IDX] / exp_sum; + auto idx = row_idx + j; + ptr[idx] = ptr[idx] / exp_sum; } } } @@ -256,8 +251,6 @@ template struct ExampleRunner { std::free(ptr); -#undef IDX - #endif #if 0 @@ -294,18 +287,10 @@ template struct ExampleRunner { return passed; } - void init_cache_flush(const ProblemShapeType &problem_size) { + void init_cache_clear(const ProblemShapeType &problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; -#if CACHE_FLUSH == 1 - auto ref_d_element = max(l3_cache_size / sizeof(ElementOutput), M * N * L); - block_ref_D.reset(ref_d_element); - syclcompat::memset(block_ref_D.get(), 0, - ref_d_element * sizeof(ElementOutput)); - -#elif CACHE_FLUSH == 2 - pingpong_size_a = max((size_t)M * K * L, l3_cache_size / sizeof(ElementA)); pingpong_size_b = max((size_t)K * N * L, l3_cache_size / sizeof(ElementB)); pingpong_size_d = @@ -332,8 +317,6 @@ template struct ExampleRunner { syclcompat::memcpy(block_D.get() + i * pingpong_size_d, d.data(), d.size() * sizeof(ElementC)); } -#endif - // syclcompat::wait(); } @@ -433,8 +416,10 @@ template struct ExampleRunner { // return; } - // ================ init cache flush ================ - init_cache_flush(problem_size); + // ================ init cache clear ================ + if constexpr(cache_clear) { + init_cache_clear(problem_size); + } // ================ run and collect performance data ================ if (total_iterations > 0) { @@ -443,19 +428,6 @@ template struct ExampleRunner { auto worst = 0.f; for (int i = 0; i < testIterations + warmup; ++i) { -#if CACHE_FLUSH == 1 - init_cache_flush(problem_size); - typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get(), stride_A, block_B.get(), stride_B}, - {{1, 0.f}, - nullptr /*block_C.get() + i * M * N * L*/, - stride_C, - block_D.get(), - stride_D}, - hw_info}; -#elif CACHE_FLUSH == 2 typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, @@ -467,7 +439,6 @@ template struct ExampleRunner { block_D.get() + (i % PINGPONG_ITER) * pingpong_size_d, stride_D}, hw_info}; -#endif Gemm gemm_op; gemm_op.can_implement(arguments); @@ -490,7 +461,6 @@ template struct ExampleRunner { float average = total_time / testIterations; double tflops = (2.0 * M * N * K * L) * 1e-12; - double gflops = (2.0 * M * N * K * L) * 1e-9; double hbm = L * @@ -520,7 +490,7 @@ template struct ExampleRunner { }; template + int sg_tile_k, bool wg_order_m_first = false, uint32_t snake_n = 0, bool cache_clear = true> void collective_gemm(int M, int K, int N, int L = 1) { // // Parse options @@ -625,7 +595,7 @@ void collective_gemm(int M, int K, int N, int L = 1) { using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - ExampleRunner runner; + ExampleRunner runner; runner.template run( M, K, N, L, hw_info); @@ -673,7 +643,7 @@ int main() { #endif #if defined(EPILOGUE_RELU) - // gemm + softmax + // gemm + relu collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); #endif } From 59c0ce403f2e2726b447df73bca6cea0cfdfe492 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 11 Jul 2024 00:36:46 -0700 Subject: [PATCH 04/36] auto format --- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 203 ++++++++---------- 1 file changed, 85 insertions(+), 118 deletions(-) diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 0d87c14241..4af016e6b2 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -41,12 +41,15 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template -class GemmUniversal< - ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, - cute::enable_if_t +class GemmUniversal>> { public: // @@ -55,7 +58,7 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); + "ProblemShape{} should be or "); // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; @@ -71,12 +74,13 @@ class GemmUniversal< using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; - static_assert(cute::is_void_v or - cute::is_same_v, - "Intel PVC does not support specializing the tile scheduler."); + static_assert( + cute::is_void_v or cute::is_same_v, + "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, + using TileScheduler = typename detail::TileSchedulerSelector, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; @@ -89,37 +93,28 @@ class GemmUniversal< using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; static_assert( - cute::is_same_v, + cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. - static constexpr int SharedStorageSize = 0; + static int constexpr SharedStorageSize = 0; - static constexpr int SubgroupSize = - CollectiveMainloop::SubgroupSize; // sub_group size - static constexpr uint32_t MaxThreadsPerBlock = - CollectiveMainloop::MaxThreadsPerBlock; - static constexpr uint32_t MinBlocksPerMultiprocessor = + static int constexpr SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static uint32_t constexpr MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + static uint32_t constexpr MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; - static constexpr int num_sg = + static int constexpr num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group - static constexpr int DpasM = CollectiveMainloop::DpasM; - static constexpr int DpasN = CollectiveMainloop::DpasN; - static constexpr int DpasK = CollectiveMainloop::DpasK; - - static constexpr int FragsM = CollectiveMainloop::FragsM; - static constexpr int FragsN = CollectiveMainloop::FragsN; + static int constexpr DpasM = CollectiveMainloop::DpasM; + static int constexpr DpasN = CollectiveMainloop::DpasN; + static int constexpr DpasK = CollectiveMainloop::DpasK; - static constexpr int VecC = CollectiveMainloop::VecC; + static int constexpr FragsM = CollectiveMainloop::FragsM; + static int constexpr FragsN = CollectiveMainloop::FragsN; - // Kernel level shared memory storage - struct SharedStorage { - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - EpilogueTensorStorage epilogue; - }; + static int constexpr VecC = CollectiveMainloop::VecC; // Device side arguments struct Arguments { @@ -143,61 +138,53 @@ class GemmUniversal< // Methods // - // Convert to underlying arguments. In this case, a simple copy for the - // aliased type. - static Params to_underlying_arguments(Arguments const &args, - void *workspace) { + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) { (void)workspace; return {args.mode, args.problem_shape, - CollectiveMainloop::to_underlying_arguments( - args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments( - args.problem_shape, args.epilogue, workspace)}; + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace)}; } - static bool can_implement(Arguments const &args) { + static bool can_implement(Arguments const& args) { bool mode_implementable = args.mode == GemmUniversalMode::kGemm or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); return mode_implementable && TileScheduler::can_implement(args.scheduler); } - static int get_workspace_size(Arguments const &args) { return 0; } + static int get_workspace_size(Arguments const& args) { return 0; } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static cutlass::Status initialize_workspace(Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } - static dim3 get_grid_shape(Params const ¶ms) { + static dim3 get_grid_shape(Params const& params) { auto M = get<0>(params.problem_shape); auto N = get<1>(params.problem_shape); auto L = get<3>(params.problem_shape); - const int sg_m = - cute::ceil_div(M, - CollectiveMainloop::wg_tile_m); // sub_groups required to - // process A fragments - const int sg_n = - cute::ceil_div(N, - CollectiveMainloop::wg_tile_n); // sub_groups required to - // process B fragments + int const sg_m = cute::ceil_div(M, + CollectiveMainloop::wg_tile_m); // sub_groups required to + // process A fragments + int const sg_n = cute::ceil_div(N, + CollectiveMainloop::wg_tile_n); // sub_groups required to + // process B fragments return dim3(sg_n, sg_m, L); } static dim3 get_block_shape() { - return dim3(cute::ceil_div(CollectiveMainloop::wg_tile_n, - CollectiveMainloop::sg_tile_n / SubgroupSize), - cute::ceil_div(CollectiveMainloop::wg_tile_m, - CollectiveMainloop::sg_tile_m), - 1); + return dim3( + cute::ceil_div(CollectiveMainloop::wg_tile_n, CollectiveMainloop::sg_tile_n / SubgroupSize), + cute::ceil_div(CollectiveMainloop::wg_tile_m, CollectiveMainloop::sg_tile_m), 1); } CUTLASS_DEVICE - void operator()(Params const ¶ms, char *smem_buf) { + void operator()(Params const& params, char* smem_buf) { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); @@ -205,8 +192,7 @@ class GemmUniversal< CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only - // rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -215,65 +201,51 @@ class GemmUniversal< // Preconditions static_assert(cute::rank(StrideA{}) == 3, - "StrideA must be rank-3: [M, K, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideA must be rank-3: [M, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideB must be rank-3: [N, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideC must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideD must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); - // Get the appropriate blocks for this sub_group -- potential for sub_group - // locality + // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); int thread_idy = int(ThreadIdxY()); - static constexpr auto sg_per_wg_n = + static auto constexpr sg_per_wg_n = CollectiveMainloop::wg_tile_n / CollectiveMainloop::sg_tile_n; auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) - const int m_coord = - BlockIdxY() * CollectiveMainloop::wg_tile_m + - (get_sub_group_id() / sg_per_wg_n) * CollectiveMainloop::sg_tile_m; - const int n_coord = - BlockIdxX() * CollectiveMainloop::wg_tile_n + - (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; - const int l_coord = BlockIdxZ(); - const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); - - Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}), - make_stride(Int{}, _1{})); - constexpr int version = - is_same_v - ? 1 - : 2; - - Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(0, n_coord, l_coord), - make_shape(K, Int{}, _1{}), - make_stride(_1{}, Int{})); + int const m_coord = BlockIdxY() * CollectiveMainloop::wg_tile_m + + (get_sub_group_id() / sg_per_wg_n) * CollectiveMainloop::sg_tile_m; + int const n_coord = BlockIdxX() * CollectiveMainloop::wg_tile_n + + (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; + int const l_coord = BlockIdxZ(); + + Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, l_coord), + make_shape(_1{}, K, _1{}), make_stride(Int{}, _1{})); + int constexpr version = + is_same_v ? 1 : 2; + + Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, l_coord), + make_shape(K, Int{}, _1{}), make_stride(_1{}, Int{})); // Compute tile residues for predication - auto m_max_coord = - M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord - auto n_max_coord = - N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord + auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord + auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord auto k_residue = - K - get<2>(subgroup_shape) * - (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape TiledMma tiled_mma; - Tensor accumulators = - make_tensor(Shape, Int, Int>{}); + Tensor accumulators = make_tensor(Shape, Int, Int>{}); clear(accumulators); int k_tile_count = cute::ceil_div(K, CollectiveMainloop::sg_tile_k); @@ -281,16 +253,14 @@ class GemmUniversal< // Perform the collective scoped MMA CollectiveMainloop collective_mma; - collective_mma(accumulators, tAi(_, _, _, 0), tBi(_, _, _, 0), accumulators, - k_tile_iter, k_tile_count, residue_mnk, thread_idx, smem_buf, - params.mainloop); + collective_mma(accumulators, tAi(_, _, _, 0), tBi(_, _, _, 0), accumulators, k_tile_iter, + k_tile_count, residue_mnk, thread_idx, smem_buf, params.mainloop); #ifdef EPILOGUE_RELU // relu CollectiveEpilogue collective_relu(params.epilogue); - collective_relu(problem_shape_MNKL, - make_shape(Int{}, Int{}, Int{}), - make_coord(m_coord, n_coord, 0, l_coord), accumulators); + collective_relu(problem_shape_MNKL, make_shape(Int{}, Int{}, Int{}), + make_coord(m_coord, n_coord, 0, l_coord), accumulators); #endif #ifdef EPILOGUE_SOFTMAX @@ -299,14 +269,11 @@ class GemmUniversal< collective_softmax(accumulators); #endif - auto gmem_tiled_copy_c = - make_xe_2d_copy(make_tensor( - params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); + auto gmem_tiled_copy_c = make_xe_2d_copy( + make_tensor(params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), - make_shape(Int{}, Int{}, _1{}), - make_stride(Int{}, Int{})); + Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + make_shape(Int{}, Int{}, _1{}), make_stride(Int{}, Int{})); copy(gmem_tiled_copy_c, accumulators, tCi(_, _, _, 0)); } From bdadf1e78537b556a989686ff40db8ff7690bec5 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 11 Jul 2024 00:39:54 -0700 Subject: [PATCH 05/36] auto format --- examples/sycl/pvc/pvc_gemm.cpp | 249 ++++++++---------- .../intel_pvc_epilogue_tensor_softmax.hpp | 28 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 210 ++++++++------- 3 files changed, 239 insertions(+), 248 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 62cc1cdf4b..635dca4dd1 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -48,22 +48,20 @@ #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -template static void fill_matrix(std::vector &M) { +template static void fill_matrix(std::vector& M) { std::random_device dev; std::mt19937 rng(dev()); std::uniform_real_distribution dist((T)0.0, #ifdef EPILOGUE_SOFTMAX - (T)0.1); + (T)0.1); #else - (T)1.0); + (T)1.0); #endif - std::generate(std::begin(M), std::end(M), - [&] { return static_cast(dist(rng)); }); + std::generate(std::begin(M), std::end(M), [&] { return static_cast(dist(rng)); }); } template -static void vnni_matrix(T *dst, const T *src, int batch, int numRows, - int numCols, int factor) { +static void vnni_matrix(T* dst, T const* src, int batch, int numRows, int numCols, int factor) { for (int b = 0; b < batch; b++) { for (int r = 0; r < numRows / factor; r++) { for (int c = 0; c < numCols; c++) { @@ -80,9 +78,9 @@ using namespace cute; using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = float; // <- data type of epilogue operations -using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A -using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B -using ElementOutput = float; // <- data type of elements in output matrix D +using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A +using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -96,11 +94,11 @@ struct Options { float alpha, beta; Options() - : help(false), error(false), m(4096), n(4096), k(4096), l(1), - iterations(100), alpha(1.f), beta(0.f) {} + : help(false), error(false), m(4096), n(4096), k(4096), l(1), iterations(100), alpha(1.f), + beta(0.f) {} // Parses the command line - void parse(int argc, char const **args) { + void parse(int argc, char const** args) { cutlass::CommandLine cmd(argc, args); if (cmd.check_cmd_line_flag("help")) { @@ -118,7 +116,7 @@ struct Options { } /// Prints the usage statement. - std::ostream &print_usage(std::ostream &out) const { + std::ostream& print_usage(std::ostream& out) const { out << "PVC GEMM Example\n\n" << "Options:\n\n" @@ -179,7 +177,7 @@ template struct ExampleRunner { cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; - static constexpr auto l3_cache_size = 256 * 1024 * 1024; + static auto constexpr l3_cache_size = 256 * 1024 * 1024; size_t PINGPONG_ITER = 1; size_t pingpong_size_a; @@ -193,20 +191,17 @@ template struct ExampleRunner { // Methods // - bool verify(const ProblemShapeType &problem_size, ElementCompute alpha, - ElementCompute beta) { + bool verify(ProblemShapeType const& problem_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C((ElementC *)nullptr /*block_C.get()*/, - LayoutC::packed({M, N})); + cutlass::TensorRef ref_C((ElementC*)nullptr /*block_C.get()*/, LayoutC::packed({M, N})); cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); cutlass::reference::device::GemmComplex( {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, - cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, - ElementAccumulator(0), + cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, ElementAccumulator(0), L, // batch_count M * K, // batch_stride_A K * N, // batch_stride_B @@ -216,10 +211,8 @@ template struct ExampleRunner { #ifdef EPILOGUE_SOFTMAX - ElementOutput *ptr = - (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); - syclcompat::memcpy(ptr, block_ref_D.get(), - M * N * L * sizeof(ElementOutput)); + ElementOutput* ptr = (ElementOutput*)std::malloc(M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(ptr, block_ref_D.get(), M * N * L * sizeof(ElementOutput)); syclcompat::wait(); for (int l = 0; l < L; l++) { for (int i = 0; i < M; i++) { @@ -245,8 +238,7 @@ template struct ExampleRunner { } } - syclcompat::memcpy(block_ref_D.get(), ptr, - M * N * L * sizeof(ElementOutput)); + syclcompat::memcpy(block_ref_D.get(), ptr, M * N * L * sizeof(ElementOutput)); syclcompat::wait(); std::free(ptr); @@ -287,57 +279,48 @@ template struct ExampleRunner { return passed; } - void init_cache_clear(const ProblemShapeType &problem_size) { + void init_cache_clear(ProblemShapeType const& problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; pingpong_size_a = max((size_t)M * K * L, l3_cache_size / sizeof(ElementA)); pingpong_size_b = max((size_t)K * N * L, l3_cache_size / sizeof(ElementB)); - pingpong_size_d = - max((size_t)M * N * L, l3_cache_size / sizeof(ElementOutput)); + pingpong_size_d = max((size_t)M * N * L, l3_cache_size / sizeof(ElementOutput)); auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); - PINGPONG_ITER = - std::min((size_t)3, - std::max((size_t)1, - (size_t)gmem_size / - ((pingpong_size_a * sizeof(ElementA) + - pingpong_size_b * sizeof(ElementB) + - pingpong_size_d * sizeof(ElementOutput))) - - 1)); + PINGPONG_ITER = std::min((size_t)3, + std::max((size_t)1, (size_t)gmem_size / ((pingpong_size_a * sizeof(ElementA) + + pingpong_size_b * sizeof(ElementB) + + pingpong_size_d * sizeof(ElementOutput))) - + 1)); block_A.reset(pingpong_size_a * PINGPONG_ITER); block_B.reset(pingpong_size_b * PINGPONG_ITER); // block_C.reset(M * N * L * ITER); block_D.reset(pingpong_size_d * PINGPONG_ITER); for (int i = 0; i < PINGPONG_ITER; i++) { - syclcompat::memcpy(block_A.get() + i * pingpong_size_a, a.data(), - a.size() * sizeof(ElementA)); - syclcompat::memcpy(block_B.get() + i * pingpong_size_b, b.data(), - b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_D.get() + i * pingpong_size_d, d.data(), - d.size() * sizeof(ElementC)); + syclcompat::memcpy( + block_A.get() + i * pingpong_size_a, a.data(), a.size() * sizeof(ElementA)); + syclcompat::memcpy( + block_B.get() + i * pingpong_size_b, b.data(), b.size() * sizeof(ElementB)); + syclcompat::memcpy( + block_D.get() + i * pingpong_size_d, d.data(), d.size() * sizeof(ElementC)); } // syclcompat::wait(); } /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(const ProblemShapeType &problem_size) { + void initialize(ProblemShapeType const& problem_size) { auto [M, N, K, L] = problem_size; - stride_A = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = - cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L)); - stride_C = - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = - cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); block_A.reset((size_t)M * K * L); block_B.reset((size_t)K * N * L); // block_C.reset(M * N * L); block_D.reset((size_t)M * N * L); - block_ref_D.reset( - (size_t)max(l3_cache_size / sizeof(ElementOutput), (size_t)M * N * L)); + block_ref_D.reset((size_t)max(l3_cache_size / sizeof(ElementOutput), (size_t)M * N * L)); // TODO: Enable initialization on device directly once RNG is // available through SYCL. @@ -353,13 +336,11 @@ template struct ExampleRunner { syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } - template - void run(int M, int K, int N, int L, - const cutlass::KernelHardwareInfo &hw_info) { - static constexpr auto warmup = 10; - static constexpr auto testIterations = 10; - static constexpr auto total_iterations = warmup + testIterations; + template + void run(int M, int K, int N, int L, cutlass::KernelHardwareInfo const& hw_info) { + static auto constexpr warmup = 10; + static auto constexpr testIterations = 10; + static auto constexpr total_iterations = warmup + testIterations; ProblemShapeType problem_size = ProblemShapeType{M, N, K, L}; initialize(problem_size); @@ -400,24 +381,24 @@ template struct ExampleRunner { printf("PVC GEMM%s%s Example %s, MKNL(%d, %d,%d,%d), Config(%d, " "%d,%d,%d,%d) !!!!!!!!!!!!!\n\n", #ifdef EPILOGUE_RELU - "-relu" + "-relu" #else - "" + "" #endif - , + , #ifdef EPILOGUE_SOFTMAX - "-softmax" + "-softmax" #else - "" + "" #endif - , - (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, - sg_tile_m, sg_tile_n, sg_tile_k); + , + (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, + sg_tile_k); // return; } // ================ init cache clear ================ - if constexpr(cache_clear) { + if constexpr (cache_clear) { init_cache_clear(problem_size); } @@ -428,16 +409,12 @@ template struct ExampleRunner { auto worst = 0.f; for (int i = 0; i < testIterations + warmup; ++i) { - typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, + typename Gemm::GemmKernel::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get() + (i % PINGPONG_ITER) * pingpong_size_a, stride_A, - block_B.get() + (i % PINGPONG_ITER) * pingpong_size_b, stride_B}, - {{1, 0.f}, - nullptr /*block_C.get() + i * M * N * L*/, - stride_C, - block_D.get() + (i % PINGPONG_ITER) * pingpong_size_d, - stride_D}, + block_B.get() + (i % PINGPONG_ITER) * pingpong_size_b, stride_B}, + {{1, 0.f}, nullptr /*block_C.get() + i * M * N * L*/, stride_C, + block_D.get() + (i % PINGPONG_ITER) * pingpong_size_d, stride_D}, hw_info}; Gemm gemm_op; @@ -462,11 +439,10 @@ template struct ExampleRunner { float average = total_time / testIterations; double tflops = (2.0 * M * N * K * L) * 1e-12; - double hbm = - L * - (M * K * sizeof(ElementInputA) + K * N * sizeof(ElementInputB) + - M * N * sizeof(ElementOutput)) * - 1e-9; + double hbm = L * + (M * K * sizeof(ElementInputA) + K * N * sizeof(ElementInputB) + + M * N * sizeof(ElementOutput)) * + 1e-9; printf("Collective pvc gemm%s, MKNL(%d, %d, %d, %d), Config(%d, %d, " "%d, %d, %d):\n max: (%6.4f)ms, (%4.2f)TFlop/s, " @@ -474,23 +450,28 @@ template struct ExampleRunner { "(%4.2f)GB/s\n average: (%6.4f)ms, (%4.2f)TFlop/s, " "(%4.2f)GB/s\n\n\n", #if defined(EPILOGUE_RELU) - "-relu" + "-relu" #elif defined(EPILOGUE_SOFTMAX) - "softmax" + "softmax" #else - "" + "" #endif - , - M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, - best * 1000, tflops / best, hbm / best, worst * 1000, - tflops / worst, hbm / worst, average * 1000, tflops / average, - hbm / average); + , + M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, best * 1000, + tflops / best, hbm / best, worst * 1000, tflops / worst, hbm / worst, average * 1000, + tflops / average, hbm / average); } } }; -template +template void collective_gemm(int M, int K, int N, int L = 1) { // // Parse options @@ -521,8 +502,7 @@ void collective_gemm(int M, int K, int N, int L = 1) { // Change device_id to another value if you are running on a machine with // multiple GPUs and wish to use a GPU other than that with device ID 0. hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); bool passed; @@ -537,68 +517,63 @@ void collective_gemm(int M, int K, int N, int L = 1) { using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; - using TileShape = Shape, Int, Int, - Int, Int>; + using TileShape = + Shape, Int, Int, Int, Int>; - using TiledMma = TiledMMA, - Layout>>; + using TiledMma = TiledMMA, Layout>>; using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; #ifdef EPILOGUE_RELU - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits::value, // <- the number of - // elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear + using EpilogueOp = + cutlass::epilogue::thread::LinearCombinationRelu::value, // <- the number of + // elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear #else - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits::value, // <- the number of - // elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear + using EpilogueOp = + cutlass::epilogue::thread::LinearCombination::value, // <- the number of + // elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear // combination function #endif // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, TileShape, ElementInputA, - cutlass::gemm::TagToStrideA_t, ElementInputB, - cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, - void, cute::identity, // A + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma, ElementInputB, + cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, void, + cute::identity, // A GmemTiledCopyB, void, void, cute::identity // B >; #ifdef EPILOGUE_SOFTMAX - using CollectiveEpilogue = - cutlass::epilogue::collective::PvcEpilogueTensorSoftmax< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, EpilogueOp, - cutlass::gemm::EpilogueDefault, CollectiveMainloop::sg_tile_m, - CollectiveMainloop::sg_tile_n / CollectiveMainloop::SubgroupSize>; + using CollectiveEpilogue = cutlass::epilogue::collective::PvcEpilogueTensorSoftmax< + cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, EpilogueOp, + cutlass::gemm::EpilogueDefault, CollectiveMainloop::sg_tile_m, + CollectiveMainloop::sg_tile_n / CollectiveMainloop::SubgroupSize>; #else - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, EpilogueOp, - cutlass::gemm::EpilogueDefault>; + using CollectiveEpilogue = + cutlass::epilogue::collective::DefaultEpilogue, + cutlass::gemm::TagToStrideC_t, EpilogueOp, cutlass::gemm::EpilogueDefault>; #endif - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, CollectiveMainloop, CollectiveEpilogue>; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, + CollectiveMainloop, CollectiveEpilogue>; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; ExampleRunner runner; - runner.template run( - M, K, N, L, hw_info); + runner.template run(M, K, N, L, hw_info); } int main() { diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp index 88f787f764..01bd25b7ec 100644 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp @@ -44,8 +44,12 @@ namespace epilogue { namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template class PvcEpilogueTensorSoftmax { public: using EpilogueSchedule = EpilogueSchedule_; @@ -68,9 +72,9 @@ class PvcEpilogueTensorSoftmax { // Host side epilogue arguments struct Arguments { typename ThreadEpilogueOp::Params thread{}; - ElementC const *ptr_C = nullptr; + ElementC const* ptr_C = nullptr; StrideC dC{}; - ElementD *ptr_D = nullptr; + ElementD* ptr_D = nullptr; StrideD dD{}; }; @@ -78,23 +82,21 @@ class PvcEpilogueTensorSoftmax { using Params = Arguments; template - static constexpr Params - to_underlying_arguments([[maybe_unused]] ProblemShape const &_, - Arguments const &args, - [[maybe_unused]] void *workspace) { + static Params constexpr to_underlying_arguments([[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { return args; } - template CUTLASS_DEVICE void operator()(T &t) { + template CUTLASS_DEVICE void operator()(T& t) { static_assert(cute::is_same_v && m <= 32); - auto const &group = - sycl::ext::oneapi::experimental::this_nd_item<3>().get_group(); + auto const& group = sycl::ext::oneapi::experimental::this_nd_item<3>().get_group(); - static constexpr auto vec_size = 4; + static auto constexpr vec_size = 4; static_assert((m % vec_size) == 0 && vec_size <= 16); - static constexpr auto loop_cnt = m / vec_size; + static auto constexpr loop_cnt = m / vec_size; sycl::vec local_max; sycl::vec local_plus; diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index bbe7a73c57..7bb07c02d2 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -43,21 +43,38 @@ namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// -#define get_sub_group_id() \ - (sycl::ext::oneapi::experimental::this_nd_item<3>() \ - .get_sub_group() \ - .get_group_id()[0]) - -template -struct CollectiveMma { +#define get_sub_group_id() \ + (sycl::ext::oneapi::experimental::this_nd_item<3>().get_sub_group().get_group_id()[0]) + +template +struct CollectiveMma { // // Type Aliases // @@ -80,54 +97,56 @@ struct CollectiveMma(tile_shape))::value; - static constexpr auto wg_tile_n = decltype(get<1>(tile_shape))::value; - static constexpr auto sg_tile_m = decltype(get<2>(tile_shape))::value; - static constexpr auto sg_tile_n = decltype(get<3>(tile_shape))::value; - static constexpr auto sg_tile_k = decltype(get<4>(tile_shape))::value; - static constexpr auto sg_per_wg_m = wg_tile_m / sg_tile_m; - static constexpr auto sg_per_wg_n = wg_tile_n / sg_tile_n; - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - static constexpr int DpasM = get<0>( - shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per - // sub_group for Matrix A - static constexpr int DpasN = get<1>( - shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per - // sub_group for Matrix B - static constexpr int DpasK = get<1>( - shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per - // sub_group for Matrix A - - static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - static constexpr int FragsM = sg_tile_m / DpasM; // A frags per sub_group - static constexpr int FragsN = sg_tile_n / DpasN; // B frags per sub_group - static constexpr int FragsK = sg_tile_k / DpasK; + static auto constexpr wg_tile_m = decltype(get<0>(tile_shape))::value; + static auto constexpr wg_tile_n = decltype(get<1>(tile_shape))::value; + static auto constexpr sg_tile_m = decltype(get<2>(tile_shape))::value; + static auto constexpr sg_tile_n = decltype(get<3>(tile_shape))::value; + static auto constexpr sg_tile_k = decltype(get<4>(tile_shape))::value; + static auto constexpr sg_per_wg_m = wg_tile_m / sg_tile_m; + static auto constexpr sg_per_wg_n = wg_tile_n / sg_tile_n; + static int constexpr SubgroupSize = DispatchPolicy::SubgroupSize; + + static int constexpr DpasM = + get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per + // sub_group for Matrix A + static int constexpr DpasN = + get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per + // sub_group for Matrix B + static int constexpr DpasK = + get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per + // sub_group for Matrix A + + static uint32_t constexpr MaxThreadsPerBlock = DpasM * DpasN; + static uint32_t constexpr MinBlocksPerMultiprocessor = 1; + + static int constexpr FragsM = sg_tile_m / DpasM; // A frags per sub_group + static int constexpr FragsN = sg_tile_n / DpasN; // B frags per sub_group + static int constexpr FragsK = sg_tile_k / DpasK; // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by // the sub_group size. - static constexpr int VecC = (DpasN * DpasM) / SubgroupSize; - static constexpr int VecA = (DpasM * DpasK) / SubgroupSize; - static constexpr int VecB = (DpasN * DpasK) / SubgroupSize; + static int constexpr VecC = (DpasN * DpasM) / SubgroupSize; + static int constexpr VecA = (DpasM * DpasK) / SubgroupSize; + static int constexpr VecB = (DpasN * DpasK) / SubgroupSize; // Host side kernel arguments struct Arguments { - ElementA const *ptr_A; + ElementA const* ptr_A; StrideA dA; - ElementB const *ptr_B; + ElementB const* ptr_B; StrideB dB; }; struct Params { - using XE_Copy_A = decltype(make_xe_2d_copy( - make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}))); - using XE_Copy_B = decltype(make_xe_2d_copy( - make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}))); + using XE_Copy_A = + decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), + StrideA{}))); + using XE_Copy_B = + decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), + StrideB{}))); XE_Copy_A gmem_tiled_copy_a; XE_Copy_B gmem_tiled_copy_b; }; @@ -139,18 +158,16 @@ struct CollectiveMma - static constexpr Params - to_underlying_arguments(ProblemShape const &problem_shape, - Arguments const &args, void *workspace) { + static Params constexpr to_underlying_arguments(ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { (void)workspace; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; - Tensor tensorA = - make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensorB = - make_tensor(args.ptr_B, make_layout(make_shape(K, N, L), args.dB)); + Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K, N, L), args.dB)); typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); @@ -158,59 +175,58 @@ struct CollectiveMma - CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, - FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int k_tile_count, - ResidueMNK residue_mnk, int thread_idx, - char *smem_buf, Params const &mainloop) { + template + CUTLASS_DEVICE void operator()(FrgTensorD& accum, + TensorA gA, + TensorB gB, + FrgTensorC const& src_accum, + KTileIterator k_tile_iter, + int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf, + Params const& mainloop) { (void)residue_mnk; (void)thread_idx; (void)smem_buf; - static_assert(is_rmem::value, - "D tensor must be rmem resident."); - static_assert( - is_tuple::value, + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_tuple::value, "A tensor must be a tuple iterator."); - static_assert( - is_tuple::value, + static_assert(is_tuple::value, "B tensor must be a tuple iterator."); - static_assert(is_rmem::value, - "C tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor( - Shape, Int<1>>{}); + Tensor tAr = make_tensor(Shape, Int<1>>{}); - constexpr int version = - is_same_v ? 1 : 2; + int constexpr version = is_same_v ? 1 : 2; Tensor tBr = make_tensor( Shape, Int>{}); - Tensor tAr_view = make_tensor(static_cast(tAr).data(), - Shape, Int, Int>{}); - Tensor tBr_view = make_tensor(static_cast(tBr).data(), - Shape, Int, Int>{}); + Tensor tAr_view = make_tensor( + static_cast(tAr).data(), Shape, Int, Int>{}); + Tensor tBr_view = make_tensor( + static_cast(tBr).data(), Shape, Int, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; int K = size<1>(mainloop.gmem_tiled_copy_a.tensor); - Tensor tAi = make_tensor( - make_inttuple_iter( - *gA.data() + - make_coord((get_sub_group_id() % sg_per_wg_n % 4) * DpasM, 0)), - make_layout(make_shape(_1{}, _1{}, K), - make_stride(_1{}, E<0>{}, E<1>{}))); - Tensor tBi = make_tensor( - make_inttuple_iter( - *gB.data() + - make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * DpasK, - (get_sub_group_id() / sg_per_wg_n % 2 * 2) * DpasN)), - make_layout(make_shape(_1{}, K, _1{}), - make_stride(_1{}, E<0>{}, E<1>{}))); + Tensor tAi = + make_tensor(make_inttuple_iter( + *gA.data() + make_coord((get_sub_group_id() % sg_per_wg_n % 4) * DpasM, 0)), + make_layout(make_shape(_1{}, _1{}, K), make_stride(_1{}, E<0>{}, E<1>{}))); + Tensor tBi = + make_tensor(make_inttuple_iter( + *gB.data() + make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * DpasK, + (get_sub_group_id() / sg_per_wg_n % 2 * 2) * DpasN)), + make_layout(make_shape(_1{}, K, _1{}), make_stride(_1{}, E<0>{}, E<1>{}))); // // Mainloop // @@ -221,8 +237,7 @@ struct CollectiveMma Date: Sat, 13 Jul 2024 22:31:09 -0700 Subject: [PATCH 06/36] fix comments about prefetch --- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 200 +++++++++--------- 1 file changed, 97 insertions(+), 103 deletions(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 7bb07c02d2..b3772bdac4 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,14 +18,15 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -43,38 +44,21 @@ namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// -#define get_sub_group_id() \ - (sycl::ext::oneapi::experimental::this_nd_item<3>().get_sub_group().get_group_id()[0]) - -template -struct CollectiveMma { +#define get_sub_group_id() \ + (sycl::ext::oneapi::experimental::this_nd_item<3>() \ + .get_sub_group() \ + .get_group_id()[0]) + +template +struct CollectiveMma { // // Type Aliases // @@ -106,15 +90,15 @@ struct CollectiveMma(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per - // sub_group for Matrix A - static int constexpr DpasN = - get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per - // sub_group for Matrix B - static int constexpr DpasK = - get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per - // sub_group for Matrix A + static int constexpr DpasM = get<0>( + shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per + // sub_group for Matrix A + static int constexpr DpasN = get<1>( + shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per + // sub_group for Matrix B + static int constexpr DpasK = get<1>( + shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per + // sub_group for Matrix A static uint32_t constexpr MaxThreadsPerBlock = DpasM * DpasN; static uint32_t constexpr MinBlocksPerMultiprocessor = 1; @@ -132,21 +116,19 @@ struct CollectiveMma(make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), - StrideA{}))); - using XE_Copy_B = - decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), - StrideB{}))); + using XE_Copy_A = decltype(make_xe_2d_copy( + make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), StrideA{}))); + using XE_Copy_B = decltype(make_xe_2d_copy( + make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), StrideB{}))); XE_Copy_A gmem_tiled_copy_a; XE_Copy_B gmem_tiled_copy_b; }; @@ -158,16 +140,18 @@ struct CollectiveMma - static Params constexpr to_underlying_arguments(ProblemShape const& problem_shape, - Arguments const& args, - void* workspace) { + static Params constexpr to_underlying_arguments( + ProblemShape const &problem_shape, Arguments const &args, + void *workspace) { (void)workspace; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; - Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K, N, L), args.dB)); + Tensor tensorA = + make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensorB = + make_tensor(args.ptr_B, make_layout(make_shape(K, N, L), args.dB)); typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); @@ -175,58 +159,66 @@ struct CollectiveMma - CUTLASS_DEVICE void operator()(FrgTensorD& accum, - TensorA gA, - TensorB gB, - FrgTensorC const& src_accum, - KTileIterator k_tile_iter, - int k_tile_count, - ResidueMNK residue_mnk, - int thread_idx, - char* smem_buf, - Params const& mainloop) { + template + CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, int thread_idx, + char *smem_buf, Params const &mainloop) { (void)residue_mnk; (void)thread_idx; (void)smem_buf; - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_tuple::value, + static_assert(is_rmem::value, + "D tensor must be rmem resident."); + static_assert( + is_tuple::value, "A tensor must be a tuple iterator."); - static_assert(is_tuple::value, + static_assert( + is_tuple::value, "B tensor must be a tuple iterator."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(is_rmem::value, + "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape, Int<1>>{}); + Tensor tAr = make_tensor( + Shape, Int<1>>{}); - int constexpr version = is_same_v ? 1 : 2; + int constexpr version = + is_same_v ? 1 : 2; Tensor tBr = make_tensor( Shape, Int>{}); - Tensor tAr_view = make_tensor( - static_cast(tAr).data(), Shape, Int, Int>{}); - Tensor tBr_view = make_tensor( - static_cast(tBr).data(), Shape, Int, Int>{}); + Tensor tAr_view = make_tensor(static_cast(tAr).data(), + Shape, Int, Int>{}); + Tensor tBr_view = make_tensor(static_cast(tBr).data(), + Shape, Int, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; int K = size<1>(mainloop.gmem_tiled_copy_a.tensor); - Tensor tAi = - make_tensor(make_inttuple_iter( - *gA.data() + make_coord((get_sub_group_id() % sg_per_wg_n % 4) * DpasM, 0)), - make_layout(make_shape(_1{}, _1{}, K), make_stride(_1{}, E<0>{}, E<1>{}))); - Tensor tBi = - make_tensor(make_inttuple_iter( - *gB.data() + make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * DpasK, - (get_sub_group_id() / sg_per_wg_n % 2 * 2) * DpasN)), - make_layout(make_shape(_1{}, K, _1{}), make_stride(_1{}, E<0>{}, E<1>{}))); + + // Cooperative prefetch + // Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in + // one row/col use the same tile A/B. Each thread loads sizeof(tile A or B) + // / numof(sg_per_wg_n or sg_per_wg_m) Currently, sg_per_wg_m x sg_per_wg_n + // = 4 x 8 is the most efficient + // TODO: Replace the demo cooperative prefetch with more general way. + Tensor tAi = make_tensor( + make_inttuple_iter( + *gA.data() + + make_coord((get_sub_group_id() % sg_per_wg_n % 4) * DpasM, 0)), + make_layout(make_shape(_1{}, _1{}, K), + make_stride(_1{}, E<0>{}, E<1>{}))); + Tensor tBi = make_tensor( + make_inttuple_iter( + *gB.data() + + make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * DpasK, + (get_sub_group_id() / sg_per_wg_n % 2 * 2) * DpasN)), + make_layout(make_shape(_1{}, K, _1{}), + make_stride(_1{}, E<0>{}, E<1>{}))); // // Mainloop // @@ -237,7 +229,8 @@ struct CollectiveMma Date: Sun, 14 Jul 2024 22:11:42 -0700 Subject: [PATCH 07/36] fix comments of enum and sycl macro --- include/cute/arch/copy_xe.hpp | 122 ++++++---------------------------- include/cute/arch/mma_xe.hpp | 16 ++--- 2 files changed, 27 insertions(+), 111 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index db0f3a6334..51200dd08b 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -38,22 +38,26 @@ namespace cute { #ifdef __SYCL_DEVICE_ONLY__ +#ifdef SYCL_INTEL_TARGET #define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x #else #define SYCL_DEVICE_BUILTIN(x) \ - inline x { assert(false); } + inline x { CUTE_INVALID_CONTROL_PATH("Trying to use IGC built-in on non-Intel hardware"); } #endif +#else +#define SYCL_DEVICE_BUILTIN(x) \ + inline x { CUTE_INVALID_CONTROL_PATH("Trying to use device built-in on host."); } +#endif enum LSC_LDCC { - LSC_LDCC_DEFAULT = 0, - LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached - LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached - LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached - LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached - LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached - LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached - LSC_LDCC_L1IAR_L3C - = 7, // Override to L1 invalidate-after-read, and L3 cached + kLSC_LDCC_DEFAULT = 0, + kLSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + kLSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + kLSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + kLSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + kLSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + kLSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + kLSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached }; SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( @@ -119,28 +123,19 @@ struct XE_2D_U16x8x16x1x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort8 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - LSC_LDCC_L1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif + kLSC_LDCC_L1C_L3C); } }; }; @@ -151,13 +146,9 @@ struct XE_2D_U32x8x16x1x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } }; @@ -167,28 +158,19 @@ struct XE_2D_U16x16x16x1x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - LSC_LDCC_L1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif + kLSC_LDCC_L1C_L3C); } }; }; @@ -199,29 +181,20 @@ struct XE_2D_U16x8x16x4x2_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort64 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( long(baseoffset), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); // __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - LSC_LDCC_L1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif + kLSC_LDCC_L1C_L3C); } }; }; @@ -232,28 +205,19 @@ struct XE_2D_U16x8x16x2x2_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort32*) dst = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( long(baseoffset), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - LSC_LDCC_L1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif + kLSC_LDCC_L1C_L3C); } }; }; @@ -264,29 +228,20 @@ struct XE_2D_U16x8x16x1x2_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); intel::ushort16 tmp = (intel_subgroup_block_read_u16_m8k16v2( (long)baseoffset, width, height, pitch, coord)); *(intel::ushort16 *)dst = *reinterpret_cast(&tmp); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - LSC_LDCC_L1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif + kLSC_LDCC_L1C_L3C); } }; }; @@ -297,28 +252,19 @@ struct XE_2D_U16x8x16x4x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort32*) dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( long(baseoffset), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - LSC_LDCC_L1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif + kLSC_LDCC_L1C_L3C); } }; }; @@ -329,14 +275,10 @@ struct XE_2D_U32x8x16x2x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( long(baseoffset), width - 1, height - 1, pitch - 1, coord); *(intel::uint16 *)dst = *reinterpret_cast(&tmp); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } }; @@ -346,14 +288,10 @@ struct XE_2D_U16x16x16x2x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( long(baseoffset), width - 1, height - 1, pitch - 1, coord); *(intel::uint16 *)dst = *reinterpret_cast(&tmp); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; @@ -363,12 +301,8 @@ struct XE_2D_U16x16x16x2x2_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::uint32*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long(base_address), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } // using PREFETCH = typename XE_2D_U16x8x16x4x2_LD_N::PREFETCH; @@ -379,12 +313,8 @@ struct XE_2D_U16x16x16x1x2_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::int16*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long(base_address), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; @@ -394,12 +324,8 @@ struct XE_2D_U16x16x16x2x1_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::int16*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long(base_address), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; @@ -409,13 +335,9 @@ struct XE_2D_U16x16x16x1x1_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: *(intel::int8*) dst = intel_subgroup_block_read_transform_u16_k16((long)base_address, width, height, pitch, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } using PREFETCH = typename XE_2D_U16x16x16x1x1_LD_N::PREFETCH; @@ -426,14 +348,10 @@ struct XE_2D_U32x8x16x1x1_ST_N template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, const T *src) { - #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, *(intel::uint8 *)src); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif } }; diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 878c587fdc..2850576d6e 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -35,10 +35,16 @@ #include #ifdef __SYCL_DEVICE_ONLY__ +#ifdef SYCL_INTEL_TARGET #define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x #else -#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#define SYCL_DEVICE_OCL(x) \ + inline x { CUTE_INVALID_CONTROL_PATH("Trying to use IGC built-in on non-Intel hardware"); } #endif +#else +#define SYCL_DEVICE_OCL(x) \ + inline x { CUTE_INVALID_CONTROL_PATH("Trying to use device built-in on host."); } +#endif SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc)); SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::intel::int8 b, float acc)); @@ -62,11 +68,7 @@ struct XE_8x16x16_BF16BF16F32F32_NN intel::int8 const& b, intel::float8 const& c) { -#if defined(SYCL_INTEL_TARGET) d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_BF16BF16F32F32_NN on non-PVC hardware"); -#endif } }; //float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) @@ -83,11 +85,7 @@ struct XE_1x16x16_BF16BF16F32F32_NN intel::int8 const& b, float const& c) { -#if defined(SYCL_INTEL_TARGET) d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_BF16BF16F32F32_NN on non-PVC hardware"); -#endif } }; } //namespace cute From 1e3f855998d63cc26cf464ebee4c20c49882dd2a Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Mon, 15 Jul 2024 01:05:36 -0700 Subject: [PATCH 08/36] update from tensor library repo --- .../epilogue/collective/default_epilogue.hpp | 33 +++ .../epilogue/thread/linear_combination_relu.h | 13 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 62 ++--- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 218 ++++++++++-------- .../util/reference/device/tensor_foreach.h | 10 +- 5 files changed, 200 insertions(+), 136 deletions(-) diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index 71ba713ba3..ebb7aabcd5 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -182,6 +182,39 @@ class DefaultEpilogue { } #endif + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout> + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor & accumulators){ + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + if (epilogue_op.is_source_needed()) { + auto source = make_fragment_like(accumulators); + auto gmem_tiled_copy_c = + make_xe_2d_copy(make_tensor( + params.ptr_C, make_shape(M, N, L), params.dC)); + + Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), + make_shape(size<1>(accumulators), size<2>(accumulators), L), + make_stride(size<0>(blk_shape_MNK), size<1>(blk_shape_MNK))); + copy(gmem_tiled_copy_c, tCi(_, _, _, l_coord), source); + epilogue_op(accumulators, source); + } else { + epilogue_op(accumulators); + } + } + template< class ProblemShapeMNKL, class BlockShapeMNK, diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 7ecba15006..343e2a9ec2 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -184,26 +184,25 @@ class LinearCombinationRelu { } } -#ifdef EPILOGUE_RELU using ElementC = ElementOutput_; using ElementD = ElementOutput_; - template + + template CUTLASS_HOST_DEVICE - void operator()(cute::Tensor &accumulators) const { + void operator()(TensorType &accumulators) const { for (int i = 0; i < size(accumulators); i++) { accumulators(i) = accumulators(i) < 0 ? 0 : accumulators(i); } } - template + template CUTLASS_HOST_DEVICE - void operator()(cute::Tensor &accumulators, - cute::Tensor const &source) const { + void operator()(TensorDst &accumulators, + TensorSrc const &source) const { for (int i = 0; i < size(accumulators); i++) { accumulators(i) = accumulators(i) < 0 ? source(i) : accumulators(i) + source(i); } } -#endif /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index b3772bdac4..79b8abc029 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -40,7 +40,6 @@ #include "cute/tensor_predicate.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// - namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,38 +80,38 @@ struct CollectiveMma(tile_shape))::value; - static auto constexpr wg_tile_n = decltype(get<1>(tile_shape))::value; - static auto constexpr sg_tile_m = decltype(get<2>(tile_shape))::value; - static auto constexpr sg_tile_n = decltype(get<3>(tile_shape))::value; - static auto constexpr sg_tile_k = decltype(get<4>(tile_shape))::value; - static auto constexpr sg_per_wg_m = wg_tile_m / sg_tile_m; - static auto constexpr sg_per_wg_n = wg_tile_n / sg_tile_n; - static int constexpr SubgroupSize = DispatchPolicy::SubgroupSize; - - static int constexpr DpasM = get<0>( + static constexpr auto wg_tile_m = decltype(get<0>(tile_shape))::value; + static constexpr auto wg_tile_n = decltype(get<1>(tile_shape))::value; + static constexpr auto sg_tile_m = decltype(get<2>(tile_shape))::value; + static constexpr auto sg_tile_n = decltype(get<3>(tile_shape))::value; + static constexpr auto sg_tile_k = decltype(get<4>(tile_shape))::value; + static constexpr auto sg_per_wg_m = wg_tile_m / sg_tile_m; + static constexpr auto sg_per_wg_n = wg_tile_n / sg_tile_n; + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static constexpr int DpasM = get<0>( shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per // sub_group for Matrix A - static int constexpr DpasN = get<1>( + static constexpr int DpasN = get<1>( shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per // sub_group for Matrix B - static int constexpr DpasK = get<1>( + static constexpr int DpasK = get<1>( shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per // sub_group for Matrix A - static uint32_t constexpr MaxThreadsPerBlock = DpasM * DpasN; - static uint32_t constexpr MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - static int constexpr FragsM = sg_tile_m / DpasM; // A frags per sub_group - static int constexpr FragsN = sg_tile_n / DpasN; // B frags per sub_group - static int constexpr FragsK = sg_tile_k / DpasK; + static constexpr int FragsM = sg_tile_m / DpasM; // A frags per sub_group + static constexpr int FragsN = sg_tile_n / DpasN; // B frags per sub_group + static constexpr int FragsK = sg_tile_k / DpasK; // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by // the sub_group size. - static int constexpr VecC = (DpasN * DpasM) / SubgroupSize; - static int constexpr VecA = (DpasM * DpasK) / SubgroupSize; - static int constexpr VecB = (DpasN * DpasK) / SubgroupSize; + static constexpr int VecC = (DpasN * DpasM) / SubgroupSize; + static constexpr int VecA = (DpasM * DpasK) / SubgroupSize; + static constexpr int VecB = (DpasN * DpasK) / SubgroupSize; // Host side kernel arguments struct Arguments { @@ -140,9 +139,9 @@ struct CollectiveMma - static Params constexpr to_underlying_arguments( - ProblemShape const &problem_shape, Arguments const &args, - void *workspace) { + static constexpr Params + to_underlying_arguments(ProblemShape const &problem_shape, + Arguments const &args, void *workspace) { (void)workspace; auto problem_shape_MNKL = append<4>(problem_shape, 1); @@ -185,7 +184,7 @@ struct CollectiveMma( Shape, Int<1>>{}); - int constexpr version = + constexpr int version = is_same_v ? 1 : 2; Tensor tBr = make_tensor( Shape, Int>{}); @@ -200,11 +199,14 @@ struct CollectiveMma(mainloop.gmem_tiled_copy_a.tensor); - // Cooperative prefetch - // Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in - // one row/col use the same tile A/B. Each thread loads sizeof(tile A or B) - // / numof(sg_per_wg_n or sg_per_wg_m) Currently, sg_per_wg_m x sg_per_wg_n - // = 4 x 8 is the most efficient + /* Cooperative prefetch + Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in + one row/col use the same tile A/B. + Each thread loads sizeof(tile A or B) / numof(sg_per_wg_n or + sg_per_wg_m). + + Currently, sg_per_wg_m x sg_per_wg_n = 4 x 8 is the most efficient + */ // TODO: Replace the demo cooperative prefetch with more general way. Tensor tAi = make_tensor( make_inttuple_iter( diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 4af016e6b2..a98c22ccf1 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,19 +18,21 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/kernel_hardware_info.hpp" @@ -41,15 +43,12 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template -class GemmUniversal +class GemmUniversal< + ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, + cute::enable_if_t>> { public: // @@ -58,7 +57,7 @@ class GemmUniversal or "); + "ProblemShape{} should be or "); // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; @@ -74,13 +73,12 @@ class GemmUniversal or cute::is_same_v, - "Intel PVC does not support specializing the tile scheduler."); + static_assert(cute::is_void_v or + cute::is_same_v, + "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; - using TileScheduler = typename detail::TileSchedulerSelector, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; @@ -93,28 +91,31 @@ class GemmUniversal, + cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. - static int constexpr SharedStorageSize = 0; + static constexpr int SharedStorageSize = 0; - static int constexpr SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size - static uint32_t constexpr MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - static uint32_t constexpr MinBlocksPerMultiprocessor = + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor; - static int constexpr num_sg = + static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group - static int constexpr DpasM = CollectiveMainloop::DpasM; - static int constexpr DpasN = CollectiveMainloop::DpasN; - static int constexpr DpasK = CollectiveMainloop::DpasK; + static constexpr int DpasM = CollectiveMainloop::DpasM; + static constexpr int DpasN = CollectiveMainloop::DpasN; + static constexpr int DpasK = CollectiveMainloop::DpasK; - static int constexpr FragsM = CollectiveMainloop::FragsM; - static int constexpr FragsN = CollectiveMainloop::FragsN; + static constexpr int FragsM = CollectiveMainloop::FragsM; + static constexpr int FragsN = CollectiveMainloop::FragsN; - static int constexpr VecC = CollectiveMainloop::VecC; + static constexpr int VecC = CollectiveMainloop::VecC; // Device side arguments struct Arguments { @@ -138,53 +139,61 @@ class GemmUniversal(params.problem_shape); auto N = get<1>(params.problem_shape); auto L = get<3>(params.problem_shape); - int const sg_m = cute::ceil_div(M, - CollectiveMainloop::wg_tile_m); // sub_groups required to - // process A fragments - int const sg_n = cute::ceil_div(N, - CollectiveMainloop::wg_tile_n); // sub_groups required to - // process B fragments + int const sg_m = + cute::ceil_div(M, + CollectiveMainloop::wg_tile_m); // sub_groups required to + // process A fragments + int const sg_n = + cute::ceil_div(N, + CollectiveMainloop::wg_tile_n); // sub_groups required to + // process B fragments return dim3(sg_n, sg_m, L); } static dim3 get_block_shape() { - return dim3( - cute::ceil_div(CollectiveMainloop::wg_tile_n, CollectiveMainloop::sg_tile_n / SubgroupSize), - cute::ceil_div(CollectiveMainloop::wg_tile_m, CollectiveMainloop::sg_tile_m), 1); + return dim3(cute::ceil_div(CollectiveMainloop::wg_tile_n, + CollectiveMainloop::sg_tile_n / SubgroupSize), + cute::ceil_div(CollectiveMainloop::wg_tile_m, + CollectiveMainloop::sg_tile_m), + 1); } CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) { + void operator()(Params const ¶ms, char *smem_buf) { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); @@ -192,7 +201,8 @@ class GemmUniversal::value); // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only + // rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -201,51 +211,64 @@ class GemmUniversal."); + "StrideA must be rank-3: [M, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideB must be rank-3: [N, K, L]. If batch mode is not " + "needed, set L stride to Int<0>."); static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideC must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not " - "needed, set L stride to Int<0>."); + "StrideD must be rank-3: [M, N, L]. If batch mode is not " + "needed, set L stride to Int<0>."); - // Get the appropriate blocks for this sub_group -- potential for sub_group locality + // Get the appropriate blocks for this sub_group -- potential for sub_group + // locality int thread_idx = int(ThreadIdxX()); int thread_idy = int(ThreadIdxY()); - static auto constexpr sg_per_wg_n = + static constexpr auto sg_per_wg_n = CollectiveMainloop::wg_tile_n / CollectiveMainloop::sg_tile_n; auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) - int const m_coord = BlockIdxY() * CollectiveMainloop::wg_tile_m + - (get_sub_group_id() / sg_per_wg_n) * CollectiveMainloop::sg_tile_m; - int const n_coord = BlockIdxX() * CollectiveMainloop::wg_tile_n + - (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; - int const l_coord = BlockIdxZ(); - - Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, l_coord), - make_shape(_1{}, K, _1{}), make_stride(Int{}, _1{})); - int constexpr version = - is_same_v ? 1 : 2; - - Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, l_coord), - make_shape(K, Int{}, _1{}), make_stride(_1{}, Int{})); + const int m_coord = + BlockIdxY() * CollectiveMainloop::wg_tile_m + + (get_sub_group_id() / sg_per_wg_n) * CollectiveMainloop::sg_tile_m; + const int n_coord = + BlockIdxX() * CollectiveMainloop::wg_tile_n + + (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; + const int l_coord = BlockIdxZ(); + + Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( + make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}), + make_stride(Int{}, _1{})); + constexpr int version = + is_same_v + ? 1 + : 2; + + Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( + make_coord(0, n_coord, l_coord), + make_shape(K, Int{}, _1{}), + make_stride(_1{}, Int{})); // Compute tile residues for predication - auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord - auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord + auto m_max_coord = + M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord + auto n_max_coord = + N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord auto k_residue = - K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + K - get<2>(subgroup_shape) * + (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape TiledMma tiled_mma; - Tensor accumulators = make_tensor(Shape, Int, Int>{}); + Tensor accumulators = make_tensor( + Shape, Int, Int>{}); clear(accumulators); int k_tile_count = cute::ceil_div(K, CollectiveMainloop::sg_tile_k); @@ -253,14 +276,16 @@ class GemmUniversal{}, Int{}, Int{}), - make_coord(m_coord, n_coord, 0, l_coord), accumulators); + collective_relu(problem_shape_MNKL, + make_shape(Int{}, Int{}, Int{}), + make_coord(m_coord, n_coord, 0, l_coord), accumulators); #endif #ifdef EPILOGUE_SOFTMAX @@ -269,11 +294,14 @@ class GemmUniversal( - make_tensor(params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); + auto gmem_tiled_copy_c = + make_xe_2d_copy(make_tensor( + params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), - make_shape(Int{}, Int{}, _1{}), make_stride(Int{}, Int{})); + Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), + make_shape(Int{}, Int{}, _1{}), + make_stride(Int{}, Int{})); copy(gmem_tiled_copy_c, accumulators, tCi(_, _, _, 0)); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h index 728c0a02f0..37e238e86e 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -54,7 +54,9 @@ struct TensorForEach { #if defined (CUTLASS_ENABLE_SYCL) // TODO: query the queue for block size block_size = 128; - grid_size = (size(size) + block_size - 1) / block_size; + grid_size = (size.product() + block_size - 1) / block_size; + int sm_count = KernelHardwareInfo::query_device_multiprocessor_count(); + grid_size = grid_size > sm_count / 2 ? sm_count / 2 : grid_size; #else // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API cudaError_t result = cudaOccupancyMaxPotentialBlockSize( @@ -75,7 +77,7 @@ struct TensorForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, 0, size, params); + syclcompat::launch>(sycl_grid, sycl_block, size, params); #else dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); @@ -103,7 +105,7 @@ struct TensorDiagonalForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3((end - start + block_size - 1) / block_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, 0, size, params, start, end); + syclcompat::launch>(sycl_grid, sycl_block, size, params, start, end); #else dim3 block(block_size, 1, 1); dim3 grid((end - start + block_size - 1) / block_size, 1, 1); @@ -153,7 +155,7 @@ struct BlockForEach { #if defined(CUTLASS_ENABLE_SYCL) const auto sycl_block = syclcompat::dim3(block_size, 1, 1); const auto sycl_grid = syclcompat::dim3(grid_size, 1, 1); - syclcompat::launch>(sycl_grid, sycl_block, 0, ptr, capacity, params); + syclcompat::launch>(sycl_grid, sycl_block, ptr, capacity, params); #else dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); From 8e951d115e598ac7abc5fb9ce301801b5bff9f93 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Mon, 15 Jul 2024 01:16:59 -0700 Subject: [PATCH 09/36] fix format --- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 137 +++++++++------ .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 165 ++++++++---------- 2 files changed, 155 insertions(+), 147 deletions(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 79b8abc029..8c03e9e74c 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -18,15 +18,14 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once @@ -48,16 +47,38 @@ using namespace cute; .get_sub_group() \ .get_group_id()[0]) -template -struct CollectiveMma { +template < + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopIntelPVCUnpredicated, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ // // Type Aliases // @@ -115,19 +136,17 @@ struct CollectiveMma( - make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}))); - using XE_Copy_B = decltype(make_xe_2d_copy( - make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}))); + using XE_Copy_A = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), + repeat_like(StrideA{}, int32_t(0)), StrideA{}))); + using XE_Copy_B = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), + repeat_like(StrideB{}, int32_t(0)), StrideB{}))); XE_Copy_A gmem_tiled_copy_a; XE_Copy_B gmem_tiled_copy_b; }; @@ -140,17 +159,14 @@ struct CollectiveMma static constexpr Params - to_underlying_arguments(ProblemShape const &problem_shape, - Arguments const &args, void *workspace) { - (void)workspace; + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; + auto [M,N,K,L] = problem_shape_MNKL; - Tensor tensorA = - make_tensor(args.ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensorB = - make_tensor(args.ptr_B, make_layout(make_shape(K, N, L), args.dB)); + Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K,N,L), args.dB)); typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); @@ -158,27 +174,34 @@ struct CollectiveMma - CUTLASS_DEVICE void operator()(FrgTensorD &accum, TensorA gA, TensorB gB, - FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int k_tile_count, - ResidueMNK residue_mnk, int thread_idx, - char *smem_buf, Params const &mainloop) { + template < + class FrgTensorD, + class TensorA, + class TensorB, + class FrgTensorC, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + char *smem_buf, + Params const& mainloop) + { (void)residue_mnk; (void)thread_idx; (void)smem_buf; - static_assert(is_rmem::value, - "D tensor must be rmem resident."); - static_assert( - is_tuple::value, - "A tensor must be a tuple iterator."); - static_assert( - is_tuple::value, - "B tensor must be a tuple iterator."); - static_assert(is_rmem::value, - "C tensor must be rmem resident."); + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_tuple::value, "A tensor must be a tuple iterator."); + static_assert(is_tuple::value, "B tensor must be a tuple iterator."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data Tensor tAr = make_tensor( @@ -190,7 +213,7 @@ struct CollectiveMma, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), - Shape, Int, Int>{}); + Shape, Int, Int>{}); Tensor tBr_view = make_tensor(static_cast(tBr).data(), Shape, Int, Int>{}); @@ -200,11 +223,9 @@ struct CollectiveMma(mainloop.gmem_tiled_copy_a.tensor); /* Cooperative prefetch - Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in - one row/col use the same tile A/B. - Each thread loads sizeof(tile A or B) / numof(sg_per_wg_n or - sg_per_wg_m). - + Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in one row/col use the same tile A/B. + Each thread loads sizeof(tile A or B) / numof(sg_per_wg_n or sg_per_wg_m). + Currently, sg_per_wg_m x sg_per_wg_n = 4 x 8 is the most efficient */ // TODO: Replace the demo cooperative prefetch with more general way. diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index a98c22ccf1..186c5196c1 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation @@ -18,24 +18,23 @@ * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - *POSSIBILITY OF SUCH DAMAGE. + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cutlass/cutlass.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" #include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/tensor.hpp" @@ -43,13 +42,19 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// -template +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> class GemmUniversal< - ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, TileScheduler_, - cute::enable_if_t>> { + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ public: // // Type Aliases @@ -57,7 +62,7 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); + "ProblemShape{} should be or "); // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; @@ -73,9 +78,8 @@ class GemmUniversal< using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; - static_assert(cute::is_void_v or - cute::is_same_v, - "Intel PVC does not support specializing the tile scheduler."); + static_assert(cute::is_void_v or cute::is_same_v, + "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< TileScheduler_, ArchTag, TileShape, @@ -85,15 +89,13 @@ class GemmUniversal< // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; + using StrideC = typename CollectiveEpilogue::StrideC; using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; + using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert( - cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); + static_assert(cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; @@ -139,31 +141,35 @@ class GemmUniversal< // Methods // - // Convert to underlying arguments. In this case, a simple copy for the - // aliased type. - static Params to_underlying_arguments(Arguments const &args, - void *workspace) { - (void)workspace; - return {args.mode, args.problem_shape, - CollectiveMainloop::to_underlying_arguments( - args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments( - args.problem_shape, args.epilogue, workspace)}; + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; } - static bool can_implement(Arguments const &args) { - bool mode_implementable = - args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + static bool + can_implement(Arguments const& args) { + bool mode_implementable = args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); return mode_implementable && TileScheduler::can_implement(args.scheduler); } - static int get_workspace_size(Arguments const &args) { return 0; } + static int + get_workspace_size(Arguments const& args) { + return 0; + } - static cutlass::Status - initialize_workspace(Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter *cuda_adapter = nullptr) { + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { return Status::kSuccess; } @@ -172,28 +178,24 @@ class GemmUniversal< auto N = get<1>(params.problem_shape); auto L = get<3>(params.problem_shape); - int const sg_m = - cute::ceil_div(M, - CollectiveMainloop::wg_tile_m); // sub_groups required to - // process A fragments - int const sg_n = - cute::ceil_div(N, - CollectiveMainloop::wg_tile_n); // sub_groups required to - // process B fragments + int const sg_m = cute::ceil_div(M, + CollectiveMainloop::wg_tile_m); // sub_groups required to + // process A fragments + int const sg_n = cute::ceil_div(N, + CollectiveMainloop::wg_tile_n); // sub_groups required to + // process B fragments return dim3(sg_n, sg_m, L); } static dim3 get_block_shape() { - return dim3(cute::ceil_div(CollectiveMainloop::wg_tile_n, - CollectiveMainloop::sg_tile_n / SubgroupSize), - cute::ceil_div(CollectiveMainloop::wg_tile_m, - CollectiveMainloop::sg_tile_m), - 1); + return dim3( + cute::ceil_div(CollectiveMainloop::wg_tile_n, CollectiveMainloop::sg_tile_n / SubgroupSize), + cute::ceil_div(CollectiveMainloop::wg_tile_m, CollectiveMainloop::sg_tile_m), 1); } CUTLASS_DEVICE - void operator()(Params const ¶ms, char *smem_buf) { + void operator()(Params const& params, char* smem_buf) { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); @@ -201,8 +203,7 @@ class GemmUniversal< CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only - // rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -210,21 +211,12 @@ class GemmUniversal< auto L = get<3>(problem_shape_MNKL); // Preconditions - static_assert(cute::rank(StrideA{}) == 3, - "StrideA must be rank-3: [M, K, L]. If batch mode is not " - "needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not " - "needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not " - "needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not " - "needed, set L stride to Int<0>."); - - // Get the appropriate blocks for this sub_group -- potential for sub_group - // locality + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); int thread_idy = int(ThreadIdxY()); @@ -255,20 +247,15 @@ class GemmUniversal< make_stride(_1{}, Int{})); // Compute tile residues for predication - auto m_max_coord = - M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord - auto n_max_coord = - N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord - auto k_residue = - K - get<2>(subgroup_shape) * - (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max + auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord + auto n_max_coord = N - get<1>(subgroup_shape) * n_coord; // N - SUB_N * n_coord + auto k_residue = K - get<2>(subgroup_shape) * (K / get<2>(subgroup_shape)); // K - SUB_K * k_coord_max auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape TiledMma tiled_mma; - Tensor accumulators = make_tensor( - Shape, Int, Int>{}); + Tensor accumulators = make_tensor(Shape, Int, Int>{}); clear(accumulators); int k_tile_count = cute::ceil_div(K, CollectiveMainloop::sg_tile_k); From c92adb380ace2f31e5da472a5dbacd02833e5946 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Mon, 15 Jul 2024 01:27:21 -0700 Subject: [PATCH 10/36] rm redundancy code --- include/cutlass/gemm/collective/intel_pvc_mma.hpp | 4 ++-- include/cutlass/relatively_equal.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 8c03e9e74c..f59fc10231 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -5,8 +5,8 @@ * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * - * 1. Redistributions of source code must retain the above copyright notice, - *this list of conditions and the following disclaimer. + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index a3eee0405c..6926921be5 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -71,7 +71,7 @@ bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { return true; } else if (a == zero || b == zero || diff < nonzero_floor) { - return diff < (epsilon * nonzero_floor) || (diff / abs_B) < (T)0.001f; + return diff < (epsilon * nonzero_floor); } return diff < epsilon * (abs_A + abs_B); From 6bdda755fe05ab040e691a2a78b3b481b13ed86e Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Tue, 16 Jul 2024 01:26:41 -0700 Subject: [PATCH 11/36] resolve conflict --- build.sh | 8 ++--- .../epilogue/collective/default_epilogue.hpp | 35 ------------------- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 2 -- 3 files changed, 4 insertions(+), 41 deletions(-) diff --git a/build.sh b/build.sh index c483462146..3ceb147e9e 100644 --- a/build.sh +++ b/build.sh @@ -2,8 +2,8 @@ script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) cp ${script_dir}/tools/clang-format/clang-format.hook ${script_dir}/.git/hooks/pre-commit chmod +x ${script_dir}/.git/hooks/pre-commit -# https://github.com/intel/llvm/releases/tag/nightly-2024-05-16 -sycl_compiler_path=/opt/cutlass/compiler/0516/ +# https://github.com/intel/llvm/releases/tag/nightly-2024-07-03 +sycl_compiler_path=/opt/cutlass/compiler/0703/ # https://ubit-gfx.intel.com/build/19168301/artifacts gpu_driver_path=/opt/cutlass/gpu_driver/gfx-driver-ci-comp_igc-25012/extract/ @@ -16,10 +16,10 @@ output=intel_gpu_pvc unset epilogue # epilogue relu -#epilogue+=" -DEPILOGUE_RELU " +# epilogue+=" -DEPILOGUE_RELU " # epilogue softmax -#epilogue+=" -DEPILOGUE_SOFTMAX " +# epilogue+=" -DEPILOGUE_SOFTMAX " export ZE_AFFINITY_MASK=0 export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/ diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index ebb7aabcd5..de24020265 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -147,41 +147,6 @@ class DefaultEpilogue { return epilogue_op.is_source_needed(); } -#ifdef EPILOGUE_RELU - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout> - CUTLASS_HOST_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor & accumulators){ - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - if (epilogue_op.is_source_needed()) { - auto source = make_fragment_like(accumulators); - auto gmem_tiled_copy_c = - make_xe_2d_copy(make_tensor( - params.ptr_C, make_shape(M, N, L), params.dC)); - - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), - make_shape(size<1>(accumulators), size<2>(accumulators), L), - make_stride(size<0>(blk_shape_MNK), size<1>(blk_shape_MNK))); - copy(gmem_tiled_copy_c, tCi(_, _, _, l_coord), source); - epilogue_op(accumulators, source); - } else { - epilogue_op(accumulators); - } - } -#endif - template< class ProblemShapeMNKL, class BlockShapeMNK, diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 186c5196c1..ebfd0b5b29 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -197,8 +197,6 @@ class GemmUniversal< CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - // Preconditions CUTE_STATIC_ASSERT(is_static::value); From 349659315250e4ad5a33646844cb9d9542434b77 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 17 Jul 2024 01:22:42 -0700 Subject: [PATCH 12/36] revert the change of nv hpp --- include/cutlass/relatively_equal.h | 2 +- .../include/cutlass/util/reference/device/tensor_compare.h | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 6926921be5..fd900b6605 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -71,7 +71,7 @@ bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { return true; } else if (a == zero || b == zero || diff < nonzero_floor) { - return diff < (epsilon * nonzero_floor); + return diff < epsilon * nonzero_floor; } return diff < epsilon * (abs_A + abs_B); diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index de96d53122..34e66e8bc0 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -95,18 +95,15 @@ __global__ void size_t idx = ThreadIdxX() + BlockDimX() * BlockIdxX(); - //for (; idx < capacity; idx += GridDimX() * BlockDimX()) { - if (idx < capacity ){ + for (; idx < capacity; idx += GridDimX() * BlockDimX()) { Element a = cutlass::ReferenceFactory::get(ptr_A, idx); Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { *equal = 0; - //printf("error, idx at: %lu, capacity: %lu, a: %f, b: %f\n", idx, capacity, a, b); return; } } - // } } } // namespace kernel From 69d5c2a179a43617f927c5d09f5df9f13e3537b5 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 17 Jul 2024 01:25:08 -0700 Subject: [PATCH 13/36] Restore invalid changes --- .../include/cutlass/util/reference/device/tensor_compare.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 34e66e8bc0..3c312f5ff8 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -96,6 +96,7 @@ __global__ void size_t idx = ThreadIdxX() + BlockDimX() * BlockIdxX(); for (; idx < capacity; idx += GridDimX() * BlockDimX()) { + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); Element b = cutlass::ReferenceFactory::get(ptr_B, idx); @@ -238,7 +239,7 @@ bool BlockCompareRelativelyEqual( #if defined (CUTLASS_ENABLE_SYCL) block_size = 128; grid_size = (capacity + block_size - 1) / block_size; - //grid_size = (grid_size < 64 ? grid_size : 64); // limit grid size to avoid out_of_resources runtime error. + grid_size = (grid_size < 64 ? grid_size : 64); // limit grid size to avoid out_of_resources runtime error. #else // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API cudaError_t result = cudaOccupancyMaxPotentialBlockSize( From 962766b08223bec1d2416cea527583709aa51a0c Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 17 Jul 2024 23:07:16 -0700 Subject: [PATCH 14/36] refine gemm interface will codeplay epilogue --- examples/sycl/pvc/pvc_gemm.cpp | 201 +++++------------- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 31 +-- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 75 ++++--- 3 files changed, 111 insertions(+), 196 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 635dca4dd1..a278d5329a 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -33,6 +33,8 @@ #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" +#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" #include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -51,12 +53,7 @@ template static void fill_matrix(std::vector& M) { std::random_device dev; std::mt19937 rng(dev()); - std::uniform_real_distribution dist((T)0.0, -#ifdef EPILOGUE_SOFTMAX - (T)0.1); -#else - (T)1.0); -#endif + std::uniform_real_distribution dist((T)0.0, (T)1.0); std::generate(std::begin(M), std::end(M), [&] { return static_cast(dist(rng)); }); } @@ -209,66 +206,6 @@ template struct ExampleRunner { M * N // batch_stride_D ); -#ifdef EPILOGUE_SOFTMAX - - ElementOutput* ptr = (ElementOutput*)std::malloc(M * N * L * sizeof(ElementOutput)); - syclcompat::memcpy(ptr, block_ref_D.get(), M * N * L * sizeof(ElementOutput)); - syclcompat::wait(); - for (int l = 0; l < L; l++) { - for (int i = 0; i < M; i++) { - auto row_idx = l * M * N + i * N; - auto row_max = ptr[l * M * N + i * N]; - - ElementOutput exp_sum = (ElementOutput)0; - for (int j = 0; j < N; j++) { - auto idx = row_idx + j; - row_max = max(row_max, ptr[idx]); - } - for (int j = 0; j < N; j++) { - auto idx = row_idx + j; - ptr[idx] = ptr[idx] - row_max; - ptr[idx] = exp(ptr[idx]); - exp_sum += ptr[idx]; - } - - for (int j = 0; j < N; j++) { - auto idx = row_idx + j; - ptr[idx] = ptr[idx] / exp_sum; - } - } - } - - syclcompat::memcpy(block_ref_D.get(), ptr, M * N * L * sizeof(ElementOutput)); - syclcompat::wait(); - - std::free(ptr); - -#endif - -#if 0 - ElementOutput *ptr = - (ElementOutput *)std::malloc(M * N * L * sizeof(ElementOutput)); - - syclcompat::memcpy(ptr, block_D.get(), M * N * L * sizeof(ElementOutput)); - - ElementOutput *ptr_refD = - (ElementOutput *)std::malloc((size_t)M * N * L * sizeof(ElementOutput)); - syclcompat::memcpy(ptr_refD, block_ref_D.get(), - (size_t)M * N * L * sizeof(ElementOutput)); - syclcompat::wait(); - for (int b = 0; b < L; b++) { - for (int i = 0; i < M; i++) { - for (int j = 0; j < N; j++) { - int idx = b * M * N + i * N + j; - if (abs(ptr[idx] - ptr_refD[idx]) / ptr_refD[idx] >= 0.01f) - std::cout << "(" << b << ", " << i << ", " << j << "): " << "host: " << ptr[idx] - << " and device: " << ptr_refD[idx] << std::endl; - } - } - } - std::free(ptr); - std::free(ptr_refD); -#endif syclcompat::wait(); // Check if output from CUTLASS kernel and reference kernel are relatively @@ -378,22 +315,10 @@ template struct ExampleRunner { // Verify that the result is correct bool passed = verify(problem_size, 1, 0.f); if (!passed) { - printf("PVC GEMM%s%s Example %s, MKNL(%d, %d,%d,%d), Config(%d, " - "%d,%d,%d,%d) !!!!!!!!!!!!!\n\n", -#ifdef EPILOGUE_RELU - "-relu" -#else - "" -#endif - , -#ifdef EPILOGUE_SOFTMAX - "-softmax" -#else - "" -#endif - , - (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, - sg_tile_k); + printf("PVC GEMM Example %s, MKNL(%d, %d,%d,%d), Config(%d, " + "%d,%d,%d,%d) !!!!!!!!!!!!!\n\n", + (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, + sg_tile_k); // return; } @@ -444,22 +369,14 @@ template struct ExampleRunner { M * N * sizeof(ElementOutput)) * 1e-9; - printf("Collective pvc gemm%s, MKNL(%d, %d, %d, %d), Config(%d, %d, " + printf("Collective pvc gemm, MKNL(%d, %d, %d, %d), Config(%d, %d, " "%d, %d, %d):\n max: (%6.4f)ms, (%4.2f)TFlop/s, " "(%4.2f)GB/s\n min: (%6.4f)ms, (%4.2f)TFlop/s, " "(%4.2f)GB/s\n average: (%6.4f)ms, (%4.2f)TFlop/s, " "(%4.2f)GB/s\n\n\n", -#if defined(EPILOGUE_RELU) - "-relu" -#elif defined(EPILOGUE_SOFTMAX) - "softmax" -#else - "" -#endif - , - M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, best * 1000, - tflops / best, hbm / best, worst * 1000, tflops / worst, hbm / worst, average * 1000, - tflops / average, hbm / average); + M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, best * 1000, + tflops / best, hbm / best, worst * 1000, tflops / worst, hbm / worst, average * 1000, + tflops / average, hbm / average); } } }; @@ -506,6 +423,14 @@ void collective_gemm(int M, int K, int N, int L = 1) { bool passed; + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + // The code section below describes datatype for input, output matrices and // computation between elements in input matrices. @@ -517,54 +442,43 @@ void collective_gemm(int M, int K, int N, int L = 1) { using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; - using TileShape = - Shape, Int, Int, Int, Int>; - - using TiledMma = TiledMMA, Layout>>; - - using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; -#ifdef EPILOGUE_RELU - using EpilogueOp = - cutlass::epilogue::thread::LinearCombinationRelu::value, // <- the number of - // elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear - -#else - using EpilogueOp = - cutlass::epilogue::thread::LinearCombination::value, // <- the number of - // elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear - // combination function -#endif + using TileShape = Shape<_256, _256, _32>; + // using TileShape = + // Shape, Int, Int, Int, Int>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_32,_64,_32>>; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma, ElementInputB, cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, void, cute::identity, // A GmemTiledCopyB, void, void, cute::identity // B >; -#ifdef EPILOGUE_SOFTMAX - using CollectiveEpilogue = cutlass::epilogue::collective::PvcEpilogueTensorSoftmax< - cutlass::gemm::TagToStrideC_t, cutlass::gemm::TagToStrideC_t, EpilogueOp, - cutlass::gemm::EpilogueDefault, CollectiveMainloop::sg_tile_m, - CollectiveMainloop::sg_tile_n / CollectiveMainloop::SubgroupSize>; -#else - using CollectiveEpilogue = - cutlass::epilogue::collective::DefaultEpilogue, - cutlass::gemm::TagToStrideC_t, EpilogueOp, cutlass::gemm::EpilogueDefault>; -#endif + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16x1x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; using GemmKernel = cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue>; @@ -578,7 +492,6 @@ void collective_gemm(int M, int K, int N, int L = 1) { int main() { auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); -#if !defined(EPILOGUE_RELU) && !defined(EPILOGUE_SOFTMAX) collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); collective_gemm<256, 256, 32, 64, 32>(8192, 8192, 8192); collective_gemm<256, 256, 32, 64, 32>(1, 5120, 13824); @@ -605,20 +518,4 @@ int main() { collective_gemm<256, 256, 32, 64, 32>(32768, 128, 4096, 4); collective_gemm<256, 256, 32, 64, 32>(32768, 4096, 128, 4); collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 128, 32); -#endif - -#if defined(EPILOGUE_SOFTMAX) - // gemm + softmax - collective_gemm<64, 1024, 16, 64, 32>(1024, 64, 1024, 4); - collective_gemm<128, 512, 16, 64, 32>(512, 64, 512, 32); - collective_gemm<64, 1024, 16, 64, 32>(1024, 64, 1024, 16); - collective_gemm<32, 2048, 16, 64, 16>(2048, 64, 2048, 8); - collective_gemm<16, 4096, 16, 64, 32>(4096, 64, 4096, 4); - collective_gemm<8, 8192, 8, 128, 16>(8192, 64, 8192, 2); -#endif - -#if defined(EPILOGUE_RELU) - // gemm + relu - collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); -#endif } diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index f59fc10231..0ddaa0265f 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -83,7 +83,7 @@ struct CollectiveMma< // Type Aliases // using DispatchPolicy = MainloopIntelPVCUnpredicated; - using TileShape = TileShape_; + using WorkgroupTileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; using ElementB = ElementB_; @@ -100,12 +100,16 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; - TileShape tile_shape; - static constexpr auto wg_tile_m = decltype(get<0>(tile_shape))::value; - static constexpr auto wg_tile_n = decltype(get<1>(tile_shape))::value; - static constexpr auto sg_tile_m = decltype(get<2>(tile_shape))::value; - static constexpr auto sg_tile_n = decltype(get<3>(tile_shape))::value; - static constexpr auto sg_tile_k = decltype(get<4>(tile_shape))::value; + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + using SubgroupTileShape = decltype(tile_shape(TiledMma())); + WorkgroupTileShape wg_tile_shape; + SubgroupTileShape sg_tile_shape; + + static constexpr auto wg_tile_m = decltype(get<0>(wg_tile_shape))::value; + static constexpr auto wg_tile_n = decltype(get<1>(wg_tile_shape))::value; + static constexpr auto sg_tile_m = decltype(get<0>(sg_tile_shape))::value; + static constexpr auto sg_tile_n = decltype(get<1>(sg_tile_shape))::value; + static constexpr auto sg_tile_k = decltype(get<2>(sg_tile_shape))::value; static constexpr auto sg_per_wg_m = wg_tile_m / sg_tile_m; static constexpr auto sg_per_wg_n = wg_tile_n / sg_tile_n; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -119,8 +123,9 @@ struct CollectiveMma< static constexpr int DpasK = get<1>( shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per // sub_group for Matrix A + static constexpr uint32_t MaxThreadsPerBlock = + cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; - static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN; static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static constexpr int FragsM = sg_tile_m / DpasM; // A frags per sub_group @@ -204,13 +209,13 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor( - Shape, Int<1>>{}); - constexpr int version = is_same_v ? 1 : 2; - Tensor tBr = make_tensor( - Shape, Int>{}); + + Tensor tAr = make_tensor(Shape, Int<1>>{}); + Tensor tBr = make_tensor(Shape, Int>{}); + + Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index ebfd0b5b29..eed45a77f3 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -66,7 +66,8 @@ class GemmUniversal< // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; + using TileShape = typename CollectiveMainloop::WorkgroupTileShape; + using WorkgroupTileShape = TileShape; using TiledMma = typename CollectiveMainloop::TiledMma; using ArchTag = typename CollectiveMainloop::ArchTag; using ElementA = typename CollectiveMainloop::ElementA; @@ -82,7 +83,7 @@ class GemmUniversal< "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, TileShape, + TileScheduler_, ArchTag, WorkgroupTileShape, cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; @@ -119,6 +120,12 @@ class GemmUniversal< static constexpr int VecC = CollectiveMainloop::VecC; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + // Device side arguments struct Arguments { GemmUniversalMode mode{}; @@ -196,7 +203,7 @@ class GemmUniversal< CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { - + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions CUTE_STATIC_ASSERT(is_static::value); @@ -229,6 +236,7 @@ class GemmUniversal< BlockIdxX() * CollectiveMainloop::wg_tile_n + (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; const int l_coord = BlockIdxZ(); + const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}), @@ -261,34 +269,39 @@ class GemmUniversal< // Perform the collective scoped MMA CollectiveMainloop collective_mma; - collective_mma(accumulators, tAi(_, _, _, 0), tBi(_, _, _, 0), accumulators, - k_tile_iter, k_tile_count, residue_mnk, thread_idx, smem_buf, - params.mainloop); - -#ifdef EPILOGUE_RELU - // relu - CollectiveEpilogue collective_relu(params.epilogue); - collective_relu(problem_shape_MNKL, - make_shape(Int{}, Int{}, Int{}), - make_coord(m_coord, n_coord, 0, l_coord), accumulators); -#endif - -#ifdef EPILOGUE_SOFTMAX - // softmax - CollectiveEpilogue collective_softmax; - collective_softmax(accumulators); -#endif - - auto gmem_tiled_copy_c = - make_xe_2d_copy(make_tensor( - params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); - - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), - make_shape(Int{}, Int{}, _1{}), - make_stride(Int{}, Int{})); - - copy(gmem_tiled_copy_c, accumulators, tCi(_, _, _, 0)); + collective_mma( + accumulators, + tAi(_,_,_,0), + tBi(_,_,_,0), + accumulators, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + smem_buf, + params.mainloop + ); + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + epilogue( + problem_shape_MNKL, + subgroup_shape, + tile_coord, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + + // auto gmem_tiled_copy_c = + // make_xe_2d_copy(make_tensor( + // params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); + + // Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( + // make_coord(m_coord, n_coord, l_coord), + // make_shape(Int{}, Int{}, _1{}), + // make_stride(Int{}, Int{})); + + // copy(gmem_tiled_copy_c, accumulators, tCi(_, _, _, 0)); } }; From 7739df61073d736a9001eb5caf7e0a50ba94a531 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 18 Jul 2024 00:38:52 -0700 Subject: [PATCH 15/36] fix the issue of batch gemm --- examples/sycl/pvc/pvc_gemm.cpp | 1 - include/cute/atom/copy_traits_xe.hpp | 6 +++--- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 10 +++++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index a278d5329a..93d6dc78de 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -491,7 +491,6 @@ void collective_gemm(int M, int K, int N, int L = 1) { } int main() { - auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); collective_gemm<256, 256, 32, 64, 32>(8192, 8192, 8192); collective_gemm<256, 256, 32, 64, 32>(1, 5120, 13824); diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 7562dbc144..a66f8879f8 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -51,7 +51,7 @@ struct XE_2D_LD_Unpack { int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); auto [y, x, z] = src.data().coord_; - CopyOp::copy(traits.tensor.data() + z * W * H / sizeof(typename Copy_Traits::CopyInternalType), W, H, W, intel::coord_t {x, y}, + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t {x, y}, &*dst.data()); } @@ -99,7 +99,7 @@ struct XE_2D_PF_Unpack { int H = size<0>(traits.tensor); int W = size<1>(traits.tensor) * sizeof(T); auto [y, x, z] = src.data().coord_; - CopyOp::template copy(traits.tensor.data() + z * W * H / sizeof(T), W, H, W, + CopyOp::template copy(traits.tensor.data() + z, W, H, W, intel::coord_t {static_cast(x), static_cast(y)}); } }; @@ -416,7 +416,7 @@ struct XE_2D_ST_Unpack { * sizeof(typename Copy_Traits::CopyInternalType); auto [y, x, z] = dst.data().coord_; - CopyOp::copy(traits.tensor.data() + z * W * H / sizeof(typename Copy_Traits::CopyInternalType), W, H, W, intel::coord_t {x, y}, + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t {x, y}, &*src.data()); } diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index eed45a77f3..66266f7a44 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -239,7 +239,7 @@ class GemmUniversal< const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, l_coord), make_shape(_1{}, K, _1{}), + make_coord(m_coord, 0, 0), make_shape(_1{}, K, L), make_stride(Int{}, _1{})); constexpr int version = is_same_v{}, _1{}), + make_coord(0, n_coord, 0), + make_shape(K, Int{}, L), make_stride(_1{}, Int{})); // Compute tile residues for predication @@ -271,8 +271,8 @@ class GemmUniversal< CollectiveMainloop collective_mma; collective_mma( accumulators, - tAi(_,_,_,0), - tBi(_,_,_,0), + tAi(_,_,_,l_coord), + tBi(_,_,_,l_coord), accumulators, k_tile_iter, k_tile_count, residue_mnk, From 5b1f514dcb26f811d4b30e3ff2bfab3f7bfd1538 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Tue, 23 Jul 2024 00:52:16 -0700 Subject: [PATCH 16/36] rm epilogue and revert gemm example --- build.sh | 40 -- examples/sycl/pvc/pvc_gemm.cpp | 437 +++++++----------- .../epilogue/collective/default_epilogue.hpp | 33 -- .../intel_pvc_epilogue_tensor_softmax.hpp | 157 ------- .../epilogue/thread/linear_combination_relu.h | 11 - 5 files changed, 164 insertions(+), 514 deletions(-) delete mode 100644 build.sh delete mode 100644 include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp diff --git a/build.sh b/build.sh deleted file mode 100644 index 3ceb147e9e..0000000000 --- a/build.sh +++ /dev/null @@ -1,40 +0,0 @@ -script_dir=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cp ${script_dir}/tools/clang-format/clang-format.hook ${script_dir}/.git/hooks/pre-commit -chmod +x ${script_dir}/.git/hooks/pre-commit - -# https://github.com/intel/llvm/releases/tag/nightly-2024-07-03 -sycl_compiler_path=/opt/cutlass/compiler/0703/ - -# https://ubit-gfx.intel.com/build/19168301/artifacts -gpu_driver_path=/opt/cutlass/gpu_driver/gfx-driver-ci-comp_igc-25012/extract/ - -# AOT compile -output=intel_gpu_pvc -# jit compile -#output=spir64 - -unset epilogue - -# epilogue relu -# epilogue+=" -DEPILOGUE_RELU " - -# epilogue softmax -# epilogue+=" -DEPILOGUE_SOFTMAX " - -export ZE_AFFINITY_MASK=0 -export CPATH=$sycl_compiler_path:$sycl_compiler_path/include/:$sycl_compiler_path/include/sycl/ -export LIBRARY_PATH=$gpu_driver_path/usr/lib/x86_64-linux-gnu/:$sycl_compiler_path/lib/ -export LD_LIBRARY_PATH=$LIBRARY_PATH -export IGC_EnableVISANoSchedule=1 -export IGC_ShaderDumpEnable=1 -export IGC_DumpToCustomDir=./mm_dumps -export IGC_VATemp=1 -export ONEAPI_DEVICE_SELECTOR=level_zero:gpu - -target=./examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute -rm -rf * - -cmake .. -G Ninja -DCMAKE_CUDA_HOST_COMPILER=${sycl_compiler_path}/bin/clang++ \ --DCUTLASS_ENABLE_SYCL=ON -DDPCPP_SYCL_TARGET=$output -DCMAKE_CXX_COMPILER=${sycl_compiler_path}/bin/clang++ \ --DCMAKE_CXX_FLAGS=" -DPREFETCH_DEFAULT -DSYCL_INTEL_TARGET ${epilogue} " \ -&& ninja -v $target && $target diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 93d6dc78de..51c44d6a79 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -32,77 +32,73 @@ #define CUTLASS_SYCLCOMPAT_PROFILING_ENABLED #include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" #include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" -#include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/util/GPU_Clock.hpp" #include #include -#include "cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -template static void fill_matrix(std::vector& M) { - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist((T)0.0, (T)1.0); - std::generate(std::begin(M), std::end(M), [&] { return static_cast(dist(rng)); }); +template +static void fill_matrix(std::vector &vector) +{ + std::generate(std::begin(vector), std::end(vector), [&] { + return static_cast( (rand() / double(RAND_MAX)) ); + }); } template -static void vnni_matrix(T* dst, T const* src, int batch, int numRows, int numCols, int factor) { - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } +static void vnni_matrix( + T* dst, const T* src, + int batch, int numRows, int numCols, int factor) +{ + for (int b = 0; b < batch; b++) { + for (int r = 0; r < numRows / factor; r++) { + for (int c = 0; c < numCols; c++) { + for (int k = 0; k < factor; k++) { + dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = + src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; + } + } } } - } } using namespace cute; -using ElementAccumulator = float; // <- data type of accumulator -using ElementComputeEpilogue = float; // <- data type of epilogue operations -using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A -using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B -using ElementOutput = float; // <- data type of elements in output matrix D - /////////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing struct Options { - bool help; bool error; - int m, n, k, l, iterations; float alpha, beta; - Options() - : help(false), error(false), m(4096), n(4096), k(4096), l(1), iterations(100), alpha(1.f), - beta(0.f) {} + Options(): + help(false), + error(false), + m(4096), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } // Parses the command line - void parse(int argc, char const** args) { + void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); if (cmd.check_cmd_line_flag("help")) { help = true; return; } - cmd.get_cmd_line_argument("m", m, 4096); cmd.get_cmd_line_argument("n", n, 4096); cmd.get_cmd_line_argument("k", k, 4096); @@ -113,20 +109,18 @@ struct Options { } /// Prints the usage statement. - std::ostream& print_usage(std::ostream& out) const { + std::ostream & print_usage(std::ostream &out) const { out << "PVC GEMM Example\n\n" - << "Options:\n\n" - << " --help If specified, displays this " - "usage statement\n\n" - << " --m= Sets the M extent of the GEMM\n" - << " --n= Sets the N extent of the GEMM\n" - << " --k= Sets the K extent of the GEMM\n" - << " --l= Sets the L extent (batch count) " - "of the GEMM\n" - << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Iterations\n\n"; + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; return out; } @@ -134,34 +128,31 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template struct ExampleRunner { +template < + class Gemm +> +struct ExampleRunner { using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; - using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; - using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementAcc = typename Gemm::ElementAccumulator; - using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - // // Data members // - /// Initialization StrideA stride_A; StrideB stride_B; @@ -170,256 +161,181 @@ template struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - // cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_B_vnni; + cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; - static auto constexpr l3_cache_size = 256 * 1024 * 1024; - - size_t PINGPONG_ITER = 1; - size_t pingpong_size_a; - size_t pingpong_size_b; - size_t pingpong_size_d; - - std::vector a; - std::vector b; - std::vector d; // // Methods // - bool verify(ProblemShapeType const& problem_size, ElementCompute alpha, ElementCompute beta) { + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { auto [M, N, K, L] = problem_size; cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C((ElementC*)nullptr /*block_C.get()*/, LayoutC::packed({M, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); cutlass::reference::device::GemmComplex( - {M, N, K}, alpha, ref_A, cutlass::ComplexTransform::kNone, ref_B, - cutlass::ComplexTransform::kNone, beta, ref_C, ref_D, ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); syclcompat::wait(); - // Check if output from CUTLASS kernel and reference kernel are relatively - // equal or not need to set a larger error margin for comparison to succeed + // Check if output from CUTLASS kernel and reference kernel are relatively equal or not + // need to set a larger error margin for comparison to succeed + auto epsilon = static_cast(0.1f); + auto nonzero_floor = static_cast(0.1f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( - block_ref_D.get(), block_D.get(), M * N * L, 0.5f, 0.5f); + block_ref_D.get(), block_D.get(), block_D.size(), + epsilon, nonzero_floor); return passed; } - void init_cache_clear(ProblemShapeType const& problem_size) { + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - pingpong_size_a = max((size_t)M * K * L, l3_cache_size / sizeof(ElementA)); - pingpong_size_b = max((size_t)K * N * L, l3_cache_size / sizeof(ElementB)); - pingpong_size_d = max((size_t)M * N * L, l3_cache_size / sizeof(ElementOutput)); - auto gmem_size = syclcompat::get_current_device().get_global_mem_size(); - PINGPONG_ITER = std::min((size_t)3, - std::max((size_t)1, (size_t)gmem_size / ((pingpong_size_a * sizeof(ElementA) + - pingpong_size_b * sizeof(ElementB) + - pingpong_size_d * sizeof(ElementOutput))) - - 1)); - block_A.reset(pingpong_size_a * PINGPONG_ITER); - block_B.reset(pingpong_size_b * PINGPONG_ITER); - // block_C.reset(M * N * L * ITER); - block_D.reset(pingpong_size_d * PINGPONG_ITER); - - for (int i = 0; i < PINGPONG_ITER; i++) { - syclcompat::memcpy( - block_A.get() + i * pingpong_size_a, a.data(), a.size() * sizeof(ElementA)); - syclcompat::memcpy( - block_B.get() + i * pingpong_size_b, b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy( - block_D.get() + i * pingpong_size_d, d.data(), d.size() * sizeof(ElementC)); - } - // syclcompat::wait(); - } - - /// Initialize operands to be used in the GEMM and reference GEMM - void initialize(ProblemShapeType const& problem_size) { - auto [M, N, K, L] = problem_size; - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - block_A.reset((size_t)M * K * L); - block_B.reset((size_t)K * N * L); - // block_C.reset(M * N * L); - block_D.reset((size_t)M * N * L); - block_ref_D.reset((size_t)max(l3_cache_size / sizeof(ElementOutput), (size_t)M * N * L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_B_vnni.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); // TODO: Enable initialization on device directly once RNG is // available through SYCL. - a = std::vector((size_t)M * K * L); - b = std::vector((size_t)K * N * L); - d = std::vector((size_t)M * N * L, ElementC{0}); - std::cout << "random generating..." << std::endl; + std::vector a(K * M * L); + std::vector b(K * N * L); + std::vector b_vnni(b.size()); + std::vector c(M * N * L); + std::vector d(M * N * L, ElementC{0}); + fill_matrix(a); fill_matrix(b); + fill_matrix(c); + vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); + syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - // syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); + syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); + syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } - template - void run(int M, int K, int N, int L, cutlass::KernelHardwareInfo const& hw_info) { - static auto constexpr warmup = 10; - static auto constexpr testIterations = 10; - static auto constexpr total_iterations = warmup + testIterations; - ProblemShapeType problem_size = ProblemShapeType{M, N, K, L}; + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + initialize(problem_size); sycl::property_list prop = { sycl::property::queue::in_order(), sycl::property::queue::enable_profiling() }; - auto q = sycl::queue(syclcompat::get_default_context(), syclcompat::get_current_device(), prop); syclcompat::set_default_queue(q); typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get(), stride_A, block_B.get(), stride_B}, - {{1, 0.f}, - nullptr /*block_C.get()*/, - stride_C, - block_D.get(), - stride_D}, - hw_info}; - Gemm gemm_op_verify; + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); cutlass::device_memory::allocation workspace(workspace_size); - gemm_op_verify.can_implement(arguments); + gemm_op.can_implement(arguments); - gemm_op_verify.initialize(arguments, workspace.get()); + gemm_op.initialize(arguments, workspace.get()); // Run the GEMM - gemm_op_verify.run(); + gemm_op.run(); + syclcompat::wait(); // Verify that the result is correct - bool passed = verify(problem_size, 1, 0.f); - if (!passed) { - printf("PVC GEMM Example %s, MKNL(%d, %d,%d,%d), Config(%d, " - "%d,%d,%d,%d) !!!!!!!!!!!!!\n\n", - (passed ? "Passed" : "Failed"), M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, - sg_tile_k); - // return; - } - - // ================ init cache clear ================ - if constexpr (cache_clear) { - init_cache_clear(problem_size); - } - - // ================ run and collect performance data ================ - if (total_iterations > 0) { - auto total_time = 0.f; - auto best = 999.f; - auto worst = 0.f; - - for (int i = 0; i < testIterations + warmup; ++i) { - typename Gemm::GemmKernel::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get() + (i % PINGPONG_ITER) * pingpong_size_a, stride_A, - block_B.get() + (i % PINGPONG_ITER) * pingpong_size_b, stride_B}, - {{1, 0.f}, nullptr /*block_C.get() + i * M * N * L*/, stride_C, - block_D.get() + (i % PINGPONG_ITER) * pingpong_size_d, stride_D}, - hw_info}; - - Gemm gemm_op; - gemm_op.can_implement(arguments); - gemm_op.initialize(arguments, workspace.get()); - - GPU_Clock timer; - timer.start(); + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + // timer.start(); + for (int i = 0; i < options.iterations; ++i) { + if (i == 10) timer.start(); gemm_op.run(); - syclcompat::wait(); - - auto current_time = timer.seconds(); - if (i >= warmup) { - total_time += current_time; - - best = min(best, current_time); - - worst = max(worst, current_time); - } } + syclcompat::wait(); - float average = total_time / testIterations; - double tflops = (2.0 * M * N * K * L) * 1e-12; - - double hbm = L * - (M * K * sizeof(ElementInputA) + K * N * sizeof(ElementInputB) + - M * N * sizeof(ElementOutput)) * - 1e-9; - - printf("Collective pvc gemm, MKNL(%d, %d, %d, %d), Config(%d, %d, " - "%d, %d, %d):\n max: (%6.4f)ms, (%4.2f)TFlop/s, " - "(%4.2f)GB/s\n min: (%6.4f)ms, (%4.2f)TFlop/s, " - "(%4.2f)GB/s\n average: (%6.4f)ms, (%4.2f)TFlop/s, " - "(%4.2f)GB/s\n\n\n", - M, K, N, L, wg_tile_m, wg_tile_n, sg_tile_m, sg_tile_n, sg_tile_k, best * 1000, - tflops / best, hbm / best, worst * 1000, tflops / worst, hbm / worst, average * 1000, - tflops / average, hbm / average); + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); } + + return; } + }; -template -void collective_gemm(int M, int K, int N, int L = 1) { +int main(int argc, const char** argv) +{ // // Parse options // Options options; - // options.parse(argc, argv); + options.parse(argc, argv); if (options.help) { options.print_usage(std::cout) << std::endl; - return; + return 0; } if (options.error) { std::cerr << "Aborting execution." << std::endl; - return; + return -1; } // // Run examples // - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a - // given device ID. This information is used by the underlying kernel. + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. cutlass::KernelHardwareInfo hw_info; - // Change device_id to another value if you are running on a machine with - // multiple GPUs and wish to use a GPU other than that with device ID 0. - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); bool passed; @@ -431,9 +347,6 @@ void collective_gemm(int M, int K, int N, int L = 1) { using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D - // The code section below describes datatype for input, output matrices and - // computation between elements in input matrices. - using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; @@ -443,13 +356,11 @@ void collective_gemm(int M, int K, int N, int L = 1) { using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; using TileShape = Shape<_256, _256, _32>; - // using TileShape = - // Shape, Int, Int, Int, Int>; using TiledMma = TiledMMA, Layout>, Tile<_32,_64,_32>>; - + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; @@ -458,63 +369,43 @@ void collective_gemm(int M, int K, int N, int L = 1) { using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; - - // Mainloop - using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma, ElementInputB, - cutlass::gemm::TagToStrideB_t, TiledMma, GmemTiledCopyA, void, void, - cute::identity, // A - GmemTiledCopyB, void, void, cute::identity // B - >; - using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< - EpilogueDispatchPolicy, - TileShape, - ElementAccumulator, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - FusionCallBacks, - XE_2D_U32x8x16x1x1_LD_N, - void, void, - XE_2D_U32x8x16x1x1_ST_N, - void, void>; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, - CollectiveMainloop, CollectiveEpilogue>; + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16x1x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - ExampleRunner runner; + ExampleRunner runner; - runner.template run(M, K, N, L, hw_info); -} + runner.run(options, hw_info); -int main() { - collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 4096); - collective_gemm<256, 256, 32, 64, 32>(8192, 8192, 8192); - collective_gemm<256, 256, 32, 64, 32>(1, 5120, 13824); - collective_gemm<256, 256, 32, 64, 32>(1024, 28672, 8192); - collective_gemm<256, 256, 32, 64, 32>(3072, 4096, 3072); - collective_gemm<256, 256, 32, 64, 32>(4, 4096, 12288); - - // collective shape from habana - collective_gemm<256, 256, 32, 64, 32>(512, 8192, 8192); - collective_gemm<256, 256, 32, 64, 32>(512, 8192, 32768); - collective_gemm<256, 256, 32, 64, 32>(512, 32768, 8192); - collective_gemm<256, 256, 32, 64, 32>(16384, 8192, 1024); - collective_gemm<256, 256, 32, 64, 32>(16384, 1024, 8192); - collective_gemm<256, 256, 32, 64, 32>(16384, 8192, 4096); - collective_gemm<256, 256, 32, 64, 32>(16384, 4096, 8192); - collective_gemm<256, 256, 32, 64, 32>(4096, 16384, 8192); - collective_gemm<256, 256, 32, 64, 32>(8192, 16384, 4096); - collective_gemm<256, 256, 32, 64, 32>(1024, 16384, 8192); - collective_gemm<256, 256, 32, 64, 32>(8192, 16384, 1024); - - collective_gemm<256, 256, 32, 64, 32>(8, 128, 16384, 4096); - collective_gemm<16, 512, 16, 16, 32>(8, 16384, 128, 4096); - - collective_gemm<256, 256, 32, 64, 32>(32768, 128, 4096, 4); - collective_gemm<256, 256, 32, 64, 32>(32768, 4096, 128, 4); - collective_gemm<256, 256, 32, 64, 32>(4096, 4096, 128, 32); -} + return 0; +} \ No newline at end of file diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index de24020265..bbeeacacd3 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -147,39 +147,6 @@ class DefaultEpilogue { return epilogue_op.is_source_needed(); } - template< - class ProblemShapeMNKL, - class BlockShapeMNK, - class BlockCoordMNKL, - class FrgEngine, class FrgLayout> - CUTLASS_HOST_DEVICE void - operator()( - ProblemShapeMNKL problem_shape_mnkl, - BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, - cute::Tensor & accumulators){ - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - if (epilogue_op.is_source_needed()) { - auto source = make_fragment_like(accumulators); - auto gmem_tiled_copy_c = - make_xe_2d_copy(make_tensor( - params.ptr_C, make_shape(M, N, L), params.dC)); - - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( - make_coord(m_coord, n_coord, l_coord), - make_shape(size<1>(accumulators), size<2>(accumulators), L), - make_stride(size<0>(blk_shape_MNK), size<1>(blk_shape_MNK))); - copy(gmem_tiled_copy_c, tCi(_, _, _, l_coord), source); - epilogue_op(accumulators, source); - } else { - epilogue_op(accumulators); - } - } - template< class ProblemShapeMNKL, class BlockShapeMNK, diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp deleted file mode 100644 index 01bd25b7ec..0000000000 --- a/include/cutlass/epilogue/collective/intel_pvc_epilogue_tensor_softmax.hpp +++ /dev/null @@ -1,157 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" - -#include "cute/tensor.hpp" -#include "cutlass/cuda_host_adapter.hpp" -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace collective { -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class PvcEpilogueTensorSoftmax { -public: - using EpilogueSchedule = EpilogueSchedule_; - using DispatchPolicy = EpilogueSchedule_; - - // derived types of output thread level operator - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using ElementD = typename ThreadEpilogueOp::ElementD; - using StrideD = StrideD_; - - using GmemTiledCopyC = void; - using GmemTiledCopyD = void; - - // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread{}; - ElementC const* ptr_C = nullptr; - StrideC dC{}; - ElementD* ptr_D = nullptr; - StrideD dD{}; - }; - - // Device side epilogue params - using Params = Arguments; - - template - static Params constexpr to_underlying_arguments([[maybe_unused]] ProblemShape const& _, - Arguments const& args, - [[maybe_unused]] void* workspace) { - return args; - } - - template CUTLASS_DEVICE void operator()(T& t) { - static_assert(cute::is_same_v && m <= 32); - - auto const& group = sycl::ext::oneapi::experimental::this_nd_item<3>().get_group(); - - static auto constexpr vec_size = 4; - - static_assert((m % vec_size) == 0 && vec_size <= 16); - static auto constexpr loop_cnt = m / vec_size; - - sycl::vec local_max; - sycl::vec local_plus; - - for (int loop = 0; loop < loop_cnt; loop++) { - - auto base_row = loop * vec_size; - // init local max - for (int i = 0; i < vec_size; i++) { - local_max[i] = t[(base_row + i) * n]; - } - - for (int i = 0; i < vec_size; i++) { - for (int j = 0; j < n; j++) { - local_max[i] = max(local_max[i], t((base_row + i) * n + j)); - } - } - - // get group max - auto group_max = reduce_over_group(group, local_max, sycl::maximum<>()); - - // -max, exp, and get local plus - for (int i = 0; i < vec_size; i++) { - for (int j = 0; j < n; j++) { - auto offset = (base_row + i) * n + j; - t[offset] -= group_max[i]; - t[offset] = sycl::exp(t[offset]); - - local_plus[i] += t[offset]; - } - } - - // get group plus - auto group_plus = reduce_over_group(group, local_plus, sycl::plus<>()); - - // last div - for (int i = 0; i < vec_size; i++) { - for (int j = 0; j < n; j++) { - auto offset = (base_row + i) * n + j; - t[offset] = t[offset] / group_plus[i]; - // local_sum += t[i * n + j]; - } - } - } - - // printf("verify softmax, local_sum: %f, group_sum: %f\n", local_sum, - // reduce_over_group(group, local_sum, sycl::plus<>())); - // } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 343e2a9ec2..2d66a4e2a8 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -184,17 +184,6 @@ class LinearCombinationRelu { } } - using ElementC = ElementOutput_; - using ElementD = ElementOutput_; - - template - CUTLASS_HOST_DEVICE - void operator()(TensorType &accumulators) const { - for (int i = 0; i < size(accumulators); i++) { - accumulators(i) = accumulators(i) < 0 ? 0 : accumulators(i); - } - } - template CUTLASS_HOST_DEVICE void operator()(TensorDst &accumulators, From f5e23e8dbb90ee9bbee4c55f5db0f1dc93823ff3 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 24 Jul 2024 17:39:31 -0700 Subject: [PATCH 17/36] only keep code changes of gemm --- examples/sycl/pvc/pvc_gemm.cpp | 22 ++-- .../epilogue/thread/linear_combination_relu.h | 9 -- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 67 ++++------- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 109 +++++++----------- 4 files changed, 81 insertions(+), 126 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 51c44d6a79..4b1dacfd47 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -79,8 +79,10 @@ using namespace cute; // Command line options parsing struct Options { + bool help; bool error; + int m, n, k, l, iterations; float alpha, beta; @@ -99,6 +101,7 @@ struct Options { help = true; return; } + cmd.get_cmd_line_argument("m", m, 4096); cmd.get_cmd_line_argument("n", n, 4096); cmd.get_cmd_line_argument("k", k, 4096); @@ -137,22 +140,28 @@ struct ExampleRunner { using StrideB = typename Gemm::GemmKernel::StrideB; using StrideC = typename Gemm::GemmKernel::StrideC; using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using LayoutD = typename Gemm::LayoutD; + using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementAcc = typename Gemm::ElementAccumulator; + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; using ElementC = typename Gemm::ElementC; using ElementOutput = typename CollectiveEpilogue::ElementOutput; using ElementCompute = typename CollectiveEpilogue::ElementCompute; using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + // // Data members // + /// Initialization StrideA stride_A; StrideB stride_B; @@ -161,7 +170,6 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_B_vnni; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; @@ -222,7 +230,6 @@ struct ExampleRunner { block_A.reset(M * K * L); block_B.reset(K * N * L); - block_B_vnni.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); @@ -238,11 +245,9 @@ struct ExampleRunner { fill_matrix(a); fill_matrix(b); fill_matrix(c); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } @@ -256,13 +261,14 @@ struct ExampleRunner { sycl::property::queue::in_order(), sycl::property::queue::enable_profiling() }; + auto q = sycl::queue(syclcompat::get_default_context(), syclcompat::get_current_device(), prop); syclcompat::set_default_queue(q); typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {block_A.get(), stride_A, block_B.get(), stride_B}, {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; @@ -287,9 +293,8 @@ struct ExampleRunner { if (passed && options.iterations > 0) { GPU_Clock timer; - // timer.start(); + timer.start(); for (int i = 0; i < options.iterations; ++i) { - if (i == 10) timer.start(); gemm_op.run(); } syclcompat::wait(); @@ -355,11 +360,12 @@ int main(int argc, const char** argv) using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; + // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; using TiledMma = TiledMMA, Layout>, - Tile<_32,_64,_32>>; + Tile<_32,_64,_32>>; // Subgroup level-tile using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 2d66a4e2a8..8152a55086 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -184,15 +184,6 @@ class LinearCombinationRelu { } } - template - CUTLASS_HOST_DEVICE - void operator()(TensorDst &accumulators, - TensorSrc const &source) const { - for (int i = 0; i < size(accumulators); i++) { - accumulators(i) = accumulators(i) < 0 ? source(i) : accumulators(i) + source(i); - } - } - /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 0ddaa0265f..cdcbbb4fe2 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -34,11 +34,13 @@ #include "cutlass/gemm/dispatch_policy.hpp" #include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/tensor_predicate.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -100,44 +102,27 @@ struct CollectiveMma< using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + using MmaAtomShape = typename TiledMma::AtomShape_MNK; using SubgroupTileShape = decltype(tile_shape(TiledMma())); - WorkgroupTileShape wg_tile_shape; - SubgroupTileShape sg_tile_shape; - - static constexpr auto wg_tile_m = decltype(get<0>(wg_tile_shape))::value; - static constexpr auto wg_tile_n = decltype(get<1>(wg_tile_shape))::value; - static constexpr auto sg_tile_m = decltype(get<0>(sg_tile_shape))::value; - static constexpr auto sg_tile_n = decltype(get<1>(sg_tile_shape))::value; - static constexpr auto sg_tile_k = decltype(get<2>(sg_tile_shape))::value; - static constexpr auto sg_per_wg_m = wg_tile_m / sg_tile_m; - static constexpr auto sg_per_wg_n = wg_tile_n / sg_tile_n; - static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - static constexpr int DpasM = get<0>( - shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per - // sub_group for Matrix A - static constexpr int DpasN = get<1>( - shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per - // sub_group for Matrix B - static constexpr int DpasK = get<1>( - shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per - // sub_group for Matrix A - static constexpr uint32_t MaxThreadsPerBlock = - cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; + static constexpr auto sg_per_wg_m = get<0>(WorkgroupTileShape{}) / get<0>(SubgroupTileShape{}); + static constexpr auto sg_per_wg_n = get<1>(WorkgroupTileShape{}) / get<1>(SubgroupTileShape{}); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + static constexpr uint32_t MaxThreadsPerBlock = + cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{}) * SubgroupSize; - static constexpr int FragsM = sg_tile_m / DpasM; // A frags per sub_group - static constexpr int FragsN = sg_tile_n / DpasN; // B frags per sub_group - static constexpr int FragsK = sg_tile_k / DpasK; + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape()); - // Calculate the vector width based on the amount of registers - // required per work item by dividing the total fragment size by + // Calculate the vector width based on the amount of registers + // required per work item by dividing the total fragment size by // the sub_group size. - static constexpr int VecC = (DpasN * DpasM) / SubgroupSize; - static constexpr int VecA = (DpasM * DpasK) / SubgroupSize; - static constexpr int VecB = (DpasN * DpasK) / SubgroupSize; + static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; + static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; // Host side kernel arguments struct Arguments { @@ -171,7 +156,7 @@ struct CollectiveMma< auto [M,N,K,L] = problem_shape_MNKL; Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K,N,L), args.dB)); + Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); @@ -212,10 +197,8 @@ struct CollectiveMma< constexpr int version = is_same_v ? 1 : 2; - Tensor tAr = make_tensor(Shape, Int<1>>{}); - Tensor tBr = make_tensor(Shape, Int>{}); - - + Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor(Shape(SubgroupTileShape{}) * version>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); @@ -237,14 +220,14 @@ struct CollectiveMma< Tensor tAi = make_tensor( make_inttuple_iter( *gA.data() + - make_coord((get_sub_group_id() % sg_per_wg_n % 4) * DpasM, 0)), + make_coord((get_sub_group_id() % sg_per_wg_n % 4) * get<0>(MmaAtomShape{}), 0)), make_layout(make_shape(_1{}, _1{}, K), make_stride(_1{}, E<0>{}, E<1>{}))); Tensor tBi = make_tensor( make_inttuple_iter( *gB.data() + - make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * DpasK, - (get_sub_group_id() / sg_per_wg_n % 2 * 2) * DpasN)), + make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}), + (get_sub_group_id() / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}))), make_layout(make_shape(_1{}, K, _1{}), make_stride(_1{}, E<0>{}, E<1>{}))); // @@ -254,18 +237,18 @@ struct CollectiveMma< for (int i = 0; i < 3; i++) { prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); prefetch(mainloop.gmem_tiled_copy_b, tBi(_, prefetch_k, _)); - prefetch_k += sg_tile_k; + prefetch_k += get<2>(SubgroupTileShape{}); } for (int k_tile = 0, k = 0; k_tile < k_tile_count; - ++k_tile, k += DpasK * FragsK) { + ++k_tile, k += get<2>(SubgroupTileShape{})) { // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, gA(_, _, k), tAr); copy(mainloop.gmem_tiled_copy_b, gB(_, k, _), tBr); prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); prefetch(mainloop.gmem_tiled_copy_b, tBi(_, prefetch_k, _)); - prefetch_k += sg_tile_k; + prefetch_k += get<2>(SubgroupTileShape{}); for (int kl = 0; kl < FragsK; kl++) { cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 66266f7a44..e08cd68c61 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -68,12 +68,12 @@ class GemmUniversal< using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::WorkgroupTileShape; using WorkgroupTileShape = TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -83,8 +83,8 @@ class GemmUniversal< "Intel PVC does not support specializing the tile scheduler."); using TileSchedulerTag = TileScheduler_; using TileScheduler = typename detail::TileSchedulerSelector< - TileScheduler_, ArchTag, WorkgroupTileShape, - cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + TileScheduler_, ArchTag, WorkgroupTileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; using TileSchedulerArguments = typename TileScheduler::Arguments; // Epilogue derived types @@ -101,20 +101,14 @@ class GemmUniversal< // MSVC requires the cast to fix a warning-as-error. static constexpr int SharedStorageSize = 0; - static constexpr int SubgroupSize = - CollectiveMainloop::SubgroupSize; // sub_group size - static constexpr uint32_t MaxThreadsPerBlock = - CollectiveMainloop::MaxThreadsPerBlock; - static constexpr uint32_t MinBlocksPerMultiprocessor = - CollectiveMainloop::MinBlocksPerMultiprocessor; + static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; + using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group - static constexpr int DpasM = CollectiveMainloop::DpasM; - static constexpr int DpasN = CollectiveMainloop::DpasN; - static constexpr int DpasK = CollectiveMainloop::DpasK; - static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -180,32 +174,30 @@ class GemmUniversal< return Status::kSuccess; } - static dim3 get_grid_shape(Params const ¶ms) { - auto M = get<0>(params.problem_shape); - auto N = get<1>(params.problem_shape); - auto L = get<3>(params.problem_shape); - - int const sg_m = cute::ceil_div(M, - CollectiveMainloop::wg_tile_m); // sub_groups required to - // process A fragments - int const sg_n = cute::ceil_div(N, - CollectiveMainloop::wg_tile_n); // sub_groups required to - // process B fragments - - return dim3(sg_n, sg_m, L); + static dim3 + get_grid_shape(Params const& params) { + int batch_count = 1; + if constexpr (cute::rank(ProblemShape{}) == 4) { + batch_count = cute::size<3>(params.problem_shape); + } + return dim3( + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), + batch_count + ); } - static dim3 get_block_shape() { - return dim3( - cute::ceil_div(CollectiveMainloop::wg_tile_n, CollectiveMainloop::sg_tile_n / SubgroupSize), - cute::ceil_div(CollectiveMainloop::wg_tile_m, CollectiveMainloop::sg_tile_m), 1); + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) { + void + operator()(Params const& params, char* smem_buf) { SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions - CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) @@ -223,24 +215,17 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - int thread_idy = int(ThreadIdxY()); - - static constexpr auto sg_per_wg_n = - CollectiveMainloop::wg_tile_n / CollectiveMainloop::sg_tile_n; - - auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) - const int m_coord = - BlockIdxY() * CollectiveMainloop::wg_tile_m + - (get_sub_group_id() / sg_per_wg_n) * CollectiveMainloop::sg_tile_m; - const int n_coord = - BlockIdxX() * CollectiveMainloop::wg_tile_n + - (get_sub_group_id() % sg_per_wg_n) * CollectiveMainloop::sg_tile_n; + constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) + const int m_coord = BlockIdxX() * get<0>(workgroup_shape) + get_sub_group_id() / CollectiveMainloop::sg_per_wg_n * get<0>(subgroup_shape); + const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + get_sub_group_id() % CollectiveMainloop::sg_per_wg_n * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, 0), make_shape(_1{}, K, L), - make_stride(Int{}, _1{})); + make_coord(m_coord, 0, 0), + make_shape(_1{}, K, L), + make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); constexpr int version = is_same_v @@ -248,9 +233,9 @@ class GemmUniversal< : 2; Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(0, n_coord, 0), - make_shape(K, Int{}, L), - make_stride(_1{}, Int{})); + make_coord(0, n_coord, 0), + make_shape(K, Int{}, L), + make_stride(_1{}, Int(MmaAtomShape())>{})); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord @@ -264,8 +249,8 @@ class GemmUniversal< Tensor accumulators = make_tensor(Shape, Int, Int>{}); clear(accumulators); - int k_tile_count = cute::ceil_div(K, CollectiveMainloop::sg_tile_k); - auto k_tile_iter = cute::make_coord_iterator(make_shape(k_tile_count)); + auto k_tile_iter = cute::make_coord_iterator(make_shape(K / get<2>(subgroup_shape))); + int k_tile_count = K / get<2>(subgroup_shape); // Perform the collective scoped MMA CollectiveMainloop collective_mma; @@ -280,6 +265,7 @@ class GemmUniversal< smem_buf, params.mainloop ); + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; epilogue( problem_shape_MNKL, @@ -291,17 +277,6 @@ class GemmUniversal< thread_idx, smem_buf ); - - // auto gmem_tiled_copy_c = - // make_xe_2d_copy(make_tensor( - // params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); - - // Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor( - // make_coord(m_coord, n_coord, l_coord), - // make_shape(Int{}, Int{}, _1{}), - // make_stride(Int{}, Int{})); - - // copy(gmem_tiled_copy_c, accumulators, tCi(_, _, _, 0)); } }; From 1c57c36ea9f256f26739e3938116ada7df2bfdd3 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 25 Jul 2024 17:30:50 -0700 Subject: [PATCH 18/36] comments clean --- examples/sycl/pvc/pvc_gemm.cpp | 2 +- include/cute/arch/copy_xe.hpp | 93 +++++++++++++++++-- include/cute/arch/mma_xe.hpp | 22 +++-- include/cute/atom/mma_traits_xe.hpp | 2 +- .../epilogue/thread/linear_combination_relu.h | 2 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 12 +-- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 8 +- 7 files changed, 109 insertions(+), 32 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 4b1dacfd47..3bf9cb53be 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -363,7 +363,7 @@ int main(int argc, const char** argv) // Workgroup-level tile using TileShape = Shape<_256, _256, _32>; - using TiledMma = TiledMMA, + using TiledMma = TiledMMA, Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 51200dd08b..4e82b5b8ab 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -38,16 +38,11 @@ namespace cute { #ifdef __SYCL_DEVICE_ONLY__ -#ifdef SYCL_INTEL_TARGET #define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x #else #define SYCL_DEVICE_BUILTIN(x) \ - inline x { CUTE_INVALID_CONTROL_PATH("Trying to use IGC built-in on non-Intel hardware"); } + inline x { assert(false); } #endif -#else -#define SYCL_DEVICE_BUILTIN(x) \ - inline x { CUTE_INVALID_CONTROL_PATH("Trying to use device built-in on host."); } -#endif enum LSC_LDCC { kLSC_LDCC_DEFAULT = 0, @@ -123,19 +118,28 @@ struct XE_2D_U16x8x16x1x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort8 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kLSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif } }; }; @@ -146,9 +150,13 @@ struct XE_2D_U32x8x16x1x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } }; @@ -158,19 +166,28 @@ struct XE_2D_U16x16x16x1x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kLSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif } }; }; @@ -181,20 +198,29 @@ struct XE_2D_U16x8x16x4x2_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort64 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( long(baseoffset), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); // __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kLSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif } }; }; @@ -205,19 +231,28 @@ struct XE_2D_U16x8x16x2x2_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort32*) dst = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( long(baseoffset), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kLSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif } }; }; @@ -228,20 +263,29 @@ struct XE_2D_U16x8x16x1x2_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); intel::ushort16 tmp = (intel_subgroup_block_read_u16_m8k16v2( (long)baseoffset, width, height, pitch, coord)); *(intel::ushort16 *)dst = *reinterpret_cast(&tmp); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kLSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif } }; }; @@ -252,19 +296,28 @@ struct XE_2D_U16x8x16x4x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::ushort32*) dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( long(baseoffset), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kLSC_LDCC_L1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif } }; }; @@ -275,10 +328,14 @@ struct XE_2D_U32x8x16x2x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( long(baseoffset), width - 1, height - 1, pitch - 1, coord); *(intel::uint16 *)dst = *reinterpret_cast(&tmp); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } }; @@ -288,10 +345,14 @@ struct XE_2D_U16x16x16x2x1_LD_N CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, T *dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( long(baseoffset), width - 1, height - 1, pitch - 1, coord); *(intel::uint16 *)dst = *reinterpret_cast(&tmp); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; @@ -301,8 +362,12 @@ struct XE_2D_U16x16x16x2x2_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::uint32*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long(base_address), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } // using PREFETCH = typename XE_2D_U16x8x16x4x2_LD_N::PREFETCH; @@ -313,8 +378,12 @@ struct XE_2D_U16x16x16x1x2_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::int16*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long(base_address), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; @@ -324,8 +393,12 @@ struct XE_2D_U16x16x16x2x1_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); *(intel::int16*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long(base_address), width - 1, height - 1, pitch - 1, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; @@ -335,9 +408,13 @@ struct XE_2D_U16x16x16x1x1_V { template CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: *(intel::int8*) dst = intel_subgroup_block_read_transform_u16_k16((long)base_address, width, height, pitch, coord); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } using PREFETCH = typename XE_2D_U16x16x16x1x1_LD_N::PREFETCH; @@ -348,10 +425,14 @@ struct XE_2D_U32x8x16x1x1_ST_N template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, const T *src) { + #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, *(intel::uint8 *)src); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); + #endif } }; diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 2850576d6e..85033e3653 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -35,16 +35,10 @@ #include #ifdef __SYCL_DEVICE_ONLY__ -#ifdef SYCL_INTEL_TARGET #define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x #else -#define SYCL_DEVICE_OCL(x) \ - inline x { CUTE_INVALID_CONTROL_PATH("Trying to use IGC built-in on non-Intel hardware"); } +#define SYCL_DEVICE_OCL(x) inline x { assert(false); } #endif -#else -#define SYCL_DEVICE_OCL(x) \ - inline x { CUTE_INVALID_CONTROL_PATH("Trying to use device built-in on host."); } -#endif SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc)); SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::intel::int8 b, float acc)); @@ -55,7 +49,7 @@ namespace cute { //# of vector component of a x subgroup-size x function name //float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); //TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 -struct XE_8x16x16_BF16BF16F32F32_NN +struct XE_8x16x16_F32BF16BF16F32_TT { using DRegisters = intel::float8[1]; using ARegisters = intel::short8[1]; @@ -68,11 +62,15 @@ struct XE_8x16x16_BF16BF16F32F32_NN intel::int8 const& b, intel::float8 const& c) { +#if defined(SYCL_INTEL_TARGET) d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_BF16BF16F32F32_NN on non-PVC hardware"); +#endif } }; //float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) -struct XE_1x16x16_BF16BF16F32F32_NN +struct XE_1x16x16_F32BF16BF16F32_TT { using DRegisters = float[1]; using ARegisters = short[1]; @@ -85,7 +83,11 @@ struct XE_1x16x16_BF16BF16F32F32_NN intel::int8 const& b, float const& c) { +#if defined(SYCL_INTEL_TARGET) d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_BF16BF16F32F32_NN on non-PVC hardware"); +#endif } }; -} //namespace cute +} //namespace cute \ No newline at end of file diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index a5ef6dbec2..8dca2dba55 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -38,7 +38,7 @@ namespace cute { template <> -struct MMA_Traits +struct MMA_Traits { using ValTypeD = float; using ValTypeA = bfloat16_t; diff --git a/include/cutlass/epilogue/thread/linear_combination_relu.h b/include/cutlass/epilogue/thread/linear_combination_relu.h index 8152a55086..07ebdec93d 100644 --- a/include/cutlass/epilogue/thread/linear_combination_relu.h +++ b/include/cutlass/epilogue/thread/linear_combination_relu.h @@ -183,7 +183,7 @@ class LinearCombinationRelu { threshold_ = reinterpret_cast(allones); } } - + /// Computes linear scaling: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index cdcbbb4fe2..12e0981c25 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -36,7 +36,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/atom/mma_atom.hpp" #include "cute/tensor_predicate.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,10 +43,6 @@ namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// -#define get_sub_group_id() \ - (sycl::ext::oneapi::experimental::this_nd_item<3>() \ - .get_sub_group() \ - .get_group_id()[0]) template < class TileShape_, @@ -210,6 +205,7 @@ struct CollectiveMma< int K = size<1>(mainloop.gmem_tiled_copy_a.tensor); + auto sub_group_id = ThreadIdxX() / SubgroupSize; /* Cooperative prefetch Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in one row/col use the same tile A/B. Each thread loads sizeof(tile A or B) / numof(sg_per_wg_n or sg_per_wg_m). @@ -220,14 +216,14 @@ struct CollectiveMma< Tensor tAi = make_tensor( make_inttuple_iter( *gA.data() + - make_coord((get_sub_group_id() % sg_per_wg_n % 4) * get<0>(MmaAtomShape{}), 0)), + make_coord((sub_group_id % sg_per_wg_n % 4) * get<0>(MmaAtomShape{}), 0)), make_layout(make_shape(_1{}, _1{}, K), make_stride(_1{}, E<0>{}, E<1>{}))); Tensor tBi = make_tensor( make_inttuple_iter( *gB.data() + - make_coord((get_sub_group_id() / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}), - (get_sub_group_id() / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}))), + make_coord((sub_group_id / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}), + (sub_group_id / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}))), make_layout(make_shape(_1{}, K, _1{}), make_stride(_1{}, E<0>{}, E<1>{}))); // diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index e08cd68c61..981025b4c8 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -106,9 +106,6 @@ class GemmUniversal< using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; - static constexpr int num_sg = - MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group - static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -215,10 +212,11 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); + int sub_group_id = thread_idx / SubgroupSize; constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) - const int m_coord = BlockIdxX() * get<0>(workgroup_shape) + get_sub_group_id() / CollectiveMainloop::sg_per_wg_n * get<0>(subgroup_shape); - const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + get_sub_group_id() % CollectiveMainloop::sg_per_wg_n * get<1>(subgroup_shape); + const int m_coord = BlockIdxX() * get<0>(workgroup_shape) + sub_group_id / CollectiveMainloop::sg_per_wg_n * get<0>(subgroup_shape); + const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + sub_group_id % CollectiveMainloop::sg_per_wg_n * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); From 13ae1a1d33ac7d9673a85be072d9b87c853b5ccb Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 25 Jul 2024 18:24:55 -0700 Subject: [PATCH 19/36] rebase other examples --- benchmarks/common/benchmark_runner.hpp | 36 ++----------------- ...ench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp | 6 ++-- examples/sycl/pvc/pvc_gemm.cpp | 4 +-- .../sycl/pvc/pvc_gemm_with_epilogue_relu.cpp | 19 ++++------ 4 files changed, 14 insertions(+), 51 deletions(-) diff --git a/benchmarks/common/benchmark_runner.hpp b/benchmarks/common/benchmark_runner.hpp index 659a3e1436..4b77609730 100644 --- a/benchmarks/common/benchmark_runner.hpp +++ b/benchmarks/common/benchmark_runner.hpp @@ -320,42 +320,10 @@ struct PvcBenchmarkRunner : BenchmarkRunner { using ProblemShapeType = typename Base::ProblemShapeType; - cutlass::DeviceAllocation block_B_vnni; - - template - void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) - { - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } - } - } - } - void initialize(const ProblemShapeType& problem_size) override { Base::initialize(problem_size); - - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - block_B_vnni.reset(Base::block_B.size()); - - std::vector b(K * N * L); - std::vector b_vnni(b.size()); - - Base::block_B.copy_to_host(b.data()); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); - - block_B_vnni.copy_from_host(b_vnni.data()); } - + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) override { ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; @@ -364,7 +332,7 @@ struct PvcBenchmarkRunner : BenchmarkRunner { typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {Base::block_A.get(), Base::stride_A, block_B_vnni.get(), Base::stride_B}, + {Base::block_A.get(), Base::stride_A, Base::block_B.get(), Base::stride_B}, { {options.alpha, options.beta}, Base::block_C.get(), Base::stride_C, Base::block_D.get(), Base::stride_D diff --git a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp index ea90d89b7e..3203e7f367 100644 --- a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp +++ b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp @@ -85,15 +85,15 @@ int main(int argc, const char** argv) using LayoutD = cutlass::layout::RowMajor; // Workgroup-level tile - using TileShape = Shape<_32, _256, _32>; + using TileShape = Shape<_256, _256, _32>; using TiledMma = TiledMMA< - MMA_Atom, + MMA_Atom, Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 3bf9cb53be..9cec462caf 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -89,7 +89,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(100), + m(4096), n(4096), k(4096), l(1), iterations(20), alpha(1.f), beta(0.f) { } @@ -108,7 +108,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("iterations", iterations, 20); } /// Prints the usage statement. diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp index 2075379580..1352ef6212 100644 --- a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -90,7 +90,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(100), + m(4096), n(4096), k(4096), l(1), iterations(10), alpha(1.f), beta(0.f) { } @@ -109,7 +109,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("iterations", iterations, 10); } /// Prints the usage statement. @@ -171,7 +171,6 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_B_vnni; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; @@ -238,7 +237,6 @@ struct ExampleRunner { block_A.reset(M * K * L); block_B.reset(K * N * L); - block_B_vnni.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); @@ -247,18 +245,15 @@ struct ExampleRunner { // available through SYCL. std::vector a(K * M * L); std::vector b(K * N * L); - std::vector b_vnni(b.size()); std::vector c(M * N * L); std::vector d(M * N * L, ElementC{0}); fill_matrix(a); fill_matrix(b); fill_matrix(c); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } @@ -271,7 +266,7 @@ struct ExampleRunner { typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {block_A.get(), stride_A, block_B.get(), stride_B}, {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; @@ -302,7 +297,7 @@ struct ExampleRunner { } syclcompat::wait(); - float cute_time = timer.seconds() / options.iterations; + float cute_time = timer.seconds(); double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); @@ -361,12 +356,12 @@ int main(int argc, const char** argv) using LayoutD = cutlass::layout::RowMajor; using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; // Workgroup-level tile - using TileShape = Shape<_32, _256, _32>; + using TileShape = Shape<_256, _256, _32>; - using TiledMma = TiledMMA, + using TiledMma = TiledMMA, Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile From fdb724475b66ab3d5382826b7a0d32984b6651c8 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 25 Jul 2024 18:26:01 -0700 Subject: [PATCH 20/36] rm vnni_matrix func --- examples/sycl/pvc/pvc_gemm.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 9cec462caf..8cbefb8470 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -55,24 +55,6 @@ static void fill_matrix(std::vector &vector) return static_cast( (rand() / double(RAND_MAX)) ); }); } - -template -static void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) -{ - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } - } - } -} - using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// From d09da2947897e2eaac168e23a831c8040989e1d5 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 25 Jul 2024 18:27:37 -0700 Subject: [PATCH 21/36] code clean --- examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp index 1352ef6212..cfa0365bc9 100644 --- a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -90,7 +90,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(10), + m(4096), n(4096), k(4096), l(1), iterations(100), alpha(1.f), beta(0.f) { } @@ -109,7 +109,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 10); + cmd.get_cmd_line_argument("iterations", iterations, 100); } /// Prints the usage statement. @@ -297,7 +297,7 @@ struct ExampleRunner { } syclcompat::wait(); - float cute_time = timer.seconds(); + float cute_time = timer.seconds() / options.iterations; double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); From 5a3d227e3e244799d73e547b295396332eb7f35b Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Sun, 28 Jul 2024 17:35:35 -0700 Subject: [PATCH 22/36] define N-major tensor --- include/cute/atom/copy_traits_xe.hpp | 56 ++++++++++++++----- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 20 +++---- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 6 +- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index a66f8879f8..2bd39d67df 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -35,7 +35,37 @@ #include -namespace cute { +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto get_shape_WHD(cute::Stride, IntT, IntT> , cute::Shape shape_MKL) { + return shape_MKL; +} + +template +CUTE_HOST_DEVICE constexpr +auto get_shape_WHD(cute::Stride, IntT> , cute::Shape shape_MKL) { + return Shape(get<1>(shape_MKL), get<0>(shape_MKL), get<2>(shape_MKL)); +} + +template +CUTE_HOST_DEVICE constexpr +auto get_coordinates(cute::Stride, IntT, IntT> , + Tensor>, SLayout> const &src) { + auto [x, y, z] = src.data().coord_; + return make_coord(x, y, z); +} + +template +CUTE_HOST_DEVICE constexpr +auto get_coordinates(cute::Stride, IntT> , + Tensor>, SLayout> const &src) { + auto [x, y, z] = src.data().coord_; + return make_coord(y, x, z); +} + template struct XE_2D_LD_Unpack { GTensor tensor; @@ -47,11 +77,11 @@ struct XE_2D_LD_Unpack { Tensor>, SLayout> const &src, Tensor &dst) { static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) - * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x, z] = src.data().coord_; - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t {x, y}, + auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); + int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType); + int H = size<1>(shape_whd); + auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data()); } @@ -96,9 +126,10 @@ struct XE_2D_PF_Unpack { Tensor>, SLayout> const &src, Tensor &dst) { using T = typename Copy_Traits::CopyInternalType; - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) * sizeof(T); - auto [y, x, z] = src.data().coord_; + auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); + int W = size<0>(shape_whd) * sizeof(T); + int H = size<1>(shape_whd); + auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); CopyOp::template copy(traits.tensor.data() + z, W, H, W, intel::coord_t {static_cast(x), static_cast(y)}); } @@ -412,12 +443,9 @@ struct XE_2D_ST_Unpack { Tensor>, DLayout> &dst) { static_assert(is_rmem::value); int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) - * sizeof(typename Copy_Traits::CopyInternalType); + int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); auto [y, x, z] = dst.data().coord_; - - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t {x, y}, - &*src.data()); + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data()); } template diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 12e0981c25..e90083f389 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -198,7 +198,8 @@ struct CollectiveMma< Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); Tensor tBr_view = make_tensor(static_cast(tBr).data(), - Shape, Int, Int>{}); + Shape, Int, Int>{}, + Stride<_1, Int, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; @@ -222,9 +223,9 @@ struct CollectiveMma< Tensor tBi = make_tensor( make_inttuple_iter( *gB.data() + - make_coord((sub_group_id / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}), - (sub_group_id / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}))), - make_layout(make_shape(_1{}, K, _1{}), + make_coord((sub_group_id / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}), + (sub_group_id / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}))), + make_layout(make_shape(_1{}, _1{}, K), make_stride(_1{}, E<0>{}, E<1>{}))); // // Mainloop @@ -232,7 +233,7 @@ struct CollectiveMma< int prefetch_k = 0; for (int i = 0; i < 3; i++) { prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); - prefetch(mainloop.gmem_tiled_copy_b, tBi(_, prefetch_k, _)); + prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); prefetch_k += get<2>(SubgroupTileShape{}); } @@ -240,16 +241,13 @@ struct CollectiveMma< ++k_tile, k += get<2>(SubgroupTileShape{})) { // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, gA(_, _, k), tAr); - copy(mainloop.gmem_tiled_copy_b, gB(_, k, _), tBr); + copy(mainloop.gmem_tiled_copy_b, gB(_, _, k), tBr); prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); - prefetch(mainloop.gmem_tiled_copy_b, tBi(_, prefetch_k, _)); + prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); prefetch_k += get<2>(SubgroupTileShape{}); - for (int kl = 0; kl < FragsK; kl++) { - cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), - src_accum); - } + cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum); } } }; diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 981025b4c8..811540e93b 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -231,9 +231,9 @@ class GemmUniversal< : 2; Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(0, n_coord, 0), - make_shape(K, Int{}, L), - make_stride(_1{}, Int(MmaAtomShape())>{})); + make_coord(n_coord, 0, 0), + make_shape(Int{}, K, L), + make_stride(Int(MmaAtomShape())>{}, _1{})); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord From b50574a04d6d4c7f76b1cf109704ea07b45d129f Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Mon, 29 Jul 2024 17:35:13 -0700 Subject: [PATCH 23/36] delete useless header --- include/cutlass/gemm/kernel/intel_pvc_gemm.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 811540e93b..05b099bc24 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -34,7 +34,6 @@ #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" #include "cute/tensor.hpp" From 2c6d1ba34e9638590084a38c38aa34313291c15b Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Mon, 29 Jul 2024 18:28:05 -0700 Subject: [PATCH 24/36] more comments --- include/cutlass/gemm/collective/intel_pvc_mma.hpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index e90083f389..e0f09f5342 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -231,7 +231,10 @@ struct CollectiveMma< // Mainloop // int prefetch_k = 0; - for (int i = 0; i < 3; i++) { + // Manually set the value to 3 + // TODO: Expose to user to set distance + int constexpr prefetch_distance = 3; + for (int i = 0; i < prefetch_distance; i++) { prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); prefetch_k += get<2>(SubgroupTileShape{}); From c97ccd8f60325b35fc1d2321b1cc671c9ba8c1bb Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Mon, 29 Jul 2024 22:37:57 -0700 Subject: [PATCH 25/36] modify comments --- include/cute/arch/copy_xe.hpp | 2 +- include/cutlass/gemm/collective/intel_pvc_mma.hpp | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 4e82b5b8ab..7bc9d4a8cd 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -436,4 +436,4 @@ struct XE_2D_U32x8x16x1x1_ST_N } }; -} // end namespace cute \ No newline at end of file +} // end namespace diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index e0f09f5342..6bbcd129cc 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -198,8 +198,8 @@ struct CollectiveMma< Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); Tensor tBr_view = make_tensor(static_cast(tBr).data(), - Shape, Int, Int>{}, - Stride<_1, Int, Int>{}); + Shape, Int, Int>{}, + Stride<_1, Int, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; @@ -231,8 +231,9 @@ struct CollectiveMma< // Mainloop // int prefetch_k = 0; - // Manually set the value to 3 - // TODO: Expose to user to set distance + + // Manually set the value to 1 + // TODO: Expose to users like stages parameter int constexpr prefetch_distance = 3; for (int i = 0; i < prefetch_distance; i++) { prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); From ede5c0317159950eeb534240f9da7f07ffad1e6a Mon Sep 17 00:00:00 2001 From: Jiaxingla <109135611+Jiaxingla@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:10:10 +0800 Subject: [PATCH 26/36] Update pvc_gemm Co-authored-by: Mehdi Goli --- examples/sycl/pvc/pvc_gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 8cbefb8470..661016e809 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -396,4 +396,4 @@ int main(int argc, const char** argv) runner.run(options, hw_info); return 0; -} \ No newline at end of file +} From f9aae6ff73883d3dac0a56ebf30df7c6da6617b2 Mon Sep 17 00:00:00 2001 From: Jiaxingla <109135611+Jiaxingla@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:10:27 +0800 Subject: [PATCH 27/36] Update mma_xe Co-authored-by: Mehdi Goli --- include/cute/arch/mma_xe.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 85033e3653..7c5ad7b74e 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -90,4 +90,4 @@ struct XE_1x16x16_F32BF16BF16F32_TT #endif } }; -} //namespace cute \ No newline at end of file +} //namespace cute From 7878a7cc40e813f0c9357bd372ce3b90dc80d70b Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 31 Jul 2024 01:17:25 -0700 Subject: [PATCH 28/36] more comments --- include/cute/arch/copy_xe.hpp | 47 ++++++++++++++++++++++++++++ include/cute/atom/copy_traits_xe.hpp | 3 ++ 2 files changed, 50 insertions(+) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 7bc9d4a8cd..0cf9971085 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -112,6 +112,9 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( #undef SYCL_DEVICE_BUILTIN +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 1x1 memory blocks, and each block size is 8x16x16bits struct XE_2D_U16x8x16x1x1_LD_N { template @@ -144,6 +147,9 @@ struct XE_2D_U16x8x16x1x1_LD_N }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 1x1 memory blocks, and each block size is 8x16x32bits struct XE_2D_U32x8x16x1x1_LD_N { template @@ -160,6 +166,9 @@ struct XE_2D_U32x8x16x1x1_LD_N } }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 1x1 memory blocks, and each block size is 16x16x16bits struct XE_2D_U16x16x16x1x1_LD_N { template @@ -192,6 +201,9 @@ struct XE_2D_U16x16x16x1x1_LD_N }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 4x2 memory blocks, and each block size is 8x16x16bits struct XE_2D_U16x8x16x4x2_LD_N { template @@ -225,6 +237,9 @@ struct XE_2D_U16x8x16x4x2_LD_N }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 4x2 memory blocks, and each block size is 8x16x16bits struct XE_2D_U16x8x16x2x2_LD_N { template @@ -257,6 +272,9 @@ struct XE_2D_U16x8x16x2x2_LD_N }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 1x2 memory blocks, and each block size is 8x16x16bits struct XE_2D_U16x8x16x1x2_LD_N { template @@ -290,6 +308,9 @@ struct XE_2D_U16x8x16x1x2_LD_N }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 4x1 memory blocks, and each block size is 8x16x16bits struct XE_2D_U16x8x16x4x1_LD_N { template @@ -322,6 +343,9 @@ struct XE_2D_U16x8x16x4x1_LD_N }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 2x1 memory blocks, and each block size is 8x16x32bits struct XE_2D_U32x8x16x2x1_LD_N { template @@ -339,6 +363,9 @@ struct XE_2D_U32x8x16x2x1_LD_N } }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// Loads 2x1 memory blocks, and each block size is 16x16x16bits struct XE_2D_U16x16x16x2x1_LD_N { template @@ -358,6 +385,10 @@ struct XE_2D_U16x16x16x2x1_LD_N using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// Loads 2x2 memory blocks, and each block size is 16x16x16bits struct XE_2D_U16x16x16x2x2_V { template @@ -374,6 +405,10 @@ struct XE_2D_U16x16x16x2x2_V using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// Loads 1x2 memory blocks, and each block size is 16x16x16bits struct XE_2D_U16x16x16x1x2_V { template @@ -389,6 +424,10 @@ struct XE_2D_U16x16x16x1x2_V using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// Loads 2x1 memory blocks, and each block size is 16x16x16bits struct XE_2D_U16x16x16x2x1_V { template @@ -404,6 +443,10 @@ struct XE_2D_U16x16x16x2x1_V using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// Loads 1x1 memory blocks, and each block size is 16x16x16bits struct XE_2D_U16x16x16x1x1_V { template @@ -420,6 +463,10 @@ struct XE_2D_U16x16x16x1x1_V using PREFETCH = typename XE_2D_U16x16x16x1x1_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// Loads 1x1 memory blocks, and each block size is 8x16x32bits struct XE_2D_U32x8x16x1x1_ST_N { template diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 2bd39d67df..8bf93166b0 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -97,6 +97,9 @@ struct XE_2D_LD_Unpack { } }; +/// ThrID: How many threads involved. +/// DstLayout: Size<0>(DstLayout) same as thrID, Size<1>(DstLayout) represents data layout that hold by each thread in bits. +/// TODO: SrcLayout just a placeHolder, not used. template struct Copy_Traits : XE_2D_LD_Unpack { From 4c426458cce01c4096d403604eb28c4fe498a9a9 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 31 Jul 2024 01:22:45 -0700 Subject: [PATCH 29/36] code clean --- .../sycl/pvc/pvc_gemm_with_epilogue_relu.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp index cfa0365bc9..521d64f7d5 100644 --- a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -57,23 +57,6 @@ static void fill_matrix(std::vector &vector) }); } -template -static void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) -{ - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } - } - } -} - using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// From abbbe4f386fe1b529fcd8b2016ff1b57ddaa15cb Mon Sep 17 00:00:00 2001 From: Jiaxingla <109135611+Jiaxingla@users.noreply.github.com> Date: Wed, 31 Jul 2024 20:34:51 +0800 Subject: [PATCH 30/36] fix typo --- include/cutlass/gemm/collective/intel_pvc_mma.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 6bbcd129cc..e9314ace84 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -232,7 +232,7 @@ struct CollectiveMma< // int prefetch_k = 0; - // Manually set the value to 1 + // Manually set the prefetch_distance to 3 // TODO: Expose to users like stages parameter int constexpr prefetch_distance = 3; for (int i = 0; i < prefetch_distance; i++) { From 8e9a84f10a6cbaced1c784fd5b8a7920c85a4582 Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 31 Jul 2024 05:53:38 -0700 Subject: [PATCH 31/36] revert the change of copy_atom --- include/cute/atom/copy_atom.hpp | 4 ---- include/cute/atom/copy_traits_xe.hpp | 20 ++++++++++---------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index a619b725d8..e20bace705 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -111,10 +111,6 @@ struct Copy_Atom, CopyInternalType> // recurse this rank-1 layout by peeling off the mode // ((A,B,C,...)) -> (A,B,C,...) return copy(*this, tensor<0>(src), tensor<0>(dst)); - } else if constexpr (is_tuple::engine_type::iterator:: - value_type>::value) { - return copy_unpack(*this, src, dst); } else { static_assert(dependent_false, "No instruction match and no recursion possible."); } diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 8bf93166b0..018f227060 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -150,7 +150,7 @@ struct Copy_Traits using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout - = Layout>, Stride<_16, Stride<_256, _1>>>; + = Layout>, Stride<_16, Stride<_16, _1>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_16, Stride<_256, _1>>>; @@ -172,7 +172,7 @@ struct Copy_Traits using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit using SrcLayout - = Layout>, Stride<_32, Stride<_512, _1>>>; + = Layout>, Stride<_16, Stride<_16, _1>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_32, Stride<_512, _1>>>; @@ -193,8 +193,8 @@ struct Copy_Traits // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16, Stride<_256, _1>>>; + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_16, Stride<_256, _1>>>; @@ -215,8 +215,8 @@ struct Copy_Traits // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32, Stride<_512, _1>>>; + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_32, Stride<_512, _1>>>; @@ -236,8 +236,8 @@ struct Copy_Traits // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_32, Stride<_512, _1>>>; + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_32, Stride<_512, _1>>>; @@ -258,8 +258,8 @@ struct Copy_Traits // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>, - Stride<_16, Stride<_256, _1>>>; + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; // Map from (dst-thr,dst-val) to bit using DstLayout = Layout>, Stride<_16, Stride<_256, _1>>>; From ea30c83f4e035aaa527fcc1ba6d4aecd8550427e Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Wed, 31 Jul 2024 22:55:14 -0700 Subject: [PATCH 32/36] rename enum of LSC_LDCC --- include/cute/arch/copy_xe.hpp | 42 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 0cf9971085..86d9a79db1 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -44,15 +44,15 @@ namespace cute inline x { assert(false); } #endif -enum LSC_LDCC { - kLSC_LDCC_DEFAULT = 0, - kLSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached - kLSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached - kLSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached - kLSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached - kLSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached - kLSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached - kLSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +enum CacheControl { + kDefault = 0, + kL1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + kL1UC_L3C = 2, // Override to L1 uncached and L3 cached + kL1C_L3UC = 3, // Override to L1 cached and L3 uncached + kL1C_L3C = 4, // Override to L1 cached and L3 cached + kL1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + kL1S_L3C = 6, // Override to L1 streaming load and L3 cached + kL1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached }; SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( @@ -93,22 +93,22 @@ SYCL_DEVICE_BUILTIN(intel::int8 intel_subgroup_block_read_transform_u16_k16( int pitch_minus_one, intel::coord_t coord)); SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum LSC_LDCC cache_control)); + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); #undef SYCL_DEVICE_BUILTIN @@ -138,7 +138,7 @@ struct XE_2D_U16x8x16x1x1_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kLSC_LDCC_L1C_L3C); + kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -192,7 +192,7 @@ struct XE_2D_U16x16x16x1x1_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kLSC_LDCC_L1C_L3C); + kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -228,7 +228,7 @@ struct XE_2D_U16x8x16x4x2_LD_N // __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kLSC_LDCC_L1C_L3C); + kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -263,7 +263,7 @@ struct XE_2D_U16x8x16x2x2_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kLSC_LDCC_L1C_L3C); + kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -299,7 +299,7 @@ struct XE_2D_U16x8x16x1x2_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kLSC_LDCC_L1C_L3C); + kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -334,7 +334,7 @@ struct XE_2D_U16x8x16x4x1_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kLSC_LDCC_L1C_L3C); + kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); From 043fbea27ba4a1e987b9432b549ddc01f9722cec Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 1 Aug 2024 01:00:45 -0700 Subject: [PATCH 33/36] fix typo --- include/cute/arch/copy_xe.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 86d9a79db1..fb09498335 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -190,7 +190,7 @@ struct XE_2D_U16x16x16x1x1_LD_N int height, int pitch, intel::coord_t coord) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kL1C_L3C); #else @@ -225,7 +225,6 @@ struct XE_2D_U16x8x16x4x2_LD_N int height, int pitch, intel::coord_t coord) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); - // __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, kL1C_L3C); From abf38bdbc7a6c5a264aa831782165c0515b5308b Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 1 Aug 2024 01:11:16 -0700 Subject: [PATCH 34/36] scope enums --- include/cute/arch/copy_xe.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index fb09498335..460cc6193b 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -44,7 +44,7 @@ namespace cute inline x { assert(false); } #endif -enum CacheControl { +enum class CacheControl { kDefault = 0, kL1UC_L3UC = 1, // Override to L1 uncached and L3 uncached kL1UC_L3C = 2, // Override to L1 uncached and L3 cached @@ -138,7 +138,7 @@ struct XE_2D_U16x8x16x1x1_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kL1C_L3C); + CacheControl::kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -192,7 +192,7 @@ struct XE_2D_U16x16x16x1x1_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kL1C_L3C); + CacheControl::kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -227,7 +227,7 @@ struct XE_2D_U16x8x16x4x2_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kL1C_L3C); + CacheControl::kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -262,7 +262,7 @@ struct XE_2D_U16x8x16x2x2_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kL1C_L3C); + CacheControl::kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -298,7 +298,7 @@ struct XE_2D_U16x8x16x1x2_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kL1C_L3C); + CacheControl::kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); @@ -333,7 +333,7 @@ struct XE_2D_U16x8x16x4x1_LD_N static_assert(sizeof(T) == 2, "Expected T to have size 2"); __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - kL1C_L3C); + CacheControl::kL1C_L3C); #else CUTE_INVALID_CONTROL_PATH( "Trying to use block prefetch on non-PVC hardware"); From 5193329d08f5a47a8b5ada548e8593ba5edfe14f Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 1 Aug 2024 18:07:43 -0700 Subject: [PATCH 35/36] modify commment of copy --- include/cute/arch/copy_xe.hpp | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 460cc6193b..6dba92d7f8 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -114,7 +114,7 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 1x1 memory blocks, and each block size is 8x16x16bits +/// The loading block size is 16bitsx8x16, with a total of 1x1 blocks. struct XE_2D_U16x8x16x1x1_LD_N { template @@ -149,7 +149,7 @@ struct XE_2D_U16x8x16x1x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 1x1 memory blocks, and each block size is 8x16x32bits +/// The loading block size is 32bitsx8x16, with a total of 1x1 blocks. struct XE_2D_U32x8x16x1x1_LD_N { template @@ -168,7 +168,7 @@ struct XE_2D_U32x8x16x1x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 1x1 memory blocks, and each block size is 16x16x16bits +/// The loading block size is 16bitsx8x16, with a total of 1x1 blocks. struct XE_2D_U16x16x16x1x1_LD_N { template @@ -203,7 +203,7 @@ struct XE_2D_U16x16x16x1x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 4x2 memory blocks, and each block size is 8x16x16bits +/// The loading block size is 16bitsx8x16, with a total of 4x2 blocks. struct XE_2D_U16x8x16x4x2_LD_N { template @@ -238,7 +238,7 @@ struct XE_2D_U16x8x16x4x2_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 4x2 memory blocks, and each block size is 8x16x16bits +/// The loading block size is 16bitsx8x16, with a total of 2x2 blocks. struct XE_2D_U16x8x16x2x2_LD_N { template @@ -273,7 +273,7 @@ struct XE_2D_U16x8x16x2x2_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 1x2 memory blocks, and each block size is 8x16x16bits +/// The loading block size is 16bitsx8x16, with a total of 1x2 blocks. struct XE_2D_U16x8x16x1x2_LD_N { template @@ -309,7 +309,7 @@ struct XE_2D_U16x8x16x1x2_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 4x1 memory blocks, and each block size is 8x16x16bits +/// The loading block size is 16bitsx8x16, with a total of 4x1 blocks. struct XE_2D_U16x8x16x4x1_LD_N { template @@ -344,7 +344,7 @@ struct XE_2D_U16x8x16x4x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 2x1 memory blocks, and each block size is 8x16x32bits +/// The loading block size is 32bitsx8x16, with a total of 2x1 blocks. struct XE_2D_U32x8x16x2x1_LD_N { template @@ -364,7 +364,7 @@ struct XE_2D_U32x8x16x2x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// Loads 2x1 memory blocks, and each block size is 16x16x16bits +/// The loading block size is 16bitsx16x16, with a total of 2x1 blocks. struct XE_2D_U16x16x16x2x1_LD_N { template @@ -387,7 +387,7 @@ struct XE_2D_U16x16x16x2x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. /// From flat format in memory transform to VNNI format in registers. -/// Loads 2x2 memory blocks, and each block size is 16x16x16bits +/// The loading block size is 16bitsx16x16, with a total of 2x2 blocks struct XE_2D_U16x16x16x2x2_V { template @@ -407,7 +407,7 @@ struct XE_2D_U16x16x16x2x2_V /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. /// From flat format in memory transform to VNNI format in registers. -/// Loads 1x2 memory blocks, and each block size is 16x16x16bits +/// The loading block size is 16bitsx16x16, with a total of 1x2 blocks struct XE_2D_U16x16x16x1x2_V { template @@ -426,7 +426,7 @@ struct XE_2D_U16x16x16x1x2_V /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. /// From flat format in memory transform to VNNI format in registers. -/// Loads 2x1 memory blocks, and each block size is 16x16x16bits +/// The loading block size is 16bitsx16x16, with a total of 2x1 blocks struct XE_2D_U16x16x16x2x1_V { template @@ -445,7 +445,7 @@ struct XE_2D_U16x16x16x2x1_V /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. /// From flat format in memory transform to VNNI format in registers. -/// Loads 1x1 memory blocks, and each block size is 16x16x16bits +/// The loading block size is 16bitsx16x16, with a total of 1x1 blocks struct XE_2D_U16x16x16x1x1_V { template @@ -463,9 +463,8 @@ struct XE_2D_U16x16x16x1x1_V }; /// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// From flat format in memory transform to VNNI format in registers. -/// Loads 1x1 memory blocks, and each block size is 8x16x32bits +/// Stores an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The storing block size is 32bitsx8x16, with a total of 1x1 blocks struct XE_2D_U32x8x16x1x1_ST_N { template From b854995e8e50d95497b9692fb12136a427e38bcf Mon Sep 17 00:00:00 2001 From: Jiaxingla Date: Thu, 1 Aug 2024 18:29:00 -0700 Subject: [PATCH 36/36] remove useless copy --- include/cute/arch/copy_xe.hpp | 50 ++++++++++++---------------- include/cute/atom/copy_traits_xe.hpp | 22 ++---------- 2 files changed, 25 insertions(+), 47 deletions(-) diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 6dba92d7f8..20767c5a05 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -61,6 +61,9 @@ SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( SYCL_DEVICE_BUILTIN(intel::ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN(intel::ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); SYCL_DEVICE_BUILTIN(intel::uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord)); @@ -168,8 +171,8 @@ struct XE_2D_U32x8x16x1x1_LD_N /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 1x1 blocks. -struct XE_2D_U16x16x16x1x1_LD_N +/// The loading block size is 16bitsx8x16, with a total of 2x1 blocks. +struct XE_2D_U16x8x16x2x1_LD_N { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, @@ -177,20 +180,19 @@ struct XE_2D_U16x16x16x1x1_LD_N T *dst) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( + *(intel::ushort16 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } - struct PREFETCH { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord, CacheControl::kL1C_L3C); #else @@ -362,28 +364,6 @@ struct XE_2D_U32x8x16x2x1_LD_N } }; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx16x16, with a total of 2x1 blocks. -struct XE_2D_U16x16x16x2x1_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( - long(baseoffset), width - 1, height - 1, pitch - 1, coord); - *(intel::uint16 *)dst = *reinterpret_cast(&tmp); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } - - using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; -}; - /// @brief This function loads data from 2D memory surface. /// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. /// From flat format in memory transform to VNNI format in registers. @@ -459,7 +439,21 @@ struct XE_2D_U16x16x16x1x1_V #endif } - using PREFETCH = typename XE_2D_U16x16x16x1x1_LD_N::PREFETCH; + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; /// @brief This function loads data from 2D memory surface. diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 018f227060..261654506a 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -247,12 +247,12 @@ struct Copy_Traits }; template -struct Copy_Traits - : XE_2D_PF_Unpack { +struct Copy_Traits + : XE_2D_PF_Unpack { template CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( + : XE_2D_PF_Unpack( traits.tensor) {} // Logical thread id to thread idx @@ -359,22 +359,6 @@ struct Copy_Traits using CopyInternalType = uint; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - // 32 bits register file - using CopyInternalType = uint; -}; - template struct Copy_Traits : XE_2D_LD_Unpack {