Skip to content

sycl: Fix and disable more configurations of mul_mat #15151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2705,9 +2705,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
" : converting src1 to fp16");

// iterate tensor dims and find the slowest moving dim and stride
int64_t last_dim=0;
int64_t last_str=0;
int64_t largest_str=0;
int last_dim=0;
int last_str=0;
size_t largest_str=0;
for(int i = 0; i< 4; i++){
// last stride is always the largest
if(src1->nb[i] == largest_str){
Expand Down Expand Up @@ -2783,7 +2783,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
const sycl::half *src1, float *dst,
int64_t a0, int64_t a1, int64_t batcha,
int64_t b0, int64_t b1, int64_t batchb,
int64_t /*b0*/, int64_t b1, int64_t batchb,
int64_t sa0, int64_t sa1, int64_t sa2,
int64_t sb0, int64_t sb1, int64_t sb2,
int64_t sd2) {
Expand Down Expand Up @@ -2832,14 +2832,26 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
}
};

bool cont_batches_a = nb02 * ne02 == nb03;
bool cont_batches_b = nb12 * ne12 == nb13;
if (cont_batches_a && cont_batches_b) {
const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
if (cont_batches_dim2_a && cont_batches_dim2_b) {
// A batch is considered contiguous if the dimension 2 is not strided
int64_t batches0 = ne02 * ne03;
int64_t batches1 = ne12 * ne13;
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
str_b2, nb2 / sizeof(float));
} else if (cont_batches_dim3_a && cont_batches_dim3_b) {
// This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
int64_t batches0 = ne02 * ne03;
int64_t batches1 = ne12 * ne13;
int64_t str_a3 = nb03 / type_size_src0;
int64_t str_b3 = nb13 / type_size_src1;
launch_gemm_for_batches(src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
str_b3, nb2 / sizeof(float));
} else {
for (int64_t b_a = 0; b_a < ne03; b_a++) {
const sycl::half *src0_f16_shifted
Expand Down Expand Up @@ -4215,6 +4227,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
// FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
// TODO: The configuration below needs more work to be supported with oneDNN
if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1) {
return false;
}
return true;
}
case GGML_OP_OUT_PROD:
Expand Down
Loading