From a158e1b54786d06a1b73c963aca55f2e1d1aaaae Mon Sep 17 00:00:00 2001 From: OuadiElfarouki Date: Fri, 28 Jun 2024 07:50:19 +0100 Subject: [PATCH] Enabled more types for oneMKL's gemm_batch API --- ggml-sycl.cpp | 29 ++--------------------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 89a46ebe9a7bd..02cf7e79cb702 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -2490,6 +2490,7 @@ namespace dpct b, ldb, beta, c, ldc, batch_size); break; } +#endif case detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): @@ -2522,7 +2523,6 @@ namespace dpct batch_size); break; } -#endif case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): @@ -2659,6 +2659,7 @@ namespace dpct stride_c, batch_size); break; } +#endif case detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): @@ -2687,7 +2688,6 @@ namespace dpct beta, c, ldc, stride_c, batch_size); break; } -#endif case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): @@ -15026,9 +15026,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, SYCL_CHECK(ggml_sycl_set_device(g_main_device)); dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda || - main_stream->get_backend() == sycl::backend::ext_oneapi_hip; - SYCL_CHECK( CHECK_TRY_ERROR(g_sycl_handles[g_main_device] = main_stream)); @@ -15059,10 +15056,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; - if (no_mixed_dtypes) { - cu_compute_type = dpct::library_data_t::real_half; - cu_data_type = dpct::library_data_t::real_half; - } // dst strides size_t nbd2 = dst->nb[2]; @@ -15076,21 +15069,8 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const void * alpha = &alpha_f32; const void * beta = &beta_f32; - if (no_mixed_dtypes) { - alpha = &alpha_f16; - beta = &beta_f16; - } - - // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway - // when oneMKL open source supports half, half, float, float: datatypes dst_t = (char *) dst_ddf; - if (no_mixed_dtypes) { - dst_t = (char *) dst_f16.alloc(ne_dst); - - nbd2 /= sizeof(float) / sizeof(sycl::half); - nbd3 /= sizeof(float) / sizeof(sycl::half); - } GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -15152,11 +15132,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type))); } - - if (no_mixed_dtypes) { - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16); - to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream); - } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__