Skip to content

Commit

Permalink
Merge pull request #23 from rolandschulz/add-xe-atoms
Browse files Browse the repository at this point in the history
Add XE MMA/copy atom
  • Loading branch information
rolandschulz authored Apr 17, 2024
2 parents ae6989a + 4d2b315 commit ddbba2f
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 0 deletions.
45 changes: 45 additions & 0 deletions include/cute/arch/copy_xe.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include <cute/config.hpp>
#include <cute/arch/copy.hpp>
#include <cute/util/sycl_vec.hpp>

#ifdef __SYCL_DEVICE_ONLY__
#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x
#else
#define SYCL_DEVICE_BUILTIN(x) \
inline x { assert(false); }
#endif

SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord, uint8 data));
SYCL_DEVICE_BUILTIN(ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord));
SYCL_DEVICE_BUILTIN(uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord));
#undef SYCL_DEVICE_BUILTIN

struct XE_2D_LOAD //m8k16
{
template<class T>
CUTE_HOST_DEVICE static void copy(const void* baseoffset, int width, int height, int pitch, int2_ coord, T* dst)
{
if constexpr(sizeof(T)==sizeof(ushort)) {
*(ushort8*)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord);
} else if constexpr(sizeof(T)==sizeof(uint)) {
*(uint8*)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord);
} else {
static_assert(false);
}
}
};

struct XE_2D_SAVE //m8k16
{
template<class T>
CUTE_HOST_DEVICE static void copy(void* baseoffset, int width, int height, int pitch, int2_ coord, const T* src)
{
if constexpr(sizeof(T)==sizeof(uint)) {
__builtin_IB_subgroup_block_write_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord, *(uint8*)src);
} else {
static_assert(false);
}
}
};
56 changes: 56 additions & 0 deletions include/cute/arch/mma_xe.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

#pragma once

#include <cute/config.hpp>
#include <cute/arch/mma.hpp>
#include <cute/util/sycl_vec.hpp>

#ifdef __SYCL_DEVICE_ONLY__
#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x
#else
#define SYCL_DEVICE_OCL(x) inline x { assert(false); }
#endif

SYCL_DEVICE_OCL(float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc));
SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc));
#undef SYCL_DEVICE_OCL

namespace cute {
//MxNxK_A,B,C,D
//# of vector component of a x subgroup-size x function name
//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc);
//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8
struct XE_8x16x16_BF16BF16F32F32_NN
{
using DRegisters = float8[1];
using ARegisters = short8[1];
using BRegisters = int8[1];
using CRegisters = float8[1];

CUTE_HOST_DEVICE static void
fma(float8 & d,
short8 const& a,
int8 const& b,
float8 const& c)
{
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
}
};
//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc)
struct XE_1x16x16_BF16BF16F32F32_NN
{
using DRegisters = float[1];
using ARegisters = short[1];
using BRegisters = int8[1];
using CRegisters = float[1];

CUTE_HOST_DEVICE static void
fma(float & d,
short const& a,
int8 const& b,
float const& c)
{
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c);
}
};
} //namespace cute
4 changes: 4 additions & 0 deletions include/cute/atom/copy_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,4 +769,8 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
#include <cute/atom/copy_traits_sm90_tma.hpp>
#endif

#if defined(CUTLASS_ENABLE_SYCL)
#include <cute/atom/copy_traits_xe.hpp>
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
63 changes: 63 additions & 0 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#pragma once

#include <cute/atom/copy_traits.hpp>
#include <cute/atom/copy_atom.hpp>

#include <cute/arch/copy_xe.hpp>

namespace cute
{
template <class GTensor>
struct Copy_Traits<XE_2D_LOAD, GTensor>
{
// using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails
using ThrID = Layout<_1>;
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>; // hacky: does vec of 8
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1, NumBits>>; // TODO: is _1 correct?
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1, NumBits>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;

GTensor tensor;

template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const &traits,
Tensor<ViewEngine<ArithmeticTupleIterator<TS>>, SLayout> const &src,
Tensor<TD, DLayout> &dst)
{
static_assert(is_rmem<TD>::value);
int H = size<0>(traits.tensor);
// int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type);
int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); //TODO: inconsistent to give the size in elements but use vector for copy
auto [y, x] = src.data().coord_;
XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data());
}

template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const &traits,
Tensor<TS, SLayout> const &src,
Tensor<ViewEngine<ArithmeticTupleIterator<TD>>, DLayout> &dst)
{
static_assert(is_rmem<TS>::value);
int H = size<0>(traits.tensor);
int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type);
auto [y, x] = dst.data().coord_;
XE_2D_SAVE::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*src.data());
}
};

template <class GEngine, class GLayout>
auto make_xe_2d_copy(Tensor<GEngine, GLayout> gtensor)
{
using GTensor = Tensor<GEngine, GLayout>;
using Traits = Copy_Traits<XE_2D_LOAD, GTensor>;
Traits traits{gtensor};
return Copy_Atom<Traits, typename GEngine::value_type>{traits};
}
}
3 changes: 3 additions & 0 deletions include/cute/atom/mma_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,4 +946,7 @@ print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and
#include <cute/atom/mma_traits_sm80.hpp>
#include <cute/atom/mma_traits_sm90.hpp>
#include <cute/atom/mma_traits_sm90_gmma.hpp>
#if defined(CUTLASS_ENABLE_SYCL)
#include <cute/atom/mma_traits_xe.hpp>
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
24 changes: 24 additions & 0 deletions include/cute/atom/mma_traits_xe.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include <cute/arch/mma_xe.hpp>
#include <cute/atom/mma_traits.hpp>

#include <cute/layout.hpp>

namespace cute
{
template <>
struct MMA_Traits<XE_8x16x16_BF16BF16F32F32_NN>
{
using ValTypeD = float;
using ValTypeA = sycl::ext::oneapi::bfloat16;
using ValTypeB = sycl::ext::oneapi::bfloat16;
using ValTypeC = float;

using Shape_MNK = Shape<_8,_16,_16>;
using ThrID = Layout<_16>;
using ALayout = Layout<Shape<_8, _16>, Stride<_8, _1>>; // (T16,V8) -> (m,n)
using BLayout = Layout<Shape<_16, _16>, Stride<_16, _1>>;
using CLayout = Layout<Shape<_8, _16>, Stride<_8, _1>>;
};
}
17 changes: 17 additions & 0 deletions include/cute/util/sycl_vec.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

//fwd declare OCL function and OCL types
#include <sycl.hpp> //for sycl::vec

#ifdef __SYCL_DEVICE_ONLY__
template<class T, int N> using vector_t = typename sycl::vec<T,N>::vector_t;
#else
template<class T, int N> using vector_t = sycl::vec<T,N>;
#endif

using float8 = vector_t<float, 8>;
using short8 = vector_t<short, 8>;
using ushort8 = vector_t<ushort, 8>;
using int2_ = vector_t<int, 2>; //conflicts with vector_types
using int8 = vector_t<int, 8>;
using uint8 = vector_t<uint, 8>;

0 comments on commit ddbba2f

Please sign in to comment.