forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #23 from rolandschulz/add-xe-atoms
Add XE MMA/copy atom
- Loading branch information
Showing
7 changed files
with
212 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#pragma once | ||
|
||
#include <cute/config.hpp> | ||
#include <cute/arch/copy.hpp> | ||
#include <cute/util/sycl_vec.hpp> | ||
|
||
#ifdef __SYCL_DEVICE_ONLY__ | ||
#define SYCL_DEVICE_BUILTIN(x) SYCL_EXTERNAL extern "C" x | ||
#else | ||
#define SYCL_DEVICE_BUILTIN(x) \ | ||
inline x { assert(false); } | ||
#endif | ||
|
||
SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord, uint8 data)); | ||
SYCL_DEVICE_BUILTIN(ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord)); | ||
SYCL_DEVICE_BUILTIN(uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2_ coord)); | ||
#undef SYCL_DEVICE_BUILTIN | ||
|
||
struct XE_2D_LOAD //m8k16 | ||
{ | ||
template<class T> | ||
CUTE_HOST_DEVICE static void copy(const void* baseoffset, int width, int height, int pitch, int2_ coord, T* dst) | ||
{ | ||
if constexpr(sizeof(T)==sizeof(ushort)) { | ||
*(ushort8*)dst = __builtin_IB_subgroup_block_read_flat_u16_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord); | ||
} else if constexpr(sizeof(T)==sizeof(uint)) { | ||
*(uint8*)dst = __builtin_IB_subgroup_block_read_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord); | ||
} else { | ||
static_assert(false); | ||
} | ||
} | ||
}; | ||
|
||
struct XE_2D_SAVE //m8k16 | ||
{ | ||
template<class T> | ||
CUTE_HOST_DEVICE static void copy(void* baseoffset, int width, int height, int pitch, int2_ coord, const T* src) | ||
{ | ||
if constexpr(sizeof(T)==sizeof(uint)) { | ||
__builtin_IB_subgroup_block_write_flat_u32_m8k16v1((long)baseoffset, width - 1, height - 1, pitch - 1, coord, *(uint8*)src); | ||
} else { | ||
static_assert(false); | ||
} | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
|
||
#pragma once | ||
|
||
#include <cute/config.hpp> | ||
#include <cute/arch/mma.hpp> | ||
#include <cute/util/sycl_vec.hpp> | ||
|
||
#ifdef __SYCL_DEVICE_ONLY__ | ||
#define SYCL_DEVICE_OCL(x) SYCL_EXTERNAL x | ||
#else | ||
#define SYCL_DEVICE_OCL(x) inline x { assert(false); } | ||
#endif | ||
|
||
SYCL_DEVICE_OCL(float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc)); | ||
SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc)); | ||
#undef SYCL_DEVICE_OCL | ||
|
||
namespace cute { | ||
//MxNxK_A,B,C,D | ||
//# of vector component of a x subgroup-size x function name | ||
//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); | ||
//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 | ||
struct XE_8x16x16_BF16BF16F32F32_NN | ||
{ | ||
using DRegisters = float8[1]; | ||
using ARegisters = short8[1]; | ||
using BRegisters = int8[1]; | ||
using CRegisters = float8[1]; | ||
|
||
CUTE_HOST_DEVICE static void | ||
fma(float8 & d, | ||
short8 const& a, | ||
int8 const& b, | ||
float8 const& c) | ||
{ | ||
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); | ||
} | ||
}; | ||
//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) | ||
struct XE_1x16x16_BF16BF16F32F32_NN | ||
{ | ||
using DRegisters = float[1]; | ||
using ARegisters = short[1]; | ||
using BRegisters = int8[1]; | ||
using CRegisters = float[1]; | ||
|
||
CUTE_HOST_DEVICE static void | ||
fma(float & d, | ||
short const& a, | ||
int8 const& b, | ||
float const& c) | ||
{ | ||
d = intel_sub_group_bf16_bf16_matrix_mad_k16(a, b, c); | ||
} | ||
}; | ||
} //namespace cute |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#pragma once | ||
|
||
#include <cute/atom/copy_traits.hpp> | ||
#include <cute/atom/copy_atom.hpp> | ||
|
||
#include <cute/arch/copy_xe.hpp> | ||
|
||
namespace cute | ||
{ | ||
template <class GTensor> | ||
struct Copy_Traits<XE_2D_LOAD, GTensor> | ||
{ | ||
// using ThrID = Layout<_16>; //TODO: I think it should be 16 (copy is per subgroup) - but static_assert fails | ||
using ThrID = Layout<_1>; | ||
using NumBits = Int<sizeof(typename GTensor::engine_type::value_type) * 8>; // hacky: does vec of 8 | ||
// Map from (src-thr,src-val) to bit | ||
using SrcLayout = Layout<Shape<_1, NumBits>>; // TODO: is _1 correct? | ||
// Map from (dst-thr,dst-val) to bit | ||
using DstLayout = Layout<Shape<_1, NumBits>>; | ||
// Reference map from (thr,val) to bit | ||
using RefLayout = SrcLayout; | ||
|
||
GTensor tensor; | ||
|
||
template <class TS, class SLayout, | ||
class TD, class DLayout> | ||
CUTE_HOST_DEVICE friend constexpr void | ||
copy_unpack(Copy_Traits const &traits, | ||
Tensor<ViewEngine<ArithmeticTupleIterator<TS>>, SLayout> const &src, | ||
Tensor<TD, DLayout> &dst) | ||
{ | ||
static_assert(is_rmem<TD>::value); | ||
int H = size<0>(traits.tensor); | ||
// int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); | ||
int W = size<1>(traits.tensor) * sizeof(typename TD::value_type); //TODO: inconsistent to give the size in elements but use vector for copy | ||
auto [y, x] = src.data().coord_; | ||
XE_2D_LOAD::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*dst.data()); | ||
} | ||
|
||
template <class TS, class SLayout, | ||
class TD, class DLayout> | ||
CUTE_HOST_DEVICE friend constexpr void | ||
copy_unpack(Copy_Traits const &traits, | ||
Tensor<TS, SLayout> const &src, | ||
Tensor<ViewEngine<ArithmeticTupleIterator<TD>>, DLayout> &dst) | ||
{ | ||
static_assert(is_rmem<TS>::value); | ||
int H = size<0>(traits.tensor); | ||
int W = size<1>(traits.tensor) * sizeof(typename decltype(traits.tensor)::engine_type::value_type); | ||
auto [y, x] = dst.data().coord_; | ||
XE_2D_SAVE::copy(traits.tensor.data().get(), W, H, W, int2_{x, y}, &*src.data()); | ||
} | ||
}; | ||
|
||
template <class GEngine, class GLayout> | ||
auto make_xe_2d_copy(Tensor<GEngine, GLayout> gtensor) | ||
{ | ||
using GTensor = Tensor<GEngine, GLayout>; | ||
using Traits = Copy_Traits<XE_2D_LOAD, GTensor>; | ||
Traits traits{gtensor}; | ||
return Copy_Atom<Traits, typename GEngine::value_type>{traits}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
|
||
#include <cute/arch/mma_xe.hpp> | ||
#include <cute/atom/mma_traits.hpp> | ||
|
||
#include <cute/layout.hpp> | ||
|
||
namespace cute | ||
{ | ||
template <> | ||
struct MMA_Traits<XE_8x16x16_BF16BF16F32F32_NN> | ||
{ | ||
using ValTypeD = float; | ||
using ValTypeA = sycl::ext::oneapi::bfloat16; | ||
using ValTypeB = sycl::ext::oneapi::bfloat16; | ||
using ValTypeC = float; | ||
|
||
using Shape_MNK = Shape<_8,_16,_16>; | ||
using ThrID = Layout<_16>; | ||
using ALayout = Layout<Shape<_8, _16>, Stride<_8, _1>>; // (T16,V8) -> (m,n) | ||
using BLayout = Layout<Shape<_16, _16>, Stride<_16, _1>>; | ||
using CLayout = Layout<Shape<_8, _16>, Stride<_8, _1>>; | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#pragma once | ||
|
||
//fwd declare OCL function and OCL types | ||
#include <sycl.hpp> //for sycl::vec | ||
|
||
#ifdef __SYCL_DEVICE_ONLY__ | ||
template<class T, int N> using vector_t = typename sycl::vec<T,N>::vector_t; | ||
#else | ||
template<class T, int N> using vector_t = sycl::vec<T,N>; | ||
#endif | ||
|
||
using float8 = vector_t<float, 8>; | ||
using short8 = vector_t<short, 8>; | ||
using ushort8 = vector_t<ushort, 8>; | ||
using int2_ = vector_t<int, 2>; //conflicts with vector_types | ||
using int8 = vector_t<int, 8>; | ||
using uint8 = vector_t<uint, 8>; |