From 0b5c9117ad69888de46058c196c60dfd21e25bda Mon Sep 17 00:00:00 2001 From: Jiaxingla <109135611+Jiaxingla@users.noreply.github.com> Date: Fri, 2 Aug 2024 17:00:25 +0800 Subject: [PATCH] Intel gpu backend gemm pipeline (#89) Enable the prefetch by copy atom. --------- Co-authored-by: Mehdi Goli --- benchmarks/common/benchmark_runner.hpp | 36 +- ...ench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp | 6 +- examples/sycl/pvc/pvc_gemm.cpp | 36 +- .../sycl/pvc/pvc_gemm_with_epilogue_relu.cpp | 30 +- include/cute/arch/copy_xe.hpp | 217 ++++++- include/cute/arch/mma_xe.hpp | 9 +- include/cute/atom/copy_traits_xe.hpp | 572 +++++++++++------- include/cute/atom/mma_traits_xe.hpp | 2 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 71 ++- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 20 +- 10 files changed, 634 insertions(+), 365 deletions(-) diff --git a/benchmarks/common/benchmark_runner.hpp b/benchmarks/common/benchmark_runner.hpp index 659a3e1436..4b77609730 100644 --- a/benchmarks/common/benchmark_runner.hpp +++ b/benchmarks/common/benchmark_runner.hpp @@ -320,42 +320,10 @@ struct PvcBenchmarkRunner : BenchmarkRunner { using ProblemShapeType = typename Base::ProblemShapeType; - cutlass::DeviceAllocation block_B_vnni; - - template - void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) - { - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } - } - } - } - void initialize(const ProblemShapeType& problem_size) override { Base::initialize(problem_size); - - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - block_B_vnni.reset(Base::block_B.size()); - - std::vector b(K * N * L); - std::vector b_vnni(b.size()); - - Base::block_B.copy_to_host(b.data()); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); - - block_B_vnni.copy_from_host(b_vnni.data()); } - + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) override { ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; @@ -364,7 +332,7 @@ struct PvcBenchmarkRunner : BenchmarkRunner { typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {Base::block_A.get(), Base::stride_A, block_B_vnni.get(), Base::stride_B}, + {Base::block_A.get(), Base::stride_A, Base::block_B.get(), Base::stride_B}, { {options.alpha, options.beta}, Base::block_C.get(), Base::stride_C, Base::block_D.get(), Base::stride_D diff --git a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp index ea90d89b7e..3203e7f367 100644 --- a/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp +++ b/benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp @@ -85,15 +85,15 @@ int main(int argc, const char** argv) using LayoutD = cutlass::layout::RowMajor; // Workgroup-level tile - using TileShape = Shape<_32, _256, _32>; + using TileShape = Shape<_256, _256, _32>; using TiledMma = TiledMMA< - MMA_Atom, + MMA_Atom, Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 5141a084cd..661016e809 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -55,24 +55,6 @@ static void fill_matrix(std::vector &vector) return static_cast( (rand() / double(RAND_MAX)) ); }); } - -template -static void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) -{ - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } - } - } -} - using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -89,7 +71,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(100), + m(4096), n(4096), k(4096), l(1), iterations(20), alpha(1.f), beta(0.f) { } @@ -108,7 +90,7 @@ struct Options { cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 100); + cmd.get_cmd_line_argument("iterations", iterations, 20); } /// Prints the usage statement. @@ -170,7 +152,6 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_B_vnni; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; @@ -231,7 +212,6 @@ struct ExampleRunner { block_A.reset(M * K * L); block_B.reset(K * N * L); - block_B_vnni.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); @@ -247,11 +227,9 @@ struct ExampleRunner { fill_matrix(a); fill_matrix(b); fill_matrix(c); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } @@ -272,7 +250,7 @@ struct ExampleRunner { typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {block_A.get(), stride_A, block_B.get(), stride_B}, {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; @@ -362,14 +340,14 @@ int main(int argc, const char** argv) using LayoutD = cutlass::layout::RowMajor; using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; // Workgroup-level tile - using TileShape = Shape<_32, _256, _32>; + using TileShape = Shape<_256, _256, _32>; - using TiledMma = TiledMMA, + using TiledMma = TiledMMA, Layout>, - Tile<_32,_64,_32>>; // Subgroup level-tile + Tile<_32,_64,_32>>; // Subgroup level-tile using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp index 2075379580..521d64f7d5 100644 --- a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -57,23 +57,6 @@ static void fill_matrix(std::vector &vector) }); } -template -static void vnni_matrix( - T* dst, const T* src, - int batch, int numRows, int numCols, int factor) -{ - for (int b = 0; b < batch; b++) { - for (int r = 0; r < numRows / factor; r++) { - for (int c = 0; c < numCols; c++) { - for (int k = 0; k < factor; k++) { - dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] = - src[((b * (numRows / factor) + r) * factor + k) * numCols + c]; - } - } - } - } -} - using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -171,7 +154,6 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_B_vnni; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_ref_D; @@ -238,7 +220,6 @@ struct ExampleRunner { block_A.reset(M * K * L); block_B.reset(K * N * L); - block_B_vnni.reset(K * N * L); block_C.reset(M * N * L); block_D.reset(M * N * L); block_ref_D.reset(M * N * L); @@ -247,18 +228,15 @@ struct ExampleRunner { // available through SYCL. std::vector a(K * M * L); std::vector b(K * N * L); - std::vector b_vnni(b.size()); std::vector c(M * N * L); std::vector d(M * N * L, ElementC{0}); fill_matrix(a); fill_matrix(b); fill_matrix(c); - vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2); syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB)); syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); } @@ -271,7 +249,7 @@ struct ExampleRunner { typename Gemm::GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_A.get(), stride_A, block_B_vnni.get(), stride_B}, + {block_A.get(), stride_A, block_B.get(), stride_B}, {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; @@ -361,12 +339,12 @@ int main(int argc, const char** argv) using LayoutD = cutlass::layout::RowMajor; using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; // Workgroup-level tile - using TileShape = Shape<_32, _256, _32>; + using TileShape = Shape<_256, _256, _32>; - using TiledMma = TiledMMA, + using TiledMma = TiledMMA, Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 3bfc5c8535..20767c5a05 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -44,12 +44,26 @@ namespace cute inline x { assert(false); } #endif +enum class CacheControl { + kDefault = 0, + kL1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + kL1UC_L3C = 2, // Override to L1 uncached and L3 cached + kL1C_L3UC = 3, // Override to L1 cached and L3 uncached + kL1C_L3C = 4, // Override to L1 cached and L3 cached + kL1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + kL1S_L3C = 6, // Override to L1 streaming load and L3 cached + kL1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord, intel::uint8 data)); SYCL_DEVICE_BUILTIN(intel::ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN(intel::ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); SYCL_DEVICE_BUILTIN(intel::uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord)); @@ -80,10 +94,30 @@ SYCL_DEVICE_BUILTIN(intel::int16 __builtin_IB_subgroup_block_read_flat_transform SYCL_DEVICE_BUILTIN(intel::int8 intel_subgroup_block_read_transform_u16_k16( long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, intel::coord_t coord)); - +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); #undef SYCL_DEVICE_BUILTIN +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 16bitsx8x16, with a total of 1x1 blocks. struct XE_2D_U16x8x16x1x1_LD_N { template @@ -98,8 +132,27 @@ struct XE_2D_U16x8x16x1x1_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 32bitsx8x16, with a total of 1x1 blocks. struct XE_2D_U32x8x16x1x1_LD_N { template @@ -116,7 +169,10 @@ struct XE_2D_U32x8x16x1x1_LD_N } }; -struct XE_2D_U16x16x16x1x1_LD_N +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 16bitsx8x16, with a total of 2x1 blocks. +struct XE_2D_U16x8x16x2x1_LD_N { template CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, @@ -124,14 +180,32 @@ struct XE_2D_U16x16x16x1x1_LD_N T *dst) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( + *(intel::ushort16 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( (long)baseoffset, width - 1, height - 1, pitch - 1, coord); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 16bitsx8x16, with a total of 4x2 blocks. struct XE_2D_U16x8x16x4x2_LD_N { template @@ -146,8 +220,27 @@ struct XE_2D_U16x8x16x4x2_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 16bitsx8x16, with a total of 2x2 blocks. struct XE_2D_U16x8x16x2x2_LD_N { template @@ -162,8 +255,27 @@ struct XE_2D_U16x8x16x2x2_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 16bitsx8x16, with a total of 1x2 blocks. struct XE_2D_U16x8x16x1x2_LD_N { template @@ -179,8 +291,27 @@ struct XE_2D_U16x8x16x1x2_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 16bitsx8x16, with a total of 4x1 blocks. struct XE_2D_U16x8x16x4x1_LD_N { template @@ -195,8 +326,27 @@ struct XE_2D_U16x8x16x4x1_LD_N CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The loading block size is 32bitsx8x16, with a total of 2x1 blocks. struct XE_2D_U32x8x16x2x1_LD_N { template @@ -214,23 +364,10 @@ struct XE_2D_U32x8x16x2x1_LD_N } }; -struct XE_2D_U16x16x16x2x1_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( - long(baseoffset), width - 1, height - 1, pitch - 1, coord); - *(intel::uint16 *)dst = *reinterpret_cast(&tmp); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } -}; - +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// The loading block size is 16bitsx16x16, with a total of 2x2 blocks struct XE_2D_U16x16x16x2x2_V { template @@ -242,8 +379,15 @@ struct XE_2D_U16x16x16x2x2_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + // using PREFETCH = typename XE_2D_U16x8x16x4x2_LD_N::PREFETCH; + using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// The loading block size is 16bitsx16x16, with a total of 1x2 blocks struct XE_2D_U16x16x16x1x2_V { template @@ -255,8 +399,14 @@ struct XE_2D_U16x16x16x1x2_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// The loading block size is 16bitsx16x16, with a total of 2x1 blocks struct XE_2D_U16x16x16x2x1_V { template @@ -268,8 +418,14 @@ struct XE_2D_U16x16x16x2x1_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; }; +/// @brief This function loads data from 2D memory surface. +/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// From flat format in memory transform to VNNI format in registers. +/// The loading block size is 16bitsx16x16, with a total of 1x1 blocks struct XE_2D_U16x16x16x1x1_V { template @@ -282,8 +438,27 @@ struct XE_2D_U16x16x16x1x1_V CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } + + struct PREFETCH { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; }; +/// @brief This function loads data from 2D memory surface. +/// Stores an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. +/// The storing block size is 32bitsx8x16, with a total of 1x1 blocks struct XE_2D_U32x8x16x1x1_ST_N { template @@ -300,4 +475,4 @@ struct XE_2D_U32x8x16x1x1_ST_N } }; -} // end namespace cute +} // end namespace diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 3d1bfb8f68..7c5ad7b74e 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -45,7 +45,11 @@ SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::i #undef SYCL_DEVICE_OCL namespace cute { -struct XE_8x16x16_F32BF16BF16F32_TN +//MxNxK_A,B,C,D +//# of vector component of a x subgroup-size x function name +//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); +//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 +struct XE_8x16x16_F32BF16BF16F32_TT { using DRegisters = intel::float8[1]; using ARegisters = intel::short8[1]; @@ -65,7 +69,8 @@ struct XE_8x16x16_F32BF16BF16F32_TN #endif } }; -struct XE_1x16x16_F32BF16BF16F32_TN +//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +struct XE_1x16x16_F32BF16BF16F32_TT { using DRegisters = float[1]; using ARegisters = short[1]; diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 63eae73dd4..261654506a 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -35,11 +35,11 @@ #include -namespace cute +namespace cute { template -CUTE_HOST_DEVICE constexpr +CUTE_HOST_DEVICE constexpr auto get_shape_WHD(cute::Stride, IntT, IntT> , cute::Shape shape_MKL) { return shape_MKL; } @@ -67,286 +67,406 @@ auto get_coordinates(cute::Stride, IntT> , } template -struct XE_2D_LD_Unpack -{ - GTensor tensor; - - using Copy_Traits = Copy_Traits; - - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const &traits, - Tensor>, SLayout> const &src, - Tensor &dst) - { - static_assert(is_rmem::value); - auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); - int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType); - int H = size<1>(shape_whd); - auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data()); - } - - template - CUTE_HOST_DEVICE constexpr auto - get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const - { - return make_tensor(make_inttuple_iter(coord), - make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)), - make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor))))); - } +struct XE_2D_LD_Unpack { + GTensor tensor; + + using Copy_Traits = Copy_Traits; + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + static_assert(is_rmem::value); + auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); + int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType); + int H = size<1>(shape_whd); + auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, + &*dst.data()); + } + + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + GShape const &shape, GStride const &stride_mul) const { + return make_tensor(make_inttuple_iter(coord), + make_layout(make_shape(_1 {}, get<0>(shape), get<1>(shape), + get<2>(shape)), + make_stride(_1 {}, E<0> {} * get<0>(stride_mul), + E<1> {} * get<1>(stride_mul), + E<2> {} * get<2>(stride(tensor))))); + } }; +/// ThrID: How many threads involved. +/// DstLayout: Size<0>(DstLayout) same as thrID, Size<1>(DstLayout) represents data layout that hold by each thread in bits. +/// TODO: SrcLayout just a placeHolder, not used. template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct XE_2D_PF_Unpack { + GTensor tensor; + + template + CUTE_HOST_DEVICE XE_2D_PF_Unpack(G_Tensor const &t) : tensor(t) {} + + using Copy_Traits = Copy_Traits; + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) { + using T = typename Copy_Traits::CopyInternalType; + auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); + int W = size<0>(shape_whd) * sizeof(T); + int H = size<1>(shape_whd); + auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); + CopyOp::template copy(traits.tensor.data() + z, W, H, W, + intel::coord_t {static_cast(x), static_cast(y)}); + } +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_PF_Unpack { + + template + CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) + : XE_2D_PF_Unpack( + traits.tensor) {} + + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_16, Stride<_16, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = uint; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = uint; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>>; // one coordinate - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>, - Stride<_16, Stride<_512, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_512, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16, Stride<_256, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - // 32 bits register file - using CopyInternalType = uint; -}; - -template -struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_1>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>>; // expected 4 coordinates - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - // 32 bits register file - using CopyInternalType = uint; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32, Stride<_512, _1>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + // 32 bits register file + using CopyInternalType = uint; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape< Shape< _8, _2>, Shape<_16, _2, _2>>>, - Stride<_16, Stride, Stride< _1, _512, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout< + Shape<_16, Shape, Shape<_16, _2, _2>>>, + Stride<_16, Stride, Stride<_1, _512, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape< _8, Shape<_16, _2, _2>>>, - Stride<_16, Stride<_1024, Stride< _1, _512, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_1024, Stride<_1, _512, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_16, _2>>>, + Stride<_16, Stride, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template struct Copy_Traits - : XE_2D_LD_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout< Shape<_16, Shape<_8, Shape<_16, _2>>>, - Stride<_16, Stride<_512, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + : XE_2D_LD_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>>, + Stride<_16, Stride<_512, Stride<_1, _256>>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; template -struct XE_2D_ST_Unpack -{ - GTensor tensor; - using Copy_Traits = Copy_Traits; - template - CUTE_HOST_DEVICE friend constexpr void - copy_unpack(Copy_Traits const &traits, - Tensor const &src, - Tensor>, DLayout> &dst) - { - static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x, z] = dst.data().coord_; - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data()); - } - - template - CUTE_HOST_DEVICE constexpr auto - get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const - { - return make_tensor(make_inttuple_iter(coord), - make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)), - make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor))))); - } +struct XE_2D_ST_Unpack { + GTensor tensor; + using Copy_Traits = Copy_Traits; + template + CUTE_HOST_DEVICE friend constexpr void copy_unpack( + Copy_Traits const &traits, Tensor const &src, + Tensor>, DLayout> &dst) { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); + auto [y, x, z] = dst.data().coord_; + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data()); + } + + template + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + GShape const &shape, GStride const &stride_mul) const { + return make_tensor(make_inttuple_iter(coord), + make_layout(make_shape(_1 {}, get<0>(shape), get<1>(shape), + get<2>(shape)), + make_stride(_1 {}, E<0> {} * get<0>(stride_mul), + E<1> {} * get<1>(stride_mul), + E<2> {} * get<2>(stride(tensor))))); + } }; template struct Copy_Traits - : XE_2D_ST_Unpack -{ - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, Stride<_0, _1>>; - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; - using CopyInternalType = uint; + : XE_2D_ST_Unpack { + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout + = Layout>, Stride<_32, Stride<_512, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + using CopyInternalType = uint; }; template -auto make_xe_2d_copy(Tensor gtensor) -{ +auto make_xe_2d_copy(Tensor gtensor) { using GTensor = Tensor; using Traits = Copy_Traits; - Traits traits{gtensor}; - return Copy_Atom{traits}; + Traits traits {gtensor}; + return Copy_Atom {traits}; } } // end namespace cute diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index 1cbefc872d..8dca2dba55 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -38,7 +38,7 @@ namespace cute { template <> -struct MMA_Traits +struct MMA_Traits { using ValTypeD = float; using ValTypeA = bfloat16_t; diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index f69ae7bdf0..e9314ace84 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -36,11 +36,10 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/atom/mma_atom.hpp" #include "cute/tensor_predicate.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// - + namespace cutlass::gemm::collective { using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -103,8 +102,11 @@ struct CollectiveMma< using MmaAtomShape = typename TiledMma::AtomShape_MNK; using SubgroupTileShape = decltype(tile_shape(TiledMma())); + static constexpr auto sg_per_wg_m = get<0>(WorkgroupTileShape{}) / get<0>(SubgroupTileShape{}); + static constexpr auto sg_per_wg_n = get<1>(WorkgroupTileShape{}) / get<1>(SubgroupTileShape{}); + static constexpr uint32_t MaxThreadsPerBlock = - cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{})* SubgroupSize; + cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{}) * SubgroupSize; static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group @@ -187,29 +189,70 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, _1>{}); - Tensor tBr = make_tensor(Shape(SubgroupTileShape{}) / 2>, Int>{}); + constexpr int version = + is_same_v ? 1 : 2; + + Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor(Shape(SubgroupTileShape{}) * version>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); Tensor tBr_view = make_tensor(static_cast(tBr).data(), Shape, Int, Int>{}, - Stride<_1, Int(SubgroupTileShape{}) / 2>, Int>{}); + Stride<_1, Int, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; + int K = size<1>(mainloop.gmem_tiled_copy_a.tensor); + + auto sub_group_id = ThreadIdxX() / SubgroupSize; + /* Cooperative prefetch + Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in one row/col use the same tile A/B. + Each thread loads sizeof(tile A or B) / numof(sg_per_wg_n or sg_per_wg_m). + + Currently, sg_per_wg_m x sg_per_wg_n = 4 x 8 is the most efficient + */ + // TODO: Replace the demo cooperative prefetch with more general way. + Tensor tAi = make_tensor( + make_inttuple_iter( + *gA.data() + + make_coord((sub_group_id % sg_per_wg_n % 4) * get<0>(MmaAtomShape{}), 0)), + make_layout(make_shape(_1{}, _1{}, K), + make_stride(_1{}, E<0>{}, E<1>{}))); + Tensor tBi = make_tensor( + make_inttuple_iter( + *gB.data() + + make_coord((sub_group_id / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}), + (sub_group_id / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}))), + make_layout(make_shape(_1{}, _1{}, K), + make_stride(_1{}, E<0>{}, E<1>{}))); // // Mainloop // - for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(MmaAtomShape()) * FragsK) - { - // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); - copy(mainloop.gmem_tiled_copy_b, gB(_,_,k/2), tBr); - - cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum); - } + int prefetch_k = 0; + + // Manually set the prefetch_distance to 3 + // TODO: Expose to users like stages parameter + int constexpr prefetch_distance = 3; + for (int i = 0; i < prefetch_distance; i++) { + prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); + prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); + prefetch_k += get<2>(SubgroupTileShape{}); + } + + for (int k_tile = 0, k = 0; k_tile < k_tile_count; + ++k_tile, k += get<2>(SubgroupTileShape{})) { + // Copy gmem to rmem for the first k_tile + copy(mainloop.gmem_tiled_copy_a, gA(_, _, k), tAr); + copy(mainloop.gmem_tiled_copy_b, gB(_, _, k), tBr); + + prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); + prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); + prefetch_k += get<2>(SubgroupTileShape{}); + + cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum); + } } }; diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 6e7aee895b..05b099bc24 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -102,7 +102,6 @@ class GemmUniversal< static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; - using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; @@ -177,7 +176,6 @@ class GemmUniversal< if constexpr (cute::rank(ProblemShape{}) == 4) { batch_count = cute::size<3>(params.problem_shape); } - return dim3( cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(WorkgroupTileShape{}))), cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(WorkgroupTileShape{}))), @@ -193,9 +191,7 @@ class GemmUniversal< CUTLASS_DEVICE void operator()(Params const& params, char* smem_buf) { - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - // Preconditions CUTE_STATIC_ASSERT(is_static::value); @@ -215,10 +211,11 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); + int sub_group_id = thread_idx / SubgroupSize; constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) - const int m_coord = BlockIdxX() * get<0>(subgroup_shape); - const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); + const int m_coord = BlockIdxX() * get<0>(workgroup_shape) + sub_group_id / CollectiveMainloop::sg_per_wg_n * get<0>(subgroup_shape); + const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + sub_group_id % CollectiveMainloop::sg_per_wg_n * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); @@ -226,11 +223,16 @@ class GemmUniversal< make_coord(m_coord, 0, 0), make_shape(_1{}, K, L), make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); + constexpr int version = + is_same_v + ? 1 + : 2; Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( make_coord(n_coord, 0, 0), - make_shape(Int{}, K / 2, L), - make_stride(get<1>(MmaAtomShape()), _1{})); + make_shape(Int{}, K, L), + make_stride(Int(MmaAtomShape())>{}, _1{})); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord @@ -271,7 +273,7 @@ class GemmUniversal< residue_mnk, thread_idx, smem_buf - ); + ); } };