Skip to content

Commit

Permalink
[SYCL][Joint Matrix] Update apply to make both matrices read/write (#…
Browse files Browse the repository at this point in the history
…16155)

Spec change was added in #13153
It states that the overload of joint_matrix_apply that takes two
matrices can modify both matrices.
I also updated the test to reflect the change.
  • Loading branch information
dkhaldi authored and KornevNikita committed Feb 25, 2025
1 parent 3d6b04b commit f324c55
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 45 deletions.
36 changes: 20 additions & 16 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,35 +112,39 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
return;
}

template <typename Group, typename T, use Use, size_t M, size_t N,
template <typename Group, typename T0, typename T1, use Use, size_t M, size_t N,
layout Layout, typename F>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jmsrc,
joint_matrix<Group, T, Use, M, N, Layout> &jmdest,
joint_matrix_apply(Group sg, joint_matrix<Group, T0, Use, M, N, Layout> &jm0,
joint_matrix<Group, T1, Use, M, N, Layout> &jm1,
F &&lambda) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__) || defined(__HIP_PLATFORM_AMD_MFMA__)
std::ignore = sg;
for (int i = 0; i < jmsrc.matrix_impl.wi_marray.size(); i++) {
lambda(jmsrc.matrix_impl.wi_marray[i], jmdest.matrix_impl.wi_marray[i]);
for (int i = 0; i < jm0.matrix_impl.wi_marray.size(); i++) {
lambda(jm0.matrix_impl.wi_marray[i], jm1.matrix_impl.wi_marray[i]);
}
#else // NVPTX
using storage_element_type =
using storage_element_type0 =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jmsrc);
auto wi_data_d = sycl::ext::oneapi::detail::get_wi_data(sg, jmdest);
for (int i = 0; i < wi_data_c.length(); i++) {
storage_element_type elementsrc = wi_data_c[i];
storage_element_type elementdest = wi_data_d[i];
lambda(elementsrc, elementdest);
wi_data_d[i] = elementdest;
T0>::storage_element_type;
using storage_element_type1 =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T1>::storage_element_type;
auto wi_data_0 = sycl::ext::oneapi::detail::get_wi_data(sg, jm0);
auto wi_data_1 = sycl::ext::oneapi::detail::get_wi_data(sg, jm1);
for (int i = 0; i < wi_data_0.length(); i++) {
storage_element_type0 element0 = wi_data_0[i];
storage_element_type1 element1 = wi_data_1[i];
lambda(element0, element1);
wi_data_0[i] = element0;
wi_data_1[i] = element1;
}
#endif
#else
std::ignore = sg;
std::ignore = jmsrc;
std::ignore = jmdest;
std::ignore = jm0;
std::ignore = jm1;
std::ignore = lambda;
throw exception(make_error_code(errc::runtime),
"joint matrix is not supported on host.");
Expand Down
9 changes: 8 additions & 1 deletion sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ void matrix_copy(unsigned int rows, unsigned int cols, T *src, T *dst) {
}
}

template <typename F, typename T>
void matrix_apply(unsigned int rows, unsigned int cols, T *mat, F op) {
for (unsigned int i = 0; i < rows; i++)
for (unsigned int j = 0; j < cols; j++)
mat[i * cols + j] = op(mat[i * cols + j]);
}

template <typename T1, typename T2, bool exact = false>
bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
for (int i = 0; i < rows; i++) {
Expand All @@ -173,7 +180,7 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
<< ", Epsilon: " << FLOAT_EPSILON << "\n";
return false;
}
} else if constexpr (exact || std::is_same_v<T1, int32_t>) {
} else if constexpr (exact || std::is_integral_v<T1>) {
if (src[i * cols + j] != ref[i * cols + j]) {
std::cout << "Incorrect result in matrix."
<< "i: " << i << ", j: " << j
Expand Down
70 changes: 42 additions & 28 deletions sycl/test-e2e/Matrix/joint_matrix_apply_two_matrices_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,26 @@
//===----------------------------------------------------------------------===//
#include <sycl/usm.hpp>

template <typename Tc, typename Ta, size_t M, size_t N>
bool apply_verify(Tc *C, Tc *D, Ta *A, Ta *Ar) {
for (size_t i = 0; i < M; i++)
for (size_t j = 0; j < N; j++) {
Tc diffc = D[i * N + j] - C[i * N + j] * 2;
Ta diffa = Ar[i * N + j] - (A[i * N + j] + 42);
if constexpr (std::is_same_v<Ta, bfloat16>) {
if (std::fabs(diffc) > FLOAT_EPSILON ||
std::fabs(diffa) > FLOAT_EPSILON || std::isnan(C[i * N + j]) ||
std::isnan(A[i * N + j])) {
return false;
}
} else {
if (std::abs(diffc) > 0 || std::abs(diffa) > 0) {
return false;
}
}
}
return true;
template <typename T> T mul2(T x) { return x * 2; }

template <typename T> T add5(T x) { return x + 5; }

template <typename Tc, size_t M, size_t N>
bool apply_verify(Tc *C, Tc *D, Tc *ref) {
Tc *refcopy = (Tc *)std::malloc(M * N * sizeof(Tc));
memcpy(refcopy, ref, M * N * sizeof(Tc));
matrix_apply(M, N, ref, mul2<Tc>);
bool res = matrix_compare(M, N, D, ref);

matrix_apply(M, N, refcopy, add5<Tc>);
res &= matrix_compare(M, N, C, refcopy);
return res;
}

template <typename Tc, typename Ta, size_t TM, size_t TN, size_t TK, size_t M,
size_t N, size_t K, class kernel_name>
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, Tc *Cref, Ta *Aref,
queue q) {
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;

Expand Down Expand Up @@ -70,22 +67,33 @@ bool apply_two_matrices(Tc *C, Tc *D, Ta *A, Ta *Ar, queue q) {
joint_matrix_load(
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
joint_matrix_apply(sg, sub_c, sub_d,
[](const Tc &x, Tc &y) { y = x * 2; });
joint_matrix_apply(sg, sub_c, sub_d, [](Tc &x, Tc &y) {
y = mul2(x);
x = add5(x);
});
joint_matrix_store(
sg, sub_d, pD + (sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
joint_matrix_store(
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / sg_size * TN,
N, layout::row_major);
joint_matrix_load(
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
K);
joint_matrix_apply(sg, sub_a, sub_ar,
[](const Ta &x, Ta &y) { y = x + 42; });
joint_matrix_apply(sg, sub_a, sub_ar, [](Ta &x, Ta &y) {
y = mul2(x);
x = add5(x);
});
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_ar,
pAr + (sg_startx * TM) * K + sg_starty / sg_size * TK, K);
ext::intel::experimental::matrix::joint_matrix_store(
sg, sub_a, pA + (sg_startx * TM) * K + sg_starty / sg_size * TK,
K);
}); // parallel for
}).wait();
return apply_verify<Tc, Ta, M, N>(C, D, A, Ar);
return apply_verify<Tc, M, N>(C, D, Cref) &&
apply_verify<Ta, M, N>(A, Ar, Aref);
}

template <typename Ta, typename Tc, size_t TM, size_t TN, size_t TK,
Expand All @@ -96,16 +104,20 @@ bool test() {
static constexpr size_t K = TK * 2;
queue q;

Tc *Cref = malloc_shared<Tc>(M * N, q);
Ta *Aref = malloc_shared<Ta>(M * K, q);
Tc *C = malloc_shared<Tc>(M * N, q);
Tc *D = malloc_shared<Tc>(M * N, q);
Ta *A = malloc_shared<Ta>(M * K, q);
Ta *Ar = malloc_shared<Ta>(M * K, q);

matrix_rand(M, N, (Tc *)C, (Tc)100);
matrix_rand(M, K, (Ta *)A, (Ta)100);
matrix_rand(M, N, (Tc *)Cref, (Tc)100);
matrix_rand(M, K, (Ta *)Aref, (Ta)100);
matrix_copy(M, N, Cref, C);
matrix_copy(M, K, Aref, A);

bool res = apply_two_matrices<Tc, Ta, TM, TN, TK, M, N, K, kernel_name>(
C, D, A, Ar, q);
C, D, A, Ar, Cref, Aref, q);

if constexpr (std::is_same_v<Ta, bfloat16>)
std::cout << "bfloat16 " << TM << "x" << TN << "x" << TK << ": "
Expand All @@ -117,6 +129,8 @@ bool test() {
free(D, q);
free(A, q);
free(Ar, q);
free(Cref, q);
free(Aref, q);

return res;
}
Expand Down

0 comments on commit f324c55

Please sign in to comment.