From 4f3e97d1cffe015b25b0852c75a10d742482b06f Mon Sep 17 00:00:00 2001 From: Roland Schulz Date: Wed, 10 Apr 2024 10:58:09 -0700 Subject: [PATCH 1/3] Add XE MMA/copy atom --- include/cute/arch/copy_xe.hpp | 45 ++++++++++++++++++++ include/cute/arch/mma_xe.hpp | 56 +++++++++++++++++++++++++ include/cute/atom/copy_atom.hpp | 2 + include/cute/atom/copy_traits_xe.hpp | 63 ++++++++++++++++++++++++++++ include/cute/atom/mma_atom.hpp | 1 + include/cute/atom/mma_traits_xe.hpp | 24 +++++++++++ include/cute/util/sycl_vec.hpp | 17 ++++++++ 7 files changed, 208 insertions(+) create mode 100644 include/cute/arch/copy_xe.hpp create mode 100644 include/cute/arch/mma_xe.hpp create mode 100644 include/cute/atom/copy_traits_xe.hpp create mode 100644 include/cute/atom/mma_traits_xe.hpp create mode 100644 include/cute/util/sycl_vec.hpp diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp new file mode 100644 index 0000000000..aaf956e50a --- /dev/null +++ b/include/cute/arch/copy_xe.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x +#else +#define SYCL_DEVICE_BUILTIN(x) \ + inline x { assert(false); } +#endif + +SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord, uint8 data)); +SYCL_DEVICE_BUILTIN(ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord)); +SYCL_DEVICE_BUILTIN(uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord)); +#undef SYCL_DEVICE_BUILTIN + +struct XE_2D_LOAD //m8k16 +{ + template + CUTE_HOST_DEVICE static void copy(const void* baseoffset, int width, int height, int pitch, int2_ coord, T* dst) + { + if constexpr(sizeof(T)==sizeof(ushort)) { + *(ushort8*)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord); + } else if constexpr(sizeof(T)==sizeof(uint)) { + *(uint8*)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord); + } else { + static_assert(false); + } + } +}; + +struct XE_2D_SAVE //m8k16 +{ + template + CUTE_HOST_DEVICE static void copy(void* baseoffset, int width, int height, int pitch, int2_ coord, const T* src) + { + if constexpr(sizeof(T)==sizeof(uint)) { + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord, *(uint8*)src); + } else { + static_assert(false); + } + } +}; diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp new file mode 100644 index 0000000000..f523d550f3 --- /dev/null +++ b/include/cute/arch/mma_xe.hpp @@ -0,0 +1,56 @@ + +#pragma once + +#include +#include +#include + +#ifdef __SYCL_DEVICE_ONLY__ +#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x +#else +#define SYCL_DEVICE_OCL(x) inline x { assert(false); } +#endif + +SYCL_DEVICE_OCL(float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc)); +SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc)); +#undef SYCL_DEVICE_OCL + +namespace cute { +//MxNxK_A,B,C,D +//# of vector component of a x subgroup-size x function name +//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); +//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 +struct XE_8x16x16_BF16BF16F32F32_NN +{ + using DRegisters = float8[1]; + using ARegisters = short8[1]; + using BRegisters = int8[1]; + using CRegisters = float8[1]; + + CUTE_HOST_DEVICE static void + fma(float8 & d, + short8 const& a, + int8 const& b, + float8 const& c) + { + d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); + } +}; +//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +struct XE_1x16x16_BF16BF16F32F32_NN +{ + using DRegisters = float[1]; + using ARegisters = short[1]; + using BRegisters = int8[1]; + using CRegisters = float[1]; + + CUTE_HOST_DEVICE static void + fma(float & d, + short const& a, + int8 const& b, + float const& c) + { + d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); + } +}; +} //namespace cute \ No newline at end of file diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index d1cd3d4b71..825becfbff 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -769,4 +769,6 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and #include #endif +#include + //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp new file mode 100644 index 0000000000..e0e13c5f45 --- /dev/null +++ b/include/cute/atom/copy_traits_xe.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include +#include + +#include + +namespace cute +{ + template + struct Copy_Traits + { + // using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails + using ThrID = Layout<_1>; + using NumBits = Int; // hacky: does vec of 8 + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; // TODO: is _1 correct? + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + GTensor tensor; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const &traits, + Tensor>, SLayout> const &src, + Tensor &dst) + { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + // int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); + int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); //TODO: inconsistent to give the size in elements but use vector for copy + auto [y, x] = src.data().coord_; + XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data()); + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const &traits, + Tensor const &src, + Tensor>, DLayout> &dst) + { + static_assert(is_rmem::value); + int H = size<0>(traits.tensor); + int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); + auto [y, x] = dst.data().coord_; + XE_2D_SAVE::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*src.data()); + } + }; + + template + auto make_xe_2d_copy(Tensor gtensor) + { + using GTensor = Tensor; + using Traits = Copy_Traits; + Traits traits{gtensor}; + return Copy_Atom{traits}; + } +} \ No newline at end of file diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 674e3519e8..9421b1505b 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -946,4 +946,5 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and #include #include #include +#include //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp new file mode 100644 index 0000000000..b70263fe00 --- /dev/null +++ b/include/cute/atom/mma_traits_xe.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include + +namespace cute +{ +template <> +struct MMA_Traits +{ + using ElementDVal = float; + using ElementAVal = sycl::ext::oneapi::bfloat16; + using ElementBVal = sycl::ext::oneapi::bfloat16; + using ElementCVal = float; + + using Shape_MNK = Shape<_8,_16,_16>; + using ThrID = Layout<_16>; + using ALayout = Layout, Stride<_8, _1>>; // (T16,V8) -> (m,n) + using BLayout = Layout, Stride<_16, _1>>; + using CLayout = Layout, Stride<_8, _1>>; +}; +} \ No newline at end of file diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp new file mode 100644 index 0000000000..8202cc2a84 --- /dev/null +++ b/include/cute/util/sycl_vec.hpp @@ -0,0 +1,17 @@ +#pragma once + +//fwd declare OCL function and OCL types +#include //for sycl::vec + +#ifdef __SYCL_DEVICE_ONLY__ +template using vector_t = typename sycl::vec::vector_t; +#else +template using vector_t = sycl::vec; +#endif + +using float8 = vector_t; +using short8 = vector_t; +using ushort8 = vector_t; +using int2_ = vector_t; //conflicts with vector_types +using int8 = vector_t; +using uint8 = vector_t; \ No newline at end of file From 2fd2d841b3812132108964a12e4ef6baa82880e1 Mon Sep 17 00:00:00 2001 From: Roland Schulz Date: Wed, 10 Apr 2024 14:00:39 -0700 Subject: [PATCH 2/3] Update to 3.5 API --- include/cute/atom/mma_traits_xe.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index b70263fe00..423ee48cf9 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -10,10 +10,10 @@ namespace cute template <> struct MMA_Traits { - using ElementDVal = float; - using ElementAVal = sycl::ext::oneapi::bfloat16; - using ElementBVal = sycl::ext::oneapi::bfloat16; - using ElementCVal = float; + using ValTypeD = float; + using ValTypeA = sycl::ext::oneapi::bfloat16; + using ValTypeB = sycl::ext::oneapi::bfloat16; + using ValTypeC = float; using Shape_MNK = Shape<_8,_16,_16>; using ThrID = Layout<_16>; From 4d2b315a655e8a4eb34a37ab4b8cbbce5030f05e Mon Sep 17 00:00:00 2001 From: rolandschulz Date: Thu, 11 Apr 2024 21:03:35 -0700 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Mehdi Goli --- include/cute/arch/mma_xe.hpp | 2 +- include/cute/atom/copy_atom.hpp | 2 ++ include/cute/atom/copy_traits_xe.hpp | 2 +- include/cute/atom/mma_atom.hpp | 2 ++ include/cute/atom/mma_traits_xe.hpp | 2 +- include/cute/util/sycl_vec.hpp | 2 +- 6 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index f523d550f3..e0a9e27a3c 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -53,4 +53,4 @@ struct XE_1x16x16_BF16BF16F32F32_NN d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); } }; -} //namespace cute \ No newline at end of file +} //namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index 825becfbff..5627b722af 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -769,6 +769,8 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and #include #endif +#if defined(CUTLASS_ENABLE_SYCL) #include +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index e0e13c5f45..b4023c0b40 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -60,4 +60,4 @@ namespace cute Traits traits{gtensor}; return Copy_Atom{traits}; } -} \ No newline at end of file +} diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 9421b1505b..ffb6a08b0c 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -946,5 +946,7 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and #include #include #include +#if defined(CUTLASS_ENABLE_SYCL) #include +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index 423ee48cf9..d90389be99 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -21,4 +21,4 @@ struct MMA_Traits using BLayout = Layout, Stride<_16, _1>>; using CLayout = Layout, Stride<_8, _1>>; }; -} \ No newline at end of file +} diff --git a/include/cute/util/sycl_vec.hpp b/include/cute/util/sycl_vec.hpp index 8202cc2a84..7c38d9c83a 100644 --- a/include/cute/util/sycl_vec.hpp +++ b/include/cute/util/sycl_vec.hpp @@ -14,4 +14,4 @@ using short8 = vector_t; using ushort8 = vector_t; using int2_ = vector_t; //conflicts with vector_types using int8 = vector_t; -using uint8 = vector_t; \ No newline at end of file +using uint8 = vector_t;