diff --git a/benchmarks/pvc/gemm_configuration.hpp b/benchmarks/pvc/gemm_configuration.hpp index 332f37395d..cde3c70d92 100644 --- a/benchmarks/pvc/gemm_configuration.hpp +++ b/benchmarks/pvc/gemm_configuration.hpp @@ -76,12 +76,12 @@ struct Gemm_OperandB; template<> struct Gemm_OperandA { - using GmemTiledCopy = XE_2D_U16x8x16x4x2_LD_N; + using GmemTiledCopy = XE_2D_U16x8x16_LD_N; }; template<> struct Gemm_OperandB { - using GmemTiledCopy = XE_2D_U16x16x16x2x2_V; + using GmemTiledCopy = XE_2D_U16x16x16_LD_V; }; } // namespace details @@ -93,12 +93,12 @@ struct GemmConfiguration< bfloat16_t, LayoutB, float, LayoutC, float> { - using TileShape = Shape<_256, _256, _32>; + using TileShape = Shape<_256, _256, _16>; using DispatchPolicy = MainloopIntelPVC<3>;; using TiledMma = TiledMMA< MMA_Atom, - Layout>, - Tile<_32,_64,_32>>; + Layout>, + Tile<_64,_128,_16>>; // A using OperandA = detail::Gemm_OperandA; @@ -132,9 +132,9 @@ struct GemmConfiguration< float, TagToStrideC_t, FusionCallBacks, - XE_2D_U32x8x16x1x1_LD_N, + XE_2D_U32x8x16_LD_N, void, void, - XE_2D_U32x8x16x1x1_ST_N, + XE_2D_U32x8x16_ST_N, void, void>; using GemmKernel = kernel::GemmUniversal< diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index 982956ba3c..0c8ffd9168 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -306,15 +306,15 @@ int main(int argc, const char** argv) using LayoutC = cutlass::layout::RowMajor; using LayoutD = cutlass::layout::RowMajor; - using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; - using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V; + using GmemTiledCopyA = XE_2D_U16x8x16_LD_N; + using GmemTiledCopyB = XE_2D_U16x16x16_LD_V; // Workgroup-level tile - using TileShape = Shape<_256, _256, _32>; + using TileShape = Shape<_256, _128, _16>; using TiledMma = TiledMMA, - Layout>, - Tile<_32,_64,_32>>; // Subgroup level-tile + Layout>, + Tile<_64,_32,_16>>; // Subgroup level-tile constexpr int PipelineStages = 3; using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; @@ -333,9 +333,9 @@ int main(int argc, const char** argv) ElementOutput, cutlass::gemm::TagToStrideC_t, FusionCallBacks, - XE_2D_U32x8x16x1x1_LD_N, + XE_2D_U32x8x16_LD_N, void, void, - XE_2D_U32x8x16x1x1_ST_N, + XE_2D_U32x8x16_ST_N, void, void>; // Mainloop diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 7c34d49461..d8879bdc3f 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -29,450 +29,218 @@ * **************************************************************************************************/ #pragma once - -#include -#include -#include - -namespace cute -{ - +#include +#include +#include +#include #ifdef __SYCL_DEVICE_ONLY__ #define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x #else -#define SYCL_DEVICE_BUILTIN(x) \ - inline x { CUTE_INVALID_CONTROL_PATH("Trying to use XE built-in on non-XE hardware"); } +#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); } #endif -enum class CacheControl { - kDefault = 0, - kL1UC_L3UC = 1, // Override to L1 uncached and L3 uncached - kL1UC_L3C = 2, // Override to L1 uncached and L3 cached - kL1C_L3UC = 3, // Override to L1 cached and L3 uncached - kL1C_L3C = 4, // Override to L1 cached and L3 cached - kL1S_L3UC = 5, // Override to L1 streaming load and L3 uncached - kL1S_L3C = 6, // Override to L1 streaming load and L3 cached - kL1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached -}; - -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, intel::uint8 data)); -SYCL_DEVICE_BUILTIN(intel::ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::ushort16 intel_subgroup_block_read_u16_m8k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::int16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::int16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(intel::int8 intel_subgroup_block_read_transform_u16_k16( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord)); -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); -SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( - long baseoffset, int width_minus_one, int height_minus_one, - int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); - +using namespace cute; + +// prefetch +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uchar( + const __attribute__((opencl_global)) uint8_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ushort( + const __attribute__((opencl_global)) uint16_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint( + const __attribute__((opencl_global)) uint32_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint2( + const __attribute__((opencl_global)) uint32_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint4( + const __attribute__((opencl_global)) uint32_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_uint8( + const __attribute__((opencl_global)) uint32_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong( + const __attribute__((opencl_global)) uint64_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong2( + const __attribute__((opencl_global)) uint64_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong4( + const __attribute__((opencl_global)) uint64_t *base, int immElemOff, + enum CacheControl cacheOpt)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_lsc_prefetch_global_ulong8( + const __attribute__((opencl_global)) uint64_t *base, int immElemOff, + enum CacheControl cacheOpt)); #undef SYCL_DEVICE_BUILTIN -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 1x1 blocks. -struct XE_2D_U16x8x16x1x1_LD_N +namespace cute { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::ushort8 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } +template +struct XE_ATOMIC { + using SRegisters = S[1]; + using DRegisters = D[1]; - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif - } - }; -}; + CUTE_STATIC_ASSERT(is_same_v || is_same_v || is_same_v); -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 32bitsx8x16, with a total of 1x1 blocks. -struct XE_2D_U32x8x16x1x1_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { + template + CUTE_HOST_DEVICE static void + copy(S_ const& src, D_ & dst) { #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 4, "Expected T to have size 4"); - *(intel::uint8 *)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord); + auto v = sycl::atomic_ref(*&dst); + v += static_cast(*&src); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } -}; - -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 2x1 blocks. -struct XE_2D_U16x8x16x2x1_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::ushort16 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif + #endif } - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif - } - }; }; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 4x2 blocks. -struct XE_2D_U16x8x16x4x2_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::ushort64 *)dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( - long(baseoffset), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } +template +struct XE_1D_LDSM { + using SRegisters = S[1]; + using DRegisters = D[1]; - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif - } - }; -}; + CUTE_STATIC_ASSERT(sizeof(D) % sizeof(S) == 0, + "dst failed to vectorize into registers"); + static constexpr size_t N = sizeof(D) / sizeof(S); + CUTE_STATIC_ASSERT(N == 1 || N == 2 || N == 4 || N == 8, + "register vector only supports 1, 2, 4, 8"); -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 2x2 blocks. -struct XE_2D_U16x8x16x2x2_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { + template + CUTE_HOST_DEVICE static void + copy(const S_ &src, D_ &dst) { #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::ushort32*) dst = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( - long(baseoffset), width - 1, height - 1, pitch - 1, coord); + CUTE_STATIC_ASSERT(sizeof(S_) == sizeof(S)); + auto sg = sycl::ext::oneapi::experimental::this_nd_item<3>().get_sub_group(); + *(sycl::vec*)(&dst) + = sg.load(sycl::address_space_cast(&*&src)); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif + #endif } - - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif - } - }; }; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 1x2 blocks. -struct XE_2D_U16x8x16x1x2_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - intel::ushort16 tmp = (intel_subgroup_block_read_u16_m8k16v2( - (long)baseoffset, width, height, pitch, coord)); - *(intel::ushort16 *)dst = *reinterpret_cast(&tmp); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } +template +struct PREFETCH { + using SRegisters = S[1]; + using DRegisters = D[1]; - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { + template + CUTE_HOST_DEVICE static void copy(const S_ &src, D_ &dst) { #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); + if constexpr(sizeof(D) == 1) { + __builtin_IB_lsc_prefetch_global_uchar( + (const __attribute__((opencl_global)) uint8_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } + else if constexpr(sizeof(D) == 2) { + __builtin_IB_lsc_prefetch_global_ushort( + (const __attribute__((opencl_global)) uint16_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } + else if constexpr(sizeof(D) == 4) { + __builtin_IB_lsc_prefetch_global_uint( + (const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } + else if constexpr(sizeof(D) == 8) { + __builtin_IB_lsc_prefetch_global_uint2( + (const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } + else if constexpr(sizeof(D) == 16) { + __builtin_IB_lsc_prefetch_global_uint4( + (const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } + else if constexpr(sizeof(D) == 32) { + __builtin_IB_lsc_prefetch_global_uint8( + (const __attribute__((opencl_global)) uint32_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } + else if constexpr(sizeof(D) == 64) { + __builtin_IB_lsc_prefetch_global_ulong8( + (const __attribute__((opencl_global)) uint64_t *)(&*&src), 0, CacheControl::kL1C_L3C); + } #else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); #endif - } - }; + } }; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 16bitsx8x16, with a total of 4x1 blocks. -struct XE_2D_U16x8x16x4x1_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::ushort32*) dst = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( - long(baseoffset), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } +template +struct XE_1D_LOAD_GLOBAL { + using SRegisters = S[1]; + using DRegisters = D[1]; - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif - } - }; -}; + CUTE_STATIC_ASSERT(sizeof(D) % sizeof(S) == 0, + "dst failed to vectorize into registers"); + static constexpr size_t N = sizeof(D) / sizeof(S); + CUTE_STATIC_ASSERT(N == 1 || N == 2 || N == 4 || N == 8, + "register vector only supports 1, 2, 4, 8"); -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The loading block size is 32bitsx8x16, with a total of 2x1 blocks. -struct XE_2D_U32x8x16x2x1_LD_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord, - T *dst) { + template + CUTE_HOST_DEVICE static void + copy(const S_ &src, D_ &dst) { #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 4, "Expected T to have size 4"); - intel::uint16 tmp = __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( - long(baseoffset), width - 1, height - 1, pitch - 1, coord); - *(intel::uint16 *)dst = *reinterpret_cast(&tmp); + CUTE_STATIC_ASSERT(sizeof(S_) == sizeof(S)); + CUTE_STATIC_ASSERT(sizeof(D_) == sizeof(D)); + auto sg = sycl::ext::oneapi::experimental::this_nd_item<3>().get_sub_group(); + *(sycl::vec*)(&dst) + = sg.load(sycl::address_space_cast(&*&src)); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif + #endif } -}; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// From flat format in memory transform to VNNI format in registers. -/// The loading block size is 16bitsx16x16, with a total of 2x2 blocks -struct XE_2D_U16x16x16x2x2_V -{ - template - CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::uint32*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long(base_address), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } + using PREFETCH = PREFETCH; - // using PREFETCH = typename XE_2D_U16x8x16x4x2_LD_N::PREFETCH; - using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; }; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// From flat format in memory transform to VNNI format in registers. -/// The loading block size is 16bitsx16x16, with a total of 1x2 blocks -struct XE_2D_U16x16x16x1x2_V -{ - template - CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::int16*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long(base_address), width - 1, height - 1, pitch - 1, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } +template +struct XE_1D_STSM { + using SRegisters = S[1]; + using DRegisters = D[1]; - using PREFETCH = typename XE_2D_U16x8x16x2x2_LD_N::PREFETCH; -}; + CUTE_STATIC_ASSERT(sizeof(S) % sizeof(D) == 0, + "src failed to vectorize into registers"); + static constexpr size_t N = sizeof(S) / sizeof(D); + CUTE_STATIC_ASSERT(N == 1 || N == 2 || N == 4 || N == 8, + "register vector only supports 1, 2, 4, 8"); -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// From flat format in memory transform to VNNI format in registers. -/// The loading block size is 16bitsx16x16, with a total of 2x1 blocks -struct XE_2D_U16x16x16x2x1_V -{ - template - CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { + template + CUTE_HOST_DEVICE static void + copy(S_ const& src, D_ & dst) { #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - *(intel::int16*) dst = __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long(base_address), width - 1, height - 1, pitch - 1, coord); + auto sg = sycl::ext::oneapi::experimental::this_nd_item<3>().get_sub_group(); + sg.store(sycl::address_space_cast(&*&dst), *(sycl::vec*)(&src)); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } - - using PREFETCH = typename XE_2D_U16x8x16x4x1_LD_N::PREFETCH; }; -/// @brief This function loads data from 2D memory surface. -/// Loads an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// From flat format in memory transform to VNNI format in registers. -/// The loading block size is 16bitsx16x16, with a total of 1x1 blocks -struct XE_2D_U16x16x16x1x1_V -{ - template - CUTE_HOST_DEVICE static void copy(const void *base_address, int width, int height, int pitch, intel::coord_t coord, T* dst) { - #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: - *(intel::int8*) dst = intel_subgroup_block_read_transform_u16_k16((long)base_address, width, height, pitch, coord); - #else - CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); - #endif - } +template +struct XE_1D_STORE_GLOBAL { + using SRegisters = S[1]; + using DRegisters = D[1]; - struct PREFETCH { - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, - int height, int pitch, intel::coord_t coord) { -#if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 2, "Expected T to have size 2"); - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - CacheControl::kL1C_L3C); -#else - CUTE_INVALID_CONTROL_PATH( - "Trying to use block prefetch on non-PVC hardware"); -#endif - } - }; -}; + CUTE_STATIC_ASSERT(sizeof(S) % sizeof(D) == 0, + "src failed to vectorize into registers"); + static constexpr size_t N = sizeof(S) / sizeof(D); + CUTE_STATIC_ASSERT(N == 1 || N == 2 || N == 4 || N == 8, + "register vector only supports 1, 2, 4, 8"); -/// @brief This function loads data from 2D memory surface. -/// Stores an array of rectangular regions coord(X,Y)..coord(X+W,Y+H) from global memory into registers. -/// The storing block size is 32bitsx8x16, with a total of 1x1 blocks -struct XE_2D_U32x8x16x1x1_ST_N -{ - template - CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, - int pitch, intel::coord_t coord, const T *src) { + template + CUTE_HOST_DEVICE static void + copy(S_ const& src, D_ &dst) { #if defined(SYCL_INTEL_TARGET) - static_assert(sizeof(T) == 4, "Expected T to have size 4"); - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( - (long)baseoffset, width - 1, height - 1, pitch - 1, coord, - *(intel::uint8 *)src); + auto sg = sycl::ext::oneapi::experimental::this_nd_item<3>().get_sub_group(); + sg.store(sycl::address_space_cast(&*&dst), *(sycl::vec*)(&src)); #else CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); #endif } }; - -} // end namespace +} // end namespace cute diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index c4a72b05e2..7d86c44544 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -40,8 +40,33 @@ #define SYCL_DEVICE_OCL(x) inline x { CUTE_INVALID_CONTROL_PATH("Trying to use XE built-in on non-XE hardware"); } #endif +// mma_bf16 SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc)); +SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short4 a, cute::intel::int8 b, cute::intel::float4 acc)); +SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_bf16_bf16_matrix_mad_k16(cute::intel::short2 a, cute::intel::int8 b, cute::intel::float2 acc)); SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::intel::int8 b, float acc)); +// mma_half +SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_f16_f16_matrix_mad_k16(cute::intel::short8 a, cute::intel::int8 b, cute::intel::float8 acc)); +SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_f16_f16_matrix_mad_k16(cute::intel::short4 a, cute::intel::int8 b, cute::intel::float4 acc)); +SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_f16_f16_matrix_mad_k16(cute::intel::short2 a, cute::intel::int8 b, cute::intel::float2 acc)); +SYCL_DEVICE_OCL(float intel_sub_group_f16_f16_matrix_mad_k16(short a, cute::intel::int8 b, float acc)); +// mma_s8 +SYCL_DEVICE_OCL(cute::intel::int8 intel_sub_group_i8_i8_matrix_mad_k32(cute::intel::short8 a, cute::intel::int8 b, cute::intel::int8 acc)); +SYCL_DEVICE_OCL(cute::intel::int4 intel_sub_group_i8_i8_matrix_mad_k32(cute::intel::short4 a, cute::intel::int8 b, cute::intel::int4 acc)); +SYCL_DEVICE_OCL(cute::intel::int2 intel_sub_group_i8_i8_matrix_mad_k32(cute::intel::short2 a, cute::intel::int8 b, cute::intel::int2 acc)); +SYCL_DEVICE_OCL(int intel_sub_group_i8_i8_matrix_mad_k32(short a, cute::intel::int8 b, int acc)); +// mma_u8 +SYCL_DEVICE_OCL(cute::intel::int8 intel_sub_group_u8_u8_matrix_mad_k32(cute::intel::ushort8 a, cute::intel::uint8 b, cute::intel::int8 acc)); +SYCL_DEVICE_OCL(cute::intel::int4 intel_sub_group_u8_u8_matrix_mad_k32(cute::intel::ushort4 a, cute::intel::uint8 b, cute::intel::int4 acc)); +SYCL_DEVICE_OCL(cute::intel::int2 intel_sub_group_u8_u8_matrix_mad_k32(cute::intel::ushort2 a, cute::intel::uint8 b, cute::intel::int2 acc)); +SYCL_DEVICE_OCL(int intel_sub_group_u8_u8_matrix_mad_k32(ushort a, cute::intel::uint8 b, int acc)); +// mma_tf32 +SYCL_DEVICE_OCL(cute::intel::float8 intel_sub_group_tf32_tf32_matrix_mad_k8_f32(cute::intel::float4 a, cute::intel::float8 b, cute::intel::float8 acc)); +SYCL_DEVICE_OCL(cute::intel::float4 intel_sub_group_tf32_tf32_matrix_mad_k8_f32(cute::intel::float2 a, cute::intel::float8 b, cute::intel::float4 acc)); +SYCL_DEVICE_OCL(cute::intel::float2 intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float a, cute::intel::float8 b, cute::intel::float2 acc)); +SYCL_DEVICE_OCL(float intel_sub_group_tf32_tf32_matrix_mad_k8_f32(float a, cute::intel::float8 b, float acc)); + + #undef SYCL_DEVICE_OCL namespace cute { @@ -65,7 +90,47 @@ struct XE_8x16x16_F32BF16BF16F32_TT #if defined(SYCL_INTEL_TARGET) d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_BF16BF16F32F32_NN on non-PVC hardware"); + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware"); +#endif + } +}; +struct XE_4x16x16_F32BF16BF16F32_TT +{ + using DRegisters = intel::float4[1]; + using ARegisters = intel::short4[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::float4[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float4 & d, + intel::short4 const& a, + intel::int8 const& b, + intel::float4 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware"); +#endif + } +}; +struct XE_2x16x16_F32BF16BF16F32_TT +{ + using DRegisters = intel::float2[1]; + using ARegisters = intel::short2[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::float2[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float2 & d, + intel::short2 const& a, + intel::int8 const& b, + intel::float2 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32BF16BF16F32_TT on non-PVC hardware"); #endif } }; @@ -86,7 +151,351 @@ struct XE_1x16x16_F32BF16BF16F32_TT #if defined(SYCL_INTEL_TARGET) d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); #else - CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_BF16BF16F32F32_NN on non-PVC hardware"); + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_F32BF16BF16F32_TT on non-PVC hardware"); +#endif + } +}; + +//MxNxK_A,B,C,D +//# of vector component of a x subgroup-size x function name +//float8 intel_sub_group_f16_f16_matrix_mad_k16(short8 a, int8 b, int8 acc); +//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 +struct XE_8x16x16_F32F16F16F32_TT +{ + using DRegisters = intel::float8[1]; + using ARegisters = intel::short8[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::float8[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float8 & d, + intel::short8 const& a, + intel::int8 const& b, + intel::float8 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x16_F32F16F16F32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_4x16x16_F32F16F16F32_TT +{ + using DRegisters = intel::float4[1]; + using ARegisters = intel::short4[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::float4[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float4 & d, + intel::short4 const& a, + intel::int8 const& b, + intel::float4 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x16_F32F16F16F32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_2x16x16_F32F16F16F32_TT +{ + using DRegisters = intel::float2[1]; + using ARegisters = intel::short2[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::float2[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float2 & d, + intel::short2 const& a, + intel::int8 const& b, + intel::float2 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x16_F32F16F16F32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_1x16x16_F32F16F16F32_TT +{ + using DRegisters = float[1]; + using ARegisters = short[1]; + using BRegisters = intel::int8[1]; + using CRegisters = float[1]; + + CUTE_HOST_DEVICE static void + fma(float & d, + short const& a, + intel::int8 const& b, + float const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_f16_f16_matrix_mad_k16(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x16_F32F16F16F32_TT on non-PVC hardware"); +#endif + } +}; + +//MxNxK_A,B,C,D +//# of vector component of a x subgroup-size x function name +//float8 intel_sub_group_i8_i8_matrix_mad_k16(short8 a, int8 b, float8 acc); +//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 +struct XE_8x16x32_S32S8S8S32_TT +{ + using DRegisters = intel::int8[1]; + using ARegisters = intel::short8[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::int8[1]; + + CUTE_HOST_DEVICE static void + fma(intel::int8 & d, + intel::short8 const& a, + intel::int8 const& b, + intel::int8 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x32_S32S8S8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_4x16x32_S32S8S8S32_TT +{ + using DRegisters = intel::int4[1]; + using ARegisters = intel::short4[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::int4[1]; + + CUTE_HOST_DEVICE static void + fma(intel::int4 & d, + intel::short4 const& a, + intel::int8 const& b, + intel::int4 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x32_S32S8S8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_2x16x32_S32S8S8S32_TT +{ + using DRegisters = intel::int2[1]; + using ARegisters = intel::short2[1]; + using BRegisters = intel::int8[1]; + using CRegisters = intel::int2[1]; + + CUTE_HOST_DEVICE static void + fma(intel::int2 & d, + intel::short2 const& a, + intel::int8 const& b, + intel::int2 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x32_S32S8S8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_1x16x32_S32S8S8S32_TT +{ + using DRegisters = int[1]; + using ARegisters = short[1]; + using BRegisters = intel::int8[1]; + using CRegisters = int[1]; + + CUTE_HOST_DEVICE static void + fma(int & d, + short const& a, + intel::int8 const& b, + int const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_i8_i8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x32_S32S8S8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_8x16x32_S32U8U8S32_TT +{ + using DRegisters = intel::int8[1]; + using ARegisters = intel::ushort8[1]; + using BRegisters = intel::uint8[1]; + using CRegisters = intel::int8[1]; + + CUTE_HOST_DEVICE static void + fma(intel::int8 & d, + intel::ushort8 const& a, + intel::uint8 const& b, + intel::int8 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x32_S32U8U8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_4x16x32_S32U8U8S32_TT +{ + using DRegisters = intel::int4[1]; + using ARegisters = intel::ushort4[1]; + using BRegisters = intel::uint8[1]; + using CRegisters = intel::int4[1]; + + CUTE_HOST_DEVICE static void + fma(intel::int4 & d, + intel::ushort4 const& a, + intel::uint8 const& b, + intel::int4 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x32_S32U8U8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_2x16x32_S32U8U8S32_TT +{ + using DRegisters = intel::int2[1]; + using ARegisters = intel::ushort2[1]; + using BRegisters = intel::uint8[1]; + using CRegisters = intel::int2[1]; + + CUTE_HOST_DEVICE static void + fma(intel::int2 & d, + intel::ushort2 const& a, + intel::uint8 const& b, + intel::int2 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x32_S32U8U8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_1x16x32_S32U8U8S32_TT +{ + using DRegisters = int[1]; + using ARegisters = ushort[1]; + using BRegisters = intel::uint8[1]; + using CRegisters = int[1]; + + CUTE_HOST_DEVICE static void + fma(int & d, + ushort const& a, + intel::uint8 const& b, + int const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_u8_u8_matrix_mad_k32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x32_S32U8U8S32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_8x16x8_F32TF32TF32F32_TT +{ + using DRegisters = intel::float8[1]; + using ARegisters = intel::float4[1]; + using BRegisters = intel::float8[1]; + using CRegisters = intel::float8[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float8 & d, + intel::float4 const& a, + intel::float8 const& b, + intel::float8 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_8x16x8_F32TF32TF32F32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_4x16x8_F32TF32TF32F32_TT +{ + using DRegisters = intel::float4[1]; + using ARegisters = intel::float2[1]; + using BRegisters = intel::float8[1]; + using CRegisters = intel::float4[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float4 & d, + intel::float2 const& a, + intel::float8 const& b, + intel::float4 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_4x16x8_F32TF32TF32F32_TT on non-PVC hardware"); +#endif + } +}; + +struct XE_2x16x8_F32TF32TF32F32_TT +{ + using DRegisters = intel::float2[1]; + using ARegisters = float[1]; + using BRegisters = intel::float8[1]; + using CRegisters = intel::float2[1]; + + CUTE_HOST_DEVICE static void + fma(intel::float2 & d, + float const& a, + intel::float8 const& b, + intel::float2 const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_2x16x8_F32TF32TF32F32_TT on non-PVC hardware"); +#endif + } +}; +//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +struct XE_1x16x8_F32TF32TF32F32_TT +{ + using DRegisters = float[1]; + using ARegisters = float[1]; + using BRegisters = intel::float8[1]; + using CRegisters = float[1]; + + CUTE_HOST_DEVICE static void + fma(float & d, + float const& a, + intel::float8 const& b, + float const& c) + { +#if defined(SYCL_INTEL_TARGET) + d = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(a, b, c); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use XE_1x16x8_F32TF32TF32F32_TT on non-PVC hardware"); #endif } }; diff --git a/include/cute/arch/xe_copy_1B.hpp b/include/cute/arch/xe_copy_1B.hpp new file mode 100644 index 0000000000..214f787ab7 --- /dev/null +++ b/include/cute/arch/xe_copy_1B.hpp @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x +#else +#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); } +#endif + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x +#else +#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#endif + +using namespace cute; + +// 8bits No transform No transpose +SYCL_DEVICE_BUILTIN(ushort __builtin_IB_subgroup_block_read_flat_u8_m1k32v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort2 __builtin_IB_subgroup_block_read_flat_u8_m2k32v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort4 __builtin_IB_subgroup_block_read_flat_u8_m4k32v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort8 __builtin_IB_subgroup_block_read_flat_u8_m8k32v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort16 __builtin_IB_subgroup_block_read_flat_u8_m16k32v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort32 __builtin_IB_subgroup_block_read_flat_u8_m32k32v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +SYCL_DEVICE_BUILTIN( + intel::ushort2 __builtin_IB_subgroup_block_read_flat_u8_m1k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort4 __builtin_IB_subgroup_block_read_flat_u8_m2k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort8 __builtin_IB_subgroup_block_read_flat_u8_m4k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort16 __builtin_IB_subgroup_block_read_flat_u8_m8k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort32 __builtin_IB_subgroup_block_read_flat_u8_m16k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort64 __builtin_IB_subgroup_block_read_flat_u8_m32k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + + +// 8bits VNNI transform No transpose +SYCL_DEVICE_BUILTIN( + intel::uint8 __builtin_IB_subgroup_block_read_flat_transform_u8_k32( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint16 __builtin_IB_subgroup_block_read_flat_transform_u8_k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint32 __builtin_IB_subgroup_block_read_flat_transform_u8_k32v4( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +// 8bits No transform No transpose +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u8_m1k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uchar data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u8_m2k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uchar2 data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u8_m4k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uchar4)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u8_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uchar8)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u8_m8k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uchar8)); +#undef SYCL_DEVICE_BUILTIN + +#undef __global +#define __global __attribute__((opencl_global)) +// 8 bits No transform No transpose +SYCL_DEVICE_OCL(ushort intel_sub_group_block_read_8b_1r32c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort2 intel_sub_group_block_read_8b_2r32c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort4 intel_sub_group_block_read_8b_4r32c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort8 intel_sub_group_block_read_8b_8r32c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort16 intel_sub_group_block_read_8b_16r32c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +SYCL_DEVICE_OCL(intel::ushort2 intel_sub_group_block_read_8b_1r32x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort4 intel_sub_group_block_read_8b_2r32x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort8 intel_sub_group_block_read_8b_4r32x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort16 intel_sub_group_block_read_8b_8r32x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort32 intel_sub_group_block_read_8b_16r32x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort64 intel_sub_group_block_read_8b_32r32x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 8bits VNNI transform No transpose +SYCL_DEVICE_OCL(intel::uint8 intel_sub_group_block_read_transform_8b_32r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint16 intel_sub_group_block_read_transform_8b_32r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint32 intel_sub_group_block_read_transform_8b_32r16x4c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 8bits store +SYCL_DEVICE_OCL(void intel_sub_group_block_write_8b_1r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uchar data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_8b_2r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uchar2 data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_8b_4r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uchar4 data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_8b_8r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uchar8 data)); + + +// 2D prefetch +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_8b_1r32x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_8b_2r32x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_8b_4r32x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_8b_8r32x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_8b_32r16x1c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +#undef SYCL_DEVICE_OCL + +namespace cute +{ +struct XE_2D_U8x1x32_LD_N { + + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m1k32v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x2x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m2k32v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x2x32_ST_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + __builtin_IB_subgroup_block_write_flat_u16_m2k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::ushort2 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x4x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m4k32v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x8x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m8k32v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x16x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m16k32v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x32x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m32k32v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x1x64_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m1k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_8b_1r32x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x2x64_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m2k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_8b_2r32x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x4x64_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m4k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_8b_4r32x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x8x64_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m8k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_8b_8r32x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x16x64_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m16k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x32x64_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u8_m32k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + + + +struct XE_2D_U8x32x16_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u8_k32( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_8b_32r16x1c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U8x32x32_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u8_k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x32x64_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u8_k32v4( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x1x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + __builtin_IB_subgroup_block_write_flat_u8_m1k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uchar *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x2x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + __builtin_IB_subgroup_block_write_flat_u8_m2k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uchar2 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x4x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + __builtin_IB_subgroup_block_write_flat_u8_m4k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uchar4 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x8x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + __builtin_IB_subgroup_block_write_flat_u8_m8k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uchar8 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U8x8x32_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 1, "Expected T to have size 1"); + __builtin_IB_subgroup_block_write_flat_u8_m8k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uchar8 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; +} // end namespace cute diff --git a/include/cute/arch/xe_copy_2B.hpp b/include/cute/arch/xe_copy_2B.hpp new file mode 100644 index 0000000000..83ba9af4d1 --- /dev/null +++ b/include/cute/arch/xe_copy_2B.hpp @@ -0,0 +1,798 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x +#else +#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); } +#endif + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x +#else +#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#endif + +using namespace cute; + +SYCL_DEVICE_BUILTIN(intel::ushort16 intel_subgroup_block_read_u16_m8k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +SYCL_DEVICE_BUILTIN(intel::int8 intel_subgroup_block_read_transform_u16_k16( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, enum CacheControl cache_control)); + +// 16 bits No transform No transpose +SYCL_DEVICE_BUILTIN(ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +SYCL_DEVICE_BUILTIN( + intel::ushort2 __builtin_IB_subgroup_block_read_flat_u16_m1k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort4 __builtin_IB_subgroup_block_read_flat_u16_m2k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort8 __builtin_IB_subgroup_block_read_flat_u16_m4k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort16 __builtin_IB_subgroup_block_read_flat_u16_m8k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +// 16bits VNNI transform No transpose +SYCL_DEVICE_BUILTIN( + intel::uint8 __builtin_IB_subgroup_block_read_flat_transform_u16_k16( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +// 16bits +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u16_m1k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, ushort data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u16_m2k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::ushort2 data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u16_m4k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::ushort4 data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u16_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::ushort8 data)); +#undef SYCL_DEVICE_BUILTIN + +#undef __global__ +#define __global __attribute__((opencl_global)) +// 16bits No transform No transpose +SYCL_DEVICE_OCL(ushort intel_sub_group_block_read_16b_1r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort2 intel_sub_group_block_read_16b_2r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort4 intel_sub_group_block_read_16b_4r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort8 intel_sub_group_block_read_16b_8r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort16 intel_sub_group_block_read_16b_16r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort32 intel_sub_group_block_read_16b_32r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +SYCL_DEVICE_OCL(intel::ushort2 intel_sub_group_block_read_16b_1r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort4 intel_sub_group_block_read_16b_2r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort8 intel_sub_group_block_read_16b_4r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort16 intel_sub_group_block_read_16b_8r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort32 intel_sub_group_block_read_16b_16r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ushort64 intel_sub_group_block_read_16b_32r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 16bits VNNI transform No transpose +SYCL_DEVICE_OCL(intel::uint8 intel_sub_group_block_read_transform_16b_16r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint16 intel_sub_group_block_read_transform_16b_32r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint16 intel_sub_group_block_read_transform_16b_16r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint32 intel_sub_group_block_read_transform_16b_32r16x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 16bits store +SYCL_DEVICE_OCL(void intel_sub_group_block_write_16b_1r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, ushort data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_16b_2r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::ushort2 data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_16b_4r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::ushort4 data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_16b_8r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::ushort8 data)); + +// 2D prefetch +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_16b_1r16x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_16b_2r16x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_16b_4r16x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_16b_8r16x2c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_16b_16r16x1c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +#undef SYCL_DEVICE_OCL + +namespace cute +{ +struct XE_2D_U16x1x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m1k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x2x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m2k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x4x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m4k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x8x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m8k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x16x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m16k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x32x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m32k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x1x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m1k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_16b_1r16x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x2x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m2k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_16b_2r16x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x4x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m4k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_16b_4r16x2c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x8x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m8k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x16x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m16k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x32x32_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u16_m32k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + // __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2( + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x16x16_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u16_k16( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x32x16_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u16_k32( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x16x32_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x32x32_LD_V { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2( + (long)baseoffset, width - 1, height - 1, pitch - 1, coord, + CacheControl::kL1C_L3C); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U16x16x8_LD_T { + using inst_dtype = uint32_t; + + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k4( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x16x16_LD_T { + using inst_dtype = uint32_t; + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 2, "Expected T to have size 2"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k8( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x1x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + // static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_write_flat_u16_m1k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(ushort *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x2x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + // static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_write_flat_u16_m2k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::ushort2 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x4x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + // static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_write_flat_u16_m4k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::ushort4 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U16x8x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + // static_assert(sizeof(T) == 2, "Expected T to have size 2"); + __builtin_IB_subgroup_block_write_flat_u16_m8k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::ushort8 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; +} // end namespace cute diff --git a/include/cute/arch/xe_copy_4B.hpp b/include/cute/arch/xe_copy_4B.hpp new file mode 100644 index 0000000000..0f3a90529f --- /dev/null +++ b/include/cute/arch/xe_copy_4B.hpp @@ -0,0 +1,697 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x +#else +#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); } +#endif + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x +#else +#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#endif + +enum class CacheControl { + kDefault = 0, + kL1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + kL1UC_L3C = 2, // Override to L1 uncached and L3 cached + kL1C_L3UC = 3, // Override to L1 cached and L3 uncached + kL1C_L3C = 4, // Override to L1 cached and L3 cached + kL1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + kL1S_L3C = 6, // Override to L1 streaming load and L3 cached + kL1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + +using namespace cute; + +// 32bits specific for tf32 No transform No transpose +SYCL_DEVICE_BUILTIN( + uint __builtin_IB_subgroup_block_read_flat_u32_m1k8v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + uint __builtin_IB_subgroup_block_read_flat_u32_m2k8v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint2 __builtin_IB_subgroup_block_read_flat_u32_m4k8v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint4 __builtin_IB_subgroup_block_read_flat_u32_m8k8v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint8 __builtin_IB_subgroup_block_read_flat_u32_m16k8v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint16 __builtin_IB_subgroup_block_read_flat_u32_m32k8v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +SYCL_DEVICE_BUILTIN( + intel::uint2 __builtin_IB_subgroup_block_read_flat_u32_m1k8v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k8v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k8v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k8v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint32 __builtin_IB_subgroup_block_read_flat_u32_m32k8v2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +// 32bits No transform No transpose +SYCL_DEVICE_BUILTIN(uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint32 __builtin_IB_subgroup_block_read_flat_u32_m32k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +// 32bits No transform Transpose +SYCL_DEVICE_BUILTIN(uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint2 __builtin_IB_subgroup_block_read_flat_transpose_u32_k2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint4 __builtin_IB_subgroup_block_read_flat_transpose_u32_k4( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::uint8 __builtin_IB_subgroup_block_read_flat_transpose_u32_k8( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); + +// 32bits +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, uint data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uint2 data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uint4 data)); +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord, intel::uint8 data)); + +#undef SYCL_DEVICE_BUILTIN + +#undef __global +#define __global __attribute__((opencl_global)) +// 32bits specific for tf32 No transform No transpose +SYCL_DEVICE_OCL(uint intel_sub_group_block_read_32b_1r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(uint intel_sub_group_block_read_32b_2r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint2 intel_sub_group_block_read_32b_4r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint4 intel_sub_group_block_read_32b_8r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint8 intel_sub_group_block_read_32b_16r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint16 intel_sub_group_block_read_32b_32r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +SYCL_DEVICE_OCL(intel::uint2 intel_sub_group_block_read_32b_1r8x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint2 intel_sub_group_block_read_32b_2r8x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint4 intel_sub_group_block_read_32b_4r8x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint8 intel_sub_group_block_read_32b_8r8x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint16 intel_sub_group_block_read_32b_16r8x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint32 intel_sub_group_block_read_32b_32r8x2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 32bits No transform No transpose +SYCL_DEVICE_OCL(uint intel_sub_group_block_read_32b_1r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint2 intel_sub_group_block_read_32b_2r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint4 intel_sub_group_block_read_32b_4r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint8 intel_sub_group_block_read_32b_8r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint16 intel_sub_group_block_read_32b_16r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint32 intel_sub_group_block_read_32b_32r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 32bits No transform Transpose +SYCL_DEVICE_OCL(uint intel_sub_group_block_read_transpose_32b_16r1c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint2 intel_sub_group_block_read_transpose_32b_16r2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint4 intel_sub_group_block_read_transpose_32b_16r4c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::uint8 intel_sub_group_block_read_transpose_32b_16r8c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); + +// 32bits store +SYCL_DEVICE_OCL(void intel_sub_group_block_write_32b_1r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, uint data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_32b_2r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uint2 data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_32b_4r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uint4 data)); +SYCL_DEVICE_OCL(void intel_sub_group_block_write_32b_8r16c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord, intel::uint8 data)); +SYCL_DEVICE_OCL(void intel_sub_group_2d_block_prefetch_32b_16r8x1c( + __global void* base_address, int width, int height, int pitch, + intel::coord_t coord)); +#undef SYCL_DEVICE_OCL + +namespace cute +{ +struct XE_2D_U32x1x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m1k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x2x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m2k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x4x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m4k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x8x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m8k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x16x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m16k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x32x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m32k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x1x8_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m1k8v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x2x8_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m2k8v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x4x8_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m4k8v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x8x8_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m8k8v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x16x8_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m16k8v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x32x8_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m32k8v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x1x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m1k8v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x2x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m2k8v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x4x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m4k8v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x8x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m8k8v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x16x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m16k8v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_TF32x32x16_LD_N { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_u32_m32k8v2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + + +struct XE_2D_U32x16x1_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x16x2_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x16x4_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k4( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x16x8_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u32_k8( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, + intel::coord_t coord) { +#if defined(SYCL_INTEL_TARGET) + intel_sub_group_2d_block_prefetch_32b_16r8x1c( + (__global void*)baseoffset, width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use block prefetch on non-PVC hardware"); +#endif + } + }; +}; + +struct XE_2D_U32x1x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + // static_assert(sizeof(T) == 4, "Expected T to have size 4"); + __builtin_IB_subgroup_block_write_flat_u32_m1k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(uint *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x2x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + __builtin_IB_subgroup_block_write_flat_u32_m2k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uint2 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x4x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 4, "Expected T to have size 4"); + __builtin_IB_subgroup_block_write_flat_u32_m4k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uint4 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U32x8x16_ST_N { + template + CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + int pitch, intel::coord_t coord, + const T *src) { +#if defined(SYCL_INTEL_TARGET) + // static_assert(sizeof(T) == 4, "Expected T to have size 4"); + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord, + *(intel::uint8 *)(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +} // end namespace cute diff --git a/include/cute/arch/xe_copy_8B.hpp b/include/cute/arch/xe_copy_8B.hpp new file mode 100644 index 0000000000..5a665563ac --- /dev/null +++ b/include/cute/arch/xe_copy_8B.hpp @@ -0,0 +1,130 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x +#else +#define SYCL_DEVICE_BUILTIN(x) inline x { assert(false); } +#endif + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x +#else +#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#endif + +using namespace cute; + +// 64bits No transform Transpose +SYCL_DEVICE_BUILTIN( + intel::ulong __builtin_IB_subgroup_block_read_flat_transpose_u64_k1( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ulong2 __builtin_IB_subgroup_block_read_flat_transpose_u64_k2( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +SYCL_DEVICE_BUILTIN( + intel::ulong4 __builtin_IB_subgroup_block_read_flat_transpose_u64_k4( + long baseoffset, int width_minus_one, int height_minus_one, + int pitch_minus_one, intel::coord_t coord)); +#undef SYCL_DEVICE_BUILTIN + +#undef __global +#define __global __attribute__((opencl_global)) + +// 64bits No transform Transpose +SYCL_DEVICE_OCL(ulong intel_sub_group_block_read_transpose_64b_8r1c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ulong2 intel_sub_group_block_read_transpose_64b_8r2c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +SYCL_DEVICE_OCL(intel::ulong4 intel_sub_group_block_read_transpose_64b_8r4c( + const __global void *base_address, int width, int height, int pitch, + intel::coord_t coord)); +#undef SYCL_DEVICE_OCL + +namespace cute +{ +struct XE_2D_U64x8x1_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 8, "Expected T to have size 8"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u64_k1( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U64x8x2_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 8, "Expected T to have size 8"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u64_k2( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; + +struct XE_2D_U64x8x4_LD_T { + template + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, + int height, int pitch, intel::coord_t coord, + T *dst) { +#if defined(SYCL_INTEL_TARGET) + static_assert(sizeof(T) == 8, "Expected T to have size 8"); + *reinterpret_cast(dst) = + __builtin_IB_subgroup_block_read_flat_transpose_u64_k4( + (long)(baseoffset), width - 1, height - 1, pitch - 1, coord); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware"); +#endif + } +}; +} // end namespace cute diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index 261654506a..d645cb91b5 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -33,440 +33,1904 @@ #include #include +#include #include -namespace cute +namespace cute { + +namespace detail { + template + struct is_transpose : bool_constant {}; -template -CUTE_HOST_DEVICE constexpr -auto get_shape_WHD(cute::Stride, IntT, IntT> , cute::Shape shape_MKL) { - return shape_MKL; -} + template<> + struct is_transpose : bool_constant{}; -template -CUTE_HOST_DEVICE constexpr -auto get_shape_WHD(cute::Stride, IntT> , cute::Shape shape_MKL) { - return Shape(get<1>(shape_MKL), get<0>(shape_MKL), get<2>(shape_MKL)); -} + template<> + struct is_transpose : bool_constant{}; -template -CUTE_HOST_DEVICE constexpr -auto get_coordinates(cute::Stride, IntT, IntT> , - Tensor>, SLayout> const &src) { - auto [x, y, z] = src.data().coord_; - return make_coord(x, y, z); -} + template<> + struct is_transpose : bool_constant{}; -template -CUTE_HOST_DEVICE constexpr -auto get_coordinates(cute::Stride, IntT> , - Tensor>, SLayout> const &src) { - auto [x, y, z] = src.data().coord_; - return make_coord(y, x, z); -} + template<> + struct is_transpose : bool_constant{}; -template -struct XE_2D_LD_Unpack { - GTensor tensor; - - using Copy_Traits = Copy_Traits; - template - CUTE_HOST_DEVICE friend constexpr void copy_unpack( - Copy_Traits const &traits, - Tensor>, SLayout> const &src, - Tensor &dst) { - static_assert(is_rmem::value); - auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); - int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType); - int H = size<1>(shape_whd); - auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, - &*dst.data()); - } + template<> + struct is_transpose : bool_constant{}; + + template<> + struct is_transpose : bool_constant{}; + + template<> + struct is_transpose : bool_constant{}; + + template<> + struct is_transpose : bool_constant{}; + + template constexpr bool has_inst_dtype = false; + + template + constexpr bool has_inst_dtype> = true; +} // namespace detail end + +template struct XE_2D_LD_Unpack { + const void *base_ptr; + uint32_t width; + uint32_t height; + uint32_t pitch; - template - CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, - GShape const &shape, GStride const &stride_mul) const { - return make_tensor(make_inttuple_iter(coord), - make_layout(make_shape(_1 {}, get<0>(shape), get<1>(shape), - get<2>(shape)), - make_stride(_1 {}, E<0> {} * get<0>(stride_mul), - E<1> {} * get<1>(stride_mul), - E<2> {} * get<2>(stride(tensor))))); + XE_2D_LD_Unpack(const void *ptr, uint32_t const &w, + uint32_t const &h, uint32_t const &p) + : base_ptr(ptr), width(w), height(h), pitch(p) {} + + template + XE_2D_LD_Unpack(TraitsArgs const &traits) : base_ptr(traits.base_ptr), + width(traits.width), height(traits.height), pitch(traits.pitch) {} + + XE_2D_LD_Unpack() {} + + using Traits_LD_t = Copy_Traits; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Traits_LD_t const &traits, Tensor const &src, + Tensor &dst) { + static_assert(is_rmem::value); + + using dtype = typename Tensor::value_type; + + dtype *base_addr = (dtype *)traits.base_ptr; + + auto [m, n, l] = src.data().coord_; + + auto inst_size = sizeof(dtype); + + if constexpr (detail::has_inst_dtype) { + inst_size = sizeof(typename CopyOp::inst_dtype); } + + CopyOp::copy(base_addr + l * traits.width * traits.height, + traits.width * sizeof(dtype), traits.height, + traits.pitch * sizeof(dtype), + intel::coord_t{(int)(n * sizeof(dtype) / inst_size), (int)(m)}, + &*dst.data()); + } + + template + CUTE_HOST_DEVICE friend constexpr void + prefetch(Copy_Atom const &atom, + Tensor const &src) { + static_assert(detail::has_prefetch); + + using dtype = typename Copy_Atom::ValType; + + dtype *base_addr = (dtype *)atom.base_ptr; + + auto [m, n, l] = src.data().coord_; + + CopyOp::PREFETCH::copy((void *)(base_addr + l * atom.width * atom.height), + atom.width * sizeof(dtype), atom.height, + atom.pitch * sizeof(dtype), + intel::coord_t{(int)n, (int)m}); + } + + template {})> + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + GShape const &shape, + GStride const &stride, + Basis const & basis = {}) const { + + auto R = rank(GShape{}); + static_assert(R == 3 || R == 4, "mismatch rank"); + auto t_shape = cute::tuple_cat(make_shape(_1{}), take<1, R>(shape)); + auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(basis, stride, [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(coord), + make_layout(t_shape, t_stride)); + } + + template + static constexpr auto with(T1 && arg1, T2 && arg2, TraitsArgs&&... args) { + return Traits_LD_t{arg1, arg2, args...}; + } }; -/// ThrID: How many threads involved. -/// DstLayout: Size<0>(DstLayout) same as thrID, Size<1>(DstLayout) represents data layout that hold by each thread in bits. -/// TODO: SrcLayout just a placeHolder, not used. -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout - = Layout>, Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; -}; - -template -struct XE_2D_PF_Unpack { - GTensor tensor; - - template - CUTE_HOST_DEVICE XE_2D_PF_Unpack(G_Tensor const &t) : tensor(t) {} - - using Copy_Traits = Copy_Traits; - template - CUTE_HOST_DEVICE friend constexpr void copy_unpack( - Copy_Traits const &traits, - Tensor>, SLayout> const &src, - Tensor &dst) { - using T = typename Copy_Traits::CopyInternalType; - auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); - int W = size<0>(shape_whd) * sizeof(T); - int H = size<1>(shape_whd); - auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); - CopyOp::template copy(traits.tensor.data() + z, W, H, W, - intel::coord_t {static_cast(x), static_cast(y)}); - } +template struct XE_2D_ST_Unpack { + const void *base_ptr; + uint32_t width; + uint32_t height; + uint32_t pitch; + + XE_2D_ST_Unpack(const void *ptr, uint32_t const &w, + uint32_t const &h, uint32_t const &p) + : base_ptr(ptr), width(w), height(h), pitch(p) {} + + template + XE_2D_ST_Unpack(TraitsArgs const &traits) : base_ptr(traits.base_ptr), + width(traits.width), height(traits.height), pitch(traits.pitch) {} + + XE_2D_ST_Unpack() {} + + using Traits_ST_t = Copy_Traits; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Traits_ST_t const &traits, Tensor const &src, + Tensor &dst) { + static_assert(is_rmem::value); + + using dtype = typename Tensor::value_type; + + dtype *base_addr = (dtype *)traits.base_ptr; + + auto [m, n, l] = dst.data().coord_; + + CopyOp::copy(base_addr + l * traits.width * traits.height, + (int)(traits.width * sizeof(dtype)), (int)(traits.height), + (int)(traits.pitch * sizeof(dtype)), + intel::coord_t{(int)n, (int)m}, &*src.data()); + } + + template {})> + CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, + GShape const &shape, + GStride const &stride, + Basis const & basis = {}) const { + + auto R = rank(GShape{}); + static_assert(R == 3 || R == 4, "mismatch rank"); + auto t_shape = cute::tuple_cat(make_shape(_1{}), take<1, R>(shape)); + auto t_stride = cute::tuple_cat(make_stride(_1{}), transform(basis, stride, [&](auto i, auto s){ + return E{} * s; + })); + return make_tensor(make_inttuple_iter(coord), + make_layout(t_shape, t_stride)); + } + + template + static constexpr auto with(T1 && arg1, T2 && arg2, TraitsArgs&&... args) { + return Traits_ST_t{arg1, arg2, args...}; + } }; -template -struct Copy_Traits - : XE_2D_PF_Unpack { - template - CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( - traits.tensor) {} +// clang-format off - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_16, Stride<_16, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout - = Layout>, Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_16, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_PF_Unpack { +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; - template - CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( - traits.tensor) {} + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_16, Stride<_16, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout - = Layout>, Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_PF_Unpack { +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; - template - CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( - traits.tensor) {} + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_16, Stride<_16, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_PF_Unpack { +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; - template - CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( - traits.tensor) {} + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_16, Stride<_16, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_PF_Unpack { - template - CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( - traits.tensor) {} +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _64>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_16, Stride<_16, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_PF_Unpack { - template - CUTE_HOST_DEVICE Copy_Traits(Copy_Traits const &traits) - : XE_2D_PF_Unpack( - traits.tensor) {} +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _64>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_16, Stride<_16, _1>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout - = Layout>, Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = uint; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _64>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _64>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride<_1,_8,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>, - Stride<_16, Stride<_512, Stride<_1, _256>>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_16, Stride<_256, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { - // Logical thread id to thread idx - using ThrID = Layout<_16>; - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>, - Stride<_32, Stride<_512, _1>>>; - // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - // 32 bits register file - using CopyInternalType = uint; +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_16, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + using CopyInternalType = ushort; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _8>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_32>, + Stride, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _8>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_32>, + Stride, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _8>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _8>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _8>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1, _512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _8>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _16>>, + Stride,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_128,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_256,_128,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _64>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride< _8,Stride<_1,_512,_128,_2048>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_256,_512>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_32, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_16,Stride< _1,_512,_256,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8,_16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_128,Stride< _1,_16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_16,_16>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_256,Stride< _1,_16>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_LD_Unpack(args...) {} +}; + +// template +// struct Copy_Traits +// : XE_2D_LD_Unpack { +// // Logical thread id to thread idx +// using ThrID = Layout<_16>; +// // Map from (src-thr,src-val) to bit +// using SrcLayout = Layout, +// Stride< _0, _1>>; +// // Map from (dst-thr,dst-val) to bit +// using DstLayout = Layout, +// Stride<_32, _1>>; +// // Reference map from (thr,val) to bit +// using RefLayout = DstLayout; +// }; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2,_16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_64,Stride< _1,_32>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4,_16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_64,Stride< _1,_32>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8,_16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_128,Stride< _1,_32>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_8,_16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_32, _16>>, + Stride,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _16>>, + Stride,Stride< _1,_256>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_1,_8>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_64>, + Stride, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_2,_8>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_64>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_LD_Unpack { + using Shape_MN = Shape<_4,_8>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _4>>, + Stride,Stride< _1,_64>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_LD_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_2,_32>; + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_1, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _8,_1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_2, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_4, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_8, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride< _8,Stride<_1,_128>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + using CopyInternalType = uint8_t; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_8, _32>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = + Layout>, Stride<_0, Stride<_0, _1>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0,_1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + using CopyInternalType = uint8_t; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_1, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_16, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_2, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_4, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_8, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_16,Stride< _1,_256>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_1, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_32, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_2, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_4, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgT... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +struct Copy_Traits + : XE_2D_ST_Unpack { + using Shape_MN = Shape<_8, _16>; + // Logical thread id to thread idx + using ThrID = Layout<_16>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>, + Stride<_32,Stride< _1,_512>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride< _0, _1>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + Copy_Traits(ArgTs... args) + : XE_2D_ST_Unpack(args...) {} +}; + +template +auto make_xe_2d_copy(Tensor gtensor) { + using GTensor = Tensor; + using Traits = Copy_Traits; + // Traits traits {gtensor}; + return Copy_Atom{gtensor}; +} + +template +struct Copy_Traits> { + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> { // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; + using SrcLayout = Layout::value>>, Stride<_0, _1>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout< - Shape<_16, Shape, Shape<_16, _2, _2>>>, - Stride<_16, Stride, Stride<_1, _512, _256>>>>; + using DstLayout = Layout::value / sizeof_bits::value>, Int::value>>>, + Stride::value>, Stride::value * 16>, _1>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; - using CopyInternalType = ushort; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { +template +struct Copy_Traits> { // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; + using SrcLayout = Layout::value>>, Stride<_0, _1>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>, - Stride<_16, Stride<_1024, Stride<_1, _512, _256>>>>; + using DstLayout = Layout::value / sizeof_bits::value>, Int::value>>>, + Stride::value>, Stride::value * 16>, _1>>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; - using CopyInternalType = ushort; }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { +template +struct Copy_Traits> { // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; + using SrcLayout = Layout::value>>, Stride<_0, _1>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; + using DstLayout = Layout::value>>, + Stride::value>, _1>>; // Reference map from (thr,val) to bit using RefLayout = DstLayout; - using CopyInternalType = ushort; + + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) {} }; -template -struct Copy_Traits - : XE_2D_LD_Unpack { +template +struct Copy_Traits> { // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; + using SrcLayout = Layout::value / sizeof_bits::value>, Int::value>>>, + Stride::value>, Stride::value * 16>, _1>>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>>, - Stride<_16, Stride<_512, Stride<_1, _256>>>>; + using DstLayout = Layout::value>>, Stride<_0, _1>>; // Reference map from (thr,val) to bit - using RefLayout = DstLayout; - using CopyInternalType = ushort; -}; - -template -struct XE_2D_ST_Unpack { - GTensor tensor; - using Copy_Traits = Copy_Traits; - template - CUTE_HOST_DEVICE friend constexpr void copy_unpack( - Copy_Traits const &traits, Tensor const &src, - Tensor>, DLayout> &dst) { - static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x, z] = dst.data().coord_; - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data()); - } - - template - CUTE_HOST_DEVICE constexpr auto get_pvc_tensor(GCoord const &coord, - GShape const &shape, GStride const &stride_mul) const { - return make_tensor(make_inttuple_iter(coord), - make_layout(make_shape(_1 {}, get<0>(shape), get<1>(shape), - get<2>(shape)), - make_stride(_1 {}, E<0> {} * get<0>(stride_mul), - E<1> {} * get<1>(stride_mul), - E<2> {} * get<2>(stride(tensor))))); - } + using RefLayout = SrcLayout; }; -template -struct Copy_Traits - : XE_2D_ST_Unpack { +template +struct Copy_Traits> { // Logical thread id to thread idx using ThrID = Layout<_16>; // Map from (src-thr,src-val) to bit - using SrcLayout - = Layout>, Stride<_32, Stride<_512, _1>>>; + using SrcLayout = Layout::value / sizeof_bits::value>, Int::value>>>, + Stride::value>, Stride::value * 16>, _1>>>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout, Stride<_0, _1>>; + using DstLayout = Layout::value>>, Stride<_0, _1>>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; - using CopyInternalType = uint; }; -template -auto make_xe_2d_copy(Tensor gtensor) { - using GTensor = Tensor; - using Traits = Copy_Traits; - Traits traits {gtensor}; - return Copy_Atom {traits}; -} } // end namespace cute diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index 8dca2dba55..862483a7ca 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -47,8 +47,299 @@ struct MMA_Traits using Shape_MNK = Shape<_8,_16,_16>; using ThrID = Layout<_16>; - using ALayout = Layout, Stride<_8, _1>>; // (T16,V8) -> (m,n) - using BLayout = Layout, Stride<_16, _1>>; - using CLayout = Layout, Stride<_8, _1>>; + + using ALayout = Layout, Stride<_8, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_8, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_4,_16,_16>; + using ThrID = Layout<_16>; + + using ALayout = Layout, Stride<_4, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_4, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_16,_16>; + using ThrID = Layout<_16>; + + using ALayout = Layout, Stride<_2, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_2, _1>>; }; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_16,_16>; + using ThrID = Layout<_16>; + + using ALayout = Layout, Stride<_1, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_1, _1>>; +}; + + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_16,_16>; + using ThrID = Layout<_16>; + using ALayout = Layout, Stride<_8, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_8, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_4,_16,_16>; + using ThrID = Layout<_16>; + using ALayout = Layout, Stride<_4, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_4, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_16,_16>; + using ThrID = Layout<_16>; + using ALayout = Layout, Stride<_2, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_2, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_16,_16>; + using ThrID = Layout<_16>; + using ALayout = Layout, Stride<_1, _1>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_1, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_8,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_16, Stride<_8, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_8, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_4,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_8, Stride<_4, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_4, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_2,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_4, Stride<_2, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_2, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_1,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_2, Stride<_1, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_1, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_8,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_16, Stride<_8, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_8, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_4,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_8, Stride<_4, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_4, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_2,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_4, Stride<_2, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_2, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int; + + using Shape_MNK = Shape<_1,_16,_32>; + using ThrID = Layout<_16>; + using ALayout = Layout>, Stride<_2, Stride<_1, _1>>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_1, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_16,_8>; + using ThrID = Layout<_16>; + using ALayout = Layout, _4>, Stride, _2>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_8, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_4,_16,_8>; + using ThrID = Layout<_16>; + using ALayout = Layout, _2>, Stride, _2>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_4, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_16,_8>; + using ThrID = Layout<_16>; + using ALayout = Layout, _1>, Stride, _0>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_2, _1>>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_16,_8>; + using ThrID = Layout<_16>; + using ALayout = Layout, _1>, Stride, _0>>; + using BLayout = Layout, Stride<_1, _16>>; + using CLayout = Layout, Stride<_1, _1>>; +}; + } diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index 3e7279dd69..729c7667c2 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -43,18 +43,41 @@ template using vector_t = T __attribute__((ext_vector_type(N))) template using vector_t = sycl::marray; #endif +typedef unsigned long ulong; +typedef unsigned char uchar; + +using uchar2 = vector_t; +using uchar4 = vector_t; +using uchar8 = vector_t; + +using float2 = vector_t; +using float4 = vector_t; using float8 = vector_t; + +using short2 = vector_t; +using short4 = vector_t; using short8 = vector_t; + +using int2 = vector_t; +using int4 = vector_t; using int8 = vector_t; using int16 = vector_t; + +using uint2 = vector_t; +using uint4 = vector_t; using uint8 = vector_t; using uint16 = vector_t; +using uint32 = vector_t; +using ushort2 = vector_t; +using ushort4 = vector_t; using ushort8 = vector_t; using ushort16 = vector_t; using ushort32 = vector_t; using ushort64 = vector_t; -using uint32 = vector_t; + +using ulong2 = vector_t; +using ulong4 = vector_t; using coord_t = vector_t; } // namespace intel end diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index fffa8db117..c9802d8401 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -146,6 +146,21 @@ int canonical_warp_group_idx() { #endif } +#if defined(SYCL_INTEL_TARGET) +CUTLASS_DEVICE +auto get_sub_group_id() { + return sycl::ext::oneapi::experimental::this_nd_item<3>() + .get_sub_group() + .get_group_id()[0]; +} + +CUTLASS_DEVICE +auto get_sub_group_local_id() { + return sycl::ext::oneapi::experimental::this_nd_item<3>() + .get_sub_group() + .get_local_id()[0]; +} +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index 5e85c2cf15..2d5efab207 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -94,8 +94,8 @@ template < Tile<_32,_64,_32>>; // Subgroup level-tile using DispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; - using CopyOpG2R = XE_2D_U32x8x16x1x1_LD_N; - using CopyOpR2G = XE_2D_U32x8x16x1x1_ST_N; + using CopyOpG2R = XE_2D_U32x8x16_LD_N; + using CopyOpR2G = XE_2D_U32x8x16_ST_N; // Intel Epilogue with Linear Combination does not use shared memory using SmemLayoutAtomC_ = void; diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index 70878bcae7..b8fe974068 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -118,6 +118,18 @@ class CollectiveEpilogue< static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + using Trait_C = Copy_Traits; + using XE_Copy_C = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_C::Shape_MN{}), + get<1>(typename Trait_C::Shape_MN{}) / Int{})))); + using Trait_D = Copy_Traits; + using XE_Copy_D = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr),int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_D::Shape_MN{}), + get<1>(typename Trait_D::Shape_MN{}) / Int{})))); private: constexpr static bool is_source_supported = not cute::is_void_v; constexpr static bool is_destination_supported = not cute::is_void_v; @@ -154,13 +166,6 @@ class CollectiveEpilogue< // Device side epilogue params struct Params { - using XE_Copy_C = decltype(make_xe_2d_copy( - make_tensor(static_cast(nullptr), - repeat_like(StrideC{}, int32_t(0)), StrideC{}))); - using XE_Copy_D = decltype(make_xe_2d_copy( - make_tensor(static_cast(nullptr), - repeat_like(StrideD{}, int32_t(0)), StrideD{}))); - typename FusionCallbacks::Params thread{}; XE_Copy_C xe_load_c; XE_Copy_D xe_store_d; @@ -180,16 +185,22 @@ class CollectiveEpilogue< auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M, N, K, L] = problem_shape_MNKL; - typename Params::XE_Copy_C xe_load_c = {}; + XE_Copy_C xe_load_c = {}; if constexpr (is_source_supported) { - Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); - xe_load_c = make_xe_2d_copy(tensor_c); + xe_load_c = make_tiled_copy(Copy_Atom, ElementC>{}.with( + args.ptr_C, N, M, N), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_C::Shape_MN{}), + get<1>(typename Trait_C::Shape_MN{}) / Int{}))); } - typename Params::XE_Copy_D xe_store_d = {}; + XE_Copy_D xe_store_d = {}; if constexpr (is_destination_supported) { - Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); - xe_store_d = make_xe_2d_copy(tensor_d); + xe_store_d = make_tiled_copy(Copy_Atom, ElementD>{}.with( + args.ptr_D, N, M, N), + Layout>>{}, + make_layout(make_shape(get<0>(typename Trait_D::Shape_MN{}), + get<1>(typename Trait_D::Shape_MN{}) / Int{}))); } return { @@ -255,7 +266,18 @@ class CollectiveEpilogue< using namespace cute; using MmaAtomShape = typename TiledMma::AtomShape_MNK; - using SubgroupTileShape = decltype(tile_shape(TiledMma())); + static constexpr auto BLK_M = get<0>(CtaTileMNK{}); + static constexpr auto BLK_N = get<1>(CtaTileMNK{}); + static constexpr auto BLK_K = get<2>(CtaTileMNK{}); + // static_assert(is_same_v, "assertation fail"); + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group @@ -265,15 +287,18 @@ class CollectiveEpilogue< // Indexing variables auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + auto l_offset = l_coord; bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); Tensor trC = make_tensor(Shape>{}); Tensor trD = make_tensor(Shape>{}); Tensor tOuti = params.xe_store_d.get_pvc_tensor( - make_coord(m_coord, n_coord, 0), - make_shape(Int{}, Int{}, L), - make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{})); + make_coord(m_offset, n_offset, l_offset), + make_shape(_, Int{}, Int{}, L), + make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{}, _1{})); Tensor rw_coord = tOuti(_,_,_,l_coord); Tensor mD_crd = make_identity_tensor(make_shape(M,N)); @@ -324,7 +349,6 @@ class CollectiveEpilogue< } cst_callbacks.end(); - } private: diff --git a/include/cutlass/gemm/collective/xe_mma.hpp b/include/cutlass/gemm/collective/xe_mma.hpp index 2412f6a8f7..61fc7867fc 100644 --- a/include/cutlass/gemm/collective/xe_mma.hpp +++ b/include/cutlass/gemm/collective/xe_mma.hpp @@ -101,24 +101,35 @@ struct CollectiveMma< static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using MmaAtomShape = typename TiledMma::AtomShape_MNK; - using SubgroupTileShape = decltype(tile_shape(TiledMma())); - static constexpr auto sg_per_wg_m = get<0>(WorkgroupTileShape{}) / get<0>(SubgroupTileShape{}); - static constexpr auto sg_per_wg_n = get<1>(WorkgroupTileShape{}) / get<1>(SubgroupTileShape{}); - - static constexpr uint32_t MaxThreadsPerBlock = - cute::size(WorkgroupTileShape{}) / cute::size(SubgroupTileShape{}) * SubgroupSize; - - static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group - static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group - static constexpr int FragsK = get<2>(SubgroupTileShape{}) / get<2>(MmaAtomShape()); - - // Calculate the vector width based on the amount of registers - // required per work item by dividing the total fragment size by - // the sub_group size. - static constexpr int VecC = (get<1>(MmaAtomShape()) * get<0>(MmaAtomShape())) / SubgroupSize; - static constexpr int VecA = (get<0>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; - static constexpr int VecB = (get<1>(MmaAtomShape()) * get<2>(MmaAtomShape())) / SubgroupSize; + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + using XE_Copy_A = decltype(make_tiled_copy(atom_load_A{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), + get<1>(typename traits_load_A::Shape_MN{}) / Int{})))); + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + using XE_Copy_B = decltype(make_tiled_copy(atom_load_B{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), + get<1>(typename traits_load_B::Shape_MN{}) / Int{})))); // Host side kernel arguments struct Arguments { @@ -129,10 +140,6 @@ struct CollectiveMma< }; struct Params { - using XE_Copy_A = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), - repeat_like(StrideA{}, int32_t(0)), StrideA{}))); - using XE_Copy_B = decltype(make_xe_2d_copy(make_tensor(static_cast(nullptr), - repeat_like(StrideB{}, int32_t(0)), StrideB{}))); XE_Copy_A gmem_tiled_copy_a; XE_Copy_B gmem_tiled_copy_b; }; @@ -151,11 +158,14 @@ struct CollectiveMma< auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); - - typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); - typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); + XE_Copy_A copyA = make_tiled_copy(Copy_Atom, ElementA>{}.with(args.ptr_A, K, M, K), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), + get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + XE_Copy_B copyB = make_tiled_copy(Copy_Atom, ElementB>{}.with(args.ptr_B, N, K, N), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), + get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); return Params{copyA, copyB}; } @@ -180,76 +190,78 @@ struct CollectiveMma< char *smem_buf, Params const& mainloop) { - (void)residue_mnk; - (void)thread_idx; - (void)smem_buf; - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_tuple::value, "A tensor must be a tuple iterator."); - static_assert(is_tuple::value, "B tensor must be a tuple iterator."); static_assert(is_rmem::value, "C tensor must be rmem resident."); - // Tensor to hold input data - constexpr int version = - is_same_v ? 1 : 2; - - Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); - Tensor tBr = make_tensor(Shape(SubgroupTileShape{}) * version>, Int>{}); - - Tensor tAr_view = make_tensor(static_cast(tAr).data(), - Shape, Int, Int>{}); - Tensor tBr_view = make_tensor(static_cast(tBr).data(), - Shape, Int, Int>{}, - Stride<_1, Int, Int>{}); + (void)residue_mnk; + (void)thread_idx; + (void)smem_buf; - // Instantiate the M MA object + // Instantiate the MMA object TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_slice(thread_idx); + Tensor tCrA = thread_mma.partition_fragment_A(gA(_, _, 0)); + Tensor tCrB = thread_mma.partition_fragment_B(gB(_, _, 0)); + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_A = mainloop.gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_B = mainloop.gmem_tiled_copy_b.get_slice(thread_idx); + + auto tCrA_copy_view = gmem_thr_copy_A.retile_D(tCrA); + auto tCrB_copy_view = gmem_thr_copy_B.retile_D(tCrB); + + #if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + print(" gA : "); print(gA); print("\n"); + print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); + print(" tCrA : "); print(tCrA); print("\n"); + + print("===================== B :\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== Config: \n"); + print(" threads per workgroup : "); print(MaxThreadsPerBlock); print("\n"); + print(" SubgroupTileShape : "); print(SubgroupTileShape{}); print("\n"); + } + #endif - int K = size<1>(mainloop.gmem_tiled_copy_a.tensor); - - auto sub_group_id = ThreadIdxX() / SubgroupSize; - /* Cooperative prefetch - Divice the thread space to sg_per_wg_m x sg_per_wg_n, all the threads in one row/col use the same tile A/B. - Each thread loads sizeof(tile A or B) / numof(sg_per_wg_n or sg_per_wg_m). - - Currently, sg_per_wg_m x sg_per_wg_n = 4 x 8 is the most efficient - */ - // TODO: Replace the demo cooperative prefetch with more general way. - Tensor tAi = make_tensor( - make_inttuple_iter( - *gA.data() + - make_coord((sub_group_id % sg_per_wg_n % 4) * get<0>(MmaAtomShape{}), 0)), - make_layout(make_shape(_1{}, _1{}, K), - make_stride(_1{}, E<0>{}, E<1>{}))); - Tensor tBi = make_tensor( - make_inttuple_iter( - *gB.data() + - make_coord((sub_group_id / sg_per_wg_n % 2 * 2) * get<1>(MmaAtomShape{}), - (sub_group_id / sg_per_wg_n / 2 % 2) * get<2>(MmaAtomShape{}))), - make_layout(make_shape(_1{}, _1{}, K), - make_stride(_1{}, E<0>{}, E<1>{}))); // // Mainloop // - int prefetch_k = 0; - + const int m_coord = BlockIdxY() * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = BlockIdxX() * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + const int l_coord = BlockIdxZ(); + Tensor iter_a = mainloop.gmem_tiled_copy_a.get_pvc_tensor( + make_coord(m_coord, 0, l_coord), make_shape(_, size<1>(tCrA_copy_view.shape()), size<2>(tCrA_copy_view.shape()), k_tile_count), + append<3>(typename XE_Copy_A::Shape_MN{}, BLK_K), seq<0,1,1>{}); + Tensor iter_b = mainloop.gmem_tiled_copy_b.get_pvc_tensor( + make_coord(0, n_coord, l_coord), make_shape(_, size<2>(tCrB_copy_view.shape()), size<1>(tCrB_copy_view.shape()), k_tile_count), + append<3>(typename XE_Copy_B::Shape_MN{}, BLK_K), seq<0,1,0>{}); +#pragma unroll for (int i = 0; i < DispatchPolicy::Stages; i++) { - prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); - prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); - prefetch_k += get<2>(SubgroupTileShape{}); + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,i)); + } + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,i)); + } } - - for (int k_tile = 0, k = 0; k_tile < k_tile_count; - ++k_tile, k += get<2>(SubgroupTileShape{})) { +#pragma unroll + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { // Copy gmem to rmem for the first k_tile - copy(mainloop.gmem_tiled_copy_a, gA(_, _, k), tAr); - copy(mainloop.gmem_tiled_copy_b, gB(_, _, k), tBr); - - prefetch(mainloop.gmem_tiled_copy_a, tAi(_, _, prefetch_k)); - prefetch(mainloop.gmem_tiled_copy_b, tBi(_, _, prefetch_k)); - prefetch_k += get<2>(SubgroupTileShape{}); - - cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum); + copy(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile), tCrA_copy_view); + copy(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile), tCrB_copy_view); + if(k_tile + DispatchPolicy::Stages < k_tile_count) { + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_a, iter_a(_,_,_,k_tile + DispatchPolicy::Stages)); + } + if constexpr(cute::detail::has_prefetch) { + prefetch(mainloop.gmem_tiled_copy_b, iter_b(_,_,_,k_tile + DispatchPolicy::Stages)); + } + } + cute::gemm(tiled_mma, accum, tCrA, tCrB, src_accum); } } }; diff --git a/include/cutlass/gemm/kernel/xe_gemm.hpp b/include/cutlass/gemm/kernel/xe_gemm.hpp index 298357fca6..9e4968acad 100644 --- a/include/cutlass/gemm/kernel/xe_gemm.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm.hpp @@ -105,11 +105,6 @@ class GemmUniversal< using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; - static constexpr int FragsM = CollectiveMainloop::FragsM; - static constexpr int FragsN = CollectiveMainloop::FragsN; - - static constexpr int VecC = CollectiveMainloop::VecC; - // Kernel level shared memory storage struct SharedStorage { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -211,28 +206,22 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); + auto blk_shape = TileShape{}; + auto m_coord = BlockIdxY(); + auto n_coord = BlockIdxX(); + auto l_coord = BlockIdxZ(); + auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); int sub_group_id = thread_idx / SubgroupSize; constexpr auto workgroup_shape = WorkgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) - constexpr auto subgroup_shape = SubgroupTileShape{}; // (SUB_M,SUB_N,SUB_K) - const int m_coord = BlockIdxY() * get<0>(workgroup_shape) + sub_group_id / CollectiveMainloop::sg_per_wg_n * get<0>(subgroup_shape); - const int n_coord = BlockIdxX() * get<1>(workgroup_shape) + sub_group_id % CollectiveMainloop::sg_per_wg_n * get<1>(subgroup_shape); - const int l_coord = BlockIdxZ(); - const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); - - Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( - make_coord(m_coord, 0, 0), - make_shape(_1{}, K, L), - make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); - constexpr int version = - is_same_v - ? 1 - : 2; - - Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(n_coord, 0, 0), - make_shape(Int{}, K, L), - make_stride(Int(MmaAtomShape())>{}, _1{})); + constexpr auto subgroup_shape = SubgroupTileShape{}; + + Tensor mA_mkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(M,K,L), StrideA{}); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(N,K,L), StrideB{}); //(n,k,l) + Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) + Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) + + auto gA = local_tile(mA_mk, blk_shape, take<0, 3>(blk_coord_mnkl), Step<_1, X, _1>{}); + auto gB = local_tile(mB_nk, blk_shape, take<0, 3>(blk_coord_mnkl), Step< X, _1, _1>{}); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord @@ -243,18 +232,18 @@ class GemmUniversal< // Allocate the tiled_mma and the accumulators for the (M,N) subgroup_shape TiledMma tiled_mma; - Tensor accumulators = make_tensor(Shape, Int, Int>{}); + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); clear(accumulators); - auto k_tile_iter = cute::make_coord_iterator(make_shape(K / get<2>(subgroup_shape))); - int k_tile_count = K / get<2>(subgroup_shape); + auto k_tile_iter = cute::make_coord_iterator(make_shape(K / get<2>(workgroup_shape))); + int k_tile_count = K / get<2>(workgroup_shape); // Perform the collective scoped MMA CollectiveMainloop collective_mma; collective_mma( accumulators, - tAi(_,_,_,l_coord), - tBi(_,_,_,l_coord), + gA, + gB, accumulators, k_tile_iter, k_tile_count, residue_mnk, @@ -267,7 +256,7 @@ class GemmUniversal< epilogue( problem_shape_MNKL, subgroup_shape, - tile_coord, + blk_coord_mnkl, accumulators, tiled_mma, residue_mnk, diff --git a/test/unit/cute/CMakeLists.txt b/test/unit/cute/CMakeLists.txt index 01f1819398..1723c1f4a4 100644 --- a/test/unit/cute/CMakeLists.txt +++ b/test/unit/cute/CMakeLists.txt @@ -64,6 +64,10 @@ else() add_subdirectory(core) add_subdirectory(layout) add_subdirectory(msvc_compilation) + + if(SYCL_INTEL_TARGET) + add_subdirectory(intel_xe) + endif() if(SYCL_NVIDIA_TARGET) @@ -105,6 +109,7 @@ else() cutlass_test_unit_cute_layout cutlass_test_unit_cute_core cutlass_test_unit_cute_msvc_compilation + cutlass_test_unit_cute_intel_xe ) add_custom_target( @@ -114,6 +119,7 @@ else() test_unit_cute_core test_unit_cute_msvc_compilation #Intel Tests + test_unit_cute_intel_xe ) endif() diff --git a/test/unit/cute/intel_xe/CMakeLists.txt b/test/unit/cute/intel_xe/CMakeLists.txt new file mode 100755 index 0000000000..807c847430 --- /dev/null +++ b/test/unit/cute/intel_xe/CMakeLists.txt @@ -0,0 +1,48 @@ +# Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if(SYCL_INTEL_TARGET) +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_intel_xe + copy_1d.cpp + copy_subgroup_block.cpp + copy_block.cpp + copy_scatter.cpp + mma.cpp + gemm_partition_src_dst.cpp + gemm_partition_fragment_abc.cpp + gemm_tiled_copy_abc.cpp + gemm_row_col.cpp + gemm_col_row.cpp + gemm_col_col.cpp +) +else() +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_intel_xe +) +endif() diff --git a/test/unit/cute/intel_xe/copy_1d.cpp b/test/unit/cute/intel_xe/copy_1d.cpp new file mode 100644 index 0000000000..4ef31c9915 --- /dev/null +++ b/test/unit/cute/intel_xe/copy_1d.cpp @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace cute; +using namespace cutlass; +using namespace syclcompat::experimental; + +#define SUBGROUP_SIZE (16) + +template +void copy_kernel_vectorized(TensorS tile_S, TensorD tile_D) { + using namespace cute; + + using Element = typename TensorS::value_type; + + // Shared memory buffers + auto smem = syclcompat::local_mem(); + Tensor sTensor = make_tensor(make_smem_ptr(smem), tile_S.layout()); + + // Define `AccessType` which controls the size of the actual memory access. + // using AccessType = cutlass::AlignedArray; + + // A copy atom corresponds to one hardware memory access. + using traits_load = Copy_Traits>::type, cutlass::uint128_t>>; + using Atom_load = Copy_Atom; + using traits_store = Copy_Traits>::type>>; + using Atom_store = Copy_Atom; + + using traits_ldsm = + Copy_Traits>::type, + cutlass::uint128_t>>; + using Atom_ldsm = Copy_Atom; + using traits_stsm = + Copy_Traits>::type>>; + using Atom_stsm = Copy_Atom; + + // Construct tiled copy, a tiling of copy atoms. + // + // Note, this assumes the vector and thread layouts are aligned with contigous + // data in GMEM. Alternative thread layouts are possible but may result in + // uncoalesced reads. Alternative vector layouts are also possible, though + // incompatible layouts will result in compile time errors. + + auto VecLayout = make_layout( + make_shape(_1{}, Int{}), + Stride, _1>{}); + auto ThreadLayout = make_layout(make_shape(_1{}, _16{})); + auto tiled_copy_load = make_tiled_copy(Atom_load{}, // access size + ThreadLayout, // thread layout + VecLayout); // vector layout (e.g. 4x1) + auto tiled_copy_store = + make_tiled_copy(Atom_store{}, // access size + ThreadLayout, // thread layout + VecLayout); // vector layout (e.g. 4x1) + + auto tiled_ldsm = make_tiled_copy(Atom_ldsm{}, // access size + ThreadLayout, // thread layout + VecLayout); // vector layout (e.g. 4x1) + auto tiled_stsm = make_tiled_copy(Atom_stsm{}, // access size + ThreadLayout, // thread layout + VecLayout); // vector layout (e.g. 4x1) + + // Construct a Tensor corresponding to each thread's slice. + auto thr_copy_load = + tiled_copy_load.get_thread_slice(ThreadIdxX()); + auto thr_copy_store = + tiled_copy_store.get_thread_slice(ThreadIdxX()); + + auto thr_copy_ldsm = tiled_ldsm.get_thread_slice(ThreadIdxX()); + auto thr_copy_stsm = tiled_stsm.get_thread_slice(ThreadIdxX()); + + Tensor thr_tile_load_S = + thr_copy_load.partition_S(tile_S); // (CopyOp, CopyM, CopyN) + Tensor thr_tile_store_D = + thr_copy_store.partition_D(tile_D); // (CopyOp, CopyM, CopyN) + + Tensor thr_tile_ldsm_S = + thr_copy_ldsm.partition_S(sTensor); // (CopyOp, CopyM, CopyN) + Tensor thr_tile_stsm_D = + thr_copy_stsm.partition_D(sTensor); // (CopyOp, CopyM, CopyN) + + // Construct a register-backed Tensor with the same shape as each thread's + // partition Use make_fragment because the first mode is the instruction-local + // mode + Tensor fragment = make_fragment_like( + thr_copy_load.partition_D(tile_S)); // (CopyOp, CopyM, CopyN) + +#if 0 + if (thread(0, 0)) { + print("loading to registers from src ========================\n"); + print("tile_S:"); + print(tile_S.layout()); + print("\n"); + + print("thr_tile_load_S: "); + print(thr_tile_load_S.layout()); + print("\n"); + print(thr_tile_load_S.data()); + print("\n"); + + print("thr_tile_store_D: "); + print(thr_tile_store_D.layout()); + print("\n"); + + print("fragment: "); + print(fragment.layout()); + print("\n"); + } +#endif + + // Copy from GMEM to RMEM and from RMEM to GMEM + prefetch(tiled_copy_load, thr_tile_load_S); + copy(tiled_copy_load, thr_tile_load_S, fragment); + copy(tiled_stsm, fragment, thr_tile_stsm_D); + clear(fragment); + copy(tiled_ldsm, thr_tile_ldsm_S, fragment); + copy(tiled_copy_store, fragment, thr_tile_store_D); +} + +TEST(PVC_1d_copy, copy_double) { + { + constexpr int M = 1; + constexpr int N = 128; + using Element = double; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output(M * N); + + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(subgroup_size); + + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + // printf("%d %d\n", int(h_in[i]), int(h_out[i])); + EXPECT_EQ(host_output[i], host_src[i]); + } + } + + { + constexpr int M = 1; + constexpr int N = 128; + using Element = float; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output(M * N); + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(subgroup_size); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } + + { + constexpr int M = 1; + constexpr int N = 128; + using Element = uint16_t; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output(M * N); + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(subgroup_size); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } +} diff --git a/test/unit/cute/intel_xe/copy_block.cpp b/test/unit/cute/intel_xe/copy_block.cpp new file mode 100644 index 0000000000..3eed445c05 --- /dev/null +++ b/test/unit/cute/intel_xe/copy_block.cpp @@ -0,0 +1,351 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace cute; +using namespace cutlass; +using namespace syclcompat::experimental; + +#define SUBGROUP_SIZE (16) + +template +void copy_kernel_vectorized(TensorS S, TensorD D, TiledLoad load, + TiledStore store) { + const int m_coord = 0; + const int n_coord = 0; + const int l_coord = BlockIdxZ(); + + // ========== load ========== + auto thr_copy_load = load.get_thread_slice(ThreadIdxX()); + auto thr_tile_load_D = thr_copy_load.partition_D(S); + auto fragment = make_fragment_like(thr_tile_load_D); + auto ld_tensor = + load.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + fragment.shape(), typename TiledLoad::Shape_MN{}); + if constexpr (cute::detail::has_prefetch) prefetch(load, ld_tensor); + copy(load, ld_tensor, fragment); + + // ========== store ========== + auto thr_copy_store = store.get_thread_slice(ThreadIdxX()); + Tensor frag_view = + make_tensor(static_cast(fragment).data(), + thr_copy_store.partition_S(D).shape()); + auto st_tensor = store.get_pvc_tensor(make_coord(m_coord, n_coord, l_coord), + frag_view.shape(), + typename TiledStore::Shape_MN{}); + copy(store, frag_view, st_tensor); + +#if 0 + if (thread(1)) { + print("fragment: "); + print(fragment.layout()); + print("\n"); + + print("ld_tensor: "); + print(ld_tensor.layout()); + print("\n"); + + print("frag_view: "); + print(frag_view.layout()); + print("\n"); + + print("st_tensor: "); + print(st_tensor.layout()); + print("\n"); + } +#endif +} +template +struct copy_op; + +template +struct copy_op { + void operator()() { + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = + make_tiled_copy( + Copy_Atom, dtype>{}.with(device_src.data(), N, M, + N), + Layout>, Stride<_0, _1>>{}, + Layout(typename Copy_Traits::Shape_MN{})), _1>, Stride<_1, _0>>{}); + auto tiled_store = make_tiled_copy( + Copy_Atom, dtype>{}.with(device_output.data(), N, + M, N), + Layout>, Stride<_0, _1>>{}, + Layout(typename Copy_Traits::Shape_MN{})), _1>, Stride<_1, _0>>{}); + auto blockDim = syclcompat::dim3(size(tiled_load)); + // + // Launch the kernel + // + launch>( + launch_policy{syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_load, tiled_store); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } +}; + +template +struct copy_op { + void operator()() { + // + // Allocate and initialize + // + using dtype = char; + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = make_tiled_copy( + Copy_Atom, dtype>{}.with(device_src.data(), N, M, + N), + Layout, Stride<_0, _1>>{}, + make_layout(shape<1>(typename Copy_Atom, dtype>::ValLayoutDst{}))); + auto tiled_store = make_tiled_copy( + Copy_Atom, dtype>{}.with(device_output.data(), N, M, + N), + Layout, Stride<_0, _1>>{}, + Layout, Stride<_2, _1>>{}); + auto blockDim = syclcompat::dim3(size(tiled_load)); + // + // Launch the kernel + // + launch>( + launch_policy{syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_load, tiled_store); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } +}; + +template +struct copy_op{ + void operator()() { + // + // Allocate and initialize + // + using dtype = uint16_t; + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = make_tiled_copy( + Copy_Atom, dtype>{}.with(device_src.data(), N, M, + N), + Layout>, Stride<_0, _1>>{}, + Layout, _2>, Stride<_1, _2>>{}); + auto tiled_store = make_tiled_copy( + Copy_Atom, uint16_t>{}.with(device_output.data(), N / 2, + M * 2, N / 2), + Layout, Stride<_0, _1>>{}, + Layout, Stride<_1, _0>>{}); + auto blockDim = syclcompat::dim3(size(tiled_load)); + // + // Launch the kernel + // + launch>( + launch_policy{syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_load, tiled_store); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * 2; ++i) { + for (int j = 0; j < N / 2; ++j) { + EXPECT_EQ(host_output[i * N / 2 + j], + host_src[(i % M) * N + j + (i / M) * N / 2]); + } + } + } +}; + +template +struct copy_op { + void operator()() { + // + // Allocate and initialize + // + using dtype = uint32_t; + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = + make_tiled_copy( + Copy_Atom, dtype>{}.with(device_src.data(), N, M, + N), + Layout, _1>, Stride<_1, _0>>{}, + Layout(typename Copy_Traits::Shape_MN{}))>, Stride<_0, _1>>{}); + auto tiled_store = make_tiled_copy( + Copy_Atom, dtype>{}.with(device_output.data(), M, N, + M), + Layout>, Stride<_0, _1>>{}, + Layout(typename Copy_Traits::Shape_MN{})), _1>, Stride<_1, _0>>{}); + auto blockDim = syclcompat::dim3(size(tiled_load)); + // + // Launch the kernel + // + launch>( + launch_policy{syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_load, tiled_store); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < N; ++i) { + for (int j = 0; j < M; ++j) { + EXPECT_EQ(host_output[i * M + j], host_src[j * N + i]); + } + } + } +}; + +TEST(PVC_CuTe_Xe, block_2d_16bits_n) { + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); +} + +TEST(PVC_CuTe_Xe, block_2d_32bits_n) { + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); +} + +TEST(PVC_CuTe_Xe, block_2d_8bits_n) { + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); +} + +TEST(PVC_CuTE_Xe, block_2d_16bits_n_v2) { + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); + copy_op{}(); +} + +TEST(PVC_CuTe_Xe, block_2d_16bits_vnni) { + copy_op{}(); + copy_op{}(); +} + +TEST(PVC_CuTe_Xe, block_2d_32bits_transpose) { + copy_op{}(); + copy_op{}(); + copy_op{}(); +} diff --git a/test/unit/cute/intel_xe/copy_scatter.cpp b/test/unit/cute/intel_xe/copy_scatter.cpp new file mode 100644 index 0000000000..4373e884c6 --- /dev/null +++ b/test/unit/cute/intel_xe/copy_scatter.cpp @@ -0,0 +1,409 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace cute; +using namespace cutlass; +using namespace syclcompat::experimental; + +#define SUBGROUP_SIZE (16) + +template +void copy_kernel_global(TensorS S, TensorD D, TiledLoad load, TiledStore store) { + + auto thr_copy_load = load.get_thread_slice(ThreadIdxX()); + Tensor thr_tile_load_S = thr_copy_load.partition_S(S); + Tensor thr_tile_load_D = thr_copy_load.partition_D(S); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition Use make_fragment because the first mode is the instruction-local + // mode + Tensor fragment = + make_fragment_like(thr_tile_load_D); // (CopyOp, CopyM, CopyN) + + copy(load, thr_tile_load_S, fragment); + + auto thr_copy_store = store.get_thread_slice(ThreadIdxX()); + + Tensor thr_tile_store_D = + thr_copy_store.partition_D(D); // (CopyOp, CopyM, CopyN) + + Tensor frag_view = + make_tensor(static_cast(fragment).data(), + thr_copy_store.partition_S(D).shape()); + +#if 0 + if (thread(0)) { + print("thr_tile_load_S: "); + print(thr_tile_load_S.layout()); + print("\n"); + + print("thr_tile_load_D: "); + print(thr_tile_load_D.layout()); + print("\n"); + + print("fragment: "); + print(fragment.layout()); + print("\n"); + + print("thr_tile_store_D: "); + print(thr_tile_store_D.layout()); + print("\n"); + + print("frag_view: "); + print(frag_view.layout()); + print("\n\n"); + } +#endif + + copy(store, frag_view, thr_tile_store_D); +} + +TEST(PVC_2d_copy, load_store_global) { + { + constexpr int M = 8; + constexpr int N = 16; + using Element = uint16_t; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_copy = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(size(tiled_copy)); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_copy, tiled_copy); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } +} + +TEST(PVC_2d_copy, load_store_global_V) { + { + constexpr int M = 16; + constexpr int N = 16; + using Element = uint16_t; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_copy = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(size(tiled_copy)); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_copy, tiled_copy); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } +} + +template +void copy_kernel_local(TensorS S, TensorD D, TiledCopy Op) { + + // Shared memory buffers + using Element = typename TensorS::value_type; + ; + auto smem = syclcompat::local_mem(); + Tensor sTensor = make_tensor(make_smem_ptr(smem), S.layout()); + + auto thr_copy = Op.get_thread_slice(ThreadIdxX()); + Tensor thr_global_S = thr_copy.partition_S(S); + Tensor thr_global_D = thr_copy.partition_D(D); + Tensor thr_local_S = thr_copy.partition_S(sTensor); + Tensor thr_local_D = thr_copy.partition_D(sTensor); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition Use make_fragment because the first mode is the instruction-local + // mode + Tensor fragment = make_fragment_like(thr_global_D); // (CopyOp, CopyM, CopyN) + + copy(Op, thr_global_S, fragment); + copy(Op, fragment, thr_local_D); + clear(fragment); + copy(Op, thr_local_S, fragment); + copy(Op, fragment, thr_global_D); +} + +TEST(PVC_2d_copy, load_store_local) { + { + constexpr int M = 8; + constexpr int N = 16; + using Element = uint16_t; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_copy = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(size(tiled_copy)); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_copy); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], host_src[i]); + } + } +} + +template +void copy_kernel_atomic(TensorS S, TensorD D, TiledLoad load, TiledStore store) { + + auto thr_copy_load = load.get_thread_slice(ThreadIdxX()); + Tensor thr_tile_load_S = thr_copy_load.partition_S(S); + Tensor thr_tile_load_D = thr_copy_load.partition_D(S); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition Use make_fragment because the first mode is the instruction-local + // mode + Tensor fragment = + make_fragment_like(thr_tile_load_D); // (CopyOp, CopyM, CopyN) + + copy(load, thr_tile_load_S, fragment); + + auto thr_copy_store = store.get_thread_slice(ThreadIdxX()); + + Tensor thr_tile_store_D = + thr_copy_store.partition_D(D); // (CopyOp, CopyM, CopyN) + +#if 0 + if (thread(0)) { + print("thr_tile_load_S: "); + print(thr_tile_load_S.layout()); + print("\n"); + + print("thr_tile_load_D: "); + print(thr_tile_load_D.layout()); + print("\n"); + + print("fragment: "); + print(fragment.layout()); + print("\n"); + + print("thr_tile_store_D: "); + print(thr_tile_store_D.layout()); + print("\n"); + } +#endif + + copy(store, fragment, thr_tile_store_D); + copy(store, fragment, thr_tile_store_D); +} + +TEST(PVC_2d_copy, load_store_stomic_float) { + { + constexpr int M = 8; + constexpr int N = 16; + using Element = float; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + auto tiled_atom = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(size(tiled_load)); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_load, tiled_atom); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], 2 * host_src[i]); + } + } +} + +TEST(PVC_2d_copy, load_store_stomic_int) { + { + constexpr int M = 8; + constexpr int N = 16; + using Element = int; + // + // Allocate and initialize + // + cutlass::host_vector host_src(M * N); + cutlass::host_vector host_output(M * N); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + Tensor S = + make_tensor(make_gmem_ptr(device_src.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + Tensor D = + make_tensor(make_gmem_ptr(device_output.data()), + make_layout(Shape, Int>{}, Stride, _1>{})); + + auto tiled_load = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + auto tiled_atom = make_tiled_copy(Copy_Atom, Element>{}, + Layout, Stride<_16, _1>>{}, + Layout, Stride<_1, _8>>{}); + static constexpr auto subgroup_size = 16; + auto blockDim = syclcompat::dim3(size(tiled_load)); + // + // Launch the kernel + // + launch>( + launch_policy{ + syclcompat::dim3(1), blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + S, D, tiled_load, tiled_atom); + + syclcompat::wait_and_throw(); + host_output = device_output; + for (int i = 0; i < M * N; ++i) { + EXPECT_EQ(host_output[i], 2 * host_src[i]); + } + } +} diff --git a/test/unit/cute/intel_xe/copy_subgroup_block.cpp b/test/unit/cute/intel_xe/copy_subgroup_block.cpp new file mode 100644 index 0000000000..1cafa48c5b --- /dev/null +++ b/test/unit/cute/intel_xe/copy_subgroup_block.cpp @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace syclcompat::experimental; + +#define SUBGROUP_SIZE (16) + +template +void copy_kernel_vectorized(TensorS S, TensorD D, uint32_t M, uint32_t N) { + using namespace cute; + + using Element = typename TensorS::value_type; + + Tensor tiled_tensor_S = tiled_divide( + S, Shape, Int>{}); // ((M, N), m', n') + Tensor tiled_tensor_D = tiled_divide( + D, Shape, Int>{}); // ((M, N), m', n') + + // Slice work group. + Tensor tile_wg_S = + tiled_tensor_S(make_coord(_, _), BlockIdxX(), + BlockIdxY()); + Tensor tile_wg_D = + tiled_tensor_D(make_coord(_, _), BlockIdxX(), + BlockIdxY()); + + // Slice subgroup. + auto SubgroupShape = Shape, Int>{}; + auto sg_id = cutlass::get_sub_group_id(); + Tensor tile_sg_S = local_tile(tile_wg_S, SubgroupShape, sg_id); + Tensor tile_sg_D = local_tile(tile_wg_D, SubgroupShape, sg_id); + +#if 0 + if (thread(1)) { + print("tile_wg_S:"); + print(tile_wg_S.layout()); + print("\n"); + + print("tile_sg_S:"); + print(tile_sg_S.layout()); + print("\n"); + } +#endif + + using traits_load = Copy_Traits; + using Atom_load = Copy_Atom; + auto VecLayout = make_layout( + make_shape(get<0>(typename traits_load::Shape_MN{}), + get<1>(typename traits_load::Shape_MN{}) / _16{}), + Stride<_1, _0>{}); + auto tiled_copy_load = make_tiled_copy(Atom_load{}.with(&*S.data(), N, M, N), + Layout>{}, VecLayout); + + // Construct a Tensor corresponding to each thread's slice. + auto thr_copy_load = + tiled_copy_load.get_thread_slice(cutlass::get_sub_group_local_id()); + Tensor thr_tile_load_S = thr_copy_load.partition_S(tile_sg_S); + Tensor thr_tile_load_D = thr_copy_load.partition_D(tile_sg_S); + + // Construct a register-backed Tensor with the same shape as each thread's + // partition Use make_fragment because the first mode is the instruction-local + // mode + Tensor fragment = make_fragment_like(thr_tile_load_D); + +#if 0 + if (thread(1)) { + print("thr_tile_load_S: "); + print(thr_tile_load_S.layout()); + print("\n"); + + print("thr_tile_load_D: "); + print(thr_tile_load_D.layout()); + print("\n"); + + print("fragment: "); + print(fragment.layout()); + print("\n"); + } +#endif + + static constexpr auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = + BlockIdxX() * wg_tile_m + (cutlass::get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = + BlockIdxY() * wg_tile_n + (cutlass::get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + // Copy from GMEM to RMEM and from RMEM to GMEM + auto blk_load_S = tiled_copy_load.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), fragment.shape(), + typename traits_load::Shape_MN{}); + copy(tiled_copy_load, blk_load_S, fragment); + + using traits_store = Copy_Traits; + using Atom_store = Copy_Atom; + + auto tiled_copy_store = + make_tiled_copy(Atom_store{}.with(&*D.data(), N, M, N), + Layout, Stride<_0, _1>>{}, VecLayout); + auto thr_copy_store = + tiled_copy_store.get_thread_slice(ThreadIdxX()); + + Tensor thr_tile_store_D = thr_copy_store.partition_D(tile_sg_D); + +#if 0 + if (thread(1)) { + print("storing to dst from registers ========================\n"); + print("tile_sg_D:"); + print(tile_sg_D.layout()); + print("\n"); + + print("thr_tile_store_D: "); + print(thr_tile_store_D.layout()); + print("\n"); + } +#endif + + auto blk_store_D = tiled_copy_store.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), fragment.shape(), + typename traits_store::Shape_MN{}); + + // onlt run first subgroup + if (syclcompat::global_id::x() < 16 && !syclcompat::global_id::y() && + !syclcompat::global_id::z()) { + copy(tiled_copy_store, fragment, blk_store_D); + } +} + +template +bool copy(uint32_t M, uint32_t N) { + using namespace cute; + // + // Given a 2D shape, perform an efficient copy + // + + auto tensor_shape = make_shape(M, N); + auto block_shape = make_shape(Int{}, Int{}); + auto subgroup_shape = make_shape(Int{}, Int{}); + + // + // Allocate and initialize + // + cutlass::host_vector host_src(size(tensor_shape)); + cutlass::host_vector host_output(size(tensor_shape)); + + for (size_t i = 0; i < host_src.size(); ++i) { + host_src[i] = static_cast(i); + } + + cutlass::device_vector device_src = host_src; + cutlass::device_vector device_output = host_output; + + + // + // Make tensors + // + + Tensor tensor_S = make_tensor(make_gmem_ptr(device_src.data()), + make_layout(tensor_shape, make_stride(N, 1))); + Tensor tensor_D = make_tensor(make_gmem_ptr(device_output.data()), + make_layout(tensor_shape, make_stride(N, 1))); + + // + // Tile tensors + // + + // Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile + // shape, and modes (m', n') correspond to the number of tiles. + // + // These will be used to determine the CUDA kernel grid dimensions. + + // Thread arrangement + + static constexpr auto subgroup_size = 16; + + Layout thr_layout = + make_layout(Shape, + Int>{}); + + // + // Determine grid and block dimensions + // + + auto gridDim = syclcompat::dim3(cute::ceil_div(M, wg_tile_m), + cute::ceil_div(N, wg_tile_n)); + auto blockDim = syclcompat::dim3(size(thr_layout)); + + // + // Launch the kernel + // + launch>( + launch_policy{gridDim, blockDim, + kernel_properties{sycl_exp::sub_group_size}}, + tensor_S, tensor_D, M, N); + + syclcompat::wait_and_throw(); + + // + // Verify + // + + host_output = device_output; + + auto surface_pitch = N; + for (int i = 0; i < sg_tile_m && i < M; i++) { + for (int j = 0; j < sg_tile_n && j < N; j++) { + EXPECT_EQ(host_output[surface_pitch * i + j], surface_pitch * i + j); + } + } + + for (int i = sg_tile_m; i < sg_tile_m + 1 && i < M; i++) { + for (int j = 0; j < sg_tile_n && j < N; j++) { + EXPECT_NE(host_output[surface_pitch * i + j], surface_pitch * i + j); + } + } + + for (int i = 0; i < sg_tile_m && i < M; i++) { + for (int j = sg_tile_n; j < sg_tile_n + 1 && j < N; j++) { + EXPECT_NE(host_output[surface_pitch * i + j], surface_pitch * i + j); + } + } + return true; +} + +TEST(PVC_CuTe_Xe, block_2d_float_aligned) { + copy(8, 16); + copy(32, 128); + copy(32, 128); + copy(32, 128); + copy(1024, 4096); +} + +TEST(PVC_CuTe_Xe, block_2d_float_unaligned) { + copy(1024, 4098); + copy(1026, 4096); + copy(1026, 4098); +} diff --git a/test/unit/cute/intel_xe/gemm_col_col.cpp b/test/unit/cute/intel_xe/gemm_col_col.cpp new file mode 100644 index 0000000000..d9a0c4c5bd --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_col_col.cpp @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_utils.hpp" + +template +struct gemm_device_col_col { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = false; + static constexpr bool is_b_row_major = false; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(1, m))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(k, n), make_stride(1, k))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + TiledCopy copy_a = make_tiled_copy( + atom_load_A{}.with(A, m, k, m), Layout, _1>>{}, + make_layout(make_shape(get<1>(typename traits_load_A::Shape_MN{})/ Int{}, + get<0>(typename traits_load_A::Shape_MN{})))); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + TiledCopy copy_b = make_tiled_copy( + atom_load_B{}.with(B, k, n, k), Layout, _1>>{}, + make_layout(make_shape(get<1>(typename traits_load_B::Shape_MN{})/ Int{}, + get<0>(typename traits_load_B::Shape_MN{})))); + + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + TiledCopy copy_c = make_tiled_copy( + atom_store_C{}.with(C, n, m, n), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), + get<1>(typename traits_store_C::Shape_MN{})/ Int{}))); + auto thread_idx = ThreadIdxX(); + auto mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int, _1>>{}); + auto thr_mma = mma.get_thread_slice(thread_idx); + auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); + auto tCrC = thr_mma.partition_fragment_C(gC); + + auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); + auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); + + auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); + auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); + auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); + auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); + auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); + + clear(tCrC); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); + print(" tCrA : "); print(tCrA); print("\n"); + + print("===================== B :\n"); + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== C :\n"); + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); + print(" tCrC : "); print(tCrC); print("\n"); + } +#endif + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + auto k_tile_max = size<2>(gA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( + make_coord(k_tile * sg_tile_k, m_coord, l_coord), + tCrA_copy_view.shape(), + typename traits_load_A::Shape_MN{}, seq<1,0>{}); + Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( + make_coord(n_coord, k_tile * sg_tile_k, l_coord), + tCrB_copy_view.shape(), + typename traits_load_B::Shape_MN{}); + + copy(tiled_copy_A, blk_tgA, tCrA_copy_view); + copy(tiled_copy_B, blk_tgB, tCrB_copy_view); + + // Compute gemm on mma-partitioned smem + for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { + gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); + } + } + + Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), + typename traits_store_C::Shape_MN{}); + copy(copy_c, tCrC_copy_view, blk_tgC); + } +}; + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_32x128x64) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_16x256x64) { + run>(16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_64x1024x64) { + run>(64, 1024, 64); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_128x128x64) { + run>(128, 128, 64); +} +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_32x1024x1024) { + run>(32, 1024, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_4096x4096x256) { + run>(4096, 4096, 256); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1024x2048x512) { + run>(1024, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1026x2048x512) { + run>(1026, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1024x2050x512) { + run>(1024, 2050, 512); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_1026x2050x256) { + run>(1026, 2050, 256); +} + +TEST(PVC_CuTe_Xe, gemm_col_col_bf16_bf16_float_512x1024x512) { + run>(512, 1024, 512); +} diff --git a/test/unit/cute/intel_xe/gemm_col_row.cpp b/test/unit/cute/intel_xe/gemm_col_row.cpp new file mode 100644 index 0000000000..a3c1722472 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_col_row.cpp @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_utils.hpp" + +template +struct gemm_device_col_row { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = false; + static constexpr bool is_b_row_major = true; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(1, m))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(k, n), make_stride(n, 1))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + TiledCopy copy_a = make_tiled_copy( + atom_load_A{}.with(A, m, k, m), Layout, _1>>{}, + make_layout(make_shape(get<1>(typename traits_load_A::Shape_MN{}), + get<0>(typename traits_load_A::Shape_MN{}) / Int{}))); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + TiledCopy copy_b = make_tiled_copy( + atom_load_B{}.with(B, n, k, n), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), + get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); + + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + TiledCopy copy_c = make_tiled_copy( + atom_store_C{}.with(C, n, m, n), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), + get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + + auto thread_idx = ThreadIdxX(); + auto mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int, _1>>{}); + auto thr_mma = mma.get_thread_slice(thread_idx); + auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); + auto tCrC = thr_mma.partition_fragment_C(gC); + + auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); + auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); + + auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); + auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); + auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); + auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); + auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); + + clear(tCrC); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); + print(" tCrA : "); print(tCrA); print("\n"); + + print("===================== B :\n"); + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== C :\n"); + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); + print(" tCrC : "); print(tCrC); print("\n"); + } +#endif + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + auto k_tile_max = size<2>(gA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( + make_coord(k_tile * sg_tile_k, m_coord, l_coord), + tCrA_copy_view.shape(), typename traits_load_A::Shape_MN{}, seq<1,0>{}); + Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( + make_coord(k_tile * sg_tile_k, n_coord, l_coord), + tCrB_copy_view.shape(), typename traits_load_B::Shape_MN{}, seq<1,0>{}); + + copy(tiled_copy_A, blk_tgA, tCrA_copy_view); + copy(tiled_copy_B, blk_tgB, tCrB_copy_view); + + // Compute gemm on mma-partitioned smem + for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { + gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); + } + } + + Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), + typename traits_store_C::Shape_MN{}); + copy(copy_c, tCrC_copy_view, blk_tgC); + } +}; + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_32x128x64) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_16x256x64) { + run>(16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_64x1024x64) { + run>(64, 1024, 64); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_128x128x64) { + run>(128, 128, 64); +} +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_32x1024x1024) { + run>(32, 1024, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_4096x4096x256) { + run>(4096, 4096, 256); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1024x2048x512) { + run>(1024, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1026x2048x512) { + run>(1026, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1024x2050x512) { + run>(1024, 2050, 512); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_1026x2050x256) { + run>(1026, 2050, 256); +} + +TEST(PVC_CuTe_Xe, gemm_col_row_bf16_bf16_float_512x1024x512) { + run>(512, 1024, 512); +} diff --git a/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp b/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp new file mode 100755 index 0000000000..6224e4a6e7 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_partition_fragment_abc.cpp @@ -0,0 +1,249 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_utils.hpp" + +template +struct gemm_device_partition_fragment_abc { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = true; + static constexpr bool is_b_row_major = true; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + using namespace cute; + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(k, 1))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(k, n), make_stride(n, 1))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); // (m,n,k) + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + TiledCopy copy_a = make_tiled_copy( + atom_load_A{}.with(A, k, m, k), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), + get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + TiledCopy copy_b = make_tiled_copy( + atom_load_B{}.with(B, n, k, n), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), + get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); + + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + TiledCopy copy_c = make_tiled_copy( + atom_store_C{}.with(C, n, m, n), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), + get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + + auto thread_idx = ThreadIdxX(); + + TiledMMA mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int>>{}); + auto thrd_mma = mma.get_thread_slice(thread_idx); + + Tensor fragment_A = thrd_mma.partition_fragment_A(gA(_, _, 0)); + Tensor fragment_temp = thrd_mma.partition_fragment_B(gB(_, _, 0)); + Tensor fragment_B = make_tensor( + static_cast(fragment_temp).data(), + make_shape(size<0>(fragment_temp.shape()), + size<2>(fragment_temp.shape()), + size<1>(fragment_temp.shape()))); + Tensor fragment_C = thrd_mma.partition_fragment_C(gC); + + ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); + auto copy_view_A = thr_copy_a.retile_D(fragment_A); + + ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); + auto copy_view_B = thr_copy_b.retile_D(fragment_B); + + ThrCopy thr_copy_c = copy_c.get_slice(thread_idx); + auto copy_view_C = thr_copy_c.retile_D(fragment_C); + + clear(fragment_C); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print(" fragment_A : "); print(fragment_A); print("\n"); + print(" copy_view_A : "); print(copy_view_A); print("\n"); + + print("===================== B :\n"); + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print(" fragment_B : "); print(fragment_B); print("\n"); + print(" copy_view_B : "); print(copy_view_B); print("\n"); + + print("===================== C :\n"); + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print(" fragment_C : "); print(fragment_C); print("\n"); + print(" copy_view_C : "); print(copy_view_C); print("\n"); + } +#endif + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + auto k_tile_max = size<2>(gA); + Tensor blk_tgA = copy_a.get_pvc_tensor( + make_coord(m_coord, 0, l_coord), append<4>(copy_view_A.shape(), k_tile_max), + append<3>(typename traits_load_A::Shape_MN{}, sg_tile_k), seq<0, 1, 1>{}); + Tensor blk_tgB = copy_b.get_pvc_tensor( + make_coord(0, n_coord, l_coord), append<4>(copy_view_B.shape(), k_tile_max), + append<3>(typename traits_load_B::Shape_MN{}, sg_tile_k), seq<0, 1, 0>{}); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) { + print("blk_tgA : "); print(blk_tgA); print("\n"); + print("blk_tgB : "); print(blk_tgB); print("\n"); + } +#endif + + // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors + copy(copy_a, blk_tgA(_, _, _, k_tile), copy_view_A); + copy(copy_b, blk_tgB(_, _, _, k_tile), copy_view_B); + + // Compute gemm on mma-partitioned smem + for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { + gemm(mma, fragment_A(_, _, i), fragment_B(_, i, _), fragment_C); + } + } + + Tensor blk_tgC = copy_c.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), fragment_C.shape(), + typename traits_store_C::Shape_MN{}); + + copy(copy_c, fragment_C, blk_tgC); + } +}; + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_32x128x64) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_16x256x64) { + run>(16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_64x1024x64) { + run>(64, 1024, 64); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_128x128x64) { + run>(128, 128, 64); +} +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_32x1024x1024) { + run>(32, 1024, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_4096x4096x256) { + run>(4096, 4096, 256); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_1024x2048x512) { + run>(1024, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_1026x2048x512) { + run>(1026, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_1024x2050x512) { + run>(1024, 2050, 512); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_1026x2050x256) { + run>(1026, 2050, 256); +} + +TEST(PVC_CuTe_Xe, gemm_partition_fragment_abc_bf16_bf16_float_512x1024x512) { + run>(512, 1024, 512); +} diff --git a/test/unit/cute/intel_xe/gemm_partition_src_dst.cpp b/test/unit/cute/intel_xe/gemm_partition_src_dst.cpp new file mode 100755 index 0000000000..4ed1bb40e0 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_partition_src_dst.cpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_utils.hpp" + +template +struct gemm_device_partition_sd { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = true; + static constexpr bool is_b_row_major = true; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(k, 1))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(k, n), make_stride(n, 1))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); // (m,n,k) + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + auto sg_id = get_sub_group_id(); + Tensor sgA = local_tile( + gA, make_shape(Int{}, Int{}, k / sg_tile_k), + sg_id / sg_per_wg_x); + Tensor sgB = local_tile( + gB, make_shape(Int{}, Int{}, k / sg_tile_k), + sg_id % sg_per_wg_x); + Tensor sgC = + local_tile(gC, make_shape(Int{}, Int{}), + make_coord(sg_id / sg_per_wg_x, sg_id % sg_per_wg_x)); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + TiledCopy copy_a = make_tiled_copy( + atom_load_A{}.with(A, k, m, k), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), + get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + TiledCopy copy_b = make_tiled_copy( + atom_load_B{}.with(B, n, k, n), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), + get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + TiledCopy copy_c = make_tiled_copy( + atom_store_C{}.with(C, n, m, n), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), + get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + TiledMMA mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int>>{}); + + auto thread_idx = get_sub_group_local_id(); + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); + Tensor tgA = thr_copy_a.partition_D(sgA); + Tensor fragment_A = make_fragment_like(tgA(_, _, _, 0)); + + ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); + Tensor tgB = thr_copy_b.partition_D(sgB); + Tensor fragment_B = make_fragment_like(tgB(_, _, _, 0)); + + ThrCopy thr_copy_c = copy_c.get_slice(thread_idx); + Tensor tgC = thr_copy_c.partition_S(sgC); + Tensor fragment_C = make_fragment_like(tgC); + clear(fragment_C); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print("tgA : "); print(tgA); print("\n"); + print("fragment_A : "); print(fragment_A); print("\n"); + + print("===================== B :\n"); + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print("tgB : "); print(tgB); print("\n"); + print("fragment_B : "); print(fragment_B); print("\n"); + + print("===================== C :\n"); + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print("tgC : "); print(tgC); print("\n"); + print("fragment_C : "); print(fragment_C); print("\n"); + } +#endif + + auto k_tile_max = size<3>(tgA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + + Tensor blk_tgA = copy_a.get_pvc_tensor( + make_coord(m_coord, k_tile * sg_tile_k, l_coord), fragment_A.shape(), + typename traits_load_A::Shape_MN{}); + Tensor blk_tgB = copy_b.get_pvc_tensor( + make_coord(k_tile * sg_tile_k, n_coord, l_coord), fragment_B.shape(), + typename traits_load_B::Shape_MN{}, seq<1,0>{}); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP) && k_tile == 1) { + print("blk_tgA : "); print(blk_tgA); print("\n"); + print("blk_tgB : "); print(blk_tgB); print("\n"); + } +#endif + + // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors + copy(copy_a, blk_tgA, fragment_A); + copy(copy_b, blk_tgB, fragment_B); + + // Compute gemm on mma-partitioned smem + for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { + gemm(mma, fragment_A(_, _, i), fragment_B(_, _, i), fragment_C); + } + } + + Tensor blk_tgC = copy_c.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), fragment_C.shape(), + typename traits_store_C::Shape_MN{}); + + copy(copy_c, fragment_C, blk_tgC); + } +}; + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_32x128x64) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_16x256x64) { + run>(16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_64x1024x64) { + run>(64, 1024, 64); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_128x128x64) { + run>(128, 128, 64); +} +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_32x1024x1024) { + run>(32, 1024, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_4096x4096x256) { + run>(4096, 4096, 256); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_1024x2048x512) { + run>(1024, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_1026x2048x512) { + run>(1026, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_1024x2050x512) { + run>(1024, 2050, 512); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_1026x2050x256) { + run>(1026, 2050, 256); +} + +TEST(PVC_CuTe_Xe, gemm_partition_sd_bf16_bf16_float_512x1024x512) { + run>(512, 1024, 512); +} diff --git a/test/unit/cute/intel_xe/gemm_row_col.cpp b/test/unit/cute/intel_xe/gemm_row_col.cpp new file mode 100755 index 0000000000..57bd808d45 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_row_col.cpp @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_utils.hpp" + +template +struct gemm_device_row_col { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = true; + static constexpr bool is_b_row_major = false; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(k, 1))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(k, n), make_stride(1, k))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + TiledCopy copy_a = make_tiled_copy( + atom_load_A{}.with(A, k, m, k), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), + get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + TiledCopy copy_b = make_tiled_copy( + atom_load_B{}.with(B, k, n, k), Layout, _1>>{}, + make_layout(make_shape(get<1>(typename traits_load_B::Shape_MN{})/ Int{}, + get<0>(typename traits_load_B::Shape_MN{})))); + + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + TiledCopy copy_c = make_tiled_copy( + atom_store_C{}.with(C, n, m, n), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), + get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + + auto thread_idx = ThreadIdxX(); + auto mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int, _1>>{}); + auto thr_mma = mma.get_thread_slice(thread_idx); + auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); + auto tCrC = thr_mma.partition_fragment_C(gC); + + auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); + auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); + + auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); + auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); + auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); + auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); + auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); + + clear(tCrC); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); + print(" tCrA : "); print(tCrA); print("\n"); + + print("===================== B :\n"); + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== C :\n"); + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); + print(" tCrC : "); print(tCrC); print("\n"); + } +#endif + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + auto k_tile_max = size<2>(gA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( + make_coord(m_coord, k_tile * sg_tile_k, l_coord), + tCrA_copy_view.shape(), + typename traits_load_A::Shape_MN{}); + Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( + make_coord(n_coord, k_tile * sg_tile_k, l_coord), + tCrB_copy_view.shape(), + typename traits_load_B::Shape_MN{}); + + copy(tiled_copy_A, blk_tgA, tCrA_copy_view); + copy(tiled_copy_B, blk_tgB, tCrB_copy_view); + + // Compute gemm on mma-partitioned smem + for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { + gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); + } + } + + Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), + typename traits_store_C::Shape_MN{}); + copy(copy_c, tCrC_copy_view, blk_tgC); + } +}; + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_32x128x64) { + run>(32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_16x256x64) { + run>(16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_64x1024x64) { + run>(64, 1024, 64); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_128x128x64) { + run>(128, 128, 64); +} +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_32x1024x1024) { + run>(32, 1024, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_4096x4096x256) { + run>(4096, 4096, 256); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1024x2048x512) { + run>(1024, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1026x2048x512) { + run>(1026, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1024x2050x512) { + run>(1024, 2050, 512); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_1026x2050x256) { + run>(1026, 2050, 256); +} + +TEST(PVC_CuTe_Xe, gemm_row_col_bf16_bf16_float_512x1024x512) { + run>(512, 1024, 512); +} diff --git a/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp b/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp new file mode 100755 index 0000000000..a979ce3d28 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_tiled_copy_abc.cpp @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "gemm_utils.hpp" + +template +struct gemm_device_tiled_copy_abc { + using TA = dtype_a; + using TB = dtype_b; + using TC = dtype_c; + + static constexpr bool is_a_row_major = true; + static constexpr bool is_b_row_major = true; + + static constexpr uint32_t wg_tile_m = wg_m; + static constexpr uint32_t wg_tile_n = wg_n; + static constexpr uint32_t sg_tile_m = sg_m; + static constexpr uint32_t sg_tile_n = sg_n; + static constexpr uint32_t sg_tile_k = sg_k; + + static void func(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(k, 1))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(k, n), make_stride(n, 1))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + TiledCopy copy_a = make_tiled_copy( + atom_load_A{}.with(A, k, m, k), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_A::Shape_MN{}), + get<1>(typename traits_load_A::Shape_MN{}) / Int{}))); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + TiledCopy copy_b = make_tiled_copy( + atom_load_B{}.with(B, n, k, n), Layout>>{}, + make_layout(make_shape(get<0>(typename traits_load_B::Shape_MN{}), + get<1>(typename traits_load_B::Shape_MN{}) / Int{}))); + + using traits_store_C = Copy_Traits; + using atom_store_C = Copy_Atom; + TiledCopy copy_c = make_tiled_copy( + atom_store_C{}.with(C, n, m, n), + Layout>>{}, + make_layout(make_shape(get<0>(typename traits_store_C::Shape_MN{}), + get<1>(typename traits_store_C::Shape_MN{}) / Int{}))); + auto thread_idx = ThreadIdxX(); + auto mma = make_tiled_mma( + MMA_Atom{}, + Layout, + Int, _1>>{}); + auto thr_mma = mma.get_thread_slice(thread_idx); + auto tCrA = thr_mma.partition_fragment_A(gA(_, _, 0)); + auto tCrB = thr_mma.partition_fragment_B(gB(_, _, 0)); + auto tCrC = thr_mma.partition_fragment_C(gC); + + auto tiled_copy_A = make_tiled_copy_A(copy_a, mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(thread_idx); + auto tCrA_copy_view = thr_copy_A.retile_D(tCrA); + + auto tiled_copy_B = make_tiled_copy_B(copy_b, mma); + auto thr_copy_B = tiled_copy_B.get_thread_slice(thread_idx); + auto tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + auto tiled_copy_C = make_tiled_copy_C(copy_c, mma); + auto thr_copy_C = tiled_copy_C.get_thread_slice(thread_idx); + auto tCrC_copy_view = thr_copy_C.retile_D(tCrC); + + clear(tCrC); + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD, LOG_GROUP)) { + print("===================== A :\n"); + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print("tCrA_copy_view : "); print(tCrA_copy_view); print("\n"); + print(" tCrA : "); print(tCrA); print("\n"); + + print("===================== B :\n"); + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print("tCrB_copy_view : "); print(tCrB_copy_view); print("\n"); + print(" tCrB : "); print(tCrB); print("\n"); + + print("===================== C :\n"); + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print("tCrC_copy_view : "); print(tCrC_copy_view); print("\n"); + print(" tCrC : "); print(tCrC); print("\n"); + } +#endif + + auto sg_per_wg_x = wg_tile_n / sg_tile_n; + const int m_coord = BlockIdxX() * wg_tile_m + + (get_sub_group_id() / sg_per_wg_x) * sg_tile_m; + const int n_coord = BlockIdxY() * wg_tile_n + + (get_sub_group_id() % sg_per_wg_x) * sg_tile_n; + const int l_coord = BlockIdxZ(); + + auto k_tile_max = size<2>(gA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + Tensor blk_tgA = tiled_copy_A.get_pvc_tensor( + make_coord(m_coord, k_tile * sg_tile_k, l_coord), + tCrA_copy_view.shape(), + typename traits_load_A::Shape_MN{}); + Tensor blk_tgB = tiled_copy_B.get_pvc_tensor( + make_coord(k_tile * sg_tile_k, n_coord, l_coord), + tCrB_copy_view.shape(), + typename traits_load_B::Shape_MN{}, seq<1,0>{}); + + copy(tiled_copy_A, blk_tgA, tCrA_copy_view); + copy(tiled_copy_B, blk_tgB, tCrB_copy_view); + + // Compute gemm on mma-partitioned smem + for (int i = 0; i < sg_tile_k / SUBGROUP_SIZE; i++) { + gemm(mma, tCrA(_, _, i), tCrB(_, _, i), tCrC); + } + } + + Tensor blk_tgC = tiled_copy_C.get_pvc_tensor( + make_coord(m_coord, n_coord, l_coord), tCrC_copy_view.shape(), + typename traits_store_C::Shape_MN{}); + copy(copy_c, tCrC_copy_view, blk_tgC); + } +}; + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_32x128x64) { + run>( + 32, 128, 64); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_16x256x64) { + run>( + 16, 256, 64); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_64x1024x64) { + run>( + 64, 1024, 64); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_128x128x64) { + run>( + 128, 128, 64); +} +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_32x1024x1024) { + run>( + 32, 1024, 1024); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_4096x4096x256) { + run>( + 4096, 4096, 256); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_1024x2048x512) { + run>( + 1024, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_1026x2048x512) { + run>( + 1026, 2048, 512); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_1024x2050x512) { + run>( + 1024, 2050, 512); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_1026x2050x256) { + run>( + 1026, 2050, 256); +} + +TEST(PVC_CuTe_Xe, gemm_tiled_copy_abc_bf16_bf16_float_512x1024x512) { + run>( + 512, 1024, 512); +} diff --git a/test/unit/cute/intel_xe/gemm_utils.hpp b/test/unit/cute/intel_xe/gemm_utils.hpp new file mode 100644 index 0000000000..3742d457a5 --- /dev/null +++ b/test/unit/cute/intel_xe/gemm_utils.hpp @@ -0,0 +1,119 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace cute; +using namespace cutlass; +using namespace syclcompat::experimental; + +#define SUBGROUP_SIZE (16) + +#define CUTLASS_ENABLE_DEBUG_PRINTS (0) +#define LOG_GROUP (0) +#define LOG_THREAD (0) + +template +void verify(uint32_t m, uint32_t n, uint32_t k, atype *A, btype *B, ctype *C, + bool row_a, bool row_b) { + int cnt = 0; + bool is_normal = true; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + ctype expect = ctype(0); + for (int z = 0; z < k; z++) { + auto a = row_a ? A[i * k + z] : A[i + z * m]; + auto b = row_b ? B[z * n + j] : B[z + j * k]; + expect += a * b; + } + + ctype val = C[i * n + j]; + + if (isnormal(val) && isnormal(expect)) { + auto error = std::abs((expect - val) / val); + if (error > 0.01f) { + cnt++; + } + } else { + is_normal = false; + } + } + } + + EXPECT_EQ(cnt, 0); + EXPECT_EQ(is_normal, true); +} + +template static void fill_matrix(cutlass::host_vector &M) { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist((T)0.0, (T)1.0); + for (int i = 0; i < M.size(); i++) M[i] = static_cast(dist(rng)); +} + +template void run(uint32_t m, uint32_t n, uint32_t k) { + + using TA = typename kernel::TA; + using TB = typename kernel::TB; + using TC = typename kernel::TC; + + cutlass::host_vector h_A(m * k); + cutlass::host_vector h_B(n * k); + cutlass::host_vector h_C(m * n); + + fill_matrix(h_A); + fill_matrix(h_B); + + cutlass::device_vector d_A = h_A; + cutlass::device_vector d_B = h_B; + cutlass::device_vector d_C = h_C; + + auto dimBlock = syclcompat::dim3( + ceil_div(kernel::wg_tile_m, kernel::sg_tile_m), + SUBGROUP_SIZE * ceil_div(kernel::wg_tile_n, kernel::sg_tile_n)); + auto dimGrid = syclcompat::dim3(size(ceil_div(m, kernel::wg_tile_m)), + size(ceil_div(n, kernel::wg_tile_n))); + + launch( + launch_policy{dimGrid, dimBlock, + kernel_properties{sycl_exp::sub_group_size}}, + d_A.data(), d_B.data(), d_C.data(), m, n, k); + + syclcompat::wait(); + h_C = d_C; + verify(m, n, k, h_A.data(), h_B.data(), h_C.data(), + kernel::is_a_row_major, kernel::is_b_row_major); +} diff --git a/test/unit/cute/intel_xe/mma.cpp b/test/unit/cute/intel_xe/mma.cpp new file mode 100755 index 0000000000..5d80d86135 --- /dev/null +++ b/test/unit/cute/intel_xe/mma.cpp @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include +#include +#include + +#include "cutlass_unit_test.h" + +using namespace cute; +using namespace cutlass; +using namespace syclcompat::experimental; + +#define SUBGROUP_SIZE (16) + +template +void gemm_device(TA const *A, TB const *B, TC *C, uint32_t m, uint32_t n, + uint32_t k) { + using namespace cute; + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), + make_layout(make_shape(m, k), make_stride(k, 1))); + Tensor mB = make_tensor(make_gmem_ptr(B), + make_layout(make_shape(n, k), make_stride(1, n))); + Tensor mC = make_tensor(make_gmem_ptr(C), + make_layout(make_shape(m, n), make_stride(n, 1))); + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(BlockIdxX(), + BlockIdxY(), _); // (m,n,k) + + auto cta_tiler = + make_shape(Int{}, Int{}, Int{}); + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X, _1>{}); + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step{}); + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1, _1, X>{}); + + TiledMMA mma = make_tiled_mma( + MMA_Atom{}, + Layout< // Require: subgroup_layout + Shape, + Int, _1>>{}); + + ThrMMA thrd_mma = mma.get_slice(ThreadIdxX()); + + Tensor tgA = thrd_mma.partition_A(gA); + Tensor fragment_A = + thrd_mma.make_fragment_A(tgA(_, _, _, 0)); // (MMA, MMA_M, MMA_K) + + Tensor tgB = thrd_mma.partition_B(gB); + Tensor fragment_B = + thrd_mma.make_fragment_B(tgB(_, _, _, 0)); // (MMA, MMA_N, MMA_K) + + Tensor tgC = thrd_mma.partition_C(gC); + Tensor fragment_C = thrd_mma.make_fragment_C(tgC); // (MMA, MMA_M, MMA_N) + clear(fragment_C); + +#define CUTLASS_ENABLE_DEBUG_PRINTS (0) + +#define LOG_THREAD (16) + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD)) { + print("===================== A :\n"); + + print(" mA : "); print(mA); print("\n"); + print(" gA : "); print(gA); print("\n"); + print("tgA : "); print(tgA); print("\n"); + print("fragment_A : "); print(fragment_A); print("\n\n"); + } +#endif + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD)) { + print("===================== B :\n"); + + print(" mB : "); print(mB); print("\n"); + print(" gB : "); print(gB); print("\n"); + print("tgB : "); print(tgB); print("\n"); + print("fragment_B : "); print(fragment_B); print("\n\n"); + } +#endif + +#if CUTLASS_ENABLE_DEBUG_PRINTS + if (thread(LOG_THREAD)) { + print("===================== C :\n"); + + print(" mC : "); print(mC); print("\n"); + print(" gC : "); print(gC); print("\n"); + print("tgC : "); print(tgC); print("\n"); + print("fragment_C : "); print(fragment_C); print("\n\n"); + } +#endif + + auto k_tile_max = size<3>(tgA); + for (int k_tile = 0; k_tile < k_tile_max; ++k_tile) { + auto kA = tgA(_, _, _, k_tile); + auto kB = tgB(_, _, _, k_tile); + // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors + copy(kA, fragment_A); + copy(kB, fragment_B); + + // Compute gemm on mma-partitioned smem + gemm(mma, fragment_A, fragment_B, fragment_C); + } + + copy(fragment_C, tgC); +} + +// Setup params for a NT GEMM +template +void gemm(int m, int n, int k, TA *A, TB *B, TC *C) { + using namespace cute; + + auto dimBlock = syclcompat::dim3(SUBGROUP_SIZE * (wg_tile_m * wg_tile_n) / + (sg_tile_m * sg_tile_n)); + auto dimGrid = syclcompat::dim3(size(ceil_div(m, wg_tile_m)), + size(ceil_div(n, wg_tile_n))); + + launch>( + launch_policy{dimGrid, dimBlock, + kernel_properties{sycl_exp::sub_group_size}}, + A, B, C, m, n, k); +} + +template +void verify(uint32_t m, uint32_t n, uint32_t k, atype *A, btype *B, ctype *C, + ctype *D) { + std::vector h_D(m * n); + + syclcompat::memcpy(h_D.data(), D, m * n); + + int cnt = 0; + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + for (int z = 0; z < k; z++) { + C[i * n + j] += A[i * k + z] * B[z * n + j]; + } + + auto error = abs((C[i * n + j] - h_D.data()[i * n + j]) / + (float)h_D.data()[i * n + j]); + if (error > 0.01f) { + cnt++; + } + } + } + + EXPECT_EQ(cnt, 0); +} + +template static void fill_matrix(std::vector &M) { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist((T)0.0, (T)4.0); + std::generate(std::begin(M), std::end(M), + [&] { return static_cast(dist(rng)); }); +} + +template +void MMA_Test(int m, int n, int k) { + std::vector h_A(m * k); + std::vector h_B(n * k); + std::vector h_C(m * n); + h_C.clear(); + + fill_matrix(h_A); + fill_matrix(h_B); + + auto d_A = syclcompat::malloc(m * k); + auto d_B = syclcompat::malloc(k * n); + auto d_C = syclcompat::malloc(m * n); + + syclcompat::memcpy(d_A, h_A.data(), m * k); + syclcompat::memcpy(d_B, h_B.data(), k * n); + syclcompat::memcpy(d_C, h_C.data(), m * n); + + gemm(m, n, k, d_A, + d_B, d_C); + syclcompat::wait(); + + verify(m, n, k, h_A.data(), h_B.data(), h_C.data(), d_C); +} + +TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32S8S8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32S8S8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32S8S8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32S8S8S32_TT) { + MMA_Test( + 512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_8x16x32_S32U8U8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_4x16x32_S32U8U8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_2x16x32_S32U8U8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_1x16x32_S32U8U8S32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32BF16BF16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32BF16BF16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32BF16BF16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32BF16BF16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_8x16x16_F32F16F16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_4x16x16_F32F16F16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_2x16x16_F32F16F16F32_TT) { + MMA_Test(512, 512, 256); +} + +TEST(PVC_CuTe_Xe, MMA_XE_1x16x16_F32F16F16F32_TT) { + MMA_Test( + 512, 512, 256); +} + +TEST(PVC_CuTe_Xe, FMA_XE_UniversalFMA_F32F32F32F32) { + MMA_Test, 64, 64, 8, 16, 16, float, + float, float>(512, 512, 256); +}