Skip to content

Commit

Permalink
[SYCL][Matrix] Add support for missing matrix combinations for half a…
Browse files Browse the repository at this point in the history
…nd bfloat16 types (intel#15540)

Spec added in intel#15547
The new combinations are now added as comments. We will uncomment these
once IGC support becomes available
  • Loading branch information
dkhaldi authored Oct 15, 2024
1 parent ca5cc18 commit 851a90a
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 16 deletions.
82 changes: 82 additions & 0 deletions sycl/source/detail/device_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,14 +850,96 @@ struct get_device_info_impl<
matrix_type::sint32, matrix_type::sint32},
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp32},
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp32},
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp16},
{8, 0, 0, 0, 16, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp16},
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp16},
{0, 0, 0, 16, 16, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp16},
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp32},
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp16},
{0, 0, 0, 1, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp16},
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp32},
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::bf16},
{0, 0, 0, 32, 64, 16, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp16},
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp32},
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp16},
{0, 0, 0, 1, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp16},
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp32},
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp32, matrix_type::fp16},
{0, 0, 0, 32, 64, 32, matrix_type::fp16, matrix_type::fp16,
matrix_type::fp16, matrix_type::fp16},
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::bf16},
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::bf16},
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::fp32},
{8, 0, 0, 0, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::fp32},
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::bf16},
{0, 0, 0, 16, 16, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::bf16},
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::fp32},
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::bf16},
{0, 0, 0, 1, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::bf16},
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::fp32},
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::bf16},
{0, 0, 0, 32, 64, 16, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::bf16},
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::fp32},
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::bf16},
{0, 0, 0, 1, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::bf16},
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::fp32},
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::fp32},
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::fp32, matrix_type::bf16},
{0, 0, 0, 32, 64, 32, matrix_type::bf16, matrix_type::bf16,
matrix_type::bf16, matrix_type::bf16},
{8, 0, 0, 0, 16, 8, matrix_type::tf32, matrix_type::tf32,
matrix_type::fp32, matrix_type::fp32},
};
Expand Down
8 changes: 5 additions & 3 deletions sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,17 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
if constexpr (std::is_same_v<Ta, bfloat16> &&
std::is_same_v<Tc, float>)
acc += make_fp32(va[i]) * make_fp32(vb[i]);
else if constexpr (std::is_same_v<Ta, sycl::half> &&
std::is_same_v<Tc, float>)
acc += (float)va[i] * (float)vb[i];
else if constexpr (std::is_same_v<Ta, float> &&
std::is_same_v<Tc, float> ||
std::is_integral_v<Ta> && std::is_integral_v<Tc> ||
(std::is_same_v<Ta, bfloat16> ||
std::is_same_v<Ta, sycl::half>) ||
(std::is_same_v<Ta, double> &&
std::is_same_v<Tc, double>))
acc += va[i] * vb[i];
else if constexpr (std::is_same_v<Ta, sycl::half> &&
std::is_same_v<Tc, float>)
acc += (float)va[i] * (float)vb[i];
else
assert(false && "Unsupported type in matrix_multiply_ref.");
}
Expand Down
28 changes: 17 additions & 11 deletions sycl/test-e2e/Matrix/joint_matrix_bfloat16_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
big_matrix<T2, K / 2, N * 2> &B) {
size_t NDRangeM = M / TM;
size_t NDRangeN = N / TN;
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, K));
buffer<bfloat16, 2> bufB(B.get_data(), range<2>(K, N));
buffer<float, 2> bufC((float *)C.get_data(), range<2>(M, N));
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
buffer<T1, 2> bufC((T1 *)C.get_data(), range<2>(M, N));

queue q;
size_t sg_size = get_sg_size<imatrix<T1, TM, TN, TK>>(q);
q.submit([&](handler &cgh) {
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
accessor accA{bufA, cgh};
accessor accB{bufB, cgh};
accessor accC{bufC, cgh};

cgh.parallel_for<imatrix<T1, TM, TN, TK>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
Expand All @@ -41,13 +41,11 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
sub_a;
joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
// For B, we assume B has been already VNNIed.
joint_matrix<sub_group, bfloat16, use::b, TK, TN,
layout::ext_intel_packed>
joint_matrix<sub_group, T2, use::b, TK, TN, layout::ext_intel_packed>
sub_b;
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;

joint_matrix_load(
sg, sub_c,
Expand Down Expand Up @@ -122,13 +120,21 @@ int main() {

if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc
test<bfloat16, float, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
// test<bfloat16, bfloat16, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();

// This combination is not currently supported for sub group size = 32 in
// IGC
#if (!defined(SG_SZ) || SG_SZ != 32)
test<bfloat16, float, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
// test<bfloat16, bfloat16, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
test<bfloat16, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
// test<bfloat16, bfloat16, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
// test<bfloat16, bfloat16, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
// test<bfloat16, float, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
// test<bfloat16, bfloat16, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
// test<bfloat16, float, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
// test<bfloat16, bfloat16, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
#endif
break;
}
Expand Down
20 changes: 18 additions & 2 deletions sycl/test-e2e/Matrix/joint_matrix_half_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ void matrix_multiply(big_matrix<TResult, M, N> &C, big_matrix<T, M, K> &A,
buffer<TResult, 2> bufC(C.get_data(), range<2>(M, N));

queue q;
size_t sg_size = get_sg_size<mult<T, TM, TN, TK>>(q);
size_t sg_size = get_sg_size<mult<TResult, TM, TN, TK>>(q);
q.submit([&](handler &cgh) {
accessor accA{bufA, cgh};
accessor accB{bufB, cgh};
accessor accC{bufC, cgh};

cgh.parallel_for<mult<T, TM, TN, TK>>(
cgh.parallel_for<mult<TResult, TM, TN, TK>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
Expand Down Expand Up @@ -122,6 +122,22 @@ int main() {

if (combinations[i].nsize == 16) { // architecture::intel_gpu_pvc
test<float, half, 2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();
// test<half, half, 2, /*TM*/ 8, /*TN*/ 16, /*TK*/ 16>();

// This combination is not currently supported for sub group size = 32 in
// IGC
#if (!defined(SG_SZ) || SG_SZ != 32)
// test<float, half, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
// test<half, half, /*TM*/ 16, /*TN*/ 16, /*TK*/ 16>();
// test<float, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
// test<half, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 16>();
// test<float, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
// test<half, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 16>();
// test<float, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
// test<half, half, /*TM*/ 1, /*TN*/ 64, /*TK*/ 32>();
// test<float, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
// test<half, half, /*TM*/ 32, /*TN*/ 64, /*TK*/ 32>();
#endif
break;
}

Expand Down

0 comments on commit 851a90a

Please sign in to comment.