Skip to content

Commit

Permalink
[SYCL][Matrix] syntax changes as preparation before moving joint matr…
Browse files Browse the repository at this point in the history
…ix from experimental namespace (intel#11215)

As part of the effort to move joint matrix from experimental namespace
to supported. A review of the API is being done as part of
intel#7964. This results in the following
changes in the syntax:
1- Add Td to joint_matrix_mad as Tc can be different from Td on the GPU,
Now, we make D as an input argument to mad.
2-  Change “packed” to ext_intel_packed:
3- Move EWOps (get_wi_data, wi_element, get_coord) to detail namespace)
4- add const to joint_matrix in store and mad
5 - add joint_matrix_copy/assignment function
6- add apply with coordination (change existing tests) 
7- change get_coord vector type from int32_t to size_t
8- delete explicitly both = and copy ctor.
  • Loading branch information
yubingex007-a11y authored Oct 12, 2023
1 parent f605df6 commit 687f579
Show file tree
Hide file tree
Showing 53 changed files with 375 additions and 646 deletions.
67 changes: 49 additions & 18 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
namespace sycl {
inline namespace _V1 {
namespace ext {
namespace intel::experimental::matrix::layout {
constexpr sycl::ext::oneapi::experimental::matrix::layout packed =
static_cast<sycl::ext::oneapi::experimental::matrix::layout>(2);
}
namespace oneapi {
namespace experimental {
namespace matrix {
Expand All @@ -48,8 +44,7 @@ template <layout Layout> struct spv_matrix_layout_traits {

SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor)
SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor)
SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed,
__spv::MatrixLayout::Packed)
SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed)
SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic)

template <use Use> struct spv_matrix_use_traits {
Expand Down Expand Up @@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait<
using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
using storage_element_type = float;
};
} // namespace detail
} // namespace oneapi

namespace intel::experimental::matrix {

using namespace sycl::ext::oneapi::experimental::matrix;
// Begin wi_element definition
Expand All @@ -121,12 +112,12 @@ class wi_element {
std::size_t i)
: M(Mat), idx(i) {}

inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
#if defined(__SYCL_DEVICE_ONLY__)
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
const size_t row = coord[0];
const size_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
Expand Down Expand Up @@ -196,7 +187,7 @@ class wi_element {

#if __SYCL_DEVICE_ONLY__
#define OP(op) \
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
static_cast<storage_element_type>( \
Expand All @@ -211,7 +202,7 @@ class wi_element {
}
#else // __SYCL_DEVICE_ONLY__
#define OP(op) \
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
Expand Down Expand Up @@ -315,7 +306,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

#if __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
__spirv_VectorExtractDynamic< \
Expand All @@ -328,7 +319,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
}
#else // __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
Expand Down Expand Up @@ -479,7 +470,10 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
}

// End wi_data definition
} // namespace detail
} // namespace oneapi

namespace intel::experimental::matrix {
template <
typename Group, typename T, typename Tp,
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
Expand All @@ -490,7 +484,7 @@ template <
bool> = true>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_store(Group,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
const sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, Tp, Use, NumRows, NumCols, Layout> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
#if defined(__SYCL_DEVICE_ONLY__)
Expand Down Expand Up @@ -526,6 +520,43 @@ joint_matrix_store(Group,
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T,
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout,
typename F>
inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(
Group sg,
sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, Rows,
Cols, Layout> &jm,
F &&lambda) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
std::ignore = sg;
for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) {
lambda(jm.cuda_impl.wi_marray[i]);
}
#else // NVPTX
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
for (int i = 0; i < wi_data_c.length(); i++) {
storage_element_type element = wi_data_c[i];
auto [row, col] = wi_data_c[i].get_coord();
lambda(element, row, col);
wi_data_c[i] = element;
}
#endif
#else
std::ignore = sg;
std::ignore = jm;
std::ignore = lambda;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
}

} // namespace intel::experimental::matrix

} // namespace ext
Expand Down
7 changes: 6 additions & 1 deletion sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ namespace matrix {

enum class use { a, b, accumulator };

enum class layout { row_major = 0, col_major = 1, dynamic = 3 };
enum class layout {
row_major = 0,
col_major = 1,
ext_intel_packed = 2,
dynamic = 3
};

namespace precision {
class tf32 {
Expand Down
111 changes: 60 additions & 51 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct joint_matrix {

#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, Layout>
mutable sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols,
Layout>
cuda_impl;
#elif defined(__SPIR__)
__spv::__spirv_JointMatrixINTEL<
Expand All @@ -61,19 +62,8 @@ struct joint_matrix {
}
#ifdef __SYCL_DEVICE_ONLY__
#if defined(__SPIR__)
// Generate a non-trivial assignment operator and copy c'tor that prevents
// memcpy from being generated.
// TODO: to remove, when either IGC can handle alloca JointMatrix or
// combination of InstCombine + SROA + mem2reg can remove it
joint_matrix(const joint_matrix &other) {
spvm = other.spvm;
return *this;
}

joint_matrix &operator=(const joint_matrix &rhs) {
spvm = rhs.spvm;
return *this;
}
joint_matrix(const joint_matrix &other) = delete;
joint_matrix &operator=(const joint_matrix &rhs) = delete;
#endif // defined(__SPIR__)
#endif
};
Expand All @@ -97,10 +87,6 @@ class wi_data {
size_t length() {
#if defined(__NVPTX__)
return jm.cuda_impl.wi_marray.size();
#else
throw runtime_error("get_wi_data is available using: "
"ext::intel::experimental::matrix::get_wi_data.",
PI_ERROR_INVALID_DEVICE);
#endif
};

Expand All @@ -109,9 +95,6 @@ class wi_data {
return (jm.cuda_impl.wi_marray[i]);
#else
std::ignore = i;
throw runtime_error("get_wi_data is available using: "
"ext::intel::experimental::matrix::get_wi_data.",
PI_ERROR_INVALID_DEVICE);
#endif
};
};
Expand Down Expand Up @@ -139,9 +122,8 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please "
"use joint_matrix_apply() instead.")
#else
__attribute__((unavailable(
"get_wi_data can't be used on intel device, please use "
"sycl::ext::intel::experimental::matrix::get_wi_data instead!")))
__attribute__((unavailable("get_wi_data() has been removed from the API and "
"replaced with joint_matrix_apply!")))
#endif
#endif
inline __SYCL_ALWAYS_INLINE decltype(auto)
Expand Down Expand Up @@ -177,7 +159,7 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm);
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
for (int i = 0; i < wi_data_c.length(); i++) {
storage_element_type element = wi_data_c[i];
lambda(element);
Expand Down Expand Up @@ -260,7 +242,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case sycl::ext::intel::experimental::matrix::layout::packed:
case layout::ext_intel_packed:
res.spvm = __spirv_JointMatrixLoadINTEL<
DecorT, S, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
Expand Down Expand Up @@ -322,8 +304,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Group,
joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
sycl::ext::oneapi::experimental::matrix::layout Layout) {
#if defined(__SYCL_DEVICE_ONLY__)
Expand Down Expand Up @@ -355,7 +338,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case sycl::ext::intel::experimental::matrix::layout::packed:
case layout::ext_intel_packed:
__spirv_JointMatrixStoreINTEL<
DecorT, T, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
Expand All @@ -375,51 +358,77 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M,
std::size_t K, std::size_t N, layout LayoutA, layout LayoutB>
inline __SYCL_ALWAYS_INLINE
joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
joint_matrix_mad(
Group, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&C) {
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
std::size_t M, std::size_t K, std::size_t N, layout LayoutA,
layout LayoutB>
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
Group,
joint_matrix<Group, Td, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
const joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&C) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
if constexpr (std::is_same<Ta, Tb>::value) {
joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
D;
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA,
LayoutB>(
D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl);
return D;
} else {
assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad "
"requires that joint_matrix data types Ta and Tb match");
}
#else
joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> res;
if constexpr (std::is_same<Ta, uint16_t>::value &&
std::is_same<Tb, uint16_t>::value &&
std::is_same<Tc, float>::value)
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
else
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
return res;
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
#endif // defined(__NVPTX__)
#else
std::ignore = A;
std::ignore = B;
std::ignore = C;
std::ignore = D;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
use Use1, use Use2, layout Layout1, layout Layout2>
void joint_matrix_copy(
Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src,
joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dst) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
std::ignore = sg;
for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) {
dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i];
}
#else
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T2>::storage_element_type;
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
for (int i = 0; i < wi_data_c.length(); i++) {
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
}
#endif // defined(__NVPTX__)
#else
std::ignore = sg;
std::ignore = dst;
std::ignore = src;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
Expand Down
26 changes: 0 additions & 26 deletions sycl/test-e2e/Matrix/XMX8/element_wise_irreg_sum_rows.cpp

This file was deleted.

Loading

0 comments on commit 687f579

Please sign in to comment.