From a2e77e60d6d1e208096aae27e24a23ff9821c58b Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Sun, 31 Mar 2024 11:00:52 +0000 Subject: [PATCH] clang-format --- ggml-sycl/backend.hpp | 4 +- ggml-sycl/common.cpp | 300 +- ggml-sycl/common.hpp | 458 +-- ggml-sycl/dpct/helper.hpp | 6129 ++++++++++++++++++++----------------- ggml-sycl/mmq.cpp | 5596 +++++++++++++++++++-------------- ggml-sycl/mmq.hpp | 17 +- ggml-sycl/mmvq.cpp | 1678 +++++----- ggml-sycl/mmvq.hpp | 17 +- ggml-sycl/vecdotq.hpp | 1735 ++++++----- 9 files changed, 8888 insertions(+), 7046 deletions(-) diff --git a/ggml-sycl/backend.hpp b/ggml-sycl/backend.hpp index a343783bd01c0..9e15a5f184e4e 100644 --- a/ggml-sycl/backend.hpp +++ b/ggml-sycl/backend.hpp @@ -14,8 +14,8 @@ #define GGML_SYCL_BACKEND_HPP #include "common.hpp" -#include "vecdotq.hpp" -#include "mmvq.hpp" #include "mmq.hpp" +#include "mmvq.hpp" +#include "vecdotq.hpp" #endif // GGML_SYCL_BACKEND_HPP \ No newline at end of file diff --git a/ggml-sycl/common.cpp b/ggml-sycl/common.cpp index 708d787e866e4..c16c63a07ea49 100644 --- a/ggml-sycl/common.cpp +++ b/ggml-sycl/common.cpp @@ -12,169 +12,191 @@ #include "common.hpp" -int get_main_device(){ - return g_main_device; +int get_main_device() { + return g_main_device; } void check_allow_gpu_index(const int device_index) { - if (device_index >= g_device_count) { - char error_buf[256]; - snprintf(error_buf, sizeof(error_buf), - "%s error: device_index:%d is out of range: [0-%d]", __func__, - device_index, g_device_count - 1); - fprintf(stderr, "%s\n", error_buf); - assert(false); - } + if (device_index >= g_device_count) { + char error_buf[256]; + snprintf( + error_buf, + sizeof(error_buf), + "%s error: device_index:%d is out of range: [0-%d]", + __func__, + device_index, + g_device_count - 1); + fprintf(stderr, "%s\n", error_buf); + assert(false); + } } void check_allow_gpu_id(const int device_id) { - if (!g_sycl_gpu_mgr->is_allowed_gpu(device_id)) { - char error_buf[256]; - snprintf(error_buf, sizeof(error_buf), - "error: cannot set device=%d, which is not allowed. Please " - "set GPU ID in: [%s]", - device_id, g_sycl_gpu_mgr->gpus_list.c_str()); - fprintf(stderr, "%s\n", error_buf); - throw std::invalid_argument(error_buf); - } + if (!g_sycl_gpu_mgr->is_allowed_gpu(device_id)) { + char error_buf[256]; + snprintf( + error_buf, + sizeof(error_buf), + "error: cannot set device=%d, which is not allowed. Please " + "set GPU ID in: [%s]", + device_id, + g_sycl_gpu_mgr->gpus_list.c_str()); + fprintf(stderr, "%s\n", error_buf); + throw std::invalid_argument(error_buf); + } } int get_current_device_id() { - return dpct::dev_mgr::instance().current_device_id(); + return dpct::dev_mgr::instance().current_device_id(); } -void log_ggml_var_device(const char*name, float *src, size_t total_elements, bool src_on_device){ - if(!g_ggml_sycl_debug) return; - if(!src){ - printf("GGML Tensor:%s skip to save for NULL pointer\n", name); - return; - } - char filename[1024]; - sprintf(filename, "%s.txt", name); - printf("GGML Tensor:%s save to %s\n", name, filename); - - size_t total_size = total_elements*sizeof(float); - float *local_buf = NULL; - if(src_on_device) { - local_buf = (float *) ggml_sycl_host_malloc(total_size); - ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - main_stream->memcpy(local_buf, src, total_size).wait(); - } - else { - local_buf = (float *)src; - } - - std::ofstream logfile; - logfile.open(filename); - for(size_t i=0; imemcpy(local_buf, src, total_size).wait(); + } else { + local_buf = (float*)src; + } + + std::ofstream logfile; + logfile.open(filename); + for (size_t i = 0; i < total_elements; i++) { + logfile << local_buf[i] << " "; + if ((i + 1) % 20 == 0) + logfile << std::endl; + } + logfile << std::endl; + logfile.close(); + + if (src_on_device) + ggml_sycl_host_free(local_buf); } -void log_ggml_var_device_fp16(const char*name, sycl::half *src, size_t total_elements, bool src_on_device){ - if(!g_ggml_sycl_debug) return; - if(!src){ - printf("GGML Tensor:%s skip to save for NULL pointer\n", name); - return; - } - char filename[1024]; - sprintf(filename, "%s.txt", name); - printf("GGML Tensor:%s save to %s\n", name, filename); - - size_t total_size = total_elements*sizeof(sycl::half); - sycl::half *local_buf = NULL; - if(src_on_device) { - local_buf = (sycl::half *) ggml_sycl_host_malloc(total_size); - ggml_sycl_set_device(g_main_device); - dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - main_stream->memcpy(local_buf, src, total_size).wait(); - } - else { - local_buf = (sycl::half *)src; - } - - std::ofstream logfile; - logfile.open(filename); - for(size_t i=0; imemcpy(local_buf, src, total_size).wait(); + } else { + local_buf = (sycl::half*)src; + } + + std::ofstream logfile; + logfile.open(filename); + for (size_t i = 0; i < total_elements; i++) { + logfile << local_buf[i] << " "; + if ((i + 1) % 20 == 0) + logfile << std::endl; + } + logfile << std::endl; + logfile.close(); + + if (src_on_device) + ggml_sycl_host_free(local_buf); } -void print_ggml_tensor(const char*name, struct ggml_tensor *src){ - if(!g_ggml_sycl_debug) return; - if(!src){ - printf("GGML Tensor:%s skip to save for NULL pointer\n", name); - return; - } - - size_t total_elements = ggml_nelements(src); - - const bool src_on_device = src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT; - float *src_data =NULL; - if(src_on_device) { - ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra; - src_data = (float*)src_extra->data_device[g_main_device]; - } - else { - src_data = (float *)src->data; - } - - log_ggml_var_device(name, src_data, total_elements, src_on_device); +void print_ggml_tensor(const char* name, struct ggml_tensor* src) { + if (!g_ggml_sycl_debug) + return; + if (!src) { + printf("GGML Tensor:%s skip to save for NULL pointer\n", name); + return; + } + + size_t total_elements = ggml_nelements(src); + + const bool src_on_device = src->backend == GGML_BACKEND_TYPE_GPU || + src->backend == GGML_BACKEND_TYPE_GPU_SPLIT; + float* src_data = NULL; + if (src_on_device) { + ggml_tensor_extra_gpu* src_extra = (ggml_tensor_extra_gpu*)src->extra; + src_data = (float*)src_extra->data_device[g_main_device]; + } else { + src_data = (float*)src->data; + } + + log_ggml_var_device(name, src_data, total_elements, src_on_device); } -void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt) { - stop_cnt = 4; - if(log_file_name_idx>=stop_cnt) return; - char filename[1280]; - sprintf(filename, "%s_%07d", name, log_file_name_idx); - log_file_name_idx++; - print_ggml_tensor(filename, src); +void log_tensor_with_cnt( + const char* name, + struct ggml_tensor* src, + int stop_cnt) { + stop_cnt = 4; + if (log_file_name_idx >= stop_cnt) + return; + char filename[1280]; + sprintf(filename, "%s_%07d", name, log_file_name_idx); + log_file_name_idx++; + print_ggml_tensor(filename, src); } -void *ggml_sycl_host_malloc(size_t size) try { - if (getenv("GGML_SYCL_NO_PINNED") != nullptr) { - return nullptr; - } - - void * ptr = nullptr; - //allow to use dpct::get_in_order_queue() for host malloc - dpct::err0 err = CHECK_TRY_ERROR( - ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue())); - - if (err != 0) { - // clear the error - fprintf( - stderr, - "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", - size / 1024.0 / 1024.0, - "syclGetErrorString is not supported"); - return nullptr; - } - - return ptr; -} -catch (sycl::exception const &exc) { +void* ggml_sycl_host_malloc(size_t size) try { + if (getenv("GGML_SYCL_NO_PINNED") != nullptr) { + return nullptr; + } + + void* ptr = nullptr; + // allow to use dpct::get_in_order_queue() for host malloc + dpct::err0 err = CHECK_TRY_ERROR( + ptr = (void*)sycl::malloc_host(size, dpct::get_in_order_queue())); + + if (err != 0) { + // clear the error + fprintf( + stderr, + "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", + size / 1024.0 / 1024.0, + "syclGetErrorString is not supported"); + return nullptr; + } + + return ptr; +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -void ggml_sycl_host_free(void *ptr) try { - //allow to use dpct::get_in_order_queue() for host malloc - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue()))); -} -catch (sycl::exception const &exc) { +void ggml_sycl_host_free(void* ptr) try { + // allow to use dpct::get_in_order_queue() for host malloc + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue()))); +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); diff --git a/ggml-sycl/common.hpp b/ggml-sycl/common.hpp index 38d54f50a9ce0..7b72246d02e78 100644 --- a/ggml-sycl/common.hpp +++ b/ggml-sycl/common.hpp @@ -13,8 +13,8 @@ #ifndef GGML_SYCL_COMMON_HPP #define GGML_SYCL_COMMON_HPP -#include #include +#include #include "dpct/helper.hpp" @@ -22,22 +22,27 @@ #define GGML_COMMON_IMPL_SYCL #include "ggml-common.h" -void * ggml_sycl_host_malloc(size_t size); -void ggml_sycl_host_free(void * ptr); - -static int g_ggml_sycl_debug=0; -#define GGML_SYCL_DEBUG(...) do{if(g_ggml_sycl_debug) fprintf(stderr, __VA_ARGS__);}while(0) - -#define CHECK_TRY_ERROR(expr) \ - [&]() { \ - try { \ - expr; \ - return dpct::success; \ - } catch (std::exception const &e) { \ - std::cerr << e.what()<< "\nException caught at file:" << __FILE__ \ - << ", line:" << __LINE__ <<", func:"<<__func__<< std::endl; \ - return dpct::default_error; \ - } \ +void* ggml_sycl_host_malloc(size_t size); +void ggml_sycl_host_free(void* ptr); + +static int g_ggml_sycl_debug = 0; +#define GGML_SYCL_DEBUG(...) \ + do { \ + if (g_ggml_sycl_debug) \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) + +#define CHECK_TRY_ERROR(expr) \ + [&]() { \ + try { \ + expr; \ + return dpct::success; \ + } catch (std::exception const& e) { \ + std::cerr << e.what() << "\nException caught at file:" << __FILE__ \ + << ", line:" << __LINE__ << ", func:" << __func__ \ + << std::endl; \ + return dpct::default_error; \ + } \ }() // #define DEBUG_SYCL_MALLOC @@ -46,24 +51,22 @@ static int g_work_group_size = 0; // typedef sycl::half ggml_fp16_t; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 //todo for hardward optimize. -#define VER_GEN9 700 //todo for hardward optimize. -#define VER_GEN12 1000000 //todo for hardward optimize. -#define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize. +#define VER_4VEC 610 // todo for hardward optimize. +#define VER_GEN9 700 // todo for hardward optimize. +#define VER_GEN12 1000000 // todo for hardward optimize. +#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize. -#define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares +#define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares - -//define for XMX in Intel GPU -//TODO: currently, it's not used for XMX really. +// define for XMX in Intel GPU +// TODO: currently, it's not used for XMX really. #define SYCL_USE_XMX // max batch size to use MMQ kernels when tensor cores are available #define XMX_MAX_BATCH_SIZE 32 - #if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data +#pragma warning(disable : 4244 4267) // possible loss of data #endif // dmmv = dequantize_mul_mat_vec @@ -75,29 +78,40 @@ static int g_work_group_size = 0; #endif enum ggml_sycl_backend_gpu_mode { - SYCL_UNSET_GPU_MODE = -1, - SYCL_SINGLE_GPU_MODE = 0, - SYCL_MUL_GPU_MODE + SYCL_UNSET_GPU_MODE = -1, + SYCL_SINGLE_GPU_MODE = 0, + SYCL_MUL_GPU_MODE }; static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -static void crash(){ - int *ptr = NULL; - *ptr = 0; +static void crash() { + int* ptr = NULL; + *ptr = 0; } -static void ggml_sycl_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) { - fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg); - fprintf(stderr, " in function %s at %s:%d\n", func, file, line); - GGML_ASSERT(!"SYCL error"); +static void ggml_sycl_error( + const char* stmt, + const char* func, + const char* file, + const int line, + const char* msg) { + fprintf(stderr, "SYCL error: %s: %s\n", stmt, msg); + fprintf(stderr, " in function %s at %s:%d\n", func, file, line); + GGML_ASSERT(!"SYCL error"); } -#define SYCL_CHECK(err) do { \ - auto err_ = (err); if (err_ != 0) ggml_sycl_error( \ - #err, __func__, __FILE__, __LINE__, \ - "Meet error in this line code!"); \ -} while (0) +#define SYCL_CHECK(err) \ + do { \ + auto err_ = (err); \ + if (err_ != 0) \ + ggml_sycl_error( \ + #err, \ + __func__, \ + __FILE__, \ + __LINE__, \ + "Meet error in this line code!"); \ + } while (0) #if DPCT_COMPAT_RT_VERSION >= 11100 #define GGML_SYCL_ASSUME(x) __builtin_assume(x) @@ -111,11 +125,12 @@ typedef sycl::half2 dfloat2; #else typedef float dfloat; // dequantize float typedef sycl::float2 dfloat2; -#endif //GGML_SYCL_F16 - +#endif // GGML_SYCL_F16 #define WARP_SIZE 32 -#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses +#define MATRIX_ROW_PADDING \ + 512 // last row of quant. matrices is a multiple of this to avoid + // out-of-bounds memory accesses #define SYCL_GELU_BLOCK_SIZE 256 #define SYCL_SILU_BLOCK_SIZE 256 @@ -152,7 +167,9 @@ typedef sycl::float2 dfloat2; #ifndef K_QUANTS_PER_ITERATION #define K_QUANTS_PER_ITERATION 2 #else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); +static_assert( + K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, + "K_QUANTS_PER_ITERATION must be 1 or 2"); #endif #ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE @@ -166,135 +183,139 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA static dpct::queue_ptr g_syclStreams[SYCL_MAX_DEVICES][MAX_STREAMS] = {{0}}; struct ggml_tensor_extra_gpu { - void * data_device[SYCL_MAX_DEVICES]; // 1 pointer for each device for split tensors - dpct::event_ptr - events[SYCL_MAX_DEVICES] - [MAX_STREAMS]; // events for synchronizing multiple GPUs + void* data_device[SYCL_MAX_DEVICES]; // 1 pointer for each device for split + // tensors + dpct::event_ptr events[SYCL_MAX_DEVICES] + [MAX_STREAMS]; // events for synchronizing multiple GPUs }; class sycl_gpu_mgr { - public: - std::vector gpus; - std::vector devices; - sycl::queue *first_queue; - sycl::context co_ctx; - int max_compute_units = 0; - int work_group_size = 0; - std::string gpus_list = ""; - - /* - Use all GPUs with same top max compute units - */ - sycl_gpu_mgr() { - detect_sycl_gpu_list_with_max_cu(); - get_allow_gpus(); - create_context_with_gpus(); - } - - /* - Only use the assigned GPU - */ - sycl_gpu_mgr(int main_gpu_id) { - sycl::device device = dpct::dev_mgr::instance().get_device(main_gpu_id); - dpct::device_info prop; - dpct::get_device_info(prop, device); - gpus.push_back(main_gpu_id); - devices.push_back(device); - work_group_size = prop.get_max_work_group_size(); - max_compute_units = prop.get_max_compute_units(); - - get_allow_gpus(); - create_context_with_gpus(); - } - - void create_context_with_gpus() { - sycl::context ctx = sycl::context(devices); - assert(gpus.size() > 0); - first_queue = dpct::get_current_device().create_queue(ctx, devices[0]); - co_ctx = first_queue->get_context(); - } - - sycl::context &get_co_ctx() { return co_ctx; } - - void get_allow_gpus() { - gpus_list = ""; - for (size_t i = 0; i < gpus.size(); ++i) { - gpus_list += std::to_string(gpus[i]); - gpus_list += ","; - } - if (gpus_list.length() > 1) { - gpus_list.pop_back(); - } - } - - bool is_allowed_gpu(int device_id) { - return std::find(gpus.begin(), gpus.end(), device_id) != gpus.end(); - } - - void detect_sycl_gpu_list_with_max_cu() try { - int device_count = dpct::dev_mgr::instance().device_count(); - - for (int id = 0; id < device_count; id++) { - sycl::device device = dpct::dev_mgr::instance().get_device(id); - if (!device.is_gpu()) - continue; - dpct::device_info prop; - dpct::get_device_info(prop, device); - if (max_compute_units < prop.get_max_compute_units()) - max_compute_units = prop.get_max_compute_units(); - } - - for (int id = 0; id < device_count; id++) { - sycl::device device = dpct::dev_mgr::instance().get_device(id); - if (!device.is_gpu()) - continue; - dpct::device_info prop; - dpct::get_device_info(prop, device); - if (max_compute_units == prop.get_max_compute_units() && - is_ext_oneapi_device(device)) { - gpus.push_back(id); - devices.push_back(device); - work_group_size = prop.get_max_work_group_size(); - } - } - return; - } catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); - } - - int get_gpu_count() { return (int)gpus.size(); } - - int get_index(int id) { - for (int i = 0; i < (int)gpus.size(); i++) { - if (gpus[i] == id) - return i; - } - printf("miss to get device index by id=%d\n", id); - GGML_ASSERT(false); - } - - int get_next_index(int id) { - int cur_index = get_index(id); - for (int i = cur_index + 1; i < (int)gpus.size(); i++) { - if (gpus[i] == id) - return i; - } - GGML_ASSERT(false); - } - - bool is_ext_oneapi_device(const sycl::device &dev) { - sycl::backend dev_backend = dev.get_backend(); - if (dev_backend == sycl::backend::ext_oneapi_level_zero || - dev_backend == sycl::backend::ext_oneapi_cuda || - dev_backend == sycl::backend::ext_oneapi_hip) - return true; - return false; - } + public: + std::vector gpus; + std::vector devices; + sycl::queue* first_queue; + sycl::context co_ctx; + int max_compute_units = 0; + int work_group_size = 0; + std::string gpus_list = ""; + + /* + Use all GPUs with same top max compute units + */ + sycl_gpu_mgr() { + detect_sycl_gpu_list_with_max_cu(); + get_allow_gpus(); + create_context_with_gpus(); + } + + /* + Only use the assigned GPU + */ + sycl_gpu_mgr(int main_gpu_id) { + sycl::device device = dpct::dev_mgr::instance().get_device(main_gpu_id); + dpct::device_info prop; + dpct::get_device_info(prop, device); + gpus.push_back(main_gpu_id); + devices.push_back(device); + work_group_size = prop.get_max_work_group_size(); + max_compute_units = prop.get_max_compute_units(); + + get_allow_gpus(); + create_context_with_gpus(); + } + + void create_context_with_gpus() { + sycl::context ctx = sycl::context(devices); + assert(gpus.size() > 0); + first_queue = dpct::get_current_device().create_queue(ctx, devices[0]); + co_ctx = first_queue->get_context(); + } + + sycl::context& get_co_ctx() { + return co_ctx; + } + + void get_allow_gpus() { + gpus_list = ""; + for (size_t i = 0; i < gpus.size(); ++i) { + gpus_list += std::to_string(gpus[i]); + gpus_list += ","; + } + if (gpus_list.length() > 1) { + gpus_list.pop_back(); + } + } + + bool is_allowed_gpu(int device_id) { + return std::find(gpus.begin(), gpus.end(), device_id) != gpus.end(); + } + + void detect_sycl_gpu_list_with_max_cu() try { + int device_count = dpct::dev_mgr::instance().device_count(); + + for (int id = 0; id < device_count; id++) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + if (!device.is_gpu()) + continue; + dpct::device_info prop; + dpct::get_device_info(prop, device); + if (max_compute_units < prop.get_max_compute_units()) + max_compute_units = prop.get_max_compute_units(); + } + + for (int id = 0; id < device_count; id++) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + if (!device.is_gpu()) + continue; + dpct::device_info prop; + dpct::get_device_info(prop, device); + if (max_compute_units == prop.get_max_compute_units() && + is_ext_oneapi_device(device)) { + gpus.push_back(id); + devices.push_back(device); + work_group_size = prop.get_max_work_group_size(); + } + } + return; + } catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); + } + + int get_gpu_count() { + return (int)gpus.size(); + } + + int get_index(int id) { + for (int i = 0; i < (int)gpus.size(); i++) { + if (gpus[i] == id) + return i; + } + printf("miss to get device index by id=%d\n", id); + GGML_ASSERT(false); + } + + int get_next_index(int id) { + int cur_index = get_index(id); + for (int i = cur_index + 1; i < (int)gpus.size(); i++) { + if (gpus[i] == id) + return i; + } + GGML_ASSERT(false); + } + + bool is_ext_oneapi_device(const sycl::device& dev) { + sycl::backend dev_backend = dev.get_backend(); + if (dev_backend == sycl::backend::ext_oneapi_level_zero || + dev_backend == sycl::backend::ext_oneapi_cuda || + dev_backend == sycl::backend::ext_oneapi_hip) + return true; + return false; + } }; -static sycl_gpu_mgr *g_sycl_gpu_mgr = NULL; +static sycl_gpu_mgr* g_sycl_gpu_mgr = NULL; static int g_device_count = -1; static int g_all_sycl_device_count = -1; static int g_main_device = -1; @@ -305,22 +326,24 @@ static std::array g_default_tensor_split = {}; static float g_tensor_split[SYCL_MAX_DEVICES] = {0}; -static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode = SYCL_UNSET_GPU_MODE; +static ggml_sycl_backend_gpu_mode g_ggml_sycl_backend_gpu_mode = + SYCL_UNSET_GPU_MODE; struct sycl_device_capabilities { - int cc; // compute capability - bool vmm; // virtual memory support - size_t vmm_granularity; // granularity of virtual memory - int device_id; + int cc; // compute capability + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory + int device_id; }; -static sycl_device_capabilities g_device_caps[SYCL_MAX_DEVICES] = { {0, false, 0, -1} }; +static sycl_device_capabilities g_device_caps[SYCL_MAX_DEVICES] = { + {0, false, 0, -1}}; struct sycl_device_id2index { - int index; + int index; }; -static void * g_scratch_buffer = nullptr; +static void* g_scratch_buffer = nullptr; static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_offset = 0; @@ -328,14 +351,13 @@ static dpct::queue_ptr g_sycl_handles[SYCL_MAX_DEVICES] = {nullptr}; int get_main_device(); -[[noreturn]] -static void bad_arch(const sycl::stream &stream_ct1) { - stream_ct1 << "ERROR: ggml-sycl was compiled without support for the " - "current GPU architecture.\n"; - // __trap(); - std::exit(1); +[[noreturn]] static void bad_arch(const sycl::stream& stream_ct1) { + stream_ct1 << "ERROR: ggml-sycl was compiled without support for the " + "current GPU architecture.\n"; + // __trap(); + std::exit(1); - (void) bad_arch; // suppress unused function warning + (void)bad_arch; // suppress unused function warning } /* @@ -353,35 +375,45 @@ void check_allow_gpu_id(const int device_id); int get_current_device_id(); inline dpct::err0 ggml_sycl_set_device(const int device) try { - - int device_id = g_sycl_gpu_mgr->gpus[device]; - check_allow_gpu_id(device_id); - - int current_device_id; - SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id())); - - // GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d, - // current_device_id=%d\n", device, current_device); - if (device_id == current_device_id) { - return 0; - } - - return CHECK_TRY_ERROR(dpct::select_device(device_id)); -} catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - crash(); - std::exit(1); + int device_id = g_sycl_gpu_mgr->gpus[device]; + check_allow_gpu_id(device_id); + + int current_device_id; + SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id())); + + // GGML_SYCL_DEBUG("ggml_sycl_set_device device_id=%d, + // current_device_id=%d\n", device, current_device); + if (device_id == current_device_id) { + return 0; + } + + return CHECK_TRY_ERROR(dpct::select_device(device_id)); +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + crash(); + std::exit(1); } -void log_ggml_var_device(const char*name, float *src, size_t total_elements, bool src_on_device); - -void log_ggml_var_device_fp16(const char*name, sycl::half *src, size_t total_elements, bool src_on_device); - -//todo: debug for crash in some case -void print_ggml_tensor(const char*name, struct ggml_tensor *src); - -static int log_file_name_idx=0; -void log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt); +void log_ggml_var_device( + const char* name, + float* src, + size_t total_elements, + bool src_on_device); + +void log_ggml_var_device_fp16( + const char* name, + sycl::half* src, + size_t total_elements, + bool src_on_device); + +// todo: debug for crash in some case +void print_ggml_tensor(const char* name, struct ggml_tensor* src); + +static int log_file_name_idx = 0; +void log_tensor_with_cnt( + const char* name, + struct ggml_tensor* src, + int stop_cnt); #endif // GGML_SYCL_COMMON_HPP \ No newline at end of file diff --git a/ggml-sycl/dpct/helper.hpp b/ggml-sycl/dpct/helper.hpp index c0a1745a9d125..602ec7941bf38 100644 --- a/ggml-sycl/dpct/helper.hpp +++ b/ggml-sycl/dpct/helper.hpp @@ -13,9 +13,9 @@ #ifndef GGML_SYCL_DPCT_HELPER_HPP #define GGML_SYCL_DPCT_HELPER_HPP -#include -#include #include +#include +#include #include #include "ggml.h" @@ -32,8 +32,8 @@ #endif #if defined(__linux__) -#include #include +#include #endif #if defined(_WIN64) #ifndef NOMINMAX @@ -58,2884 +58,3487 @@ #define __dpct_noinline__ __attribute__((noinline)) #endif -inline std::string get_device_type_name(const sycl::device &Device) { - auto DeviceType = Device.get_info(); - switch (DeviceType) { +inline std::string get_device_type_name(const sycl::device& Device) { + auto DeviceType = Device.get_info(); + switch (DeviceType) { case sycl::info::device_type::cpu: - return "cpu"; + return "cpu"; case sycl::info::device_type::gpu: - return "gpu"; + return "gpu"; case sycl::info::device_type::host: - return "host"; + return "host"; case sycl::info::device_type::accelerator: - return "acc"; + return "acc"; default: - return "unknown"; - } -} - -inline std::string get_device_backend_and_type(const sycl::device &device) { - std::stringstream device_type; - sycl::backend backend = device.get_backend(); - device_type << backend << ":" << get_device_type_name(device); - return device_type.str(); -} - -namespace dpct -{ - typedef sycl::queue *queue_ptr; - typedef sycl::event *event_ptr; - typedef char *device_ptr; - typedef uint8_t byte_t; - typedef sycl::buffer buffer_t; - - /// SYCL default exception handler - inline auto exception_handler = [](sycl::exception_list exceptions) - { - for (std::exception_ptr const &e : exceptions) - { - try - { - std::rethrow_exception(e); - } - catch (sycl::exception const &e) - { - std::cerr << "Caught asynchronous SYCL exception:" << std::endl - << e.what() << std::endl - << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - } - } - }; - - enum error_code - { - success = 0, - default_error = 999 - }; - - enum memcpy_direction - { - host_to_host, - host_to_device, - device_to_host, - device_to_device, - automatic - }; - - enum memory_region - { - global = 0, // device global memory - constant, // device constant memory - local, // device local memory - shared, // memory which can be accessed by host and device - }; - - enum class library_data_t : unsigned char - { - real_float = 0, - complex_float, - real_double, - complex_double, - real_half, - complex_half, - real_bfloat16, - complex_bfloat16, - real_int4, - complex_int4, - real_uint4, - complex_uint4, - real_int8, - complex_int8, - real_uint8, - complex_uint8, - real_int16, - complex_int16, - real_uint16, - complex_uint16, - real_int32, - complex_int32, - real_uint32, - complex_uint32, - real_int64, - complex_int64, - real_uint64, - complex_uint64, - real_int8_4, - real_int8_32, - real_uint8_4, - library_data_t_size - }; - - template - struct DataType - { - using T2 = T; - }; - template - struct DataType> - { - using T2 = std::complex; - }; - - static void destroy_event(event_ptr event) - { - delete event; - } - - static inline unsigned int get_tid() - { + return "unknown"; + } +} + +inline std::string get_device_backend_and_type(const sycl::device& device) { + std::stringstream device_type; + sycl::backend backend = device.get_backend(); + device_type << backend << ":" << get_device_type_name(device); + return device_type.str(); +} + +namespace dpct { +typedef sycl::queue* queue_ptr; +typedef sycl::event* event_ptr; +typedef char* device_ptr; +typedef uint8_t byte_t; +typedef sycl::buffer buffer_t; + +/// SYCL default exception handler +inline auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const& e : exceptions) { + try { + std::rethrow_exception(e); + } catch (sycl::exception const& e) { + std::cerr << "Caught asynchronous SYCL exception:" << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } + } +}; + +enum error_code { success = 0, default_error = 999 }; + +enum memcpy_direction { + host_to_host, + host_to_device, + device_to_host, + device_to_device, + automatic +}; + +enum memory_region { + global = 0, // device global memory + constant, // device constant memory + local, // device local memory + shared, // memory which can be accessed by host and device +}; + +enum class library_data_t : unsigned char { + real_float = 0, + complex_float, + real_double, + complex_double, + real_half, + complex_half, + real_bfloat16, + complex_bfloat16, + real_int4, + complex_int4, + real_uint4, + complex_uint4, + real_int8, + complex_int8, + real_uint8, + complex_uint8, + real_int16, + complex_int16, + real_uint16, + complex_uint16, + real_int32, + complex_int32, + real_uint32, + complex_uint32, + real_int64, + complex_int64, + real_uint64, + complex_uint64, + real_int8_4, + real_int8_32, + real_uint8_4, + library_data_t_size +}; + +template +struct DataType { + using T2 = T; +}; +template +struct DataType> { + using T2 = std::complex; +}; + +static void destroy_event(event_ptr event) { + delete event; +} + +static inline unsigned int get_tid() { #if defined(__linux__) - return syscall(SYS_gettid); + return syscall(SYS_gettid); #elif defined(_WIN64) - return GetCurrentThreadId(); + return GetCurrentThreadId(); #else #error "Only support Windows and Linux." #endif - } +} - namespace detail - { - static void get_version(const sycl::device &dev, int &major, int &minor) - { - // Version string has the following format: - // a. OpenCL - // b. - // c. e.g gfx1030 - std::string ver; - ver = dev.get_info(); - std::string::size_type i = 0; - while (i < ver.size()) { - if (isdigit(ver[i])) - break; - i++; - } - major = std::stoi(&(ver[i])); - while (i < ver.size()) { - if (ver[i] == '.') - break; - i++; - } - if (i < ver.size()) { - // a. and b. - i++; - minor = std::stoi(&(ver[i])); - } else { - // c. - minor = 0; - } - } +namespace detail { +static void get_version(const sycl::device& dev, int& major, int& minor) { + // Version string has the following format: + // a. OpenCL + // b. + // c. e.g gfx1030 + std::string ver; + ver = dev.get_info(); + std::string::size_type i = 0; + while (i < ver.size()) { + if (isdigit(ver[i])) + break; + i++; + } + major = std::stoi(&(ver[i])); + while (i < ver.size()) { + if (ver[i] == '.') + break; + i++; + } + if (i < ver.size()) { + // a. and b. + i++; + minor = std::stoi(&(ver[i])); + } else { + // c. + minor = 0; + } +} - template - class generic_error_type - { - public: - generic_error_type() = default; - generic_error_type(T value) : value{value} {} - operator T() const { return value; } +template +class generic_error_type { + public: + generic_error_type() = default; + generic_error_type(T value) : value{value} {} + operator T() const { + return value; + } + + private: + T value; +}; + +} // namespace detail + +/// Pitched 2D/3D memory data. +class pitched_data { + public: + pitched_data() : pitched_data(nullptr, 0, 0, 0) {} + pitched_data(void* data, size_t pitch, size_t x, size_t y) + : _data(data), _pitch(pitch), _x(x), _y(y) {} + + void* get_data_ptr() { + return _data; + } + void set_data_ptr(void* data) { + _data = data; + } + + size_t get_pitch() { + return _pitch; + } + void set_pitch(size_t pitch) { + _pitch = pitch; + } + + size_t get_x() { + return _x; + } + void set_x(size_t x) { + _x = x; + }; + + size_t get_y() { + return _y; + } + void set_y(size_t y) { + _y = y; + } + + private: + void* _data; + size_t _pitch, _x, _y; +}; + +class device_info { + public: + // get interface + const char* get_name() const { + return _name; + } + char* get_name() { + return _name; + } + template < + typename WorkItemSizesTy = sycl::range<3>, + std::enable_if_t< + std::is_same_v> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() const { + if constexpr (std::is_same_v>) + return sycl::range<3>( + _max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else { + return _max_work_item_sizes_i; + } + } + template < + typename WorkItemSizesTy = sycl::range<3>, + std::enable_if_t< + std::is_same_v> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() { + if constexpr (std::is_same_v>) + return sycl::range<3>( + _max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else { + return _max_work_item_sizes_i; + } + } + bool get_host_unified_memory() const { + return _host_unified_memory; + } + int get_major_version() const { + return _major; + } + int get_minor_version() const { + return _minor; + } + int get_integrated() const { + return _integrated; + } + int get_max_clock_frequency() const { + return _frequency; + } + int get_max_compute_units() const { + return _max_compute_units; + } + int get_max_work_group_size() const { + return _max_work_group_size; + } + int get_max_sub_group_size() const { + return _max_sub_group_size; + } + int get_max_work_items_per_compute_unit() const { + return _max_work_items_per_compute_unit; + } + int get_max_register_size_per_work_group() const { + return _max_register_size_per_work_group; + } + template < + typename NDRangeSizeTy = size_t*, + std::enable_if_t< + std::is_same_v || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() const { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + template < + typename NDRangeSizeTy = size_t*, + std::enable_if_t< + std::is_same_v || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + size_t get_global_mem_size() const { + return _global_mem_size; + } + size_t get_local_mem_size() const { + return _local_mem_size; + } + size_t get_max_mem_alloc_size() const { + return _max_mem_alloc_size; + } + /// Returns the maximum clock rate of device's global memory in kHz. If + /// compiler does not support this API then returns default value 3200000 kHz. + unsigned int get_memory_clock_rate() const { + return _memory_clock_rate; + } + /// Returns the maximum bus width between device and memory in bits. If + /// compiler does not support this API then returns default value 64 bits. + unsigned int get_memory_bus_width() const { + return _memory_bus_width; + } + uint32_t get_device_id() const { + return _device_id; + } + std::array get_uuid() const { + return _uuid; + } + /// Returns global memory cache size in bytes. + unsigned int get_global_mem_cache_size() const { + return _global_mem_cache_size; + } + + // set interface + void set_name(const char* name) { + size_t length = strlen(name); + if (length < 256) { + std::memcpy(_name, name, length + 1); + } else { + std::memcpy(_name, name, 255); + _name[255] = '\0'; + } + } + void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) { + for (int i = 0; i < 3; ++i) + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + [[deprecated]] void set_max_work_item_sizes( + const sycl::id<3> max_work_item_sizes) { + for (int i = 0; i < 3; ++i) { + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + } + void set_host_unified_memory(bool host_unified_memory) { + _host_unified_memory = host_unified_memory; + } + void set_major_version(int major) { + _major = major; + } + void set_minor_version(int minor) { + _minor = minor; + } + void set_integrated(int integrated) { + _integrated = integrated; + } + void set_max_clock_frequency(int frequency) { + _frequency = frequency; + } + void set_max_compute_units(int max_compute_units) { + _max_compute_units = max_compute_units; + } + void set_global_mem_size(size_t global_mem_size) { + _global_mem_size = global_mem_size; + } + void set_local_mem_size(size_t local_mem_size) { + _local_mem_size = local_mem_size; + } + void set_max_mem_alloc_size(size_t max_mem_alloc_size) { + _max_mem_alloc_size = max_mem_alloc_size; + } + void set_max_work_group_size(int max_work_group_size) { + _max_work_group_size = max_work_group_size; + } + void set_max_sub_group_size(int max_sub_group_size) { + _max_sub_group_size = max_sub_group_size; + } + void set_max_work_items_per_compute_unit( + int max_work_items_per_compute_unit) { + _max_work_items_per_compute_unit = max_work_items_per_compute_unit; + } + void set_max_nd_range_size(int max_nd_range_size[]) { + for (int i = 0; i < 3; i++) { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_memory_clock_rate(unsigned int memory_clock_rate) { + _memory_clock_rate = memory_clock_rate; + } + void set_memory_bus_width(unsigned int memory_bus_width) { + _memory_bus_width = memory_bus_width; + } + void set_max_register_size_per_work_group( + int max_register_size_per_work_group) { + _max_register_size_per_work_group = max_register_size_per_work_group; + } + void set_device_id(uint32_t device_id) { + _device_id = device_id; + } + void set_uuid(std::array uuid) { + _uuid = std::move(uuid); + } + void set_global_mem_cache_size(unsigned int global_mem_cache_size) { + _global_mem_cache_size = global_mem_cache_size; + } + + private: + char _name[256]; + int _max_work_item_sizes_i[3]; + bool _host_unified_memory = false; + int _major; + int _minor; + int _integrated = 0; + int _frequency; + // Set estimated value 3200000 kHz as default value. + unsigned int _memory_clock_rate = 3200000; + // Set estimated value 64 bits as default value. + unsigned int _memory_bus_width = 64; + unsigned int _global_mem_cache_size; + int _max_compute_units; + int _max_work_group_size; + int _max_sub_group_size; + int _max_work_items_per_compute_unit; + int _max_register_size_per_work_group; + size_t _global_mem_size; + size_t _local_mem_size; + size_t _max_mem_alloc_size; + size_t _max_nd_range_size[3]; + int _max_nd_range_size_i[3]; + uint32_t _device_id; + std::array _uuid; +}; + +static int get_major_version(const sycl::device& dev) { + int major, minor; + detail::get_version(dev, major, minor); + return major; +} - private: - T value; - }; +static int get_minor_version(const sycl::device& dev) { + int major, minor; + detail::get_version(dev, major, minor); + return minor; +} - } // namespace detail - - /// Pitched 2D/3D memory data. - class pitched_data - { - public: - pitched_data() : pitched_data(nullptr, 0, 0, 0) {} - pitched_data(void *data, size_t pitch, size_t x, size_t y) - : _data(data), _pitch(pitch), _x(x), _y(y) {} - - void *get_data_ptr() { return _data; } - void set_data_ptr(void *data) { _data = data; } - - size_t get_pitch() { return _pitch; } - void set_pitch(size_t pitch) { _pitch = pitch; } - - size_t get_x() { return _x; } - void set_x(size_t x) { _x = x; }; - - size_t get_y() { return _y; } - void set_y(size_t y) { _y = y; } - - private: - void *_data; - size_t _pitch, _x, _y; - }; - - class device_info - { - public: - // get interface - const char *get_name() const { return _name; } - char *get_name() { return _name; } - template , - std::enable_if_t> || - std::is_same_v, - int> = 0> - auto get_max_work_item_sizes() const - { - if constexpr (std::is_same_v>) - return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); - else - { - return _max_work_item_sizes_i; - } - } - template , - std::enable_if_t> || - std::is_same_v, - int> = 0> - auto get_max_work_item_sizes() - { - if constexpr (std::is_same_v>) - return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); - else - { - return _max_work_item_sizes_i; - } - } - bool get_host_unified_memory() const { return _host_unified_memory; } - int get_major_version() const { return _major; } - int get_minor_version() const { return _minor; } - int get_integrated() const { return _integrated; } - int get_max_clock_frequency() const { return _frequency; } - int get_max_compute_units() const { return _max_compute_units; } - int get_max_work_group_size() const { return _max_work_group_size; } - int get_max_sub_group_size() const { return _max_sub_group_size; } - int get_max_work_items_per_compute_unit() const - { - return _max_work_items_per_compute_unit; - } - int get_max_register_size_per_work_group() const - { - return _max_register_size_per_work_group; - } - template || - std::is_same_v, - int> = 0> - auto get_max_nd_range_size() const - { - if constexpr (std::is_same_v) - return _max_nd_range_size; - else - return _max_nd_range_size_i; - } - template || - std::is_same_v, - int> = 0> - auto get_max_nd_range_size() - { - if constexpr (std::is_same_v) - return _max_nd_range_size; - else - return _max_nd_range_size_i; - } - size_t get_global_mem_size() const { return _global_mem_size; } - size_t get_local_mem_size() const { return _local_mem_size; } - size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; } - /// Returns the maximum clock rate of device's global memory in kHz. If - /// compiler does not support this API then returns default value 3200000 kHz. - unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } - /// Returns the maximum bus width between device and memory in bits. If - /// compiler does not support this API then returns default value 64 bits. - unsigned int get_memory_bus_width() const { return _memory_bus_width; } - uint32_t get_device_id() const { return _device_id; } - std::array get_uuid() const { return _uuid; } - /// Returns global memory cache size in bytes. - unsigned int get_global_mem_cache_size() const - { - return _global_mem_cache_size; - } +static void get_device_info(device_info& out, const sycl::device& dev) { + device_info prop; + prop.set_name(dev.get_info().c_str()); - // set interface - void set_name(const char *name) - { - size_t length = strlen(name); - if (length < 256) - { - std::memcpy(_name, name, length + 1); - } - else - { - std::memcpy(_name, name, 255); - _name[255] = '\0'; - } - } - void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) - { - for (int i = 0; i < 3; ++i) - _max_work_item_sizes_i[i] = max_work_item_sizes[i]; - } - [[deprecated]] void - set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) - { - for (int i = 0; i < 3; ++i) - { - _max_work_item_sizes_i[i] = max_work_item_sizes[i]; - } - } - void set_host_unified_memory(bool host_unified_memory) - { - _host_unified_memory = host_unified_memory; - } - void set_major_version(int major) { _major = major; } - void set_minor_version(int minor) { _minor = minor; } - void set_integrated(int integrated) { _integrated = integrated; } - void set_max_clock_frequency(int frequency) { _frequency = frequency; } - void set_max_compute_units(int max_compute_units) - { - _max_compute_units = max_compute_units; - } - void set_global_mem_size(size_t global_mem_size) - { - _global_mem_size = global_mem_size; - } - void set_local_mem_size(size_t local_mem_size) - { - _local_mem_size = local_mem_size; - } - void set_max_mem_alloc_size(size_t max_mem_alloc_size) - { - _max_mem_alloc_size = max_mem_alloc_size; - } - void set_max_work_group_size(int max_work_group_size) - { - _max_work_group_size = max_work_group_size; - } - void set_max_sub_group_size(int max_sub_group_size) - { - _max_sub_group_size = max_sub_group_size; - } - void - set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) - { - _max_work_items_per_compute_unit = max_work_items_per_compute_unit; - } - void set_max_nd_range_size(int max_nd_range_size[]) - { - for (int i = 0; i < 3; i++) - { - _max_nd_range_size[i] = max_nd_range_size[i]; - _max_nd_range_size_i[i] = max_nd_range_size[i]; - } - } - void set_memory_clock_rate(unsigned int memory_clock_rate) - { - _memory_clock_rate = memory_clock_rate; - } - void set_memory_bus_width(unsigned int memory_bus_width) - { - _memory_bus_width = memory_bus_width; - } - void - set_max_register_size_per_work_group(int max_register_size_per_work_group) - { - _max_register_size_per_work_group = max_register_size_per_work_group; - } - void set_device_id(uint32_t device_id) - { - _device_id = device_id; - } - void set_uuid(std::array uuid) - { - _uuid = std::move(uuid); - } - void set_global_mem_cache_size(unsigned int global_mem_cache_size) - { - _global_mem_cache_size = global_mem_cache_size; - } + int major, minor; + detail::get_version(dev, major, minor); + prop.set_major_version(major); + prop.set_minor_version(minor); - private: - char _name[256]; - int _max_work_item_sizes_i[3]; - bool _host_unified_memory = false; - int _major; - int _minor; - int _integrated = 0; - int _frequency; - // Set estimated value 3200000 kHz as default value. - unsigned int _memory_clock_rate = 3200000; - // Set estimated value 64 bits as default value. - unsigned int _memory_bus_width = 64; - unsigned int _global_mem_cache_size; - int _max_compute_units; - int _max_work_group_size; - int _max_sub_group_size; - int _max_work_items_per_compute_unit; - int _max_register_size_per_work_group; - size_t _global_mem_size; - size_t _local_mem_size; - size_t _max_mem_alloc_size; - size_t _max_nd_range_size[3]; - int _max_nd_range_size_i[3]; - uint32_t _device_id; - std::array _uuid; - }; - - static int get_major_version(const sycl::device &dev) - { - int major, minor; - detail::get_version(dev, major, minor); - return major; - } - - static int get_minor_version(const sycl::device &dev) - { - int major, minor; - detail::get_version(dev, major, minor); - return minor; - } - - static void get_device_info(device_info &out, const sycl::device &dev) - { - device_info prop; - prop.set_name(dev.get_info().c_str()); - - int major, minor; - detail::get_version(dev, major, minor); - prop.set_major_version(major); - prop.set_minor_version(minor); - - prop.set_max_work_item_sizes( + prop.set_max_work_item_sizes( #if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) - // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes - // is an enum class element - dev.get_info()); + // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes + // is an enum class element + dev.get_info()); #else - // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by - // an int - dev.get_info>()); + // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by + // an int + dev.get_info>()); #endif - prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); + prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); - prop.set_max_clock_frequency( - dev.get_info() * 1000); + prop.set_max_clock_frequency( + dev.get_info() * 1000); - prop.set_max_compute_units( - dev.get_info()); - prop.set_max_work_group_size( - dev.get_info()); - prop.set_global_mem_size(dev.get_info()); - prop.set_local_mem_size(dev.get_info()); - prop.set_max_mem_alloc_size(dev.get_info()); + prop.set_max_compute_units( + dev.get_info()); + prop.set_max_work_group_size( + dev.get_info()); + prop.set_global_mem_size(dev.get_info()); + prop.set_local_mem_size(dev.get_info()); + prop.set_max_mem_alloc_size( + dev.get_info()); #if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) - if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) - { - unsigned int tmp = - dev.get_info(); - if (tmp != 0) - prop.set_memory_clock_rate(1000 * tmp); - } - if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) - { - prop.set_memory_bus_width( - dev.get_info()); - } - if (dev.has(sycl::aspect::ext_intel_device_id)) - { - prop.set_device_id( - dev.get_info()); - } - if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) - { - prop.set_uuid(dev.get_info()); - } + if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) { + unsigned int tmp = + dev.get_info(); + if (tmp != 0) + prop.set_memory_clock_rate(1000 * tmp); + } + if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) { + prop.set_memory_bus_width( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_id)) { + prop.set_device_id( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) { + prop.set_uuid(dev.get_info()); + } #elif defined(_MSC_VER) && !defined(__clang__) -#pragma message("get_device_info: querying memory_clock_rate and \ +#pragma message( \ + "get_device_info: querying memory_clock_rate and \ memory_bus_width are not supported by the compiler used. \ Use 3200000 kHz as memory_clock_rate default value. \ Use 64 bits as memory_bus_width default value.") #else -#warning "get_device_info: querying memory_clock_rate and \ +#warning \ + "get_device_info: querying memory_clock_rate and \ memory_bus_width are not supported by the compiler used. \ Use 3200000 kHz as memory_clock_rate default value. \ Use 64 bits as memory_bus_width default value." #endif - size_t max_sub_group_size = 1; - std::vector sub_group_sizes = - dev.get_info(); - - for (const auto &sub_group_size : sub_group_sizes) - { - if (max_sub_group_size < sub_group_size) - max_sub_group_size = sub_group_size; - } - - prop.set_max_sub_group_size(max_sub_group_size); - - prop.set_max_work_items_per_compute_unit( - dev.get_info()); - int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; - prop.set_max_nd_range_size(max_nd_range_size); - - // Estimates max register size per work group, feel free to update the value - // according to device properties. - prop.set_max_register_size_per_work_group(65536); - - prop.set_global_mem_cache_size( - dev.get_info()); - out = prop; - } - - /// dpct device extension - class device_ext : public sycl::device - { - typedef std::mutex mutex_type; - - public: - device_ext() : sycl::device(), _ctx(*this) {} - ~device_ext() - { - std::lock_guard lock(m_mutex); - clear_queues(); - } - device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) - { - std::lock_guard lock(m_mutex); - init_queues(); - } + size_t max_sub_group_size = 1; + std::vector sub_group_sizes = + dev.get_info(); - int is_native_atomic_supported() { return 0; } - int get_major_version() const - { - return dpct::get_major_version(*this); - } + for (const auto& sub_group_size : sub_group_sizes) { + if (max_sub_group_size < sub_group_size) + max_sub_group_size = sub_group_size; + } - int get_minor_version() const - { - return dpct::get_minor_version(*this); - } + prop.set_max_sub_group_size(max_sub_group_size); - int get_max_compute_units() const - { - return get_device_info().get_max_compute_units(); - } - - /// Return the maximum clock frequency of this device in KHz. - int get_max_clock_frequency() const - { - return get_device_info().get_max_clock_frequency(); - } + prop.set_max_work_items_per_compute_unit( + dev.get_info()); + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + prop.set_max_nd_range_size(max_nd_range_size); - int get_integrated() const { return get_device_info().get_integrated(); } - - int get_max_sub_group_size() const - { - return get_device_info().get_max_sub_group_size(); - } - - int get_max_register_size_per_work_group() const - { - return get_device_info().get_max_register_size_per_work_group(); - } - - int get_max_work_group_size() const - { - return get_device_info().get_max_work_group_size(); - } + // Estimates max register size per work group, feel free to update the value + // according to device properties. + prop.set_max_register_size_per_work_group(65536); - int get_mem_base_addr_align() const - { - return get_info(); - } - - size_t get_global_mem_size() const - { - return get_device_info().get_global_mem_size(); - } - - size_t get_max_mem_alloc_size() const - { - return get_device_info().get_max_mem_alloc_size(); - } + prop.set_global_mem_cache_size( + dev.get_info()); + out = prop; +} - /// Get the number of bytes of free and total memory on the SYCL device. - /// \param [out] free_memory The number of bytes of free memory on the SYCL device. - /// \param [out] total_memory The number of bytes of total memory on the SYCL device. - void get_memory_info(size_t &free_memory, size_t &total_memory) - { - total_memory = get_device_info().get_global_mem_size(); - const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " - "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " - "use total memory as free memory"; +/// dpct device extension +class device_ext : public sycl::device { + typedef std::mutex mutex_type; + + public: + device_ext() : sycl::device(), _ctx(*this) {} + ~device_ext() { + std::lock_guard lock(m_mutex); + clear_queues(); + } + device_ext(const sycl::device& base) : sycl::device(base), _ctx(*this) { + std::lock_guard lock(m_mutex); + init_queues(); + } + + int is_native_atomic_supported() { + return 0; + } + int get_major_version() const { + return dpct::get_major_version(*this); + } + + int get_minor_version() const { + return dpct::get_minor_version(*this); + } + + int get_max_compute_units() const { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { + return get_device_info().get_integrated(); + } + + int get_max_sub_group_size() const { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const { + return get_info(); + } + + size_t get_global_mem_size() const { + return get_device_info().get_global_mem_size(); + } + + size_t get_max_mem_alloc_size() const { + return get_device_info().get_max_mem_alloc_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the SYCL + /// device. \param [out] total_memory The number of bytes of total memory on + /// the SYCL device. + void get_memory_info(size_t& free_memory, size_t& total_memory) { + total_memory = get_device_info().get_global_mem_size(); + const char* warning_info = + "get_memory_info: [warning] ext_intel_free_memory is not " + "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " + "use total memory as free memory"; #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) - if (!has(sycl::aspect::ext_intel_free_memory)) - { - std::cerr << warning_info << std::endl; - free_memory = total_memory; - } - else - { - free_memory = get_info(); - } + if (!has(sycl::aspect::ext_intel_free_memory)) { + std::cerr << warning_info << std::endl; + free_memory = total_memory; + } else { + free_memory = get_info(); + } #else - std::cerr << warning_info << std::endl; - free_memory = total_memory; + std::cerr << warning_info << std::endl; + free_memory = total_memory; #if defined(_MSC_VER) && !defined(__clang__) #pragma message("Querying the number of bytes of free memory is not supported") #else #warning "Querying the number of bytes of free memory is not supported" #endif #endif - } - - void get_device_info(device_info &out) const - { - dpct::get_device_info(out, *this); - } - - device_info get_device_info() const - { - device_info prop; - dpct::get_device_info(prop, *this); - return prop; - } - - void reset() - { - std::lock_guard lock(m_mutex); - clear_queues(); - init_queues(); - } - - sycl::queue &in_order_queue() { return *_q_in_order; } - - sycl::queue &out_of_order_queue() { return *_q_out_of_order; } - - sycl::queue &default_queue() - { - return in_order_queue(); - } - - void queues_wait_and_throw() - { - std::unique_lock lock(m_mutex); - std::vector> current_queues( - _queues); - lock.unlock(); - for (const auto &q : current_queues) - { - q->wait_and_throw(); - } - // Guard the destruct of current_queues to make sure the ref count is safe. - lock.lock(); - } - - sycl::queue *create_queue(bool enable_exception_handler = false) - { - return create_in_order_queue(enable_exception_handler); - } - - sycl::queue *create_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { - return create_in_order_queue(context, device, enable_exception_handler); - } - - sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(enable_exception_handler, - sycl::property::queue::in_order()); - } - - sycl::queue *create_in_order_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(context, device, enable_exception_handler, - sycl::property::queue::in_order()); - } - - sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(enable_exception_handler); - } - - void destroy_queue(sycl::queue *&queue) - { - std::lock_guard lock(m_mutex); - _queues.erase(std::remove_if(_queues.begin(), _queues.end(), - [=](const std::shared_ptr &q) -> bool - { - return q.get() == queue; - }), - _queues.end()); - queue = nullptr; - } - void set_saved_queue(sycl::queue *q) - { - std::lock_guard lock(m_mutex); - _saved_queue = q; - } - sycl::queue *get_saved_queue() const - { - std::lock_guard lock(m_mutex); - return _saved_queue; - } - sycl::context get_context() const { return _ctx; } - - private: - void clear_queues() - { - _queues.clear(); - _q_in_order = _q_out_of_order = _saved_queue = nullptr; - } - - void init_queues() - { - _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); - _q_out_of_order = create_queue_impl(true); - _saved_queue = &default_queue(); - } - - /// Caller should acquire resource \p m_mutex before calling this function. - template - sycl::queue *create_queue_impl(bool enable_exception_handler, - Properties... properties) - { - sycl::async_handler eh = {}; - if (enable_exception_handler) - { - eh = exception_handler; - } - _queues.push_back(std::make_shared( - _ctx, *this, eh, - sycl::property_list( + } + + void get_device_info(device_info& out) const { + dpct::get_device_info(out, *this); + } + + device_info get_device_info() const { + device_info prop; + dpct::get_device_info(prop, *this); + return prop; + } + + void reset() { + std::lock_guard lock(m_mutex); + clear_queues(); + init_queues(); + } + + sycl::queue& in_order_queue() { + return *_q_in_order; + } + + sycl::queue& out_of_order_queue() { + return *_q_out_of_order; + } + + sycl::queue& default_queue() { + return in_order_queue(); + } + + void queues_wait_and_throw() { + std::unique_lock lock(m_mutex); + std::vector> current_queues(_queues); + lock.unlock(); + for (const auto& q : current_queues) { + q->wait_and_throw(); + } + // Guard the destruct of current_queues to make sure the ref count is safe. + lock.lock(); + } + + sycl::queue* create_queue(bool enable_exception_handler = false) { + return create_in_order_queue(enable_exception_handler); + } + + sycl::queue* create_queue( + sycl::context context, + sycl::device device, + bool enable_exception_handler = false) { + return create_in_order_queue(context, device, enable_exception_handler); + } + + sycl::queue* create_in_order_queue(bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl( + enable_exception_handler, sycl::property::queue::in_order()); + } + + sycl::queue* create_in_order_queue( + sycl::context context, + sycl::device device, + bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl( + context, + device, + enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue* create_out_of_order_queue( + bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler); + } + + void destroy_queue(sycl::queue*& queue) { + std::lock_guard lock(m_mutex); + _queues.erase( + std::remove_if( + _queues.begin(), + _queues.end(), + [=](const std::shared_ptr& q) -> bool { + return q.get() == queue; + }), + _queues.end()); + queue = nullptr; + } + void set_saved_queue(sycl::queue* q) { + std::lock_guard lock(m_mutex); + _saved_queue = q; + } + sycl::queue* get_saved_queue() const { + std::lock_guard lock(m_mutex); + return _saved_queue; + } + sycl::context get_context() const { + return _ctx; + } + + private: + void clear_queues() { + _queues.clear(); + _q_in_order = _q_out_of_order = _saved_queue = nullptr; + } + + void init_queues() { + _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); + _q_out_of_order = create_queue_impl(true); + _saved_queue = &default_queue(); + } + + /// Caller should acquire resource \p m_mutex before calling this function. + template + sycl::queue* create_queue_impl( + bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; + } + _queues.push_back(std::make_shared( + _ctx, + *this, + eh, + sycl::property_list( #ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), + sycl::property::queue::enable_profiling(), #endif - properties...))); - - return _queues.back().get(); - } - - template - sycl::queue *create_queue_impl(sycl::context context, sycl::device device, - bool enable_exception_handler, - Properties... properties) { - sycl::async_handler eh = {}; - if (enable_exception_handler) { - eh = exception_handler; - } - _queues.push_back(std::make_shared( - context, device, eh, - sycl::property_list( - #ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), - #endif - properties...))); - - return _queues.back().get(); - } - - void get_version(int &major, int &minor) const - { - detail::get_version(*this, major, minor); - } - sycl::queue *_q_in_order, *_q_out_of_order; - sycl::queue *_saved_queue; - sycl::context _ctx; - std::vector> _queues; - mutable mutex_type m_mutex; - }; - - /// device manager - class dev_mgr - { - public: - device_ext ¤t_device() - { - unsigned int dev_id = current_device_id(); - check_id(dev_id); - return *_devs[dev_id]; - } - device_ext &cpu_device() const - { - std::lock_guard lock(m_mutex); - if (_cpu_device == -1) - { - throw std::runtime_error("no valid cpu device"); - } - else - { - return *_devs[_cpu_device]; - } - } - device_ext &get_device(unsigned int id) const - { - std::lock_guard lock(m_mutex); - check_id(id); - return *_devs[id]; - } - unsigned int current_device_id() const - { - std::lock_guard lock(m_mutex); - auto it = _thread2dev_map.find(get_tid()); - if (it != _thread2dev_map.end()) - return it->second; - return DEFAULT_DEVICE_ID; - } - - /// Select device with a device ID. - /// \param [in] id The id of the device which can - /// be obtained through get_device_id(const sycl::device). - void select_device(unsigned int id) - { - std::lock_guard lock(m_mutex); - check_id(id); - _thread2dev_map[get_tid()] = id; - } - unsigned int device_count() { return _devs.size(); } - - unsigned int get_device_id(const sycl::device &dev) - { - unsigned int id = 0; - for (auto dev_item : _devs) - { - if (*dev_item == dev) - { - break; - } - id++; - } - return id; - } + properties...))); + + return _queues.back().get(); + } + + template + sycl::queue* create_queue_impl( + sycl::context context, + sycl::device device, + bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; + } + _queues.push_back(std::make_shared( + context, + device, + eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back().get(); + } + + void get_version(int& major, int& minor) const { + detail::get_version(*this, major, minor); + } + sycl::queue *_q_in_order, *_q_out_of_order; + sycl::queue* _saved_queue; + sycl::context _ctx; + std::vector> _queues; + mutable mutex_type m_mutex; +}; + +/// device manager +class dev_mgr { + public: + device_ext& current_device() { + unsigned int dev_id = current_device_id(); + check_id(dev_id); + return *_devs[dev_id]; + } + device_ext& cpu_device() const { + std::lock_guard lock(m_mutex); + if (_cpu_device == -1) { + throw std::runtime_error("no valid cpu device"); + } else { + return *_devs[_cpu_device]; + } + } + device_ext& get_device(unsigned int id) const { + std::lock_guard lock(m_mutex); + check_id(id); + return *_devs[id]; + } + unsigned int current_device_id() const { + std::lock_guard lock(m_mutex); + auto it = _thread2dev_map.find(get_tid()); + if (it != _thread2dev_map.end()) + return it->second; + return DEFAULT_DEVICE_ID; + } + + /// Select device with a device ID. + /// \param [in] id The id of the device which can + /// be obtained through get_device_id(const sycl::device). + void select_device(unsigned int id) { + std::lock_guard lock(m_mutex); + check_id(id); + _thread2dev_map[get_tid()] = id; + } + unsigned int device_count() { + return _devs.size(); + } + + unsigned int get_device_id(const sycl::device& dev) { + unsigned int id = 0; + for (auto dev_item : _devs) { + if (*dev_item == dev) { + break; + } + id++; + } + return id; + } + + template + std::enable_if_t< + std::is_invocable_r_v> + select_device(const DeviceSelector& selector = sycl::gpu_selector_v) { + sycl::device selected_device = sycl::device(selector); + unsigned int selected_device_id = get_device_id(selected_device); + select_device(selected_device_id); + } + + /// Returns the instance of device manager singleton. + static dev_mgr& instance() { + static dev_mgr d_m; + return d_m; + } + dev_mgr(const dev_mgr&) = delete; + dev_mgr& operator=(const dev_mgr&) = delete; + dev_mgr(dev_mgr&&) = delete; + dev_mgr& operator=(dev_mgr&&) = delete; + + private: + mutable std::recursive_mutex m_mutex; + static bool compare_dev(sycl::device& device1, sycl::device& device2) { + dpct::device_info prop1; + dpct::get_device_info(prop1, device1); + dpct::device_info prop2; + dpct::get_device_info(prop2, device2); + return prop1.get_max_compute_units() > prop2.get_max_compute_units(); + } + static int convert_backend_index(std::string& backend) { + if (backend == "ext_oneapi_level_zero:gpu") + return 0; + if (backend == "opencl:gpu") + return 1; + if (backend == "ext_oneapi_cuda:gpu") + return 2; + if (backend == "ext_oneapi_hip:gpu") + return 3; + if (backend == "opencl:cpu") + return 4; + if (backend == "opencl:acc") + return 5; + printf("convert_backend_index: can't handle backend=%s\n", backend.c_str()); + GGML_ASSERT(false); + } + static bool compare_backend(std::string& backend1, std::string& backend2) { + return convert_backend_index(backend1) < convert_backend_index(backend2); + } + dev_mgr() { + sycl::device default_device = sycl::device(sycl::default_selector_v); + _devs.push_back(std::make_shared(default_device)); + + std::vector sycl_all_devs; + // Collect other devices except for the default device. + if (default_device.is_cpu()) + _cpu_device = 0; + + auto Platforms = sycl::platform::get_platforms(); + // Keep track of the number of devices per backend + std::map DeviceNums; + std::map> backend_devices; + + while (!Platforms.empty()) { + auto Platform = Platforms.back(); + Platforms.pop_back(); + auto devices = Platform.get_devices(); + std::string backend_type = get_device_backend_and_type(devices[0]); + for (const auto& device : devices) { + backend_devices[backend_type].push_back(device); + } + } - template - std::enable_if_t< - std::is_invocable_r_v> - select_device(const DeviceSelector &selector = sycl::gpu_selector_v) - { - sycl::device selected_device = sycl::device(selector); - unsigned int selected_device_id = get_device_id(selected_device); - select_device(selected_device_id); - } + std::vector keys; + for (auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { + keys.push_back(it->first); + } + std::sort(keys.begin(), keys.end(), compare_backend); + + for (auto& key : keys) { + std::vector devs = backend_devices[key]; + std::sort(devs.begin(), devs.end(), compare_dev); + for (const auto& dev : devs) { + sycl_all_devs.push_back(dev); + } + } - /// Returns the instance of device manager singleton. - static dev_mgr &instance() - { - static dev_mgr d_m; - return d_m; - } - dev_mgr(const dev_mgr &) = delete; - dev_mgr &operator=(const dev_mgr &) = delete; - dev_mgr(dev_mgr &&) = delete; - dev_mgr &operator=(dev_mgr &&) = delete; - - private: - mutable std::recursive_mutex m_mutex; - static bool compare_dev(sycl::device &device1, sycl::device &device2) - { - dpct::device_info prop1; - dpct::get_device_info(prop1, device1); - dpct::device_info prop2; - dpct::get_device_info(prop2, device2); - return prop1.get_max_compute_units() > prop2.get_max_compute_units(); - } - static int convert_backend_index(std::string & backend) { - if (backend == "ext_oneapi_level_zero:gpu") return 0; - if (backend == "opencl:gpu") return 1; - if (backend == "ext_oneapi_cuda:gpu") return 2; - if (backend == "ext_oneapi_hip:gpu") return 3; - if (backend == "opencl:cpu") return 4; - if (backend == "opencl:acc") return 5; - printf("convert_backend_index: can't handle backend=%s\n", backend.c_str()); - GGML_ASSERT(false); - } - static bool compare_backend(std::string &backend1, std::string &backend2) { - return convert_backend_index(backend1) < convert_backend_index(backend2); - } - dev_mgr() - { - sycl::device default_device = - sycl::device(sycl::default_selector_v); - _devs.push_back(std::make_shared(default_device)); - - std::vector sycl_all_devs; - // Collect other devices except for the default device. - if (default_device.is_cpu()) - _cpu_device = 0; - - auto Platforms = sycl::platform::get_platforms(); - // Keep track of the number of devices per backend - std::map DeviceNums; - std::map> backend_devices; - - while (!Platforms.empty()) { - auto Platform = Platforms.back(); - Platforms.pop_back(); - auto devices = Platform.get_devices(); - std::string backend_type = get_device_backend_and_type(devices[0]); - for (const auto &device : devices) { - backend_devices[backend_type].push_back(device); - } - } - - std::vector keys; - for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { - keys.push_back(it->first); - } - std::sort(keys.begin(), keys.end(), compare_backend); - - for (auto &key : keys) { - std::vector devs = backend_devices[key]; - std::sort(devs.begin(), devs.end(), compare_dev); - for (const auto &dev : devs) { - sycl_all_devs.push_back(dev); - } - } - - for (auto &dev : sycl_all_devs) - { - if (dev == default_device) - { - continue; - } - _devs.push_back(std::make_shared(dev)); - if (_cpu_device == -1 && dev.is_cpu()) - { - _cpu_device = _devs.size() - 1; - } - } - } - void check_id(unsigned int id) const - { - if (id >= _devs.size()) - { - throw std::runtime_error("invalid device id"); - } - } - std::vector> _devs; - /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current - /// thread id in _thread2dev_map, which means default device should be used - /// for the current thread. - const unsigned int DEFAULT_DEVICE_ID = 0; - /// thread-id to device-id map. - std::map _thread2dev_map; - int _cpu_device = -1; - }; - - static inline sycl::queue &get_default_queue() - { - return dev_mgr::instance().current_device().default_queue(); - } - - namespace detail - { - enum class pointer_access_attribute - { - host_only = 0, - device_only, - host_device, - end - }; + for (auto& dev : sycl_all_devs) { + if (dev == default_device) { + continue; + } + _devs.push_back(std::make_shared(dev)); + if (_cpu_device == -1 && dev.is_cpu()) { + _cpu_device = _devs.size() - 1; + } + } + } + void check_id(unsigned int id) const { + if (id >= _devs.size()) { + throw std::runtime_error("invalid device id"); + } + } + std::vector> _devs; + /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current + /// thread id in _thread2dev_map, which means default device should be used + /// for the current thread. + const unsigned int DEFAULT_DEVICE_ID = 0; + /// thread-id to device-id map. + std::map _thread2dev_map; + int _cpu_device = -1; +}; + +static inline sycl::queue& get_default_queue() { + return dev_mgr::instance().current_device().default_queue(); +} - static pointer_access_attribute get_pointer_attribute(sycl::queue &q, - const void *ptr) - { - switch (sycl::get_pointer_type(ptr, q.get_context())) - { - case sycl::usm::alloc::unknown: - return pointer_access_attribute::host_only; - case sycl::usm::alloc::device: - return pointer_access_attribute::device_only; - case sycl::usm::alloc::shared: - case sycl::usm::alloc::host: - return pointer_access_attribute::host_device; - } - } +namespace detail { +enum class pointer_access_attribute { + host_only = 0, + device_only, + host_device, + end +}; + +static pointer_access_attribute get_pointer_attribute( + sycl::queue& q, + const void* ptr) { + switch (sycl::get_pointer_type(ptr, q.get_context())) { + case sycl::usm::alloc::unknown: + return pointer_access_attribute::host_only; + case sycl::usm::alloc::device: + return pointer_access_attribute::device_only; + case sycl::usm::alloc::shared: + case sycl::usm::alloc::host: + return pointer_access_attribute::host_device; + } +} - template - inline constexpr std::uint64_t get_type_combination_id(ArgT Val) - { - static_assert((unsigned char)library_data_t::library_data_t_size <= - std::numeric_limits::max() && - "library_data_t size exceeds limit."); - static_assert(std::is_same_v, "Unsupported ArgT"); - return (std::uint64_t)Val; - } +template +inline constexpr std::uint64_t get_type_combination_id(ArgT Val) { + static_assert( + (unsigned char)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(std::is_same_v, "Unsupported ArgT"); + return (std::uint64_t)Val; +} - template - inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, - RestT... RestVal) - { - static_assert((std::uint8_t)library_data_t::library_data_t_size <= - std::numeric_limits::max() && - "library_data_t size exceeds limit."); - static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); - static_assert(std::is_same_v, "Unsupported FirstT"); - return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); - } +template +inline constexpr std::uint64_t get_type_combination_id( + FirstT FirstVal, + RestT... RestVal) { + static_assert( + (std::uint8_t)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); + static_assert(std::is_same_v, "Unsupported FirstT"); + return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); +} - class mem_mgr - { - mem_mgr() - { - // Reserved address space, no real memory allocation happens here. +class mem_mgr { + mem_mgr() { + // Reserved address space, no real memory allocation happens here. #if defined(__linux__) - mapped_address_space = - (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + mapped_address_space = (byte_t*)mmap( + nullptr, + mapped_region_size, + PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); #elif defined(_WIN64) - mapped_address_space = (byte_t *)VirtualAlloc( - NULL, // NULL specified as the base address parameter - mapped_region_size, // Size of allocation - MEM_RESERVE, // Allocate reserved pages - PAGE_NOACCESS); // Protection = no access + mapped_address_space = (byte_t*)VirtualAlloc( + NULL, // NULL specified as the base address parameter + mapped_region_size, // Size of allocation + MEM_RESERVE, // Allocate reserved pages + PAGE_NOACCESS); // Protection = no access #else #error "Only support Windows and Linux." #endif - next_free = mapped_address_space; - }; + next_free = mapped_address_space; + }; - public: - using buffer_id_t = int; + public: + using buffer_id_t = int; - struct allocation - { - buffer_t buffer; - byte_t *alloc_ptr; - size_t size; - }; + struct allocation { + buffer_t buffer; + byte_t* alloc_ptr; + size_t size; + }; - ~mem_mgr() - { + ~mem_mgr() { #if defined(__linux__) - munmap(mapped_address_space, mapped_region_size); + munmap(mapped_address_space, mapped_region_size); #elif defined(_WIN64) - VirtualFree(mapped_address_space, 0, MEM_RELEASE); + VirtualFree(mapped_address_space, 0, MEM_RELEASE); #else #error "Only support Windows and Linux." #endif - }; - - mem_mgr(const mem_mgr &) = delete; - mem_mgr &operator=(const mem_mgr &) = delete; - mem_mgr(mem_mgr &&) = delete; - mem_mgr &operator=(mem_mgr &&) = delete; - - /// Allocate - void *mem_alloc(size_t size) - { - if (!size) - return nullptr; - std::lock_guard lock(m_mutex); - if (next_free + size > mapped_address_space + mapped_region_size) - { - throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool"); - } - // Allocation - sycl::range<1> r(size); - buffer_t buf(r); - allocation A{buf, next_free, size}; - // Map allocation to device pointer - void *result = next_free; - m_map.emplace(next_free + size, A); - // Update pointer to the next free space. - next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); - - return result; - } - - /// Deallocate - void mem_free(const void *ptr) - { - if (!ptr) - return; - std::lock_guard lock(m_mutex); - auto it = get_map_iterator(ptr); - m_map.erase(it); - } - - /// map: device pointer -> allocation(buffer, alloc_ptr, size) - allocation translate_ptr(const void *ptr) - { - std::lock_guard lock(m_mutex); - auto it = get_map_iterator(ptr); - return it->second; - } - - /// Check if the pointer represents device pointer or not. - bool is_device_ptr(const void *ptr) const - { - std::lock_guard lock(m_mutex); - return (mapped_address_space <= ptr) && - (ptr < mapped_address_space + mapped_region_size); - } - - /// Returns the instance of memory manager singleton. - static mem_mgr &instance() - { - static mem_mgr m; - return m; - } - - private: - std::map m_map; - mutable std::mutex m_mutex; - byte_t *mapped_address_space; - byte_t *next_free; - const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; - const size_t alignment = 256; - /// This padding may be defined to some positive value to debug - /// out of bound accesses. - const size_t extra_padding = 0; - - std::map::iterator get_map_iterator(const void *ptr) - { - auto it = m_map.upper_bound((byte_t *)ptr); - if (it == m_map.end()) - { - // Not a virtual pointer. - throw std::runtime_error("can not get buffer from non-virtual pointer"); - } - const allocation &alloc = it->second; - if (ptr < alloc.alloc_ptr) - { - // Out of bound. - // This may happen if there's a gap between allocations due to alignment - // or extra padding and pointer points to this gap. - throw std::runtime_error("invalid virtual pointer"); - } - return it; - } - }; - - template - class accessor; - template - class memory_traits - { - public: - static constexpr sycl::access::target target = - sycl::access::target::device; - static constexpr sycl::access_mode mode = - (Memory == constant) ? sycl::access_mode::read - : sycl::access_mode::read_write; - static constexpr size_t type_size = sizeof(T); - using element_t = - typename std::conditional::type; - using value_t = typename std::remove_cv::type; - template - using accessor_t = typename std::conditional< - Memory == local, sycl::local_accessor, - sycl::accessor>::type; - using pointer_t = T *; - }; - - static inline void *dpct_malloc(size_t size, sycl::queue &q) - { - return sycl::malloc_device(size, q.get_device(), q.get_context()); - } + }; + + mem_mgr(const mem_mgr&) = delete; + mem_mgr& operator=(const mem_mgr&) = delete; + mem_mgr(mem_mgr&&) = delete; + mem_mgr& operator=(mem_mgr&&) = delete; + + /// Allocate + void* mem_alloc(size_t size) { + if (!size) + return nullptr; + std::lock_guard lock(m_mutex); + if (next_free + size > mapped_address_space + mapped_region_size) { + throw std::runtime_error( + "dpct_malloc: out of memory for virtual memory pool"); + } + // Allocation + sycl::range<1> r(size); + buffer_t buf(r); + allocation A{buf, next_free, size}; + // Map allocation to device pointer + void* result = next_free; + m_map.emplace(next_free + size, A); + // Update pointer to the next free space. + next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); + + return result; + } + + /// Deallocate + void mem_free(const void* ptr) { + if (!ptr) + return; + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + m_map.erase(it); + } + + /// map: device pointer -> allocation(buffer, alloc_ptr, size) + allocation translate_ptr(const void* ptr) { + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + return it->second; + } + + /// Check if the pointer represents device pointer or not. + bool is_device_ptr(const void* ptr) const { + std::lock_guard lock(m_mutex); + return (mapped_address_space <= ptr) && + (ptr < mapped_address_space + mapped_region_size); + } + + /// Returns the instance of memory manager singleton. + static mem_mgr& instance() { + static mem_mgr m; + return m; + } + + private: + std::map m_map; + mutable std::mutex m_mutex; + byte_t* mapped_address_space; + byte_t* next_free; + const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; + const size_t alignment = 256; + /// This padding may be defined to some positive value to debug + /// out of bound accesses. + const size_t extra_padding = 0; + + std::map::iterator get_map_iterator(const void* ptr) { + auto it = m_map.upper_bound((byte_t*)ptr); + if (it == m_map.end()) { + // Not a virtual pointer. + throw std::runtime_error("can not get buffer from non-virtual pointer"); + } + const allocation& alloc = it->second; + if (ptr < alloc.alloc_ptr) { + // Out of bound. + // This may happen if there's a gap between allocations due to alignment + // or extra padding and pointer points to this gap. + throw std::runtime_error("invalid virtual pointer"); + } + return it; + } +}; + +template +class accessor; +template +class memory_traits { + public: + static constexpr sycl::access::target target = sycl::access::target::device; + static constexpr sycl::access_mode mode = (Memory == constant) + ? sycl::access_mode::read + : sycl::access_mode::read_write; + static constexpr size_t type_size = sizeof(T); + using element_t = + typename std::conditional::type; + using value_t = typename std::remove_cv::type; + template + using accessor_t = typename std::conditional< + Memory == local, + sycl::local_accessor, + sycl::accessor>::type; + using pointer_t = T*; +}; + +static inline void* dpct_malloc(size_t size, sycl::queue& q) { + return sycl::malloc_device(size, q.get_device(), q.get_context()); +} #define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) - static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, - sycl::queue &q) - { - pitch = PITCH_DEFAULT_ALIGN(x); - return dpct_malloc(pitch * y * z, q); - } +static inline void* dpct_malloc( + size_t& pitch, + size_t x, + size_t y, + size_t z, + sycl::queue& q) { + pitch = PITCH_DEFAULT_ALIGN(x); + return dpct_malloc(pitch * y * z, q); +} - /** - * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] dev_ptr Pointer to the virtual device memory address. - * @param [in] value The value to be set. - * @param [in] size Number of elements to be set to the value. - * @return An event representing the memset operation. - */ - template - static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, - valueT value, size_t size) - { - return q.fill(dev_ptr, value, size); - } +/** + * @brief Sets \p value to the first \p size elements starting from \p dev_ptr + * in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @return An event representing the memset operation. + */ +template +static inline sycl::event dpct_memset( + sycl::queue& q, + void* dev_ptr, + valueT value, + size_t size) { + return q.fill(dev_ptr, value, size); +} - /** - * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] data Pointer to the pitched device memory region. - * @param [in] value The value to be set. - * @param [in] size 3D memory region by number of elements. - * @return An event list representing the memset operations. - */ - template - static inline std::vector - dpct_memset(sycl::queue &q, pitched_data data, valueT value, - sycl::range<3> size) - { - std::vector event_list; - size_t slice = data.get_pitch() * data.get_y(); - unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); - for (size_t z = 0; z < size.get(2); ++z) - { - unsigned char *data_ptr = data_surface; - for (size_t y = 0; y < size.get(1); ++y) - { - event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); - data_ptr += data.get_pitch(); - } - data_surface += slice; - } - return event_list; - } +/** + * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @return An event list representing the memset operations. + */ +template +static inline std::vector dpct_memset( + sycl::queue& q, + pitched_data data, + valueT value, + sycl::range<3> size) { + std::vector event_list; + size_t slice = data.get_pitch() * data.get_y(); + unsigned char* data_surface = (unsigned char*)data.get_data_ptr(); + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char* data_ptr = data_surface; + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); + data_ptr += data.get_pitch(); + } + data_surface += slice; + } + return event_list; +} - /** - * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] ptr Pointer to the virtual device memory. - * @param [in] pitch The pitch size by number of elements, including padding. - * @param [in] val The value to be set. - * @param [in] x The width of memory region by number of elements. - * @param [in] y The height of memory region by number of elements. - * @return An event list representing the memset operations. - */ - template - static inline std::vector - dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, - size_t y) - { - return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, - sycl::range<3>(x, y, 1)); - } +/** + * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @return An event list representing the memset operations. + */ +template +static inline std::vector dpct_memset( + sycl::queue& q, + void* ptr, + size_t pitch, + valueT val, + size_t x, + size_t y) { + return dpct_memset( + q, pitched_data(ptr, pitch, x, 1), val, sycl::range<3>(x, y, 1)); +} - static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, - const void *from_ptr, - memcpy_direction dir) - { - switch (dir) - { - case memcpy_direction::host_to_host: - case memcpy_direction::host_to_device: - case memcpy_direction::device_to_host: - case memcpy_direction::device_to_device: - return dir; - case memcpy_direction::automatic: - { - // table[to_attribute][from_attribute] - static const memcpy_direction - direction_table[static_cast(pointer_access_attribute::end)] - [static_cast(pointer_access_attribute::end)] = - {{memcpy_direction::host_to_host, - memcpy_direction::device_to_host, - memcpy_direction::host_to_host}, - {memcpy_direction::host_to_device, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}, - {memcpy_direction::host_to_host, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}}; - return direction_table[static_cast(get_pointer_attribute( - q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; - } - default: - throw std::runtime_error("dpct_memcpy: invalid direction value"); - } - } +static memcpy_direction deduce_memcpy_direction( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + memcpy_direction dir) { + switch (dir) { + case memcpy_direction::host_to_host: + case memcpy_direction::host_to_device: + case memcpy_direction::device_to_host: + case memcpy_direction::device_to_device: + return dir; + case memcpy_direction::automatic: { + // table[to_attribute][from_attribute] + static const memcpy_direction + direction_table[static_cast(pointer_access_attribute::end)] + [static_cast( + pointer_access_attribute::end)] = { + {memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}}; + return direction_table[static_cast(get_pointer_attribute( + q, + to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } +} - static sycl::event - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction, - const std::vector &dep_events = {}) - { - if (!size) - return sycl::event{}; - return q.memcpy(to_ptr, from_ptr, size, dep_events); - GGML_UNUSED(direction); - } +static sycl::event dpct_memcpy( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + size_t size, + memcpy_direction direction, + const std::vector& dep_events = {}) { + if (!size) + return sycl::event{}; + return q.memcpy(to_ptr, from_ptr, size, dep_events); + GGML_UNUSED(direction); +} - // Get actual copy range and make sure it will not exceed range. - static inline size_t get_copy_range(sycl::range<3> size, size_t slice, - size_t pitch) - { - return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); - } +// Get actual copy range and make sure it will not exceed range. +static inline size_t get_copy_range( + sycl::range<3> size, + size_t slice, + size_t pitch) { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); +} - static inline size_t get_offset(sycl::id<3> id, size_t slice, - size_t pitch) - { - return slice * id.get(2) + pitch * id.get(1) + id.get(0); - } +static inline size_t get_offset(sycl::id<3> id, size_t slice, size_t pitch) { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); +} - /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr - /// and \p from_range to another specified by \p to_ptr and \p to_range. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - sycl::range<3> to_range, sycl::range<3> from_range, - sycl::id<3> to_id, sycl::id<3> from_id, - sycl::range<3> size, memcpy_direction direction, - const std::vector &dep_events = {}) - { - // RAII for host pointer - class host_buffer - { - void *_buf; - size_t _size; - sycl::queue &_q; - const std::vector &_deps; // free operation depends - - public: - host_buffer(size_t size, sycl::queue &q, - const std::vector &deps) - : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} - void *get_ptr() const { return _buf; } - size_t get_size() const { return _size; } - ~host_buffer() - { - if (_buf) - { - _q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(_deps); - cgh.host_task([buf = _buf] { std::free(buf); }); }); - } - } - }; - std::vector event_list; - - size_t to_slice = to_range.get(1) * to_range.get(0), - from_slice = from_range.get(1) * from_range.get(0); - unsigned char *to_surface = - (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); - const unsigned char *from_surface = - (const unsigned char *)from_ptr + - get_offset(from_id, from_slice, from_range.get(0)); - - if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) - { - return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events)}; - } - direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); - size_t size_slice = size.get(1) * size.get(0); - switch (direction) - { - case host_to_host: - for (size_t z = 0; z < size.get(2); ++z) - { - unsigned char *to_ptr = to_surface; - const unsigned char *from_ptr = from_surface; - if (to_range.get(0) == from_range.get(0) && - to_range.get(0) == size.get(0)) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, - direction, dep_events)); - } - else - { - for (size_t y = 0; y < size.get(1); ++y) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), - direction, dep_events)); - to_ptr += to_range.get(0); - from_ptr += from_range.get(0); - } - } - to_surface += to_slice; - from_surface += from_slice; - } - break; - case host_to_device: - { - host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, - event_list); - std::vector host_events; - if (to_slice == size_slice) - { - // Copy host data to a temp host buffer with the shape of target. - host_events = - dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, - host_to_host, dep_events); - } - else - { - // Copy host data to a temp host buffer with the shape of target. - host_events = dpct_memcpy( - q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, - // If has padding data, not sure whether it is useless. So fill temp - // buffer with it. - std::vector{ - dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), - device_to_host, dep_events)}); - } - // Copy from temp host buffer to device with only one submit. - event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), - buf.get_size(), host_to_device, - host_events)); - break; - } - case device_to_host: - { - host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, - event_list); - // Copy from host temp buffer to host target with reshaping. - event_list = dpct_memcpy( - q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), - sycl::id<3>(0, 0, 0), size, host_to_host, - // Copy from device to temp host buffer with only one submit. - std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, - buf.get_size(), - device_to_host, dep_events)}); - break; - } - case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh){ - cgh.depends_on(dep_events); - cgh.parallel_for( - size, - [=](sycl::id<3> id) { - to_surface[get_offset(id, to_slice, to_range.get(0))] = - from_surface[get_offset(id, from_slice, from_range.get(0))]; - }); })); - break; - default: - throw std::runtime_error("dpct_memcpy: invalid direction value"); - } - return event_list; - } +/// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr +/// and \p from_range to another specified by \p to_ptr and \p to_range. +static inline std::vector dpct_memcpy( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + sycl::range<3> to_range, + sycl::range<3> from_range, + sycl::id<3> to_id, + sycl::id<3> from_id, + sycl::range<3> size, + memcpy_direction direction, + const std::vector& dep_events = {}) { + // RAII for host pointer + class host_buffer { + void* _buf; + size_t _size; + sycl::queue& _q; + const std::vector& _deps; // free operation depends + + public: + host_buffer( + size_t size, + sycl::queue& q, + const std::vector& deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void* get_ptr() const { + return _buf; + } + size_t get_size() const { + return _size; + } + ~host_buffer() { + if (_buf) { + _q.submit([&](sycl::handler& cgh) { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); + }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char* to_surface = + (unsigned char*)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char* from_surface = (const unsigned char*)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { + return {dpct_memcpy( + q, + to_surface, + from_surface, + to_slice * size.get(2), + direction, + dep_events)}; + } + direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char* to_ptr = to_surface; + const unsigned char* from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) { + event_list.push_back(dpct_memcpy( + q, to_ptr, from_ptr, size_slice, direction, dep_events)); + } else { + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back(dpct_memcpy( + q, to_ptr, from_ptr, size.get(0), direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: { + host_buffer buf( + get_copy_range(size, to_slice, to_range.get(0)), q, event_list); + std::vector host_events; + if (to_slice == size_slice) { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, + buf.get_ptr(), + from_surface, + to_range, + from_range, + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + size, + host_to_host, + dep_events); + } else { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, + buf.get_ptr(), + from_surface, + to_range, + from_range, + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + size, + host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{dpct_memcpy( + q, + buf.get_ptr(), + to_surface, + buf.get_size(), + device_to_host, + dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy( + q, + to_surface, + buf.get_ptr(), + buf.get_size(), + host_to_device, + host_events)); + break; + } + case device_to_host: { + host_buffer buf( + get_copy_range(size, from_slice, from_range.get(0)), q, event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, + to_surface, + buf.get_ptr(), + to_range, + from_range, + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + size, + host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy( + q, + buf.get_ptr(), + from_surface, + buf.get_size(), + device_to_host, + dep_events)}); + break; + } + case device_to_device: + event_list.push_back(q.submit([&](sycl::handler& cgh) { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); + })); + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; +} - /// memcpy 2D/3D matrix specified by pitched_data. - static inline std::vector - dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, - pitched_data from, sycl::id<3> from_id, sycl::range<3> size, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), - sycl::range<3>(to.get_pitch(), to.get_y(), 1), - sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, - size, direction); - } +/// memcpy 2D/3D matrix specified by pitched_data. +static inline std::vector dpct_memcpy( + sycl::queue& q, + pitched_data to, + sycl::id<3> to_id, + pitched_data from, + sycl::id<3> from_id, + sycl::range<3> size, + memcpy_direction direction = automatic) { + return dpct_memcpy( + q, + to.get_data_ptr(), + from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), + to_id, + from_id, + size, + direction); +} - /// memcpy 2D matrix with pitch. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - size_t to_pitch, size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), - sycl::range<3>(from_pitch, y, 1), - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), - sycl::range<3>(x, y, 1), direction); - } +/// memcpy 2D matrix with pitch. +static inline std::vector dpct_memcpy( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + size_t to_pitch, + size_t from_pitch, + size_t x, + size_t y, + memcpy_direction direction = automatic) { + return dpct_memcpy( + q, + to_ptr, + from_ptr, + sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), + direction); +} - namespace deprecated - { - - template - class usm_allocator - { - private: - using Alloc = sycl::usm_allocator; - Alloc _impl; - - public: - using value_type = typename std::allocator_traits::value_type; - using pointer = typename std::allocator_traits::pointer; - using const_pointer = typename std::allocator_traits::const_pointer; - using void_pointer = typename std::allocator_traits::void_pointer; - using const_void_pointer = - typename std::allocator_traits::const_void_pointer; - using reference = typename std::allocator_traits::value_type &; - using const_reference = - const typename std::allocator_traits::value_type &; - using difference_type = - typename std::allocator_traits::difference_type; - using size_type = typename std::allocator_traits::size_type; - using propagate_on_container_copy_assignment = typename std::allocator_traits< - Alloc>::propagate_on_container_copy_assignment; - using propagate_on_container_move_assignment = typename std::allocator_traits< - Alloc>::propagate_on_container_move_assignment; - using propagate_on_container_swap = - typename std::allocator_traits::propagate_on_container_swap; - using is_always_equal = - typename std::allocator_traits::is_always_equal; - - template - struct rebind - { - typedef usm_allocator other; - }; - - usm_allocator() : _impl(dpct::get_default_queue()) {} - ~usm_allocator() {} - usm_allocator(const usm_allocator &other) : _impl(other._impl) {} - usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} - pointer address(reference r) { return &r; } - const_pointer address(const_reference r) { return &r; } - pointer allocate(size_type cnt, const_void_pointer hint = nullptr) - { - return std::allocator_traits::allocate(_impl, cnt, hint); - } - void deallocate(pointer p, size_type cnt) - { - std::allocator_traits::deallocate(_impl, p, cnt); - } - size_type max_size() const - { - return std::allocator_traits::max_size(_impl); - } - bool operator==(const usm_allocator &other) const { return _impl == other._impl; } - bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } - }; - - } // namespace deprecated - - inline void dpct_free(void *ptr, - const sycl::queue &q) - { - if (ptr) - { - sycl::free(ptr, q.get_context()); - } - } +namespace deprecated { + +template +class usm_allocator { + private: + using Alloc = sycl::usm_allocator; + Alloc _impl; + + public: + using value_type = typename std::allocator_traits::value_type; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = typename std::allocator_traits::const_pointer; + using void_pointer = typename std::allocator_traits::void_pointer; + using const_void_pointer = + typename std::allocator_traits::const_void_pointer; + using reference = typename std::allocator_traits::value_type&; + using const_reference = + const typename std::allocator_traits::value_type&; + using difference_type = + typename std::allocator_traits::difference_type; + using size_type = typename std::allocator_traits::size_type; + using propagate_on_container_copy_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_copy_assignment; + using propagate_on_container_move_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_move_assignment; + using propagate_on_container_swap = + typename std::allocator_traits::propagate_on_container_swap; + using is_always_equal = + typename std::allocator_traits::is_always_equal; + + template + struct rebind { + typedef usm_allocator other; + }; + + usm_allocator() : _impl(dpct::get_default_queue()) {} + ~usm_allocator() {} + usm_allocator(const usm_allocator& other) : _impl(other._impl) {} + usm_allocator(usm_allocator&& other) : _impl(std::move(other._impl)) {} + pointer address(reference r) { + return &r; + } + const_pointer address(const_reference r) { + return &r; + } + pointer allocate(size_type cnt, const_void_pointer hint = nullptr) { + return std::allocator_traits::allocate(_impl, cnt, hint); + } + void deallocate(pointer p, size_type cnt) { + std::allocator_traits::deallocate(_impl, p, cnt); + } + size_type max_size() const { + return std::allocator_traits::max_size(_impl); + } + bool operator==(const usm_allocator& other) const { + return _impl == other._impl; + } + bool operator!=(const usm_allocator& other) const { + return _impl != other._impl; + } +}; + +} // namespace deprecated + +inline void dpct_free(void* ptr, const sycl::queue& q) { + if (ptr) { + sycl::free(ptr, q.get_context()); + } +} - template - inline auto get_memory(const void *x) - { - T *new_x = reinterpret_cast(const_cast(x)); - return new_x; - } +template +inline auto get_memory(const void* x) { + T* new_x = reinterpret_cast(const_cast(x)); + return new_x; +} - template - inline typename DataType::T2 get_value(const T *s, sycl::queue &q) - { - using Ty = typename DataType::T2; - Ty s_h; - if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) - detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host) - .wait(); - else - s_h = *reinterpret_cast(s); - return s_h; - } +template +inline typename DataType::T2 get_value(const T* s, sycl::queue& q) { + using Ty = typename DataType::T2; + Ty s_h; + if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) + detail::dpct_memcpy( + q, (void*)&s_h, (const void*)s, sizeof(T), device_to_host) + .wait(); + else + s_h = *reinterpret_cast(s); + return s_h; +} - } // namespace detail +} // namespace detail - template - inline auto get_value(const T *s, sycl::queue &q) - { - return detail::get_value(s, q); - } +template +inline auto get_value(const T* s, sycl::queue& q) { + return detail::get_value(s, q); +} - namespace detail - { - template - inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, int lda, const void *b, - int ldb, const void *beta, void *c, int ldc) - { +namespace detail { +template +inline void gemm_impl( + sycl::queue& q, + oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, + int m, + int n, + int k, + const void* alpha, + const void* a, + int lda, + const void* b, + int ldb, + const void* beta, + void* c, + int ldc) { #ifndef __INTEL_MKL__ - GGML_UNUSED(q); - GGML_UNUSED(a_trans); - GGML_UNUSED(b_trans); - GGML_UNUSED(m); - GGML_UNUSED(n); - GGML_UNUSED(k); - GGML_UNUSED(alpha); - GGML_UNUSED(a); - GGML_UNUSED(lda); - GGML_UNUSED(b); - GGML_UNUSED(ldb); - GGML_UNUSED(beta); - GGML_UNUSED(c); - GGML_UNUSED(ldc); - throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " - "Project does not support this API."); + GGML_UNUSED(q); + GGML_UNUSED(a_trans); + GGML_UNUSED(b_trans); + GGML_UNUSED(m); + GGML_UNUSED(n); + GGML_UNUSED(k); + GGML_UNUSED(alpha); + GGML_UNUSED(a); + GGML_UNUSED(lda); + GGML_UNUSED(b); + GGML_UNUSED(ldb); + GGML_UNUSED(beta); + GGML_UNUSED(c); + GGML_UNUSED(ldc); + throw std::runtime_error( + "The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); #else - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - data_b, ldb, beta_value, data_c, ldc); + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm( + q, + a_trans, + b_trans, + m, + n, + k, + alpha_value, + data_a, + lda, + data_b, + ldb, + beta_value, + data_c, + ldc); #endif - } +} - template - class vectorized_binary - { - public: - inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) - { - VecT v4; - for (size_t i = 0; i < v4.size(); ++i) - { - v4[i] = binary_op(a[i], b[i]); - } - return v4; - } - }; +template +class vectorized_binary { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) { + v4[i] = binary_op(a[i], b[i]); + } + return v4; + } +}; + +template +class vectorized_binary< + VecT, + BinaryOperation, + std::void_t>> { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) { + return binary_op(a, b).template as(); + } +}; + +template +inline void gemm_batch_impl( + sycl::queue& q, + oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, + int m, + int n, + int k, + const void* alpha, + const void** a, + int lda, + const void** b, + int ldb, + const void* beta, + void** c, + int ldc, + int batch_size) { + struct matrix_info_t { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; + }; + + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + + matrix_info_t* matrix_info = + (matrix_info_t*)std::malloc(sizeof(matrix_info_t)); + matrix_info->transpose_info[0] = a_trans; + matrix_info->transpose_info[1] = b_trans; + matrix_info->value_info[0] = alpha_value; + matrix_info->value_info[1] = beta_value; + matrix_info->size_info[0] = m; + matrix_info->size_info[1] = n; + matrix_info->size_info[2] = k; + matrix_info->ld_info[0] = lda; + matrix_info->ld_info[1] = ldb; + matrix_info->ld_info[2] = ldc; + matrix_info->groupsize_info = batch_size; + + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, + matrix_info->transpose_info, + matrix_info->transpose_info + 1, + matrix_info->size_info, + matrix_info->size_info + 1, + matrix_info->size_info + 2, + matrix_info->value_info, + reinterpret_cast(a), + matrix_info->ld_info, + reinterpret_cast(b), + matrix_info->ld_info + 1, + matrix_info->value_info + 1, + reinterpret_cast(c), + matrix_info->ld_info + 2, + 1, + &(matrix_info->groupsize_info)); + + q.submit([&](sycl::handler& cgh) { + cgh.depends_on(e); + cgh.host_task([=] { std::free(matrix_info); }); + }); +} - template - class vectorized_binary< - VecT, BinaryOperation, - std::void_t>> - { - public: - inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) - { - return binary_op(a, b).template as(); - } - }; +template +inline void gemm_batch_impl( + sycl::queue& q, + oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, + int m, + int n, + int k, + const void* alpha, + const void* a, + int lda, + long long int stride_a, + const void* b, + int ldb, + long long int stride_b, + const void* beta, + void* c, + int ldc, + long long int stride_c, + int batch_size) { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm_batch( + q, + a_trans, + b_trans, + m, + n, + k, + alpha_value, + data_a, + lda, + stride_a, + data_b, + ldb, + stride_b, + beta_value, + data_c, + ldc, + stride_c, + batch_size); +} - template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) - { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); - matrix_info->transpose_info[0] = a_trans; - matrix_info->transpose_info[1] = b_trans; - matrix_info->value_info[0] = alpha_value; - matrix_info->value_info[1] = beta_value; - matrix_info->size_info[0] = m; - matrix_info->size_info[1] = n; - matrix_info->size_info[2] = k; - matrix_info->ld_info[0] = lda; - matrix_info->ld_info[1] = ldb; - matrix_info->ld_info[2] = ldc; - matrix_info->groupsize_info = batch_size; - - sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); - - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); - } +} // namespace detail + +template +inline unsigned vectorized_binary( + unsigned a, + unsigned b, + const BinaryOperation binary_op) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = + detail::vectorized_binary()(v2, v3, binary_op); + v0 = v4.template as>(); + return v0; +} - template - inline void - gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, - int k, const void *alpha, const void *a, int lda, - long long int stride_a, const void *b, int ldb, - long long int stride_b, const void *beta, void *c, - int ldc, long long int stride_c, int batch_size) - { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, - data_c, ldc, stride_c, batch_size); - } +static void async_dpct_memcpy( + void* to_ptr, + const void* from_ptr, + size_t size, + memcpy_direction direction = automatic, + sycl::queue& q = dpct::get_default_queue()) { + detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); +} - } // namespace detail - - template - inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op) - { - sycl::vec v0{a}, v1{b}; - auto v2 = v0.as(); - auto v3 = v1.as(); - auto v4 = - detail::vectorized_binary()(v2, v3, binary_op); - v0 = v4.template as>(); - return v0; - } - - static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction = automatic, - sycl::queue &q = dpct::get_default_queue()) - { - detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); - } - - static inline unsigned int select_device(unsigned int id) - { - dev_mgr::instance().select_device(id); - return id; - } - - template - T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, - unsigned int logical_sub_group_size = 32) - { - unsigned int id = g.get_local_linear_id(); - unsigned int start_index = - id / logical_sub_group_size * logical_sub_group_size; - unsigned int target_offset = (id % logical_sub_group_size) ^ mask; - return sycl::select_from_group(g, x, - target_offset < logical_sub_group_size - ? start_index + target_offset - : id); - } - - template - sycl::vec extract_and_sign_or_zero_extend4(T val) - { - return sycl::vec(val) - .template as, int8_t, uint8_t>, 4>>() - .template convert(); - } - - template - using dot_product_acc_t = - std::conditional_t && std::is_unsigned_v, - uint32_t, int32_t>; - - template - inline auto dp4a(T1 a, T2 b, T3 c) - { - dot_product_acc_t res = c; - auto va = extract_and_sign_or_zero_extend4(a); - auto vb = extract_and_sign_or_zero_extend4(b); - res += va[0] * vb[0]; - res += va[1] * vb[1]; - res += va[2] * vb[2]; - res += va[3] * vb[3]; - return res; - } - - struct sub_sat - { - template - auto operator()(const T x, const T y) const - { - return sycl::sub_sat(x, y); - } - }; - - template - inline T vectorized_min(T a, T b) - { - sycl::vec v0{a}, v1{b}; - auto v2 = v0.template as(); - auto v3 = v1.template as(); - auto v4 = sycl::min(v2, v3); - v0 = v4.template as>(); - return v0; - } - - inline float pow(const float a, const int b) { return sycl::pown(a, b); } - inline double pow(const double a, const int b) { return sycl::pown(a, b); } - inline float pow(const float a, const float b) { return sycl::pow(a, b); } - inline double pow(const double a, const double b) { return sycl::pow(a, b); } - template - inline typename std::enable_if_t, T> - pow(const T a, const U b) - { - return sycl::pow(a, static_cast(b)); - } - template - inline typename std::enable_if_t, double> - pow(const T a, const U b) - { - return sycl::pow(static_cast(a), static_cast(b)); - } - - inline double min(const double a, const float b) - { - return sycl::fmin(a, static_cast(b)); - } - inline double min(const float a, const double b) - { - return sycl::fmin(static_cast(a), b); - } - inline float min(const float a, const float b) { return sycl::fmin(a, b); } - inline double min(const double a, const double b) { return sycl::fmin(a, b); } - inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) - { - return sycl::min(static_cast(a), b); - } - inline std::int32_t min(const std::int32_t a, const std::int32_t b) - { - return sycl::min(a, b); - } - inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) - { - return sycl::min(a, b); - } - inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) - { - return sycl::min(static_cast(a), b); - } - inline std::int64_t min(const std::int64_t a, const std::int64_t b) - { - return sycl::min(a, b); - } - inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) - { - return sycl::min(a, b); - } - inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) - { - return sycl::min(static_cast(a), b); - } - inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) - { - return sycl::min(static_cast(a), b); - } - // max function overloads. - // For floating-point types, `float` or `double` arguments are acceptable. - // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or - // `std::int64_t` type arguments are acceptable. - inline double max(const double a, const float b) - { - return sycl::fmax(a, static_cast(b)); - } - inline double max(const float a, const double b) - { - return sycl::fmax(static_cast(a), b); - } - inline float max(const float a, const float b) { return sycl::fmax(a, b); } - inline double max(const double a, const double b) { return sycl::fmax(a, b); } - inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) - { - return sycl::max(static_cast(a), b); - } - inline std::int32_t max(const std::int32_t a, const std::int32_t b) - { - return sycl::max(a, b); - } - inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) - { - return sycl::max(a, b); - } - inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) - { - return sycl::max(static_cast(a), b); - } - inline std::int64_t max(const std::int64_t a, const std::int64_t b) - { - return sycl::max(a, b); - } - inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) - { - return sycl::max(a, b); - } - inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) - { - return sycl::max(static_cast(a), b); - } - inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) - { - return sycl::max(static_cast(a), b); - } - - inline void - has_capability_or_fail(const sycl::device &dev, - const std::initializer_list &props) - { - for (const auto &it : props) - { - if (dev.has(it)) - continue; - switch (it) - { - case sycl::aspect::fp64: - throw std::runtime_error("'double' is not supported in '" + - dev.get_info() + - "' device"); - break; - case sycl::aspect::fp16: - throw std::runtime_error("'half' is not supported in '" + - dev.get_info() + - "' device"); - break; - default: +static inline unsigned int select_device(unsigned int id) { + dev_mgr::instance().select_device(id); + return id; +} + +template +T permute_sub_group_by_xor( + sycl::sub_group g, + T x, + unsigned int mask, + unsigned int logical_sub_group_size = 32) { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + unsigned int target_offset = (id % logical_sub_group_size) ^ mask; + return sycl::select_from_group( + g, + x, + target_offset < logical_sub_group_size ? start_index + target_offset + : id); +} + +template +sycl::vec extract_and_sign_or_zero_extend4(T val) { + return sycl::vec(val) + .template as, int8_t, uint8_t>, + 4>>() + .template convert(); +} + +template +using dot_product_acc_t = std::conditional_t< + std::is_unsigned_v && std::is_unsigned_v, + uint32_t, + int32_t>; + +template +inline auto dp4a(T1 a, T2 b, T3 c) { + dot_product_acc_t res = c; + auto va = extract_and_sign_or_zero_extend4(a); + auto vb = extract_and_sign_or_zero_extend4(b); + res += va[0] * vb[0]; + res += va[1] * vb[1]; + res += va[2] * vb[2]; + res += va[3] * vb[3]; + return res; +} + +struct sub_sat { + template + auto operator()(const T x, const T y) const { + return sycl::sub_sat(x, y); + } +}; + +template +inline T vectorized_min(T a, T b) { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::min(v2, v3); + v0 = v4.template as>(); + return v0; +} + +inline float pow(const float a, const int b) { + return sycl::pown(a, b); +} +inline double pow(const double a, const int b) { + return sycl::pown(a, b); +} +inline float pow(const float a, const float b) { + return sycl::pow(a, b); +} +inline double pow(const double a, const double b) { + return sycl::pow(a, b); +} +template +inline typename std::enable_if_t, T> pow( + const T a, + const U b) { + return sycl::pow(a, static_cast(b)); +} +template +inline typename std::enable_if_t, double> pow( + const T a, + const U b) { + return sycl::pow(static_cast(a), static_cast(b)); +} + +inline double min(const double a, const float b) { + return sycl::fmin(a, static_cast(b)); +} +inline double min(const float a, const double b) { + return sycl::fmin(static_cast(a), b); +} +inline float min(const float a, const float b) { + return sycl::fmin(a, b); +} +inline double min(const double a, const double b) { + return sycl::fmin(a, b); +} +inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) { + return sycl::min(static_cast(a), b); +} +inline std::int32_t min(const std::int32_t a, const std::int32_t b) { + return sycl::min(a, b); +} +inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) { + return sycl::min(a, b); +} +inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) { + return sycl::min(static_cast(a), b); +} +inline std::int64_t min(const std::int64_t a, const std::int64_t b) { + return sycl::min(a, b); +} +inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) { + return sycl::min(a, b); +} +inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) { + return sycl::min(static_cast(a), b); +} +inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) { + return sycl::min(a, static_cast(b)); +} +inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) { + return sycl::min(static_cast(a), b); +} +// max function overloads. +// For floating-point types, `float` or `double` arguments are acceptable. +// For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or +// `std::int64_t` type arguments are acceptable. +inline double max(const double a, const float b) { + return sycl::fmax(a, static_cast(b)); +} +inline double max(const float a, const double b) { + return sycl::fmax(static_cast(a), b); +} +inline float max(const float a, const float b) { + return sycl::fmax(a, b); +} +inline double max(const double a, const double b) { + return sycl::fmax(a, b); +} +inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) { + return sycl::max(static_cast(a), b); +} +inline std::int32_t max(const std::int32_t a, const std::int32_t b) { + return sycl::max(a, b); +} +inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) { + return sycl::max(a, b); +} +inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) { + return sycl::max(static_cast(a), b); +} +inline std::int64_t max(const std::int64_t a, const std::int64_t b) { + return sycl::max(a, b); +} +inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) { + return sycl::max(a, b); +} +inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) { + return sycl::max(static_cast(a), b); +} +inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) { + return sycl::max(a, static_cast(b)); +} +inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) { + return sycl::max(static_cast(a), b); +} + +inline void has_capability_or_fail( + const sycl::device& dev, + const std::initializer_list& props) { + for (const auto& it : props) { + if (dev.has(it)) + continue; + switch (it) { + case sycl::aspect::fp64: + throw std::runtime_error( + "'double' is not supported in '" + + dev.get_info() + "' device"); + break; + case sycl::aspect::fp16: + throw std::runtime_error( + "'half' is not supported in '" + + dev.get_info() + "' device"); + break; + default: #define __SYCL_ASPECT(ASPECT, ID) \ - case sycl::aspect::ASPECT: \ - return #ASPECT; + case sycl::aspect::ASPECT: \ + return #ASPECT; #define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) #define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) - auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string - { - switch (AspectNum) - { + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string { + switch (AspectNum) { #include #include - default: - return "unknown aspect"; - } - }; + default: + return "unknown aspect"; + } + }; #undef __SYCL_ASPECT_DEPRECATED_ALIAS #undef __SYCL_ASPECT_DEPRECATED #undef __SYCL_ASPECT - throw std::runtime_error( - "'" + getAspectNameStr(it) + "' is not supported in '" + - dev.get_info() + "' device"); - } - break; - } + throw std::runtime_error( + "'" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); } + break; + } +} - static inline unsigned int get_current_device_id() - { - return dev_mgr::instance().current_device_id(); - } - - static inline device_ext &get_current_device() - { - return dev_mgr::instance().current_device(); - } - - static inline sycl::queue &get_in_order_queue() - { - return dev_mgr::instance().current_device().in_order_queue(); - } - - static sycl::event - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction, - const std::vector &dep_events = {}) - { - if (!size) - return sycl::event{}; - return q.memcpy(to_ptr, from_ptr, size, dep_events); - GGML_UNUSED(direction); - } - - // Get actual copy range and make sure it will not exceed range. - static inline size_t get_copy_range(sycl::range<3> size, size_t slice, - size_t pitch) - { - return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); - } - - static inline size_t get_offset(sycl::id<3> id, size_t slice, - size_t pitch) - { - return slice * id.get(2) + pitch * id.get(1) + id.get(0); - } - - /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr - /// and \p from_range to another specified by \p to_ptr and \p to_range. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - sycl::range<3> to_range, sycl::range<3> from_range, - sycl::id<3> to_id, sycl::id<3> from_id, - sycl::range<3> size, memcpy_direction direction, - const std::vector &dep_events = {}) - { - // RAII for host pointer - class host_buffer - { - void *_buf; - size_t _size; - sycl::queue &_q; - const std::vector &_deps; // free operation depends - - public: - host_buffer(size_t size, sycl::queue &q, - const std::vector &deps) - : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} - void *get_ptr() const { return _buf; } - size_t get_size() const { return _size; } - ~host_buffer() - { - if (_buf) - { - _q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(_deps); - cgh.host_task([buf = _buf] { std::free(buf); }); }); - } - } - }; - std::vector event_list; - - size_t to_slice = to_range.get(1) * to_range.get(0), - from_slice = from_range.get(1) * from_range.get(0); - unsigned char *to_surface = - (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); - const unsigned char *from_surface = - (const unsigned char *)from_ptr + - get_offset(from_id, from_slice, from_range.get(0)); - - if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) - { - return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events)}; - } - direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); - size_t size_slice = size.get(1) * size.get(0); - switch (direction) - { - case host_to_host: - for (size_t z = 0; z < size.get(2); ++z) - { - unsigned char *to_ptr = to_surface; - const unsigned char *from_ptr = from_surface; - if (to_range.get(0) == from_range.get(0) && - to_range.get(0) == size.get(0)) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, - direction, dep_events)); - } - else - { - for (size_t y = 0; y < size.get(1); ++y) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), - direction, dep_events)); - to_ptr += to_range.get(0); - from_ptr += from_range.get(0); - } - } - to_surface += to_slice; - from_surface += from_slice; - } - break; - case host_to_device: - { - host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, - event_list); - std::vector host_events; - if (to_slice == size_slice) - { - // Copy host data to a temp host buffer with the shape of target. - host_events = - dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, - host_to_host, dep_events); - } - else - { - // Copy host data to a temp host buffer with the shape of target. - host_events = dpct_memcpy( - q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, - // If has padding data, not sure whether it is useless. So fill temp - // buffer with it. - std::vector{ - dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), - device_to_host, dep_events)}); - } - // Copy from temp host buffer to device with only one submit. - event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), - buf.get_size(), host_to_device, - host_events)); - break; - } - case device_to_host: - { - host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, - event_list); - // Copy from host temp buffer to host target with reshaping. - event_list = dpct_memcpy( - q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), - sycl::id<3>(0, 0, 0), size, host_to_host, - // Copy from device to temp host buffer with only one submit. - std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, - buf.get_size(), - device_to_host, dep_events)}); - break; - } - case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh) - { +static inline unsigned int get_current_device_id() { + return dev_mgr::instance().current_device_id(); +} + +static inline device_ext& get_current_device() { + return dev_mgr::instance().current_device(); +} + +static inline sycl::queue& get_in_order_queue() { + return dev_mgr::instance().current_device().in_order_queue(); +} + +static sycl::event dpct_memcpy( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + size_t size, + memcpy_direction direction, + const std::vector& dep_events = {}) { + if (!size) + return sycl::event{}; + return q.memcpy(to_ptr, from_ptr, size, dep_events); + GGML_UNUSED(direction); +} + +// Get actual copy range and make sure it will not exceed range. +static inline size_t get_copy_range( + sycl::range<3> size, + size_t slice, + size_t pitch) { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); +} + +static inline size_t get_offset(sycl::id<3> id, size_t slice, size_t pitch) { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); +} + +/// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr +/// and \p from_range to another specified by \p to_ptr and \p to_range. +static inline std::vector dpct_memcpy( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + sycl::range<3> to_range, + sycl::range<3> from_range, + sycl::id<3> to_id, + sycl::id<3> from_id, + sycl::range<3> size, + memcpy_direction direction, + const std::vector& dep_events = {}) { + // RAII for host pointer + class host_buffer { + void* _buf; + size_t _size; + sycl::queue& _q; + const std::vector& _deps; // free operation depends + + public: + host_buffer( + size_t size, + sycl::queue& q, + const std::vector& deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void* get_ptr() const { + return _buf; + } + size_t get_size() const { + return _size; + } + ~host_buffer() { + if (_buf) { + _q.submit([&](sycl::handler& cgh) { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); + }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char* to_surface = + (unsigned char*)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char* from_surface = (const unsigned char*)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) { + return {dpct_memcpy( + q, + to_surface, + from_surface, + to_slice * size.get(2), + direction, + dep_events)}; + } + direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) { + unsigned char* to_ptr = to_surface; + const unsigned char* from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) { + event_list.push_back(dpct_memcpy( + q, to_ptr, from_ptr, size_slice, direction, dep_events)); + } else { + for (size_t y = 0; y < size.get(1); ++y) { + event_list.push_back(dpct_memcpy( + q, to_ptr, from_ptr, size.get(0), direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: { + host_buffer buf( + get_copy_range(size, to_slice, to_range.get(0)), q, event_list); + std::vector host_events; + if (to_slice == size_slice) { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, + buf.get_ptr(), + from_surface, + to_range, + from_range, + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + size, + host_to_host, + dep_events); + } else { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, + buf.get_ptr(), + from_surface, + to_range, + from_range, + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + size, + host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{dpct_memcpy( + q, + buf.get_ptr(), + to_surface, + buf.get_size(), + device_to_host, + dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy( + q, + to_surface, + buf.get_ptr(), + buf.get_size(), + host_to_device, + host_events)); + break; + } + case device_to_host: { + host_buffer buf( + get_copy_range(size, from_slice, from_range.get(0)), q, event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, + to_surface, + buf.get_ptr(), + to_range, + from_range, + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + size, + host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy( + q, + buf.get_ptr(), + from_surface, + buf.get_size(), + device_to_host, + dep_events)}); + break; + } + case device_to_device: + event_list.push_back(q.submit([&](sycl::handler& cgh) { cgh.depends_on(dep_events); cgh.parallel_for( - size, - [=](sycl::id<3> id) { - to_surface[get_offset(id, to_slice, to_range.get(0))] = - from_surface[get_offset(id, from_slice, from_range.get(0))]; - }); })); - break; - default: - throw std::runtime_error("dpct_memcpy: invalid direction value"); - } - return event_list; - } - - /// memcpy 2D/3D matrix specified by pitched_data. - static inline std::vector - dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, - pitched_data from, sycl::id<3> from_id, sycl::range<3> size, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), - sycl::range<3>(to.get_pitch(), to.get_y(), 1), - sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, - size, direction); - } - - /// memcpy 2D matrix with pitch. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - size_t to_pitch, size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), - sycl::range<3>(from_pitch, y, 1), - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), - sycl::range<3>(x, y, 1), direction); - } - - inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, const void *b, library_data_t b_type, int ldb, - const void *beta, void *c, library_data_t c_type, int ldc, - library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } + size, [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); + })); + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; +} - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) - { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, - lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, - a, lda, b, ldb, &beta_half, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } - } // gemm() - - /// Computes a batch of matrix-matrix product with general matrices. - /// \param [in] q The queue where the routine should be executed. - /// \param [in] a_trans Specifies the operation applied to A. - /// \param [in] b_trans Specifies the operation applied to B. - /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. - /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. - /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). - /// \param [in] alpha Scaling factor for the matrix-matrix product. - /// \param [in] a Input matrix A. - /// \param [in] a_type Data type of the matrix A. - /// \param [in] lda Leading dimension of A. - /// \param [in] b Input matrix B. - /// \param [in] b_type Data type of the matrix B. - /// \param [in] ldb Leading dimension of B. - /// \param [in] beta Scaling factor for matrix C. - /// \param [in, out] c Input/Output matrix C. - /// \param [in] c_type Data type of the matrix C. - /// \param [in] ldc Leading dimension of C. - /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. - /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } +/// memcpy 2D/3D matrix specified by pitched_data. +static inline std::vector dpct_memcpy( + sycl::queue& q, + pitched_data to, + sycl::id<3> to_id, + pitched_data from, + sycl::id<3> from_id, + sycl::range<3> size, + memcpy_direction direction = automatic) { + return dpct_memcpy( + q, + to.get_data_ptr(), + from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), + to_id, + from_id, + size, + direction); +} - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) - { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } +/// memcpy 2D matrix with pitch. +static inline std::vector dpct_memcpy( + sycl::queue& q, + void* to_ptr, + const void* from_ptr, + size_t to_pitch, + size_t from_pitch, + size_t x, + size_t y, + memcpy_direction direction = automatic) { + return dpct_memcpy( + q, + to_ptr, + from_ptr, + sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), + direction); +} + +inline void gemm( + sycl::queue& q, + oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, + int m, + int n, + int k, + const void* alpha, + const void* a, + library_data_t a_type, + int lda, + const void* b, + library_data_t b_type, + int ldb, + const void* beta, + void* c, + library_data_t c_type, + int ldc, + library_data_t scaling_type) { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) { + scaling_type = library_data_t::complex_float; + } else if ( + scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) { + case detail::get_type_combination_id( + library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float): { + detail::gemm_impl< + std::complex, + std::complex, + std::complex, + std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double): { + detail::gemm_impl< + std::complex, + std::complex, + std::complex, + std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_float, + library_data_t::real_float): { + detail:: + gemm_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl( + q, + a_trans, + b_trans, + m, + n, + k, + &alpha_half, + a, + lda, + b, + ldb, + &beta_half, + c, + ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, + library_data_t::real_int8, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_float): { + detail::gemm_impl< + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + float>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, + library_data_t::real_int8, + library_data_t::real_int32, + library_data_t::real_int32): { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, + a_trans, + b_trans, + m, + n, + k, + &alpha_float, + a, + lda, + b, + ldb, + &beta_float, + c, + ldc); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} // gemm() + +/// Computes a batch of matrix-matrix product with general matrices. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] a_trans Specifies the operation applied to A. +/// \param [in] b_trans Specifies the operation applied to B. +/// \param [in] m Specifies the number of rows of the matrix op(A) and of the +/// matrix C. \param [in] n Specifies the number of columns of the matrix op(B) +/// and of the matrix C. \param [in] k Specifies the number of columns of the +/// matrix op(A) and the number of rows of the matrix op(B). \param [in] alpha +/// Scaling factor for the matrix-matrix product. \param [in] a Input matrix A. +/// \param [in] a_type Data type of the matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] b Input matrix B. +/// \param [in] b_type Data type of the matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] beta Scaling factor for matrix C. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] c_type Data type of the matrix C. +/// \param [in] ldc Leading dimension of C. +/// \param [in] batch_size Specifies the number of matrix multiply operations to +/// perform. \param [in] scaling_type Data type of the scaling factors. +inline void gemm_batch( + sycl::queue& q, + oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, + int m, + int n, + int k, + const void* alpha, + const void* a[], + library_data_t a_type, + int lda, + const void* b[], + library_data_t b_type, + int ldb, + const void* beta, + void* c[], + library_data_t c_type, + int ldc, + int batch_size, + library_data_t scaling_type) { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) { + scaling_type = library_data_t::complex_float; + } else if ( + scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) { + case detail::get_type_combination_id( + library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float): { + detail::gemm_batch_impl< + std::complex, + std::complex, + std::complex, + std::complex>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double): { + detail::gemm_batch_impl< + std::complex, + std::complex, + std::complex, + std::complex>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_float): { + detail::gemm_batch_impl< + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + float>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl< + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + float, + float>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, + library_data_t::real_int8, + library_data_t::real_int32, + library_data_t::real_int32): { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + &alpha_float, + a, + lda, + b, + ldb, + &beta_float, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, + library_data_t::real_int8, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + b, + ldb, + beta, + c, + ldc, + batch_size); + break; + } #endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + &alpha_half, + a, + lda, + b, + ldb, + &beta_half, + c, + ldc, + batch_size); + break; } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} - /// Computes a batch of matrix-matrix product with general matrices. - /// \param [in] q The queue where the routine should be executed. - /// \param [in] a_trans Specifies the operation applied to A. - /// \param [in] b_trans Specifies the operation applied to B. - /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. - /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. - /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). - /// \param [in] alpha Scaling factor for the matrix-matrix product. - /// \param [in] a Input matrix A. - /// \param [in] a_type Data type of the matrix A. - /// \param [in] lda Leading dimension of A. - /// \param [in] stride_a Stride between the different A matrices. - /// \param [in] b Input matrix B. - /// \param [in] b_type Data type of the matrix B. - /// \param [in] ldb Leading dimension of B. - /// \param [in] stride_b Stride between the different B matrices. - /// \param [in] beta Scaling factor for matrix C. - /// \param [in, out] c Input/Output matrix C. - /// \param [in] c_type Data type of the matrix C. - /// \param [in] ldc Leading dimension of C. - /// \param [in] stride_c Stride between the different C matrices. - /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. - /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, long long int stride_a, const void *b, - library_data_t b_type, int ldb, long long int stride_b, - const void *beta, void *c, library_data_t c_type, - int ldc, long long int stride_c, int batch_size, - library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) - { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } +/// Computes a batch of matrix-matrix product with general matrices. +/// \param [in] q The queue where the routine should be executed. +/// \param [in] a_trans Specifies the operation applied to A. +/// \param [in] b_trans Specifies the operation applied to B. +/// \param [in] m Specifies the number of rows of the matrix op(A) and of the +/// matrix C. \param [in] n Specifies the number of columns of the matrix op(B) +/// and of the matrix C. \param [in] k Specifies the number of columns of the +/// matrix op(A) and the number of rows of the matrix op(B). \param [in] alpha +/// Scaling factor for the matrix-matrix product. \param [in] a Input matrix A. +/// \param [in] a_type Data type of the matrix A. +/// \param [in] lda Leading dimension of A. +/// \param [in] stride_a Stride between the different A matrices. +/// \param [in] b Input matrix B. +/// \param [in] b_type Data type of the matrix B. +/// \param [in] ldb Leading dimension of B. +/// \param [in] stride_b Stride between the different B matrices. +/// \param [in] beta Scaling factor for matrix C. +/// \param [in, out] c Input/Output matrix C. +/// \param [in] c_type Data type of the matrix C. +/// \param [in] ldc Leading dimension of C. +/// \param [in] stride_c Stride between the different C matrices. +/// \param [in] batch_size Specifies the number of matrix multiply operations to +/// perform. \param [in] scaling_type Data type of the scaling factors. +inline void gemm_batch( + sycl::queue& q, + oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, + int m, + int n, + int k, + const void* alpha, + const void* a, + library_data_t a_type, + int lda, + long long int stride_a, + const void* b, + library_data_t b_type, + int ldb, + long long int stride_b, + const void* beta, + void* c, + library_data_t c_type, + int ldc, + long long int stride_c, + int batch_size, + library_data_t scaling_type) { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) { + scaling_type = library_data_t::complex_float; + } else if ( + scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) { + case detail::get_type_combination_id( + library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double, + library_data_t::real_double): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float, + library_data_t::complex_float): { + detail::gemm_batch_impl< + std::complex, + std::complex, + std::complex, + std::complex>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double, + library_data_t::complex_double): { + detail::gemm_batch_impl< + std::complex, + std::complex, + std::complex, + std::complex>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } #ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - stride_a, b, ldb, stride_b, beta, c, ldc, - stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_float): { + detail::gemm_batch_impl< + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + float>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, + library_data_t::real_bfloat16, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl< + oneapi::mkl::bfloat16, + oneapi::mkl::bfloat16, + float, + float>( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, + library_data_t::real_int8, + library_data_t::real_int32, + library_data_t::real_int32): { + detail:: + gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, + library_data_t::real_int8, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_float, + library_data_t::real_float): { + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + alpha, + a, + lda, + stride_a, + b, + ldb, + stride_b, + beta, + c, + ldc, + stride_c, + batch_size); + break; + } #endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, - &beta_half, c, ldc, stride_c, batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } + case detail::get_type_combination_id( + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_half, + library_data_t::real_float): { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, + a_trans, + b_trans, + m, + n, + k, + &alpha_half, + a, + lda, + stride_a, + b, + ldb, + stride_b, + &beta_half, + c, + ldc, + stride_c, + batch_size); + break; } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } +} - static inline void - async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, - size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic, - sycl::queue &q = get_default_queue()) - { - detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, - direction); - } - - using err0 = detail::generic_error_type; - using err1 = detail::generic_error_type; - - static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) { - detail::dpct_free(ptr, q); - } - - /// dpct accessor used as device function parameter. - template class accessor; - template class accessor { - public: - using memory_t = detail::memory_traits; - using element_t = typename memory_t::element_t; - using pointer_t = typename memory_t::pointer_t; - using accessor_t = typename memory_t::template accessor_t<3>; - accessor(pointer_t data, const sycl::range<3> &in_range) - : _data(data), _range(in_range) {} - template - accessor(typename std::enable_if::type &acc) - : accessor(acc, acc.get_range()) {} - accessor(const accessor_t &acc, const sycl::range<3> &in_range) - : accessor(acc.get_pointer(), in_range) {} - accessor operator[](size_t index) const { - sycl::range<2> sub(_range.get(1), _range.get(2)); - return accessor(_data + index * sub.size(), sub); - } - - pointer_t get_ptr() const { return _data; } - - private: - pointer_t _data; - sycl::range<3> _range; - }; - template class accessor { - public: - using memory_t = detail::memory_traits; - using element_t = typename memory_t::element_t; - using pointer_t = typename memory_t::pointer_t; - using accessor_t = typename memory_t::template accessor_t<2>; - accessor(pointer_t data, const sycl::range<2> &in_range) - : _data(data), _range(in_range) {} - template - accessor(typename std::enable_if::type &acc) - : accessor(acc, acc.get_range()) {} - accessor(const accessor_t &acc, const sycl::range<2> &in_range) - : accessor(acc.get_pointer(), in_range) {} - - pointer_t operator[](size_t index) const { - return _data + _range.get(1) * index; - } - - pointer_t get_ptr() const { return _data; } - - private: - pointer_t _data; - sycl::range<2> _range; - }; - - namespace detail { - /// Device variable with address space of shared, global or constant. - template class device_memory { - public: - using accessor_t = - typename detail::memory_traits::template accessor_t; - using value_t = typename detail::memory_traits::value_t; - using dpct_accessor_t = dpct::accessor; - - device_memory() : device_memory(sycl::range(1)) {} - - /// Constructor of 1-D array with initializer list - device_memory(const sycl::range &in_range, - std::initializer_list &&init_list) - : device_memory(in_range) { - assert(init_list.size() <= in_range.size()); - _host_ptr = (value_t *)std::malloc(_size); - std::memset(_host_ptr, 0, _size); - std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); - } - - /// Constructor of 2-D array with initializer list - template - device_memory( - const typename std::enable_if>::type &in_range, - std::initializer_list> &&init_list) - : device_memory(in_range) { - assert(init_list.size() <= in_range[0]); - _host_ptr = (value_t *)std::malloc(_size); - std::memset(_host_ptr, 0, _size); - auto tmp_data = _host_ptr; - for (auto sub_list : init_list) { - assert(sub_list.size() <= in_range[1]); - std::memcpy(tmp_data, sub_list.begin(), - sub_list.size() * sizeof(T)); - tmp_data += in_range[1]; - } - } - - /// Constructor with range - device_memory(const sycl::range &range_in) - : _size(range_in.size() * sizeof(T)), _range(range_in), - _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) { - static_assert( - (Memory == global) || (Memory == constant) || (Memory == shared), - "device memory region should be global, constant or shared"); - // Make sure that singleton class mem_mgr and dev_mgr will destruct - // later than this. - detail::mem_mgr::instance(); - dev_mgr::instance(); - } - - /// Constructor with range - template - device_memory(Args... Arguments) - : device_memory(sycl::range(Arguments...)) {} - - ~device_memory() { - if (_device_ptr && !_reference) - dpct::dpct_free(_device_ptr); - if (_host_ptr) - std::free(_host_ptr); - } - - /// Allocate memory with default queue, and init memory if has initial - /// value. - void init() { init(dpct::get_default_queue()); } - /// Allocate memory with specified queue, and init memory if has initial - /// value. - void init(sycl::queue &q) { - if (_device_ptr) - return; - if (!_size) - return; - allocate_device(q); - if (_host_ptr) - detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, - host_to_device); - } - - /// The variable is assigned to a device pointer. - void assign(value_t *src, size_t size) { - this->~device_memory(); - new (this) device_memory(src, size); - } - - /// Get memory pointer of the memory object, which is virtual pointer when - /// usm is not used, and device pointer when usm is used. - value_t *get_ptr() { return get_ptr(get_default_queue()); } - /// Get memory pointer of the memory object, which is virtual pointer when - /// usm is not used, and device pointer when usm is used. - value_t *get_ptr(sycl::queue &q) { - init(q); - return _device_ptr; - } - - /// Get the device memory object size in bytes. - size_t get_size() { return _size; } - - template - typename std::enable_if::type &operator[](size_t index) { - init(); - return _device_ptr[index]; - } - - /// Get dpct::accessor with dimension info for the device memory object - /// when usm is used and dimension is greater than 1. - template - typename std::enable_if::type - get_access(sycl::handler &cgh) { - return dpct_accessor_t((T *)_device_ptr, _range); - } - - private: - device_memory(value_t *memory_ptr, size_t size) - : _size(size), _range(size / sizeof(T)), _reference(true), - _device_ptr(memory_ptr) {} - - void allocate_device(sycl::queue &q) { - #ifndef DPCT_USM_LEVEL_NONE - if (Memory == shared) { - _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), - q.get_context()); - return; - } - #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY - if (Memory == constant) { - _device_ptr = (value_t *)sycl::malloc_device( - _size, q.get_device(), q.get_context(), - sycl::ext::oneapi::property::usm::device_read_only()); - return; - } - #endif - #endif - _device_ptr = (value_t *)detail::dpct_malloc(_size, q); - } - - size_t _size; - sycl::range _range; - bool _reference; - value_t *_host_ptr; - value_t *_device_ptr; - }; - template - class device_memory : public device_memory { - public: - using base = device_memory; - using value_t = typename base::value_t; - using accessor_t = - typename detail::memory_traits::template accessor_t<0>; - - /// Constructor with initial value. - device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} - - /// Default constructor - device_memory() : base(1) {} - }; - } // namespace detail +static inline void async_dpct_memcpy( + void* to_ptr, + size_t to_pitch, + const void* from_ptr, + size_t from_pitch, + size_t x, + size_t y, + memcpy_direction direction = automatic, + sycl::queue& q = get_default_queue()) { + detail::dpct_memcpy( + q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, direction); +} - template - using global_memory = detail::device_memory; - template - using constant_memory = detail::device_memory; - template - using shared_memory = detail::device_memory; +using err0 = detail::generic_error_type; +using err1 = detail::generic_error_type; +static inline void dpct_free(void* ptr, sycl::queue& q = get_default_queue()) { + detail::dpct_free(ptr, q); +} -} // COPY from DPCT head files +/// dpct accessor used as device function parameter. +template +class accessor; +template +class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<3>; + accessor(pointer_t data, const sycl::range<3>& in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type& acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t& acc, const sycl::range<3>& in_range) + : accessor(acc.get_pointer(), in_range) {} + accessor operator[](size_t index) const { + sycl::range<2> sub(_range.get(1), _range.get(2)); + return accessor(_data + index * sub.size(), sub); + } + + pointer_t get_ptr() const { + return _data; + } + + private: + pointer_t _data; + sycl::range<3> _range; +}; +template +class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<2>; + accessor(pointer_t data, const sycl::range<2>& in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type& acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t& acc, const sycl::range<2>& in_range) + : accessor(acc.get_pointer(), in_range) {} + + pointer_t operator[](size_t index) const { + return _data + _range.get(1) * index; + } + + pointer_t get_ptr() const { + return _data; + } + + private: + pointer_t _data; + sycl::range<2> _range; +}; + +namespace detail { +/// Device variable with address space of shared, global or constant. +template +class device_memory { + public: + using accessor_t = + typename detail::memory_traits::template accessor_t; + using value_t = typename detail::memory_traits::value_t; + using dpct_accessor_t = dpct::accessor; + + device_memory() : device_memory(sycl::range(1)) {} + + /// Constructor of 1-D array with initializer list + device_memory( + const sycl::range& in_range, + std::initializer_list&& init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range.size()); + _host_ptr = (value_t*)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); + } + + /// Constructor of 2-D array with initializer list + template + device_memory( + const typename std::enable_if>::type& in_range, + std::initializer_list>&& init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range[0]); + _host_ptr = (value_t*)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + auto tmp_data = _host_ptr; + for (auto sub_list : init_list) { + assert(sub_list.size() <= in_range[1]); + std::memcpy(tmp_data, sub_list.begin(), sub_list.size() * sizeof(T)); + tmp_data += in_range[1]; + } + } + + /// Constructor with range + device_memory(const sycl::range& range_in) + : _size(range_in.size() * sizeof(T)), + _range(range_in), + _reference(false), + _host_ptr(nullptr), + _device_ptr(nullptr) { + static_assert( + (Memory == global) || (Memory == constant) || (Memory == shared), + "device memory region should be global, constant or shared"); + // Make sure that singleton class mem_mgr and dev_mgr will destruct + // later than this. + detail::mem_mgr::instance(); + dev_mgr::instance(); + } + + /// Constructor with range + template + device_memory(Args... Arguments) + : device_memory(sycl::range(Arguments...)) {} + + ~device_memory() { + if (_device_ptr && !_reference) + dpct::dpct_free(_device_ptr); + if (_host_ptr) + std::free(_host_ptr); + } + + /// Allocate memory with default queue, and init memory if has initial + /// value. + void init() { + init(dpct::get_default_queue()); + } + /// Allocate memory with specified queue, and init memory if has initial + /// value. + void init(sycl::queue& q) { + if (_device_ptr) + return; + if (!_size) + return; + allocate_device(q); + if (_host_ptr) + detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, host_to_device); + } + + /// The variable is assigned to a device pointer. + void assign(value_t* src, size_t size) { + this->~device_memory(); + new (this) device_memory(src, size); + } + + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t* get_ptr() { + return get_ptr(get_default_queue()); + } + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t* get_ptr(sycl::queue& q) { + init(q); + return _device_ptr; + } + + /// Get the device memory object size in bytes. + size_t get_size() { + return _size; + } + + template + typename std::enable_if::type& operator[](size_t index) { + init(); + return _device_ptr[index]; + } + + /// Get dpct::accessor with dimension info for the device memory object + /// when usm is used and dimension is greater than 1. + template + typename std::enable_if::type get_access( + sycl::handler& cgh) { + return dpct_accessor_t((T*)_device_ptr, _range); + } + + private: + device_memory(value_t* memory_ptr, size_t size) + : _size(size), + _range(size / sizeof(T)), + _reference(true), + _device_ptr(memory_ptr) {} + + void allocate_device(sycl::queue& q) { +#ifndef DPCT_USM_LEVEL_NONE + if (Memory == shared) { + _device_ptr = + (value_t*)sycl::malloc_shared(_size, q.get_device(), q.get_context()); + return; + } +#ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY + if (Memory == constant) { + _device_ptr = (value_t*)sycl::malloc_device( + _size, + q.get_device(), + q.get_context(), + sycl::ext::oneapi::property::usm::device_read_only()); + return; + } +#endif +#endif + _device_ptr = (value_t*)detail::dpct_malloc(_size, q); + } + + size_t _size; + sycl::range _range; + bool _reference; + value_t* _host_ptr; + value_t* _device_ptr; +}; +template +class device_memory : public device_memory { + public: + using base = device_memory; + using value_t = typename base::value_t; + using accessor_t = + typename detail::memory_traits::template accessor_t<0>; + + /// Constructor with initial value. + device_memory(const value_t& val) : base(sycl::range<1>(1), {val}) {} + + /// Default constructor + device_memory() : base(1) {} +}; +} // namespace detail + +template +using global_memory = detail::device_memory; +template +using constant_memory = detail::device_memory; +template +using shared_memory = detail::device_memory; + +} // namespace dpct #endif // GGML_SYCL_DPCT_HELPER_HPP \ No newline at end of file diff --git a/ggml-sycl/mmq.cpp b/ggml-sycl/mmq.cpp index 3c1cee812e1f7..299906382738b 100644 --- a/ggml-sycl/mmq.cpp +++ b/ggml-sycl/mmq.cpp @@ -13,2860 +13,3704 @@ #include "mmq.hpp" #include "vecdotq.hpp" -typedef void (*allocate_tiles_sycl_t)(int **x_ql, sycl::half2 **x_dm, - int **x_qh, int **x_sc); -typedef void (*load_tiles_sycl_t)(const void *__restrict__ vx, - int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, - int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, - const int &i_max, const int &k, - const int &blocks_per_row); +typedef void (*allocate_tiles_sycl_t)( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc); +typedef void (*load_tiles_sycl_t)( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row); typedef float (*vec_dot_q_mul_mat_sycl_t)( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ms, - const int &i, const int &j, const int &k); + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ms, + const int& i, + const int& j, + const int& k); template -static __dpct_inline__ void -allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_qs_q4_0, float *tile_x_d_q4_0) { - (void)x_qh; (void)x_sc; - - *x_ql = tile_x_qs_q4_0; - *x_dm = (sycl::half2 *)tile_x_d_q4_0; +static __dpct_inline__ void allocate_tiles_q4_0( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_qs_q4_0, + float* tile_x_d_q4_0) { + (void)x_qh; + (void)x_sc; + + *x_ql = tile_x_qs_q4_0; + *x_dm = (sycl::half2*)tile_x_d_q4_0; } template -static __dpct_inline__ void -load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; (void)x_sc; - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_0; - const int kqsx = k % QI4_0; - - const block_q4_0 * bx0 = (const block_q4_0 *) vx; - - float * x_dmf = (float *) x_dm; +static __dpct_inline__ void load_tiles_q4_0( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + (void)x_sc; + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_0; + const int kqsx = k % QI4_0; + + const block_q4_0* bx0 = (const block_q4_0*)vx; + + float* x_dmf = (float*)x_dm; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; + const block_q4_0* bxi = bx0 + i * blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; - } + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; - const int kbxd = k % blocks_per_tile_x_row; + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = k % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q4_0* bxi = bx0 + i * blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; - } + x_dmf[i * (WARP_SIZE / QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } } static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; (void)x_sc; - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; - - int u[2*VDR_Q4_0_Q8_1_MMQ]; + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + (void)x_sc; + + const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2)); + const float* x_dmf = (const float*)x_dm; + + int u[2 * VDR_Q4_0_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; - } - - return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2 * l + 0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2 * l + 1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + return vec_dot_q4_0_q8_1_impl( + &x_ql[i * (WARP_SIZE + 1) + k], + u, + x_dmf[i * (WARP_SIZE / QI4_0) + i / QI4_0 + k / QI4_0], + y_ds[j * (WARP_SIZE / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE / QI8_1)]); } - template -static __dpct_inline__ void -allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) { - (void)x_qh; (void)x_sc; - - *x_ql = tile_x_qs_q4_1; - *x_dm = tile_x_dm_q4_1; +static __dpct_inline__ void allocate_tiles_q4_1( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_qs_q4_1, + sycl::half2* tile_x_dm_q4_1) { + (void)x_qh; + (void)x_sc; + + *x_ql = tile_x_qs_q4_1; + *x_dm = tile_x_dm_q4_1; } template -static __dpct_inline__ void -load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; (void)x_sc; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_1; - const int kqsx = k % QI4_1; - - const block_q4_1 * bx0 = (const block_q4_1 *) vx; +static __dpct_inline__ void load_tiles_q4_1( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_1; + const int kqsx = k % QI4_1; + + const block_q4_1* bx0 = (const block_q4_1*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; + const block_q4_1* bxi = bx0 + i * blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; - const int kbxd = k % blocks_per_tile_x_row; + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = k % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q4_1* bxi = bx0 + i * blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; - } + x_dm[i * (WARP_SIZE / QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } } static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; (void)x_sc; - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + (void)x_sc; + + const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2)); + + int u[2 * VDR_Q4_1_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; - } - - return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2 * l + 0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2 * l + 1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + return vec_dot_q4_1_q8_1_impl( + &x_ql[i * (WARP_SIZE + 1) + k], + u, + x_dm[i * (WARP_SIZE / QI4_1) + i / QI4_1 + k / QI4_1], + y_ds[j * (WARP_SIZE / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE / QI8_1)]); } template -static __dpct_inline__ void -allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql_q5_0, float *tile_x_d_q5_0) { - (void)x_qh; (void)x_sc; - - *x_ql = tile_x_ql_q5_0; - *x_dm = (sycl::half2 *)tile_x_d_q5_0; +static __dpct_inline__ void allocate_tiles_q5_0( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql_q5_0, + float* tile_x_d_q5_0) { + (void)x_qh; + (void)x_sc; + + *x_ql = tile_x_ql_q5_0; + *x_dm = (sycl::half2*)tile_x_d_q5_0; } template -static __dpct_inline__ void -load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; (void)x_sc; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_0; - const int kqsx = k % QI5_0; - - const block_q5_0 * bx0 = (const block_q5_0 *) vx; +static __dpct_inline__ void load_tiles_q5_0( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_0; + const int kqsx = k % QI5_0; + + const block_q5_0* bx0 = (const block_q5_0*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; + const block_q5_0* bxi = bx0 + i * blocks_per_row + kbx; - const int ql = get_int_from_uint8(bxi->qs, kqsx); - const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - qs0 = dpct::vectorized_binary( - qs0, 0x10101010, dpct::sub_sat()); // subtract 16 + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = dpct::vectorized_binary( + qs0, 0x10101010, dpct::sub_sat()); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2 * WARP_SIZE + 1) + 2 * k + 0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - qs1 = dpct::vectorized_binary( - qs1, 0x10101010, dpct::sub_sat()); // subtract 16 + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = dpct::vectorized_binary( + qs1, 0x10101010, dpct::sub_sat()); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } + x_ql[i * (2 * WARP_SIZE + 1) + 2 * k + 1] = qs1; + } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = k % blocks_per_tile_x_row; + float* x_dmf = (float*)x_dm; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q5_0* bxi = bx0 + i * blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; - } + x_dmf[i * (WARP_SIZE / QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } } static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; (void)x_sc; - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + (void)x_sc; + + const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2)); + const int index_bx = i * (WARP_SIZE / QI5_0) + i / QI5_0 + k / QI5_0; + const float* x_dmf = (const float*)x_dm; + const float* y_df = (const float*)y_ds; + + int u[2 * VDR_Q5_0_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; - } - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2 * l + 0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2 * l + 1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + return vec_dot_q8_0_q8_1_impl( + &x_ql[i * (2 * WARP_SIZE + 1) + 2 * k], + u, + x_dmf[index_bx], + y_df[j * (WARP_SIZE / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE / QI8_1)]); } template -static __dpct_inline__ void -allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) { - (void)x_qh; (void)x_sc; - - *x_ql = tile_x_ql_q5_1; - *x_dm = tile_x_dm_q5_1; +static __dpct_inline__ void allocate_tiles_q5_1( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql_q5_1, + sycl::half2* tile_x_dm_q5_1) { + (void)x_qh; + (void)x_sc; + + *x_ql = tile_x_ql_q5_1; + *x_dm = tile_x_dm_q5_1; } template -static __dpct_inline__ void -load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; (void)x_sc; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_1; - const int kqsx = k % QI5_1; - - const block_q5_1 * bx0 = (const block_q5_1 *) vx; +static __dpct_inline__ void load_tiles_q5_1( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_1; + const int kqsx = k % QI5_1; + + const block_q5_1* bx0 = (const block_q5_1*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; + const block_q5_1* bxi = bx0 + i * blocks_per_row + kbx; - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2 * WARP_SIZE + 1) + 2 * k + 0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } + x_ql[i * (2 * WARP_SIZE + 1) + 2 * k + 1] = qs1; + } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; - const int kbxd = k % blocks_per_tile_x_row; + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = k % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q5_1* bxi = bx0 + i * blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; - } + x_dm[i * (WARP_SIZE / QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } } static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; (void)x_sc; - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + (void)x_sc; + + const int kyqs = k % (QI8_1 / 2) + QI8_1 * (k / (QI8_1 / 2)); + const int index_bx = i * (WARP_SIZE / QI5_1) + +i / QI5_1 + k / QI5_1; + + int u[2 * VDR_Q5_1_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; - } - - return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2 * l + 0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2 * l + 1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + return vec_dot_q8_1_q8_1_impl( + &x_ql[i * (2 * WARP_SIZE + 1) + 2 * k], + u, + x_dm[index_bx], + y_ds[j * (WARP_SIZE / QI8_1) + (2 * k / QI8_1) % (WARP_SIZE / QI8_1)]); } template -static __dpct_inline__ void -allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_qs_q8_0, float *tile_x_d_q8_0) { - (void)x_qh; (void)x_sc; - - *x_ql = tile_x_qs_q8_0; - *x_dm = (sycl::half2 *)tile_x_d_q8_0; +static __dpct_inline__ void allocate_tiles_q8_0( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_qs_q8_0, + float* tile_x_d_q8_0) { + (void)x_qh; + (void)x_sc; + + *x_ql = tile_x_qs_q8_0; + *x_dm = (sycl::half2*)tile_x_d_q8_0; } template -static __dpct_inline__ void -load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; (void)x_sc; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI8_0; - const int kqsx = k % QI8_0; - float * x_dmf = (float *) x_dm; - - const block_q8_0 * bx0 = (const block_q8_0 *) vx; +static __dpct_inline__ void load_tiles_q8_0( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + (void)x_sc; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI8_0; + const int kqsx = k % QI8_0; + float* x_dmf = (float*)x_dm; + + const block_q8_0* bx0 = (const block_q8_0*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; + const block_q8_0* bxi = bx0 + i * blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); - } + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; - const int kbxd = k % blocks_per_tile_x_row; + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = k % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q8_0* bxi = bx0 + i * blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; - } + x_dmf[i * (WARP_SIZE / QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } } static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; (void)x_sc; - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + (void)x_sc; + + const float* x_dmf = (const float*)x_dm; + const float* y_df = (const float*)y_ds; + + return vec_dot_q8_0_q8_1_impl( + &x_ql[i * (WARP_SIZE + 1) + k], + &y_qs[j * WARP_SIZE + k], + x_dmf[i * (WARP_SIZE / QI8_0) + i / QI8_0 + k / QI8_0], + y_df[j * (WARP_SIZE / QI8_1) + k / QI8_1]); } template -static __dpct_inline__ void -allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K, - int *tile_x_sc_q2_K) { - (void)x_qh; - - *x_ql = tile_x_ql_q2_K; - *x_dm = tile_x_dm_q2_K; - *x_sc = tile_x_sc_q2_K; +static __dpct_inline__ void allocate_tiles_q2_K( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql_q2_K, + sycl::half2* tile_x_dm_q2_K, + int* tile_x_sc_q2_K) { + (void)x_qh; + + *x_ql = tile_x_ql_q2_K; + *x_dm = tile_x_dm_q2_K; + *x_sc = tile_x_sc_q2_K; } template -static __dpct_inline__ void -load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI2_K; - const int kqsx = k % QI2_K; - - const block_q2_K * bx0 = (const block_q2_K *) vx; +static __dpct_inline__ void load_tiles_q2_K( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI2_K; + const int kqsx = k % QI2_K; + + const block_q2_K* bx0 = (const block_q2_K*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; + const block_q2_K* bxi = bx0 + i * blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; - const int kbxd = k % blocks_per_tile_x_row; + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = k % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { - int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q2_K* bxi = bx0 + i * blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; - } + x_dm[i * (WARP_SIZE / QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE / 4); - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); + const block_q2_K* bxi = + bx0 + i * blocks_per_row + (k % (WARP_SIZE / 4)) / (QI2_K / 4); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); - } + x_sc[i * (WARP_SIZE / 4) + i / 4 + k % (WARP_SIZE / 4)] = + get_int_from_uint8_aligned(bxi->scales, k % (QI2_K / 4)); + } } static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; - - const int kbx = k / QI2_K; - const int ky = (k % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; - - int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); - const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float* y_df = (const float*)y_ds; + + int v[QR2_K * VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx * QI2_K + + (QI2_K / 2) * (ky / (2 * QI2_K)) + ky % (QI2_K / 2); + const int shift = 2 * ((ky % (2 * QI2_K)) / (QI2_K / 2)); #pragma unroll - for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { - v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; - } - - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + for (int l = 0; l < QR2_K * VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t* scales = + ((const uint8_t*)&x_sc[i * (WARP_SIZE / 4) + i / 4 + kbx * 4]) + ky / 4; + + const int index_y = j * WARP_SIZE + (QR2_K * k) % WARP_SIZE; + return vec_dot_q2_K_q8_1_impl_mmq( + v, + &y_qs[index_y], + scales, + x_dm[i * (WARP_SIZE / QI2_K) + i / QI2_K + kbx], + y_df[index_y / QI8_1]); } template -static __dpct_inline__ void -allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K, - int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) { - - *x_ql = tile_x_ql_q3_K; - *x_dm = tile_x_dm_q3_K; - *x_qh = tile_x_qh_q3_K; - *x_sc = tile_x_sc_q3_K; +static __dpct_inline__ void allocate_tiles_q3_K( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql_q3_K, + sycl::half2* tile_x_dm_q3_K, + int* tile_x_qh_q3_K, + int* tile_x_sc_q3_K) { + *x_ql = tile_x_ql_q3_K; + *x_dm = tile_x_dm_q3_K; + *x_qh = tile_x_qh_q3_K; + *x_sc = tile_x_sc_q3_K; } template -static __dpct_inline__ void -load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI3_K; - const int kqsx = k % QI3_K; - - const block_q3_K * bx0 = (const block_q3_K *) vx; +static __dpct_inline__ void load_tiles_q3_K( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI3_K; + const int kqsx = k % QI3_K; + + const block_q3_K* bx0 = (const block_q3_K*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; + const block_q3_K* bxi = bx0 + i * blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - } + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + } - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = k % blocks_per_tile_x_row; + float* x_dmf = (float*)x_dm; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q3_K* bxi = bx0 + i * blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; - } + x_dmf[i * (WARP_SIZE / QI3_K) + i / QI3_K + kbxd] = bxi->d; + } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + i_offset * 2 + k / (WARP_SIZE / 2); - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + const block_q3_K* bxi = + bx0 + i * blocks_per_row + (k % (WARP_SIZE / 2)) / (QI3_K / 2); - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); - } + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE / 2) + i / 2 + k % (WARP_SIZE / 2)] = + ~get_int_from_uint8(bxi->hmask, k % (QI3_K / 2)); + } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + i_offset * 4 + k / (WARP_SIZE / 4); - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + const block_q3_K* bxi = + bx0 + i * blocks_per_row + (k % (WARP_SIZE / 4)) / (QI3_K / 4); - const int ksc = k % (QI3_K/4); + const int ksc = k % (QI3_K / 4); - const int ksc_low = ksc % (QI3_K/8); - const int shift_low = 4 * (ksc / (QI3_K/8)); - const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + const int ksc_low = ksc % (QI3_K / 8); + const int shift_low = 4 * (ksc / (QI3_K / 8)); + const int sc_low = + (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; - const int ksc_high = QI3_K/8; - const int shift_high = 2 * ksc; - const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + const int ksc_high = QI3_K / 8; + const int shift_high = 2 * ksc; + const int sc_high = + ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & + 0x30303030; - const int sc = dpct::vectorized_binary( - sc_low | sc_high, 0x20202020, dpct::sub_sat()); + const int sc = dpct::vectorized_binary( + sc_low | sc_high, 0x20202020, dpct::sub_sat()); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; - } + x_sc[i * (WARP_SIZE / 4) + i / 4 + k % (WARP_SIZE / 4)] = sc; + } } static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - - const int kbx = k / QI3_K; - const int ky = (k % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - - int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + const int kbx = k / QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float* x_dmf = (const float*)x_dm; + const float* y_df = (const float*)y_ds; + + const int8_t* scales = + ((const int8_t*)(x_sc + i * (WARP_SIZE / 4) + i / 4 + kbx * 4)) + ky / 4; + + int v[QR3_K * VDR_Q3_K_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); - const int shift = 2 * ((ky % 32) / 8); - const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); - const int vlh = (vh << 2) & 0x04040404; - - v[l] = dpct::vectorized_binary(vll, vlh, dpct::sub_sat()); - } - - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + for (int l = 0; l < QR3_K * VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx * QI3_K + + (QI3_K / 2) * (ky / (2 * QI3_K)) + ky % (QI3_K / 2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = + x_qh[i * (WARP_SIZE / 2) + i / 2 + kbx * (QI3_K / 2) + (ky + l) % 8] >> + ((ky + l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = dpct::vectorized_binary(vll, vlh, dpct::sub_sat()); + } + + const int index_y = j * WARP_SIZE + (k * QR3_K) % WARP_SIZE; + return vec_dot_q3_K_q8_1_impl_mmq( + v, + &y_qs[index_y], + scales, + x_dmf[i * (WARP_SIZE / QI3_K) + i / QI3_K + kbx], + y_df[index_y / QI8_1]); } template -static __dpct_inline__ void -allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K, - int *tile_x_sc_q4_K) { - (void)x_qh; - - *x_ql = tile_x_ql_q4_K; - *x_dm = tile_x_dm_q4_K; - *x_sc = tile_x_sc_q4_K; +static __dpct_inline__ void allocate_tiles_q4_K( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql_q4_K, + sycl::half2* tile_x_dm_q4_K, + int* tile_x_sc_q4_K) { + (void)x_qh; + + *x_ql = tile_x_ql_q4_K; + *x_dm = tile_x_dm_q4_K; + *x_sc = tile_x_sc_q4_K; } template -static __dpct_inline__ void -load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_K; // == 0 if QK_K == 256 - const int kqsx = k % QI4_K; // == k if QK_K == 256 - - const block_q4_K * bx0 = (const block_q4_K *) vx; +static __dpct_inline__ void load_tiles_q4_K( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI4_K; // == 0 if QK_K == 256 + const int kqsx = k % QI4_K; // == k if QK_K == 256 + + const block_q4_K* bx0 = (const block_q4_K*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; + const block_q4_K* bxi = bx0 + i * blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } + x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q4_K* bxi = bx0 + i * blocks_per_row + kbxd; #if QK_K == 256 - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE / QI4_K) + i / QI4_K + kbxd] = bxi->dm; #else - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; + x_dm[i * (WARP_SIZE / QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; #endif - } + } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE / 8)) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K* bxi = + bx0 + i * blocks_per_row + (k % (WARP_SIZE / 8)) / (QI4_K / 8); - const int * scales = (const int *) bxi->scales; + const int* scales = (const int*)bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE / 8); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + // scale arrangement after the following two lines: sc0,...,sc3, + // sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc % 2) + (ksc != 0)] >> (4 * (ksc & (ksc / 2)))) & + 0x0F0F0F0F; // lower 4 bits + scales8 |= + (scales[ksc / 2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } + x_sc[i * (WARP_SIZE / 8) + i / 8 + ksc] = scales8; + } } static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); - - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + + const uint8_t* sc = + ((const uint8_t*)&x_sc[i * (WARP_SIZE / 8) + i / 8 + k / 16]) + + 2 * ((k % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K * k) % WARP_SIZE; + return vec_dot_q4_K_q8_1_impl_mmq( + &x_ql[i * (WARP_SIZE + 1) + k], + &y_qs[index_y], + sc, + sc + 8, + x_dm[i * (WARP_SIZE / QI4_K) + i / QI4_K], + &y_ds[index_y / QI8_1]); } template -static __dpct_inline__ void -allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K, - int *tile_x_sc_q5_K) { - (void)x_qh; - - *x_ql = tile_x_ql_q5_K; - *x_dm = tile_x_dm_q5_K; - *x_sc = tile_x_sc_q5_K; +static __dpct_inline__ void allocate_tiles_q5_K( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql_q5_K, + sycl::half2* tile_x_dm_q5_K, + int* tile_x_sc_q5_K) { + (void)x_qh; + + *x_ql = tile_x_ql_q5_K; + *x_dm = tile_x_dm_q5_K; + *x_sc = tile_x_sc_q5_K; } template -static __dpct_inline__ void -load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_K; // == 0 if QK_K == 256 - const int kqsx = k % QI5_K; // == k if QK_K == 256 - - const block_q5_K * bx0 = (const block_q5_K *) vx; +static __dpct_inline__ void load_tiles_q5_K( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI5_K; // == 0 if QK_K == 256 + const int kqsx = k % QI5_K; // == k if QK_K == 256 + + const block_q5_K* bx0 = (const block_q5_K*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR5_K*kqsx; + const block_q5_K* bxi = bx0 + i * blocks_per_row + kbx; + const int ky = QR5_K * kqsx; - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K / 4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K / 4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K / 4)) + 1)) << 4) & 0x10101010; - const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + const int kq0 = ky - ky % (QI5_K / 2) + k % (QI5_K / 4) + 0; + const int kq1 = ky - ky % (QI5_K / 2) + k % (QI5_K / 4) + (QI5_K / 4); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; - } + x_ql[i * (2 * WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2 * WARP_SIZE + 1) + kq1] = ql1 | qh1; + } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q5_K* bxi = bx0 + i * blocks_per_row + kbxd; #if QK_K == 256 - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE / QI5_K) + i / QI5_K + kbxd] = bxi->dm; #endif - } + } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE / 8)) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K* bxi = + bx0 + i * blocks_per_row + (k % (WARP_SIZE / 8)) / (QI5_K / 8); - const int * scales = (const int *) bxi->scales; + const int* scales = (const int*)bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE / 8); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + // scale arrangement after the following two lines: sc0,...,sc3, + // sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc % 2) + (ksc != 0)] >> (4 * (ksc & (ksc / 2)))) & + 0x0F0F0F0F; // lower 4 bits + scales8 |= + (scales[ksc / 2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } + x_sc[i * (WARP_SIZE / 8) + i / 8 + ksc] = scales8; + } } static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; - return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + + const uint8_t* sc = + ((const uint8_t*)&x_sc[i * (WARP_SIZE / 8) + i / 8 + k / 16]) + + 2 * ((k % 16) / 8); + + const int index_x = i * (QR5_K * WARP_SIZE + 1) + QR5_K * k; + const int index_y = j * WARP_SIZE + (QR5_K * k) % WARP_SIZE; + return vec_dot_q5_K_q8_1_impl_mmq( + &x_ql[index_x], + &y_qs[index_y], + sc, + sc + 8, + x_dm[i * (WARP_SIZE / QI5_K) + i / QI5_K], + &y_ds[index_y / QI8_1]); } template -static __dpct_inline__ void -allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc, - int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) { - (void)x_qh; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; +static __dpct_inline__ void allocate_tiles_q6_K( + int** x_ql, + sycl::half2** x_dm, + int** x_qh, + int** x_sc, + int* tile_x_ql, + sycl::half2* tile_x_dm, + int* tile_x_sc) { + (void)x_qh; + + *x_ql = tile_x_ql; + *x_dm = tile_x_dm; + *x_sc = tile_x_sc; } template -static __dpct_inline__ void -load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql, - sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh, - int *__restrict__ x_sc, const int &i_offset, const int &i_max, - const int &k, const int &blocks_per_row) { - (void)x_qh; - - GGML_SYCL_ASSUME(i_offset >= 0); - GGML_SYCL_ASSUME(i_offset < nwarps); - GGML_SYCL_ASSUME(k >= 0); - GGML_SYCL_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 - - const block_q6_K * bx0 = (const block_q6_K *) vx; +static __dpct_inline__ void load_tiles_q6_K( + const void* __restrict__ vx, + int* __restrict__ x_ql, + sycl::half2* __restrict__ x_dm, + int* __restrict__ x_qh, + int* __restrict__ x_sc, + const int& i_offset, + const int& i_max, + const int& k, + const int& blocks_per_row) { + (void)x_qh; + + GGML_SYCL_ASSUME(i_offset >= 0); + GGML_SYCL_ASSUME(i_offset < nwarps); + GGML_SYCL_ASSUME(k >= 0); + GGML_SYCL_ASSUME(k < WARP_SIZE); + + const int kbx = k / QI6_K; // == 0 if QK_K == 256 + const int kqsx = k % QI6_K; // == k if QK_K == 256 + + const block_q6_K* bx0 = (const block_q6_K*)vx; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + i_offset; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR6_K*kqsx; + const block_q6_K* bxi = bx0 + i * blocks_per_row + kbx; + const int ky = QR6_K * kqsx; - const int ql = get_int_from_uint8(bxi->ql, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + const int qh = get_int_from_uint8( + bxi->qh, (QI6_K / 4) * (kqsx / (QI6_K / 2)) + kqsx % (QI6_K / 4)); + const int qh0 = + ((qh >> (2 * ((kqsx % (QI6_K / 2)) / (QI6_K / 4)))) << 4) & 0x30303030; + const int qh1 = + (qh >> (2 * ((kqsx % (QI6_K / 2)) / (QI6_K / 4)))) & 0x30303030; - const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + const int kq0 = ky - ky % QI6_K + k % (QI6_K / 2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K / 2) + (QI6_K / 2); - x_ql[i * (2 * WARP_SIZE + 1) + kq0] = - dpct::vectorized_binary(ql0 | qh0, 0x20202020, - dpct::sub_sat()); - x_ql[i * (2 * WARP_SIZE + 1) + kq1] = - dpct::vectorized_binary(ql1 | qh1, 0x20202020, - dpct::sub_sat()); - } + x_ql[i * (2 * WARP_SIZE + 1) + kq0] = dpct::vectorized_binary( + ql0 | qh0, 0x20202020, dpct::sub_sat()); + x_ql[i * (2 * WARP_SIZE + 1) + kq1] = dpct::vectorized_binary( + ql1 | qh1, 0x20202020, dpct::sub_sat()); + } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float* x_dmf = (float*)x_dm; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; + const block_q6_K* bxi = bx0 + i * blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; - } + x_dmf[i * (WARP_SIZE / QI6_K) + i / QI6_K + kbxd] = bxi->d; + } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + i_offset * 8 + k / (WARP_SIZE / 8)) % mmq_y; - if (need_check) { - i = sycl::min(i, i_max); - } + if (need_check) { + i = sycl::min(i, i_max); + } - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + const block_q6_K* bxi = + bx0 + i * blocks_per_row + (k % (WARP_SIZE / 8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); - } + x_sc[i * (WARP_SIZE / 8) + i / 8 + k % (WARP_SIZE / 8)] = + get_int_from_int8(bxi->scales, k % (QI6_K / 8)); + } } static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat( - const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm, - const int *__restrict__ x_qh, const int *__restrict__ x_sc, - const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds, - const int &i, const int &j, const int &k) { - (void)x_qh; - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); - - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + const int* __restrict__ x_ql, + const sycl::half2* __restrict__ x_dm, + const int* __restrict__ x_qh, + const int* __restrict__ x_sc, + const int* __restrict__ y_qs, + const sycl::half2* __restrict__ y_ds, + const int& i, + const int& j, + const int& k) { + (void)x_qh; + + const float* x_dmf = (const float*)x_dm; + const float* y_df = (const float*)y_ds; + + const int8_t* sc = + ((const int8_t*)&x_sc[i * (WARP_SIZE / 8) + i / 8 + k / 8]); + + const int index_x = i * (QR6_K * WARP_SIZE + 1) + QR6_K * k; + const int index_y = j * WARP_SIZE + (QR6_K * k) % WARP_SIZE; + return vec_dot_q6_K_q8_1_impl_mmq( + &x_ql[index_x], + &y_qs[index_y], + sc, + x_dmf[i * (WARP_SIZE / QI6_K) + i / QI6_K], + &y_df[index_y / QI8_1]); } bool ggml_sycl_supports_mmq(enum ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - return true; - default: - return false; - } + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return true; + default: + return false; + } } - -template +template < + int qk, + int qr, + int qi, + bool need_sum, + typename block_q_t, + int mmq_x, + int mmq_y, + int nwarps, + load_tiles_sycl_t load_tiles, + int vdr, + vec_dot_q_mul_mat_sycl_t vec_dot> /* DPCT1110:8: The total declared local variable size in device function mul_mat_q exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. */ -static __dpct_inline__ void -mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy, - float *__restrict__ dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, - int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh, - int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs, - sycl::half2 *tile_y_ds) { - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; - - const int & ncols_dst = ncols_y; - - const int row_dst_0 = item_ct1.get_group(2) * mmq_y; - const int & row_x_0 = row_dst_0; - - const int col_dst_0 = item_ct1.get_group(1) * mmq_x; - const int & col_y_0 = col_dst_0; - - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; - - for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { - - load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, - tile_x_qh, tile_x_sc, item_ct1.get_local_id(1), - nrows_x - row_x_0 - 1, item_ct1.get_local_id(2), - blocks_per_row_x); +static __dpct_inline__ void mul_mat_q( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + int* tile_x_ql, + sycl::half2* tile_x_dm, + int* tile_x_qh, + int* tile_x_sc, + const sycl::nd_item<3>& item_ct1, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_col_y = nrows_y / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int& ncols_dst = ncols_y; + + const int row_dst_0 = item_ct1.get_group(2) * mmq_y; + const int& row_x_0 = row_dst_0; + + const int col_dst_0 = item_ct1.get_group(1) * mmq_x; + const int& col_y_0 = col_dst_0; + + float sum[mmq_y / WARP_SIZE][mmq_x / nwarps] = {{0.0f}}; + + for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { + load_tiles( + x + row_x_0 * blocks_per_row_x + ib0, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1.get_local_id(1), + nrows_x - row_x_0 - 1, + item_ct1.get_local_id(2), + blocks_per_row_x); #pragma unroll - for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2); - const int kbxd = kqs / QI8_1; + for (int ir = 0; ir < qr; ++ir) { + const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2); + const int kbxd = kqs / QI8_1; #pragma unroll - for (int i = 0; i < mmq_x; i += nwarps) { - const int col_y_eff = dpct::min( - (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i), - ncols_y - 1); // to prevent out-of-bounds memory accesses + for (int i = 0; i < mmq_x; i += nwarps) { + const int col_y_eff = dpct::min( + (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i), + ncols_y - 1); // to prevent out-of-bounds memory accesses - const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; + const block_q8_1* by0 = + &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) + kbxd]; - const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE + - kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned( - by0->qs, item_ct1.get_local_id(2) % QI8_1); - } + const int index_y = + (item_ct1.get_local_id(1) + i) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned( + by0->qs, item_ct1.get_local_id(2) % QI8_1); + } #pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = - (ids0 + item_ct1.get_local_id(1) * QI8_1 + - item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) % - mmq_x; - const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1); - const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const sycl::half2 *dsi_src = - &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) + - ir * (WARP_SIZE / QI8_1) + kby] - .ds; - sycl::half2 *dsi_dst = - &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = (*dsi_src)[0]; - } - } - - /* - DPCT1118:9: SYCL group functions and algorithms must be encountered - in converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:56: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. - */ - item_ct1.barrier(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + item_ct1.get_local_id(1) * QI8_1 + + item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) % + mmq_x; + const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1); + const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1); + + // if the sum is not needed it's faster to transform the scale to f32 + // ahead of time + const sycl::half2* dsi_src = + &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) + + ir * (WARP_SIZE / QI8_1) + kby] + .ds; + sycl::half2* dsi_dst = &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float* dfi_dst = (float*)dsi_dst; + *dfi_dst = (*dsi_src)[0]; + } + } + + /* + DPCT1118:9: SYCL group functions and algorithms must be encountered + in converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:56: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + + // #pragma unroll // unrolling this loop causes too much register pressure + for (int k = ir * WARP_SIZE / qr; k < (ir + 1) * WARP_SIZE / qr; + k += vdr) { #pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { + for (int j = 0; j < mmq_x; j += nwarps) { #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i / WARP_SIZE][j / nwarps] += vec_dot( - tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, - tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i, - item_ct1.get_local_id(1) + j, k); - } - } - } - - /* - DPCT1118:10: SYCL group functions and algorithms must be encountered - in converged control flow. You may need to adjust the code. - */ - /* - DPCT1065:57: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for - better performance if there is no access to global memory. - */ - item_ct1.barrier(); + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + sum[i / WARP_SIZE][j / nwarps] += vec_dot( + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + tile_y_qs, + tile_y_ds, + item_ct1.get_local_id(2) + i, + item_ct1.get_local_id(1) + j, + k); + } } + } + + /* + DPCT1118:10: SYCL group functions and algorithms must be encountered + in converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:57: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); } + } #pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { - const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1); + for (int j = 0; j < mmq_x; j += nwarps) { + const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1); - if (col_dst >= ncols_dst) { - return; - } + if (col_dst >= ncols_dst) { + return; + } #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i; + for (int i = 0; i < mmq_y; i += WARP_SIZE) { + const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i; - if (row_dst >= nrows_dst) { - continue; - } + if (row_dst >= nrows_dst) { + continue; + } - dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; - } + dst[col_dst * nrows_dst + row_dst] = sum[i / WARP_SIZE][j / nwarps]; } + } } -#define MMQ_X_Q4_0_RDNA2 64 -#define MMQ_Y_Q4_0_RDNA2 128 -#define NWARPS_Q4_0_RDNA2 8 -#define MMQ_X_Q4_0_RDNA1 64 -#define MMQ_Y_Q4_0_RDNA1 64 -#define NWARPS_Q4_0_RDNA1 8 +#define MMQ_X_Q4_0_RDNA2 64 +#define MMQ_Y_Q4_0_RDNA2 128 +#define NWARPS_Q4_0_RDNA2 8 +#define MMQ_X_Q4_0_RDNA1 64 +#define MMQ_Y_Q4_0_RDNA1 64 +#define NWARPS_Q4_0_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q4_0_AMPERE 4 -#define MMQ_Y_Q4_0_AMPERE 32 +#define MMQ_X_Q4_0_AMPERE 4 +#define MMQ_Y_Q4_0_AMPERE 32 #define NWARPS_Q4_0_AMPERE 4 #else -#define MMQ_X_Q4_0_AMPERE 64 -#define MMQ_Y_Q4_0_AMPERE 128 +#define MMQ_X_Q4_0_AMPERE 64 +#define MMQ_Y_Q4_0_AMPERE 128 #define NWARPS_Q4_0_AMPERE 4 #endif -#define MMQ_X_Q4_0_PASCAL 64 -#define MMQ_Y_Q4_0_PASCAL 64 +#define MMQ_X_Q4_0_PASCAL 64 +#define MMQ_Y_Q4_0_PASCAL 64 #define NWARPS_Q4_0_PASCAL 8 -template static void - mul_mat_q4_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0, - int *tile_y_qs, sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - - const int mmq_x = MMQ_X_Q4_0_AMPERE; - const int mmq_y = MMQ_Y_Q4_0_AMPERE; - const int nwarps = NWARPS_Q4_0_AMPERE; - allocate_tiles_q4_0(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_qs_q4_0, tile_x_d_q4_0); - mul_mat_q, VDR_Q4_0_Q8_1_MMQ, - vec_dot_q4_0_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q4_0( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_qs_q4_0, + float* tile_x_d_q4_0, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + + const int mmq_x = MMQ_X_Q4_0_AMPERE; + const int mmq_y = MMQ_Y_Q4_0_AMPERE; + const int nwarps = NWARPS_Q4_0_AMPERE; + allocate_tiles_q4_0( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_qs_q4_0, + tile_x_d_q4_0); + mul_mat_q< + QK4_0, + QR4_0, + QI4_0, + true, + block_q4_0, + mmq_x, + mmq_y, + nwarps, + load_tiles_q4_0, + VDR_Q4_0_Q8_1_MMQ, + vec_dot_q4_0_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q4_1_RDNA2 64 -#define MMQ_Y_Q4_1_RDNA2 128 -#define NWARPS_Q4_1_RDNA2 8 -#define MMQ_X_Q4_1_RDNA1 64 -#define MMQ_Y_Q4_1_RDNA1 64 -#define NWARPS_Q4_1_RDNA1 8 +#define MMQ_X_Q4_1_RDNA2 64 +#define MMQ_Y_Q4_1_RDNA2 128 +#define NWARPS_Q4_1_RDNA2 8 +#define MMQ_X_Q4_1_RDNA1 64 +#define MMQ_Y_Q4_1_RDNA1 64 +#define NWARPS_Q4_1_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q4_1_AMPERE 4 -#define MMQ_Y_Q4_1_AMPERE 32 +#define MMQ_X_Q4_1_AMPERE 4 +#define MMQ_Y_Q4_1_AMPERE 32 #define NWARPS_Q4_1_AMPERE 4 #else -#define MMQ_X_Q4_1_AMPERE 64 -#define MMQ_Y_Q4_1_AMPERE 128 +#define MMQ_X_Q4_1_AMPERE 64 +#define MMQ_Y_Q4_1_AMPERE 128 #define NWARPS_Q4_1_AMPERE 4 #endif -#define MMQ_X_Q4_1_PASCAL 64 -#define MMQ_Y_Q4_1_PASCAL 64 +#define MMQ_X_Q4_1_PASCAL 64 +#define MMQ_Y_Q4_1_PASCAL 64 #define NWARPS_Q4_1_PASCAL 8 -template static void - mul_mat_q4_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1, - sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q4_1_AMPERE; - const int mmq_y = MMQ_Y_Q4_1_AMPERE; - const int nwarps = NWARPS_Q4_1_AMPERE; - allocate_tiles_q4_1(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_qs_q4_1, tile_x_dm_q4_1); - mul_mat_q, VDR_Q4_1_Q8_1_MMQ, - vec_dot_q4_1_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q4_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_qs_q4_1, + sycl::half2* tile_x_dm_q4_1, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q4_1_AMPERE; + const int mmq_y = MMQ_Y_Q4_1_AMPERE; + const int nwarps = NWARPS_Q4_1_AMPERE; + allocate_tiles_q4_1( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_qs_q4_1, + tile_x_dm_q4_1); + mul_mat_q< + QK4_1, + QR4_1, + QI4_1, + true, + block_q4_1, + mmq_x, + mmq_y, + nwarps, + load_tiles_q4_1, + VDR_Q4_1_Q8_1_MMQ, + vec_dot_q4_1_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q5_0_RDNA2 64 -#define MMQ_Y_Q5_0_RDNA2 128 -#define NWARPS_Q5_0_RDNA2 8 -#define MMQ_X_Q5_0_RDNA1 64 -#define MMQ_Y_Q5_0_RDNA1 64 -#define NWARPS_Q5_0_RDNA1 8 +#define MMQ_X_Q5_0_RDNA2 64 +#define MMQ_Y_Q5_0_RDNA2 128 +#define NWARPS_Q5_0_RDNA2 8 +#define MMQ_X_Q5_0_RDNA1 64 +#define MMQ_Y_Q5_0_RDNA1 64 +#define NWARPS_Q5_0_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q5_0_AMPERE 4 -#define MMQ_Y_Q5_0_AMPERE 32 +#define MMQ_X_Q5_0_AMPERE 4 +#define MMQ_Y_Q5_0_AMPERE 32 #define NWARPS_Q5_0_AMPERE 4 #else -#define MMQ_X_Q5_0_AMPERE 128 -#define MMQ_Y_Q5_0_AMPERE 64 +#define MMQ_X_Q5_0_AMPERE 128 +#define MMQ_Y_Q5_0_AMPERE 64 #define NWARPS_Q5_0_AMPERE 4 #endif -#define MMQ_X_Q5_0_PASCAL 64 -#define MMQ_Y_Q5_0_PASCAL 64 +#define MMQ_X_Q5_0_PASCAL 64 +#define MMQ_Y_Q5_0_PASCAL 64 #define NWARPS_Q5_0_PASCAL 8 -template static void - mul_mat_q5_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0, - int *tile_y_qs, sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q5_0_AMPERE; - const int mmq_y = MMQ_Y_Q5_0_AMPERE; - const int nwarps = NWARPS_Q5_0_AMPERE; - allocate_tiles_q5_0(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql_q5_0, tile_x_d_q5_0); - mul_mat_q, VDR_Q5_0_Q8_1_MMQ, - vec_dot_q5_0_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q5_0( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql_q5_0, + float* tile_x_d_q5_0, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q5_0_AMPERE; + const int mmq_y = MMQ_Y_Q5_0_AMPERE; + const int nwarps = NWARPS_Q5_0_AMPERE; + allocate_tiles_q5_0( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql_q5_0, + tile_x_d_q5_0); + mul_mat_q< + QK5_0, + QR5_0, + QI5_0, + false, + block_q5_0, + mmq_x, + mmq_y, + nwarps, + load_tiles_q5_0, + VDR_Q5_0_Q8_1_MMQ, + vec_dot_q5_0_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q5_1_RDNA2 64 -#define MMQ_Y_Q5_1_RDNA2 128 -#define NWARPS_Q5_1_RDNA2 8 -#define MMQ_X_Q5_1_RDNA1 64 -#define MMQ_Y_Q5_1_RDNA1 64 -#define NWARPS_Q5_1_RDNA1 8 +#define MMQ_X_Q5_1_RDNA2 64 +#define MMQ_Y_Q5_1_RDNA2 128 +#define NWARPS_Q5_1_RDNA2 8 +#define MMQ_X_Q5_1_RDNA1 64 +#define MMQ_Y_Q5_1_RDNA1 64 +#define NWARPS_Q5_1_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q5_1_AMPERE 4 -#define MMQ_Y_Q5_1_AMPERE 32 +#define MMQ_X_Q5_1_AMPERE 4 +#define MMQ_Y_Q5_1_AMPERE 32 #define NWARPS_Q5_1_AMPERE 4 #else -#define MMQ_X_Q5_1_AMPERE 128 -#define MMQ_Y_Q5_1_AMPERE 64 +#define MMQ_X_Q5_1_AMPERE 128 +#define MMQ_Y_Q5_1_AMPERE 64 #define NWARPS_Q5_1_AMPERE 4 #endif -#define MMQ_X_Q5_1_PASCAL 64 -#define MMQ_Y_Q5_1_PASCAL 64 +#define MMQ_X_Q5_1_PASCAL 64 +#define MMQ_Y_Q5_1_PASCAL 64 #define NWARPS_Q5_1_PASCAL 8 -template static void -mul_mat_q5_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1, - sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q5_1_AMPERE; - const int mmq_y = MMQ_Y_Q5_1_AMPERE; - const int nwarps = NWARPS_Q5_1_AMPERE; - allocate_tiles_q5_1(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql_q5_1, tile_x_dm_q5_1); - mul_mat_q, VDR_Q5_1_Q8_1_MMQ, - vec_dot_q5_1_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q5_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql_q5_1, + sycl::half2* tile_x_dm_q5_1, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q5_1_AMPERE; + const int mmq_y = MMQ_Y_Q5_1_AMPERE; + const int nwarps = NWARPS_Q5_1_AMPERE; + allocate_tiles_q5_1( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql_q5_1, + tile_x_dm_q5_1); + mul_mat_q< + QK5_1, + QR5_1, + QI5_1, + true, + block_q5_1, + mmq_x, + mmq_y, + nwarps, + load_tiles_q5_1, + VDR_Q5_1_Q8_1_MMQ, + vec_dot_q5_1_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q8_0_RDNA2 64 -#define MMQ_Y_Q8_0_RDNA2 128 -#define NWARPS_Q8_0_RDNA2 8 -#define MMQ_X_Q8_0_RDNA1 64 -#define MMQ_Y_Q8_0_RDNA1 64 -#define NWARPS_Q8_0_RDNA1 8 +#define MMQ_X_Q8_0_RDNA2 64 +#define MMQ_Y_Q8_0_RDNA2 128 +#define NWARPS_Q8_0_RDNA2 8 +#define MMQ_X_Q8_0_RDNA1 64 +#define MMQ_Y_Q8_0_RDNA1 64 +#define NWARPS_Q8_0_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q8_0_AMPERE 4 -#define MMQ_Y_Q8_0_AMPERE 32 +#define MMQ_X_Q8_0_AMPERE 4 +#define MMQ_Y_Q8_0_AMPERE 32 #define NWARPS_Q8_0_AMPERE 4 #else -#define MMQ_X_Q8_0_AMPERE 128 -#define MMQ_Y_Q8_0_AMPERE 64 +#define MMQ_X_Q8_0_AMPERE 128 +#define MMQ_Y_Q8_0_AMPERE 64 #define NWARPS_Q8_0_AMPERE 4 #endif -#define MMQ_X_Q8_0_PASCAL 64 -#define MMQ_Y_Q8_0_PASCAL 64 +#define MMQ_X_Q8_0_PASCAL 64 +#define MMQ_Y_Q8_0_PASCAL 64 #define NWARPS_Q8_0_PASCAL 8 -template static void - mul_mat_q8_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0, - int *tile_y_qs, sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q8_0_AMPERE; - const int mmq_y = MMQ_Y_Q8_0_AMPERE; - const int nwarps = NWARPS_Q8_0_AMPERE; - allocate_tiles_q8_0(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_qs_q8_0, tile_x_d_q8_0); - mul_mat_q, VDR_Q8_0_Q8_1_MMQ, - vec_dot_q8_0_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q8_0( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_qs_q8_0, + float* tile_x_d_q8_0, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q8_0_AMPERE; + const int mmq_y = MMQ_Y_Q8_0_AMPERE; + const int nwarps = NWARPS_Q8_0_AMPERE; + allocate_tiles_q8_0( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_qs_q8_0, + tile_x_d_q8_0); + mul_mat_q< + QK8_0, + QR8_0, + QI8_0, + false, + block_q8_0, + mmq_x, + mmq_y, + nwarps, + load_tiles_q8_0, + VDR_Q8_0_Q8_1_MMQ, + vec_dot_q8_0_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q2_K_RDNA2 64 -#define MMQ_Y_Q2_K_RDNA2 128 -#define NWARPS_Q2_K_RDNA2 8 -#define MMQ_X_Q2_K_RDNA1 128 -#define MMQ_Y_Q2_K_RDNA1 32 -#define NWARPS_Q2_K_RDNA1 8 +#define MMQ_X_Q2_K_RDNA2 64 +#define MMQ_Y_Q2_K_RDNA2 128 +#define NWARPS_Q2_K_RDNA2 8 +#define MMQ_X_Q2_K_RDNA1 128 +#define MMQ_Y_Q2_K_RDNA1 32 +#define NWARPS_Q2_K_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q2_K_AMPERE 4 -#define MMQ_Y_Q2_K_AMPERE 32 +#define MMQ_X_Q2_K_AMPERE 4 +#define MMQ_Y_Q2_K_AMPERE 32 #define NWARPS_Q2_K_AMPERE 4 #else -#define MMQ_X_Q2_K_AMPERE 64 -#define MMQ_Y_Q2_K_AMPERE 128 +#define MMQ_X_Q2_K_AMPERE 64 +#define MMQ_Y_Q2_K_AMPERE 128 #define NWARPS_Q2_K_AMPERE 4 #endif -#define MMQ_X_Q2_K_PASCAL 64 -#define MMQ_Y_Q2_K_PASCAL 64 +#define MMQ_X_Q2_K_PASCAL 64 +#define MMQ_Y_Q2_K_PASCAL 64 #define NWARPS_Q2_K_PASCAL 8 -template static void -mul_mat_q2_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K, - sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs, - sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q2_K_AMPERE; - const int mmq_y = MMQ_Y_Q2_K_AMPERE; - const int nwarps = NWARPS_Q2_K_AMPERE; - allocate_tiles_q2_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K); - mul_mat_q, VDR_Q2_K_Q8_1_MMQ, - vec_dot_q2_K_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q2_K( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql_q2_K, + sycl::half2* tile_x_dm_q2_K, + int* tile_x_sc_q2_K, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q2_K_AMPERE; + const int mmq_y = MMQ_Y_Q2_K_AMPERE; + const int nwarps = NWARPS_Q2_K_AMPERE; + allocate_tiles_q2_K( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql_q2_K, + tile_x_dm_q2_K, + tile_x_sc_q2_K); + mul_mat_q< + QK_K, + QR2_K, + QI2_K, + false, + block_q2_K, + mmq_x, + mmq_y, + nwarps, + load_tiles_q2_K, + VDR_Q2_K_Q8_1_MMQ, + vec_dot_q2_K_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q3_K_RDNA2 128 -#define MMQ_Y_Q3_K_RDNA2 64 -#define NWARPS_Q3_K_RDNA2 8 -#define MMQ_X_Q3_K_RDNA1 32 -#define MMQ_Y_Q3_K_RDNA1 128 -#define NWARPS_Q3_K_RDNA1 8 +#define MMQ_X_Q3_K_RDNA2 128 +#define MMQ_Y_Q3_K_RDNA2 64 +#define NWARPS_Q3_K_RDNA2 8 +#define MMQ_X_Q3_K_RDNA1 32 +#define MMQ_Y_Q3_K_RDNA1 128 +#define NWARPS_Q3_K_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q3_K_AMPERE 4 -#define MMQ_Y_Q3_K_AMPERE 32 +#define MMQ_X_Q3_K_AMPERE 4 +#define MMQ_Y_Q3_K_AMPERE 32 #define NWARPS_Q3_K_AMPERE 4 #else -#define MMQ_X_Q3_K_AMPERE 128 -#define MMQ_Y_Q3_K_AMPERE 128 +#define MMQ_X_Q3_K_AMPERE 128 +#define MMQ_Y_Q3_K_AMPERE 128 #define NWARPS_Q3_K_AMPERE 4 #endif -#define MMQ_X_Q3_K_PASCAL 64 -#define MMQ_Y_Q3_K_PASCAL 64 +#define MMQ_X_Q3_K_PASCAL 64 +#define MMQ_Y_Q3_K_PASCAL 64 #define NWARPS_Q3_K_PASCAL 8 -template static void -mul_mat_q3_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K, - sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K, - int *tile_y_qs, sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q3_K_AMPERE; - const int mmq_y = MMQ_Y_Q3_K_AMPERE; - const int nwarps = NWARPS_Q3_K_AMPERE; - allocate_tiles_q3_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K, - tile_x_sc_q3_K); - mul_mat_q, VDR_Q3_K_Q8_1_MMQ, - vec_dot_q3_K_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q3_K( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql_q3_K, + sycl::half2* tile_x_dm_q3_K, + int* tile_x_qh_q3_K, + int* tile_x_sc_q3_K, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q3_K_AMPERE; + const int mmq_y = MMQ_Y_Q3_K_AMPERE; + const int nwarps = NWARPS_Q3_K_AMPERE; + allocate_tiles_q3_K( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql_q3_K, + tile_x_dm_q3_K, + tile_x_qh_q3_K, + tile_x_sc_q3_K); + mul_mat_q< + QK_K, + QR3_K, + QI3_K, + false, + block_q3_K, + mmq_x, + mmq_y, + nwarps, + load_tiles_q3_K, + VDR_Q3_K_Q8_1_MMQ, + vec_dot_q3_K_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q4_K_RDNA2 64 -#define MMQ_Y_Q4_K_RDNA2 128 -#define NWARPS_Q4_K_RDNA2 8 -#define MMQ_X_Q4_K_RDNA1 32 -#define MMQ_Y_Q4_K_RDNA1 64 -#define NWARPS_Q4_K_RDNA1 8 +#define MMQ_X_Q4_K_RDNA2 64 +#define MMQ_Y_Q4_K_RDNA2 128 +#define NWARPS_Q4_K_RDNA2 8 +#define MMQ_X_Q4_K_RDNA1 32 +#define MMQ_Y_Q4_K_RDNA1 64 +#define NWARPS_Q4_K_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q4_K_AMPERE 4 -#define MMQ_Y_Q4_K_AMPERE 32 +#define MMQ_X_Q4_K_AMPERE 4 +#define MMQ_Y_Q4_K_AMPERE 32 #define NWARPS_Q4_K_AMPERE 4 #else -#define MMQ_X_Q4_K_AMPERE 64 -#define MMQ_Y_Q4_K_AMPERE 128 +#define MMQ_X_Q4_K_AMPERE 64 +#define MMQ_Y_Q4_K_AMPERE 128 #define NWARPS_Q4_K_AMPERE 4 #endif -#define MMQ_X_Q4_K_PASCAL 64 -#define MMQ_Y_Q4_K_PASCAL 64 +#define MMQ_X_Q4_K_PASCAL 64 +#define MMQ_Y_Q4_K_PASCAL 64 #define NWARPS_Q4_K_PASCAL 8 -template static void - mul_mat_q4_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K, - sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs, - sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q4_K_AMPERE; - const int mmq_y = MMQ_Y_Q4_K_AMPERE; - const int nwarps = NWARPS_Q4_K_AMPERE; - allocate_tiles_q4_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K); - mul_mat_q, VDR_Q4_K_Q8_1_MMQ, - vec_dot_q4_K_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q4_K( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql_q4_K, + sycl::half2* tile_x_dm_q4_K, + int* tile_x_sc_q4_K, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q4_K_AMPERE; + const int mmq_y = MMQ_Y_Q4_K_AMPERE; + const int nwarps = NWARPS_Q4_K_AMPERE; + allocate_tiles_q4_K( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql_q4_K, + tile_x_dm_q4_K, + tile_x_sc_q4_K); + mul_mat_q< + QK_K, + QR4_K, + QI4_K, + true, + block_q4_K, + mmq_x, + mmq_y, + nwarps, + load_tiles_q4_K, + VDR_Q4_K_Q8_1_MMQ, + vec_dot_q4_K_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q5_K_RDNA2 64 -#define MMQ_Y_Q5_K_RDNA2 128 -#define NWARPS_Q5_K_RDNA2 8 -#define MMQ_X_Q5_K_RDNA1 32 -#define MMQ_Y_Q5_K_RDNA1 64 -#define NWARPS_Q5_K_RDNA1 8 +#define MMQ_X_Q5_K_RDNA2 64 +#define MMQ_Y_Q5_K_RDNA2 128 +#define NWARPS_Q5_K_RDNA2 8 +#define MMQ_X_Q5_K_RDNA1 32 +#define MMQ_Y_Q5_K_RDNA1 64 +#define NWARPS_Q5_K_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q5_K_AMPERE 4 -#define MMQ_Y_Q5_K_AMPERE 32 +#define MMQ_X_Q5_K_AMPERE 4 +#define MMQ_Y_Q5_K_AMPERE 32 #define NWARPS_Q5_K_AMPERE 4 #else -#define MMQ_X_Q5_K_AMPERE 64 -#define MMQ_Y_Q5_K_AMPERE 128 +#define MMQ_X_Q5_K_AMPERE 64 +#define MMQ_Y_Q5_K_AMPERE 128 #define NWARPS_Q5_K_AMPERE 4 #endif -#define MMQ_X_Q5_K_PASCAL 64 -#define MMQ_Y_Q5_K_PASCAL 64 +#define MMQ_X_Q5_K_PASCAL 64 +#define MMQ_Y_Q5_K_PASCAL 64 #define NWARPS_Q5_K_PASCAL 8 -template static void -mul_mat_q5_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K, - sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs, - sycl::half2 *tile_y_ds) { - int * tile_x_ql = nullptr; - sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q5_K_AMPERE; - const int mmq_y = MMQ_Y_Q5_K_AMPERE; - const int nwarps = NWARPS_Q5_K_AMPERE; - allocate_tiles_q5_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K); - mul_mat_q, VDR_Q5_K_Q8_1_MMQ, - vec_dot_q5_K_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q5_K( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql_q5_K, + sycl::half2* tile_x_dm_q5_K, + int* tile_x_sc_q5_K, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + int* tile_x_ql = nullptr; + sycl::half2* tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + int* tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q5_K_AMPERE; + const int mmq_y = MMQ_Y_Q5_K_AMPERE; + const int nwarps = NWARPS_Q5_K_AMPERE; + allocate_tiles_q5_K( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql_q5_K, + tile_x_dm_q5_K, + tile_x_sc_q5_K); + mul_mat_q< + QK_K, + QR5_K, + QI5_K, + true, + block_q5_K, + mmq_x, + mmq_y, + nwarps, + load_tiles_q5_K, + VDR_Q5_K_Q8_1_MMQ, + vec_dot_q5_K_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -#define MMQ_X_Q6_K_RDNA2 64 -#define MMQ_Y_Q6_K_RDNA2 128 -#define NWARPS_Q6_K_RDNA2 8 -#define MMQ_X_Q6_K_RDNA1 32 -#define MMQ_Y_Q6_K_RDNA1 64 -#define NWARPS_Q6_K_RDNA1 8 +#define MMQ_X_Q6_K_RDNA2 64 +#define MMQ_Y_Q6_K_RDNA2 128 +#define NWARPS_Q6_K_RDNA2 8 +#define MMQ_X_Q6_K_RDNA1 32 +#define MMQ_Y_Q6_K_RDNA1 64 +#define NWARPS_Q6_K_RDNA1 8 #if defined(SYCL_USE_XMX) -#define MMQ_X_Q6_K_AMPERE 4 -#define MMQ_Y_Q6_K_AMPERE 32 +#define MMQ_X_Q6_K_AMPERE 4 +#define MMQ_Y_Q6_K_AMPERE 32 #define NWARPS_Q6_K_AMPERE 4 #else -#define MMQ_X_Q6_K_AMPERE 64 -#define MMQ_Y_Q6_K_AMPERE 64 +#define MMQ_X_Q6_K_AMPERE 64 +#define MMQ_Y_Q6_K_AMPERE 64 #define NWARPS_Q6_K_AMPERE 4 #endif -#define MMQ_X_Q6_K_PASCAL 64 -#define MMQ_Y_Q6_K_PASCAL 64 +#define MMQ_X_Q6_K_PASCAL 64 +#define MMQ_Y_Q6_K_PASCAL 64 #define NWARPS_Q6_K_PASCAL 8 -template static void - mul_mat_q6_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, - const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm, - int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) { - // int * tile_x_ql = nullptr; - // sycl::half2 *tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - // int * tile_x_sc = nullptr; - -//sycl_todo: change according to hardware - const int mmq_x = MMQ_X_Q6_K_AMPERE; - const int mmq_y = MMQ_Y_Q6_K_AMPERE; - const int nwarps = NWARPS_Q6_K_AMPERE; - allocate_tiles_q6_K(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc, - tile_x_ql, tile_x_dm, tile_x_sc); - mul_mat_q, VDR_Q6_K_Q8_1_MMQ, - vec_dot_q6_K_q8_1_mul_mat>( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql, - tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds); +template +static void mul_mat_q6_K( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + const sycl::nd_item<3>& item_ct1, + int* tile_x_ql, + sycl::half2* tile_x_dm, + int* tile_x_sc, + int* tile_y_qs, + sycl::half2* tile_y_ds) { + // int * tile_x_ql = nullptr; + // sycl::half2 *tile_x_dm = nullptr; + int* tile_x_qh = nullptr; + // int * tile_x_sc = nullptr; + + // sycl_todo: change according to hardware + const int mmq_x = MMQ_X_Q6_K_AMPERE; + const int mmq_y = MMQ_Y_Q6_K_AMPERE; + const int nwarps = NWARPS_Q6_K_AMPERE; + allocate_tiles_q6_K( + &tile_x_ql, + &tile_x_dm, + &tile_x_qh, + &tile_x_sc, + tile_x_ql, + tile_x_dm, + tile_x_sc); + mul_mat_q< + QK_K, + QR6_K, + QI6_K, + false, + block_q6_K, + mmq_x, + mmq_y, + nwarps, + load_tiles_q6_K, + VDR_Q6_K_Q8_1_MMQ, + vec_dot_q6_K_q8_1_mul_mat>( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + tile_x_ql, + tile_x_dm, + tile_x_qh, + tile_x_sc, + item_ct1, + tile_y_qs, + tile_y_ds); } -static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q4_0_RDNA2; - mmq_y = MMQ_Y_Q4_0_RDNA2; - nwarps = NWARPS_Q4_0_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q4_0_RDNA1; - mmq_y = MMQ_Y_Q4_0_RDNA1; - nwarps = NWARPS_Q4_0_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q4_0_AMPERE; - mmq_y = MMQ_Y_Q4_0_AMPERE; - nwarps = NWARPS_Q4_0_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q4_0_PASCAL; - mmq_y = MMQ_Y_Q4_0_PASCAL; - nwarps = NWARPS_Q4_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:20: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_qs_q4_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_d_q4_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q4_0( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_qs_q4_0_acc_ct1.get_pointer(), - tile_x_d_q4_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q4_0_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q4_0_RDNA2; + mmq_y = MMQ_Y_Q4_0_RDNA2; + nwarps = NWARPS_Q4_0_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q4_0_RDNA1; + mmq_y = MMQ_Y_Q4_0_RDNA1; + nwarps = NWARPS_Q4_0_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q4_0_AMPERE; + mmq_y = MMQ_Y_Q4_0_AMPERE; + nwarps = NWARPS_Q4_0_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q4_0_PASCAL; + mmq_y = MMQ_Y_Q4_0_PASCAL; + nwarps = NWARPS_Q4_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:20: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_qs_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_0( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_qs_q4_0_acc_ct1.get_pointer(), + tile_x_d_q4_0_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:21: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_qs_q4_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_d_q4_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q4_0( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_qs_q4_0_acc_ct1.get_pointer(), - tile_x_d_q4_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:21: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_qs_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q4_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_0( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_qs_q4_0_acc_ct1.get_pointer(), + tile_x_d_q4_0_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q4_1_RDNA2; - mmq_y = MMQ_Y_Q4_1_RDNA2; - nwarps = NWARPS_Q4_1_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q4_1_RDNA1; - mmq_y = MMQ_Y_Q4_1_RDNA1; - nwarps = NWARPS_Q4_1_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q4_1_AMPERE; - mmq_y = MMQ_Y_Q4_1_AMPERE; - nwarps = NWARPS_Q4_1_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q4_1_PASCAL; - mmq_y = MMQ_Y_Q4_1_PASCAL; - nwarps = NWARPS_Q4_1_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:22: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_qs_q4_1_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); - sycl::local_accessor tile_x_dm_q4_1_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q4_1( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_qs_q4_1_acc_ct1.get_pointer(), - tile_x_dm_q4_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q4_1_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q4_1_RDNA2; + mmq_y = MMQ_Y_Q4_1_RDNA2; + nwarps = NWARPS_Q4_1_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q4_1_RDNA1; + mmq_y = MMQ_Y_Q4_1_RDNA1; + nwarps = NWARPS_Q4_1_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q4_1_AMPERE; + mmq_y = MMQ_Y_Q4_1_AMPERE; + nwarps = NWARPS_Q4_1_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q4_1_PASCAL; + mmq_y = MMQ_Y_Q4_1_PASCAL; + nwarps = NWARPS_Q4_1_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:22: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_qs_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_1( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_qs_q4_1_acc_ct1.get_pointer(), + tile_x_dm_q4_1_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:23: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_qs_q4_1_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); - sycl::local_accessor tile_x_dm_q4_1_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q4_1( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_qs_q4_1_acc_ct1.get_pointer(), - tile_x_dm_q4_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:23: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_qs_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_1( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_qs_q4_1_acc_ct1.get_pointer(), + tile_x_dm_q4_1_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q5_0_RDNA2; - mmq_y = MMQ_Y_Q5_0_RDNA2; - nwarps = NWARPS_Q5_0_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q5_0_RDNA1; - mmq_y = MMQ_Y_Q5_0_RDNA1; - nwarps = NWARPS_Q5_0_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q5_0_AMPERE; - mmq_y = MMQ_Y_Q5_0_AMPERE; - nwarps = NWARPS_Q5_0_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q5_0_PASCAL; - mmq_y = MMQ_Y_Q5_0_PASCAL; - nwarps = NWARPS_Q5_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:24: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q5_0_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_d_q5_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q5_0( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q5_0_acc_ct1.get_pointer(), - tile_x_d_q5_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q5_0_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q5_0_RDNA2; + mmq_y = MMQ_Y_Q5_0_RDNA2; + nwarps = NWARPS_Q5_0_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q5_0_RDNA1; + mmq_y = MMQ_Y_Q5_0_RDNA1; + nwarps = NWARPS_Q5_0_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q5_0_AMPERE; + mmq_y = MMQ_Y_Q5_0_AMPERE; + nwarps = NWARPS_Q5_0_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q5_0_PASCAL; + mmq_y = MMQ_Y_Q5_0_PASCAL; + nwarps = NWARPS_Q5_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:24: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_0( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q5_0_acc_ct1.get_pointer(), + tile_x_d_q5_0_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:25: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q5_0_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_d_q5_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q5_0( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q5_0_acc_ct1.get_pointer(), - tile_x_d_q5_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:25: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q5_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_0( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q5_0_acc_ct1.get_pointer(), + tile_x_d_q5_0_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q5_1_RDNA2; - mmq_y = MMQ_Y_Q5_1_RDNA2; - nwarps = NWARPS_Q5_1_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q5_1_RDNA1; - mmq_y = MMQ_Y_Q5_1_RDNA1; - nwarps = NWARPS_Q5_1_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q5_1_AMPERE; - mmq_y = MMQ_Y_Q5_1_AMPERE; - nwarps = NWARPS_Q5_1_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q5_1_PASCAL; - mmq_y = MMQ_Y_Q5_1_PASCAL; - nwarps = NWARPS_Q5_1_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:26: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q5_1_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q5_1_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q5_1( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q5_1_acc_ct1.get_pointer(), - tile_x_dm_q5_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q5_1_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q5_1_RDNA2; + mmq_y = MMQ_Y_Q5_1_RDNA2; + nwarps = NWARPS_Q5_1_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q5_1_RDNA1; + mmq_y = MMQ_Y_Q5_1_RDNA1; + nwarps = NWARPS_Q5_1_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q5_1_AMPERE; + mmq_y = MMQ_Y_Q5_1_AMPERE; + nwarps = NWARPS_Q5_1_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q5_1_PASCAL; + mmq_y = MMQ_Y_Q5_1_PASCAL; + nwarps = NWARPS_Q5_1_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:26: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_1( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q5_1_acc_ct1.get_pointer(), + tile_x_dm_q5_1_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:27: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q5_1_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q5_1_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q5_1( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q5_1_acc_ct1.get_pointer(), - tile_x_dm_q5_1_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:27: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_1_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_1( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q5_1_acc_ct1.get_pointer(), + tile_x_dm_q5_1_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q8_0_RDNA2; - mmq_y = MMQ_Y_Q8_0_RDNA2; - nwarps = NWARPS_Q8_0_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q8_0_RDNA1; - mmq_y = MMQ_Y_Q8_0_RDNA1; - nwarps = NWARPS_Q8_0_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q8_0_AMPERE; - mmq_y = MMQ_Y_Q8_0_AMPERE; - nwarps = NWARPS_Q8_0_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q8_0_PASCAL; - mmq_y = MMQ_Y_Q8_0_PASCAL; - nwarps = NWARPS_Q8_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:28: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_qs_q8_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_d_q8_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q8_0( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_qs_q8_0_acc_ct1.get_pointer(), - tile_x_d_q8_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q8_0_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q8_0_RDNA2; + mmq_y = MMQ_Y_Q8_0_RDNA2; + nwarps = NWARPS_Q8_0_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q8_0_RDNA1; + mmq_y = MMQ_Y_Q8_0_RDNA1; + nwarps = NWARPS_Q8_0_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q8_0_AMPERE; + mmq_y = MMQ_Y_Q8_0_AMPERE; + nwarps = NWARPS_Q8_0_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q8_0_PASCAL; + mmq_y = MMQ_Y_Q8_0_PASCAL; + nwarps = NWARPS_Q8_0_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:28: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_qs_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q8_0( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_qs_q8_0_acc_ct1.get_pointer(), + tile_x_d_q8_0_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:29: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_qs_q8_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_d_q8_0_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0), - cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q8_0( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_qs_q8_0_acc_ct1.get_pointer(), - tile_x_d_q8_0_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:29: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_qs_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_d_q8_0_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q8_0( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_qs_q8_0_acc_ct1.get_pointer(), + tile_x_d_q8_0_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q2_K_RDNA2; - mmq_y = MMQ_Y_Q2_K_RDNA2; - nwarps = NWARPS_Q2_K_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q2_K_RDNA1; - mmq_y = MMQ_Y_Q2_K_RDNA1; - nwarps = NWARPS_Q2_K_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q2_K_AMPERE; - mmq_y = MMQ_Y_Q2_K_AMPERE; - nwarps = NWARPS_Q2_K_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q2_K_PASCAL; - mmq_y = MMQ_Y_Q2_K_PASCAL; - nwarps = NWARPS_Q2_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:30: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q2_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q2_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K), - cgh); - sycl::local_accessor tile_x_sc_q2_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q2_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q2_K_acc_ct1.get_pointer(), - tile_x_dm_q2_K_acc_ct1.get_pointer(), - tile_x_sc_q2_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q2_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q2_K_RDNA2; + mmq_y = MMQ_Y_Q2_K_RDNA2; + nwarps = NWARPS_Q2_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q2_K_RDNA1; + mmq_y = MMQ_Y_Q2_K_RDNA1; + nwarps = NWARPS_Q2_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q2_K_AMPERE; + mmq_y = MMQ_Y_Q2_K_AMPERE; + nwarps = NWARPS_Q2_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q2_K_PASCAL; + mmq_y = MMQ_Y_Q2_K_PASCAL; + nwarps = NWARPS_Q2_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:30: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K), cgh); + sycl::local_accessor tile_x_sc_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q2_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q2_K_acc_ct1.get_pointer(), + tile_x_dm_q2_K_acc_ct1.get_pointer(), + tile_x_sc_q2_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:31: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q2_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q2_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K), - cgh); - sycl::local_accessor tile_x_sc_q2_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q2_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q2_K_acc_ct1.get_pointer(), - tile_x_dm_q2_K_acc_ct1.get_pointer(), - tile_x_sc_q2_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:31: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K), cgh); + sycl::local_accessor tile_x_sc_q2_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q2_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q2_K_acc_ct1.get_pointer(), + tile_x_dm_q2_K_acc_ct1.get_pointer(), + tile_x_sc_q2_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - +static void ggml_mul_mat_q3_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { #if QK_K == 256 - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q3_K_RDNA2; - mmq_y = MMQ_Y_Q3_K_RDNA2; - nwarps = NWARPS_Q3_K_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q3_K_RDNA1; - mmq_y = MMQ_Y_Q3_K_RDNA1; - nwarps = NWARPS_Q3_K_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q3_K_AMPERE; - mmq_y = MMQ_Y_Q3_K_AMPERE; - nwarps = NWARPS_Q3_K_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q3_K_PASCAL; - mmq_y = MMQ_Y_Q3_K_PASCAL; - nwarps = NWARPS_Q3_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:32: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K), - cgh); - sycl::local_accessor tile_x_qh_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh); - sycl::local_accessor tile_x_sc_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q3_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q3_K_acc_ct1.get_pointer(), - tile_x_dm_q3_K_acc_ct1.get_pointer(), - tile_x_qh_q3_K_acc_ct1.get_pointer(), - tile_x_sc_q3_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q3_K_RDNA2; + mmq_y = MMQ_Y_Q3_K_RDNA2; + nwarps = NWARPS_Q3_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q3_K_RDNA1; + mmq_y = MMQ_Y_Q3_K_RDNA1; + nwarps = NWARPS_Q3_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q3_K_AMPERE; + mmq_y = MMQ_Y_Q3_K_AMPERE; + nwarps = NWARPS_Q3_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q3_K_PASCAL; + mmq_y = MMQ_Y_Q3_K_PASCAL; + nwarps = NWARPS_Q3_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:32: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K), cgh); + sycl::local_accessor tile_x_qh_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh); + sycl::local_accessor tile_x_sc_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q3_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q3_K_acc_ct1.get_pointer(), + tile_x_dm_q3_K_acc_ct1.get_pointer(), + tile_x_qh_q3_K_acc_ct1.get_pointer(), + tile_x_sc_q3_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:33: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K), - cgh); - sycl::local_accessor tile_x_qh_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh); - sycl::local_accessor tile_x_sc_q3_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q3_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q3_K_acc_ct1.get_pointer(), - tile_x_dm_q3_K_acc_ct1.get_pointer(), - tile_x_qh_q3_K_acc_ct1.get_pointer(), - tile_x_sc_q3_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:33: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K), cgh); + sycl::local_accessor tile_x_qh_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh); + sycl::local_accessor tile_x_sc_q3_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q3_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q3_K_acc_ct1.get_pointer(), + tile_x_dm_q3_K_acc_ct1.get_pointer(), + tile_x_qh_q3_K_acc_ct1.get_pointer(), + tile_x_sc_q3_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } + } #endif -} -catch (sycl::exception const &exc) { +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q4_K_RDNA2; - mmq_y = MMQ_Y_Q4_K_RDNA2; - nwarps = NWARPS_Q4_K_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q4_K_RDNA1; - mmq_y = MMQ_Y_Q4_K_RDNA1; - nwarps = NWARPS_Q4_K_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q4_K_AMPERE; - mmq_y = MMQ_Y_Q4_K_AMPERE; - nwarps = NWARPS_Q4_K_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q4_K_PASCAL; - mmq_y = MMQ_Y_Q4_K_PASCAL; - nwarps = NWARPS_Q4_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:34: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q4_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q4_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K), - cgh); - sycl::local_accessor tile_x_sc_q4_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q4_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q4_K_acc_ct1.get_pointer(), - tile_x_dm_q4_K_acc_ct1.get_pointer(), - tile_x_sc_q4_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q4_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q4_K_RDNA2; + mmq_y = MMQ_Y_Q4_K_RDNA2; + nwarps = NWARPS_Q4_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q4_K_RDNA1; + mmq_y = MMQ_Y_Q4_K_RDNA1; + nwarps = NWARPS_Q4_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q4_K_AMPERE; + mmq_y = MMQ_Y_Q4_K_AMPERE; + nwarps = NWARPS_Q4_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q4_K_PASCAL; + mmq_y = MMQ_Y_Q4_K_PASCAL; + nwarps = NWARPS_Q4_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:34: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K), cgh); + sycl::local_accessor tile_x_sc_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q4_K_acc_ct1.get_pointer(), + tile_x_dm_q4_K_acc_ct1.get_pointer(), + tile_x_sc_q4_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:35: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q4_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q4_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K), - cgh); - sycl::local_accessor tile_x_sc_q4_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q4_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q4_K_acc_ct1.get_pointer(), - tile_x_dm_q4_K_acc_ct1.get_pointer(), - tile_x_sc_q4_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:35: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K), cgh); + sycl::local_accessor tile_x_sc_q4_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q4_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q4_K_acc_ct1.get_pointer(), + tile_x_dm_q4_K_acc_ct1.get_pointer(), + tile_x_sc_q4_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q5_K_RDNA2; - mmq_y = MMQ_Y_Q5_K_RDNA2; - nwarps = NWARPS_Q5_K_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q5_K_RDNA1; - mmq_y = MMQ_Y_Q5_K_RDNA1; - nwarps = NWARPS_Q5_K_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q5_K_AMPERE; - mmq_y = MMQ_Y_Q5_K_AMPERE; - nwarps = NWARPS_Q5_K_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q5_K_PASCAL; - mmq_y = MMQ_Y_Q5_K_PASCAL; - nwarps = NWARPS_Q5_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:36: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q5_K_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q5_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K), - cgh); - sycl::local_accessor tile_x_sc_q5_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q5_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q5_K_acc_ct1.get_pointer(), - tile_x_dm_q5_K_acc_ct1.get_pointer(), - tile_x_sc_q5_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q5_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q5_K_RDNA2; + mmq_y = MMQ_Y_Q5_K_RDNA2; + nwarps = NWARPS_Q5_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q5_K_RDNA1; + mmq_y = MMQ_Y_Q5_K_RDNA1; + nwarps = NWARPS_Q5_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q5_K_AMPERE; + mmq_y = MMQ_Y_Q5_K_AMPERE; + nwarps = NWARPS_Q5_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q5_K_PASCAL; + mmq_y = MMQ_Y_Q5_K_PASCAL; + nwarps = NWARPS_Q5_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:36: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K), cgh); + sycl::local_accessor tile_x_sc_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q5_K_acc_ct1.get_pointer(), + tile_x_dm_q5_K_acc_ct1.get_pointer(), + tile_x_sc_q5_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:37: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_q5_K_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_q5_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K), - cgh); - sycl::local_accessor tile_x_sc_q5_K_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q5_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_q5_K_acc_ct1.get_pointer(), - tile_x_dm_q5_K_acc_ct1.get_pointer(), - tile_x_sc_q5_K_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:37: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K), cgh); + sycl::local_accessor tile_x_sc_q5_K_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q5_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_q5_K_acc_ct1.get_pointer(), + tile_x_dm_q5_K_acc_ct1.get_pointer(), + tile_x_sc_q5_K_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } -static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols_x, - const int nrows_x, const int ncols_y, - const int nrows_y, const int nrows_dst, - dpct::queue_ptr stream) try { - - int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); - const int compute_capability = g_device_caps[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= VER_GEN13) { - mmq_x = MMQ_X_Q6_K_RDNA2; - mmq_y = MMQ_Y_Q6_K_RDNA2; - nwarps = NWARPS_Q6_K_RDNA2; - } else if (compute_capability >= VER_GEN12) { - mmq_x = MMQ_X_Q6_K_RDNA1; - mmq_y = MMQ_Y_Q6_K_RDNA1; - nwarps = NWARPS_Q6_K_RDNA1; - } else if (compute_capability >= VER_GEN9) { - mmq_x = MMQ_X_Q6_K_AMPERE; - mmq_y = MMQ_Y_Q6_K_AMPERE; - nwarps = NWARPS_Q6_K_AMPERE; - } else if (compute_capability >= VER_4VEC) { - mmq_x = MMQ_X_Q6_K_PASCAL; - mmq_y = MMQ_Y_Q6_K_PASCAL; - nwarps = NWARPS_Q6_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const sycl::range<3> block_nums(1, block_num_y, block_num_x); - const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - /* - DPCT1049:38: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K), - cgh); - sycl::local_accessor tile_x_sc_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q6_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_acc_ct1.get_pointer(), - tile_x_dm_acc_ct1.get_pointer(), - tile_x_sc_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); +static void ggml_mul_mat_q6_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols_x, + const int nrows_x, + const int ncols_y, + const int nrows_y, + const int nrows_dst, + dpct::queue_ptr stream) try { + int id; + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); + const int compute_capability = g_device_caps[id].cc; + + int mmq_x, mmq_y, nwarps; + if (compute_capability >= VER_GEN13) { + mmq_x = MMQ_X_Q6_K_RDNA2; + mmq_y = MMQ_Y_Q6_K_RDNA2; + nwarps = NWARPS_Q6_K_RDNA2; + } else if (compute_capability >= VER_GEN12) { + mmq_x = MMQ_X_Q6_K_RDNA1; + mmq_y = MMQ_Y_Q6_K_RDNA1; + nwarps = NWARPS_Q6_K_RDNA1; + } else if (compute_capability >= VER_GEN9) { + mmq_x = MMQ_X_Q6_K_AMPERE; + mmq_y = MMQ_Y_Q6_K_AMPERE; + nwarps = NWARPS_Q6_K_AMPERE; + } else if (compute_capability >= VER_4VEC) { + mmq_x = MMQ_X_Q6_K_PASCAL; + mmq_y = MMQ_Y_Q6_K_PASCAL; + nwarps = NWARPS_Q6_K_PASCAL; + } else { + GGML_ASSERT(false); + } + + const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; + const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; + const sycl::range<3> block_nums(1, block_num_y, block_num_x); + const sycl::range<3> block_dims(1, nwarps, WARP_SIZE); + + if (nrows_x % mmq_y == 0) { + const bool need_check = false; + /* + DPCT1049:38: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K), cgh); + sycl::local_accessor tile_x_sc_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q6_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_acc_ct1.get_pointer(), + tile_x_dm_acc_ct1.get_pointer(), + tile_x_sc_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } - } else { - const bool need_check = true; - /* - DPCT1049:39: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor tile_x_ql_acc_ct1( - sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); - sycl::local_accessor tile_x_dm_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K), - cgh); - sycl::local_accessor tile_x_sc_acc_ct1( - sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); - sycl::local_accessor tile_y_qs_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE), cgh); - sycl::local_accessor tile_y_ds_acc_ct1( - sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - mul_mat_q6_K( - vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, - nrows_dst, item_ct1, - tile_x_ql_acc_ct1.get_pointer(), - tile_x_dm_acc_ct1.get_pointer(), - tile_x_sc_acc_ct1.get_pointer(), - tile_y_qs_acc_ct1.get_pointer(), - tile_y_ds_acc_ct1.get_pointer()); - }); + }); + } + } else { + const bool need_check = true; + /* + DPCT1049:39: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + { + dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor tile_x_ql_acc_ct1( + sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh); + sycl::local_accessor tile_x_dm_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K), cgh); + sycl::local_accessor tile_x_sc_acc_ct1( + sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh); + sycl::local_accessor tile_y_qs_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE), cgh); + sycl::local_accessor tile_y_ds_acc_ct1( + sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + mul_mat_q6_K( + vx, + vy, + dst, + ncols_x, + nrows_x, + ncols_y, + nrows_y, + nrows_dst, + item_ct1, + tile_x_ql_acc_ct1.get_pointer(), + tile_x_dm_acc_ct1.get_pointer(), + tile_x_sc_acc_ct1.get_pointer(), + tile_y_qs_acc_ct1.get_pointer(), + tile_y_ds_acc_ct1.get_pointer()); }); - } + }); } -} -catch (sycl::exception const &exc) { + } +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); } void ggml_sycl_op_mul_mat_q( - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_row_size, - const dpct::queue_ptr &stream) try { - - const int64_t ne00 = src0->ne[0]; - - const int64_t ne10 = src1->ne[0]; - GGML_ASSERT(ne10 % QK8_1 == 0); - - const int64_t ne0 = dst->ne[0]; - - const int64_t row_diff = row_high - row_low; - - int device_id; - SYCL_CHECK( - CHECK_TRY_ERROR(device_id = get_current_device_id())); - - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into - const int64_t nrows_dst = dst->backend == GGML_BACKEND_TYPE_GPU && device_id == g_main_device ? ne0 : row_diff; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q4_1: - ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q5_0: - ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q5_1: - ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q8_0: - ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q2_K: - ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q3_K: - ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q4_K: - ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q5_K: - ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - case GGML_TYPE_Q6_K: - ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); - break; - default: - GGML_ASSERT(false); - break; - } - - (void) src1; - (void) dst; - (void) src1_ddf_i; -} -catch (sycl::exception const &exc) { + const ggml_tensor* src0, + const ggml_tensor* src1, + ggml_tensor* dst, + const char* src0_dd_i, + const float* src1_ddf_i, + const char* src1_ddq_i, + float* dst_dd_i, + const int64_t row_low, + const int64_t row_high, + const int64_t src1_ncols, + const int64_t src1_padded_row_size, + const dpct::queue_ptr& stream) try { + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int device_id; + SYCL_CHECK(CHECK_TRY_ERROR(device_id = get_current_device_id())); + + // the main device has a larger memory buffer to hold the results from all + // GPUs nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel + // writes into + const int64_t nrows_dst = + dst->backend == GGML_BACKEND_TYPE_GPU && device_id == g_main_device + ? ne0 + : row_diff; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + ggml_mul_mat_q4_0_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q4_1: + ggml_mul_mat_q4_1_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q5_0: + ggml_mul_mat_q5_0_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q5_1: + ggml_mul_mat_q5_1_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q8_0: + ggml_mul_mat_q8_0_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q2_K: + ggml_mul_mat_q2_K_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q3_K: + ggml_mul_mat_q3_K_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q4_K: + ggml_mul_mat_q4_K_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q5_K: + ggml_mul_mat_q5_K_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + case GGML_TYPE_Q6_K: + ggml_mul_mat_q6_K_q8_1_sycl( + src0_dd_i, + src1_ddq_i, + dst_dd_i, + ne00, + row_diff, + src1_ncols, + src1_padded_row_size, + nrows_dst, + stream); + break; + default: + GGML_ASSERT(false); + break; + } + + (void)src1; + (void)dst; + (void)src1_ddf_i; +} catch (sycl::exception const& exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::exit(1); diff --git a/ggml-sycl/mmq.hpp b/ggml-sycl/mmq.hpp index 94c8ce47b35d3..3fdae39a9ba3d 100644 --- a/ggml-sycl/mmq.hpp +++ b/ggml-sycl/mmq.hpp @@ -16,11 +16,18 @@ #include "common.hpp" void ggml_sycl_op_mul_mat_q( - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_row_size, - const dpct::queue_ptr &stream); + const ggml_tensor* src0, + const ggml_tensor* src1, + ggml_tensor* dst, + const char* src0_dd_i, + const float* src1_ddf_i, + const char* src1_ddq_i, + float* dst_dd_i, + const int64_t row_low, + const int64_t row_high, + const int64_t src1_ncols, + const int64_t src1_padded_row_size, + const dpct::queue_ptr& stream); bool ggml_sycl_supports_mmq(enum ggml_type type); diff --git a/ggml-sycl/mmvq.cpp b/ggml-sycl/mmvq.cpp index 4945bb8edb5b2..5cf076acc48d3 100644 --- a/ggml-sycl/mmvq.cpp +++ b/ggml-sycl/mmvq.cpp @@ -13,787 +13,1007 @@ #include "mmvq.hpp" #include "vecdotq.hpp" -typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); - - -template -static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1, - const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; - i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index - - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); - } - - // sum up partial sums and write back result +typedef float (*vec_dot_q_sycl_t)( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs); + +template < + int qk, + int qi, + typename block_q_t, + int vdr, + vec_dot_q_sycl_t vec_dot_q_sycl> +static void mul_mat_vec_q( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols, + const int nrows, + const sycl::nd_item<3>& item_ct1, + const uint32_t* iq3xxs_grid_ptr, + const uint64_t* ksigns64_ptr) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += - dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); - } - - if (item_ct1.get_local_id(2) == 0) { - dst[row] = tmp; - } -} + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} template -static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1, - const uint64_t *iq2xxs_grid_ptr, const uint8_t *ksigns_iq2xs_ptr, - const uint8_t *kmask_iq2xs_ptr ) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; - i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index - - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid_ptr, ksigns_iq2xs_ptr, kmask_iq2xs_ptr); - } - - // sum up partial sums and write back result +static void mul_mat_vec_q_iq2_xxs_q8_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols, + const int nrows, + const sycl::nd_item<3>& item_ct1, + const uint64_t* iq2xxs_grid_ptr, + const uint8_t* ksigns_iq2xs_ptr, + const uint8_t* kmask_iq2xs_ptr) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_xxs_q8_1( + &x[ibx], + &y[iby], + iqs, + iq2xxs_grid_ptr, + ksigns_iq2xs_ptr, + kmask_iq2xs_ptr); + } + + // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += - dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); - } - - if (item_ct1.get_local_id(2) == 0) { - dst[row] = tmp; - } + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } } template -static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1, - const uint64_t *iq2xs_grid_ptr, const uint64_t *ksigns64_ptr ) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; - i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index - - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid_ptr, ksigns64_ptr); - } - - // sum up partial sums and write back result +static void mul_mat_vec_q_iq2_xs_q8_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols, + const int nrows, + const sycl::nd_item<3>& item_ct1, + const uint64_t* iq2xs_grid_ptr, + const uint64_t* ksigns64_ptr) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_xs_q8_1( + &x[ibx], &y[iby], iqs, iq2xs_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += - dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); - } - - if (item_ct1.get_local_id(2) == 0) { - dst[row] = tmp; - } + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } } template -static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1, - const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr ) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; - i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index - - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid_ptr, ksigns64_ptr); - } - - // sum up partial sums and write back result +static void mul_mat_vec_q_iq3_xxs_q8_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols, + const int nrows, + const sycl::nd_item<3>& item_ct1, + const uint32_t* iq3xxs_grid_ptr, + const uint64_t* ksigns64_ptr) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq3_xxs_q8_1( + &x[ibx], &y[iby], iqs, iq3xxs_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += - dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); - } - - if (item_ct1.get_local_id(2) == 0) { - dst[row] = tmp; - } + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } } template -static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1, - const uint32_t *iq3s_grid_ptr, const uint64_t *ksigns64_ptr ) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; - i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index - - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr); - } - - // sum up partial sums and write back result +static void mul_mat_vec_q_iq3_s_q8_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols, + const int nrows, + const sycl::nd_item<3>& item_ct1, + const uint32_t* iq3s_grid_ptr, + const uint64_t* ksigns64_ptr) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += + vec_dot_iq3_s_q8_1(&x[ibx], &y[iby], iqs, iq3s_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += - dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); - } - - if (item_ct1.get_local_id(2) == 0) { - dst[row] = tmp; - } + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } } template -static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1, - const uint32_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - - if (row >= nrows) { - return; - } - - const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - -// partial sum for each thread - float tmp = 0.0f; - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; - i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i; // x block index - - const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - - const int iqs = - vdr * - (item_ct1.get_local_id(2) % - (qi / vdr)); // x block quant index when casting the quants to int - - tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr); - } - - // sum up partial sums and write back result +static void mul_mat_vec_q_iq1_s_q8_1( + const void* __restrict__ vx, + const void* __restrict__ vy, + float* __restrict__ dst, + const int ncols, + const int nrows, + const sycl::nd_item<3>& item_ct1, + const uint32_t* iq1s_grid_ptr, + const uint64_t* ksigns64_ptr) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = (const block_q_t*)vx; + const block_q8_1* y = (const block_q8_1*)vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += + vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += - dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); - } - - if (item_ct1.get_local_id(2) == 0) { - dst[row] = tmp; - } + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } } -static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q4_0_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK4_0, + QI4_0, + block_q4_0, + VDR_Q4_0_Q8_1_MMVQ, + vec_dot_q4_0_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK4_1 == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q4_1_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK4_0, + QI4_1, + block_q4_1, + VDR_Q4_1_Q8_1_MMVQ, + vec_dot_q4_1_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK5_0 == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q5_0_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK5_0, + QI5_0, + block_q5_0, + VDR_Q5_0_Q8_1_MMVQ, + vec_dot_q5_0_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK5_1 == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q5_1_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK5_1, + QI5_1, + block_q5_1, + VDR_Q5_1_Q8_1_MMVQ, + vec_dot_q5_1_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q8_0_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK8_0, + QI8_0, + block_q8_0, + VDR_Q8_0_Q8_1_MMVQ, + vec_dot_q8_0_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q2_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK_K, + QI2_K, + block_q2_K, + VDR_Q2_K_Q8_1_MMVQ, + vec_dot_q2_K_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q3_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK_K, + QI3_K, + block_q3_K, + VDR_Q3_K_Q8_1_MMVQ, + vec_dot_q3_K_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q4_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK_K, + QI4_K, + block_q4_K, + VDR_Q4_K_Q8_1_MMVQ, + vec_dot_q4_K_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q5_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK_K, + QI5_K, + block_q5_K, + VDR_Q5_K_Q8_1_MMVQ, + vec_dot_q5_K_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_q6_K_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q< + QK_K, + QI6_K, + block_q6_K, + VDR_Q6_K_Q8_1_MMVQ, + vec_dot_q6_K_q8_1>( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq2xxs_grid.init(*stream); - ksigns_iq2xs.init(*stream); - kmask_iq2xs.init(*stream); - - - stream->submit([&](sycl::handler &cgh) { - auto iq2xxs_grid_ptr_ct1 = iq2xxs_grid.get_ptr(); - auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); - auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq2_xxs_q8_1( - vx, vy, dst, ncols, nrows, item_ct1, - iq2xxs_grid_ptr_ct1, ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); - }); - }); - } +static void mul_mat_vec_iq2_xxs_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq2xxs_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq2xxs_grid_ptr_ct1 = iq2xxs_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq2_xxs_q8_1( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq2xxs_grid_ptr_ct1, + ksigns_iq2xs_ptr_ct1, + kmask_iq2xs_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq2xs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq2xs_grid_ptr_ct1 = iq2xs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq2_xs_q8_1( - vx, vy, dst, ncols, nrows, item_ct1, - iq2xs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_iq2_xs_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq2xs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq2xs_grid_ptr_ct1 = iq2xs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq2_xs_q8_1( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq2xs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3xxs_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq3_xxs_q8_1( - vx, vy, dst, ncols, nrows, item_ct1, - iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_iq3_xxs_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq3_xxs_q8_1( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3xxs_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq3s_grid.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq3_s_q8_1( - vx, vy, dst, ncols, nrows, item_ct1, - iq3s_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_iq3_s_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3s_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq3s_grid_ptr_ct1 = iq3s_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq3_s_q8_1( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq3s_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } -static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, - dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - { - iq1s_grid_gpu.init(*stream); - ksigns64.init(*stream); - - stream->submit([&](sycl::handler &cgh) { - auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr(); - auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq1_s_q8_1( - vx, vy, dst, ncols, nrows, item_ct1, - iq1s_grid_ptr_ct1, ksigns64_ptr_ct1); - }); - }); - } +static void mul_mat_vec_iq1_s_q8_1_sycl( + const void* vx, + const void* vy, + float* dst, + const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq1s_grid_gpu.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler& cgh) { + auto iq1s_grid_ptr_ct1 = iq1s_grid_gpu.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq1_s_q8_1( + vx, + vy, + dst, + ncols, + nrows, + item_ct1, + iq1s_grid_ptr_ct1, + ksigns64_ptr_ct1); + }); + }); + } } void ggml_sycl_op_mul_mat_vec_q( - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_row_size, - const dpct::queue_ptr &stream) { - - GGML_ASSERT(ggml_nrows(src1) == 1); - - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; - default: - GGML_ASSERT(false); - break; - } - - (void) src1; - (void) dst; - (void) src1_ddf_i; - (void) src1_ncols; - (void) src1_padded_row_size; + const ggml_tensor* src0, + const ggml_tensor* src1, + ggml_tensor* dst, + const char* src0_dd_i, + const float* src1_ddf_i, + const char* src1_ddq_i, + float* dst_dd_i, + const int64_t row_low, + const int64_t row_high, + const int64_t src1_ncols, + const int64_t src1_padded_row_size, + const dpct::queue_ptr& stream) { + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q4_K_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q6_K_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + default: + GGML_ASSERT(false); + break; + } + + (void)src1; + (void)dst; + (void)src1_ddf_i; + (void)src1_ncols; + (void)src1_padded_row_size; } - diff --git a/ggml-sycl/mmvq.hpp b/ggml-sycl/mmvq.hpp index 46f0d14150e79..82f4be6527179 100644 --- a/ggml-sycl/mmvq.hpp +++ b/ggml-sycl/mmvq.hpp @@ -16,10 +16,17 @@ #include "common.hpp" void ggml_sycl_op_mul_mat_vec_q( - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_row_size, - const dpct::queue_ptr &stream); + const ggml_tensor* src0, + const ggml_tensor* src1, + ggml_tensor* dst, + const char* src0_dd_i, + const float* src1_ddf_i, + const char* src1_ddq_i, + float* dst_dd_i, + const int64_t row_low, + const int64_t row_high, + const int64_t src1_ncols, + const int64_t src1_padded_row_size, + const dpct::queue_ptr& stream); #endif // GGML_SYCL_MMVQ_HPP \ No newline at end of file diff --git a/ggml-sycl/vecdotq.hpp b/ggml-sycl/vecdotq.hpp index a8c7af7b0a25c..a69494d5b0ed7 100644 --- a/ggml-sycl/vecdotq.hpp +++ b/ggml-sycl/vecdotq.hpp @@ -13,1143 +13,1250 @@ #ifndef GGML_SYCL_VECDOTQ_HPP #define GGML_SYCL_VECDOTQ_HPP -static __dpct_inline__ int get_int_from_int8(const int8_t *x8, const int &i32) { - const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment +static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { + const uint16_t* x16 = + (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte + // alignment - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; - return x32; + return x32; } -static __dpct_inline__ int get_int_from_uint8(const uint8_t *x8, - const int &i32) { - const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment +static __dpct_inline__ int get_int_from_uint8( + const uint8_t* x8, + const int& i32) { + const uint16_t* x16 = + (const uint16_t*)(x8 + sizeof(int) * i32); // assume at least 2 byte + // alignment - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; + int x32 = 0; + x32 |= x16[0] << 0; + x32 |= x16[1] << 16; - return x32; + return x32; } -static __dpct_inline__ int get_int_from_int8_aligned(const int8_t *x8, - const int &i32) { - return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +static __dpct_inline__ int get_int_from_int8_aligned( + const int8_t* x8, + const int& i32) { + return *( + (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } -static __dpct_inline__ int get_int_from_uint8_aligned(const uint8_t *x8, - const int &i32) { - return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment +static __dpct_inline__ int get_int_from_uint8_aligned( + const uint8_t* x8, + const int& i32) { + return *( + (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } -// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called -// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +// VDR = vec dot ratio, how many contiguous integers each thread processes when +// the vec dot kernel is called MMVQ = mul_mat_vec_q, MMQ = mul_mat_q #define VDR_Q4_0_Q8_1_MMVQ 2 -#define VDR_Q4_0_Q8_1_MMQ 4 +#define VDR_Q4_0_Q8_1_MMQ 4 template -static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, - const float &d4, - const sycl::half2 &ds8) { - int sumi = 0; +static __dpct_inline__ float vec_dot_q4_0_q8_1_impl( + const int* v, + const int* u, + const float& d4, + const sycl::half2& ds8) { + int sumi = 0; #pragma unroll - for (int i = 0; i < vdr; ++i) { - const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; - const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; - // SIMD dot product of quantized values - sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); - sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); - } + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } - const sycl::float2 ds8f = - ds8.convert(); + const sycl::float2 ds8f = + ds8.convert(); - // second part effectively subtracts 8 from each quant value - return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); } #define VDR_Q4_1_Q8_1_MMVQ 2 -#define VDR_Q4_1_Q8_1_MMQ 4 +#define VDR_Q4_1_Q8_1_MMQ 4 template -static __dpct_inline__ float vec_dot_q4_1_q8_1_impl(const int *v, const int *u, - const sycl::half2 &dm4, - const sycl::half2 &ds8) { - - int sumi = 0; +static __dpct_inline__ float vec_dot_q4_1_q8_1_impl( + const int* v, + const int* u, + const sycl::half2& dm4, + const sycl::half2& ds8) { + int sumi = 0; #pragma unroll - for (int i = 0; i < vdr; ++i) { - const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; - const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; - // SIMD dot product of quantized values - sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); - sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); - } + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } #ifdef GGML_SYCL_F16 - const sycl::float2 tmp = - (dm4 * ds8).convert(); - const float d4d8 = tmp.x(); - const float m4s8 = tmp.y(); + const sycl::float2 tmp = + (dm4 * ds8).convert(); + const float d4d8 = tmp.x(); + const float m4s8 = tmp.y(); #else - const sycl::float2 dm4f = - dm4.convert(); - const sycl::float2 ds8f = - ds8.convert(); - const float d4d8 = dm4f.x() * ds8f.x(); - const float m4s8 = dm4f.y() * ds8f.y(); + const sycl::float2 dm4f = + dm4.convert(); + const sycl::float2 ds8f = + ds8.convert(); + const float d4d8 = dm4f.x() * ds8f.x(); + const float m4s8 = dm4f.y() * ds8f.y(); #endif // GGML_SYCL_F16 - // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it - return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple + // threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); } #define VDR_Q5_0_Q8_1_MMVQ 2 -#define VDR_Q5_0_Q8_1_MMQ 4 +#define VDR_Q5_0_Q8_1_MMQ 4 template -static __dpct_inline__ float -vec_dot_q5_0_q8_1_impl(const int *vl, const int *vh, const int *u, - const float &d5, const sycl::half2 &ds8) { - int sumi = 0; +static __dpct_inline__ float vec_dot_q5_0_q8_1_impl( + const int* vl, + const int* vh, + const int* u, + const float& d5, + const sycl::half2& ds8) { + int sumi = 0; #pragma unroll - for (int i = 0; i < vdr; ++i) { - int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits - vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 - vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 - vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 - vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = dpct::dp4a(vi0, u[2 * i + 0], - sumi); // SIMD dot product of quantized values - - int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits - vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 - vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 - vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 - vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = dpct::dp4a(vi1, u[2 * i + 1], - sumi); // SIMD dot product of quantized values - } - - const sycl::float2 ds8f = - ds8.convert(); - - // second part effectively subtracts 16 from each quant value - return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y()); + for (int i = 0; i < vdr; ++i) { + int vi0 = + (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = dpct::dp4a( + vi0, + u[2 * i + 0], + sumi); // SIMD dot product of quantized values + + int vi1 = + (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = dpct::dp4a( + vi1, + u[2 * i + 1], + sumi); // SIMD dot product of quantized values + } + + const sycl::float2 ds8f = + ds8.convert(); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x() - (16 * vdr / QI5_0) * ds8f.y()); } #define VDR_Q5_1_Q8_1_MMVQ 2 -#define VDR_Q5_1_Q8_1_MMQ 4 +#define VDR_Q5_1_Q8_1_MMQ 4 template -static __dpct_inline__ float -vec_dot_q5_1_q8_1_impl(const int *vl, const int *vh, const int *u, - const sycl::half2 &dm5, const sycl::half2 &ds8) { - - int sumi = 0; +static __dpct_inline__ float vec_dot_q5_1_q8_1_impl( + const int* vl, + const int* vh, + const int* u, + const sycl::half2& dm5, + const sycl::half2& ds8) { + int sumi = 0; #pragma unroll - for (int i = 0; i < vdr; ++i) { - int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits - vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 - vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 - vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 - vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = dpct::dp4a(vi0, u[2 * i + 0], - sumi); // SIMD dot product of quantized values - - int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits - vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 - vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 - vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 - vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = dpct::dp4a(vi1, u[2 * i + 1], - sumi); // SIMD dot product of quantized values - } + for (int i = 0; i < vdr; ++i) { + int vi0 = + (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = dpct::dp4a( + vi0, + u[2 * i + 0], + sumi); // SIMD dot product of quantized values + + int vi1 = + (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = dpct::dp4a( + vi1, + u[2 * i + 1], + sumi); // SIMD dot product of quantized values + } #ifdef GGML_SYCL_F16 - const sycl::float2 tmp = - (dm5 * ds8).convert(); - const float d5d8 = tmp.x(); - const float m5s8 = tmp.y(); - + const sycl::float2 tmp = + (dm5 * ds8).convert(); + const float d5d8 = tmp.x(); + const float m5s8 = tmp.y(); #else - const sycl::float2 dm5f = - dm5.convert(); - const sycl::float2 ds8f = - ds8.convert(); - const float d5d8 = dm5f.x() * ds8f.x(); - const float m5s8 = dm5f.y() * ds8f.y(); + const sycl::float2 dm5f = + dm5.convert(); + const sycl::float2 ds8f = + ds8.convert(); + const float d5d8 = dm5f.x() * ds8f.x(); + const float m5s8 = dm5f.y() * ds8f.y(); #endif // GGML_SYCL_F16 - // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it - return sumi*d5d8 + m5s8 / (QI5_1 / vdr); + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads + // adding it + return sumi * d5d8 + m5s8 / (QI5_1 / vdr); } #define VDR_Q8_0_Q8_1_MMVQ 2 #define VDR_Q8_0_Q8_1_MMQ 8 template -static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u, - const float &d8_0, - const float &d8_1) { - - int sumi = 0; +static __dpct_inline__ float vec_dot_q8_0_q8_1_impl( + const int* v, + const int* u, + const float& d8_0, + const float& d8_1) { + int sumi = 0; #pragma unroll - for (int i = 0; i < vdr; ++i) { - // SIMD dot product of quantized values - sumi = dpct::dp4a(v[i], u[i], sumi); - } + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = dpct::dp4a(v[i], u[i], sumi); + } - return d8_0*d8_1 * sumi; + return d8_0 * d8_1 * sumi; } template -static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u, - const sycl::half2 &dm8, - const sycl::half2 &ds8) { - - int sumi = 0; +static __dpct_inline__ float vec_dot_q8_1_q8_1_impl( + const int* v, + const int* u, + const sycl::half2& dm8, + const sycl::half2& ds8) { + int sumi = 0; #pragma unroll - for (int i = 0; i < vdr; ++i) { - // SIMD dot product of quantized values - sumi = dpct::dp4a(v[i], u[i], sumi); - } + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = dpct::dp4a(v[i], u[i], sumi); + } #ifdef GGML_SYCL_F16 - const sycl::float2 tmp = - (dm8 * ds8).convert(); - const float d8d8 = tmp.x(); - const float m8s8 = tmp.y(); + const sycl::float2 tmp = + (dm8 * ds8).convert(); + const float d8d8 = tmp.x(); + const float m8s8 = tmp.y(); #else - const sycl::float2 dm8f = - dm8.convert(); - const sycl::float2 ds8f = - ds8.convert(); - const float d8d8 = dm8f.x() * ds8f.x(); - const float m8s8 = dm8f.y() * ds8f.y(); + const sycl::float2 dm8f = + dm8.convert(); + const sycl::float2 ds8f = + ds8.convert(); + const float d8d8 = dm8f.x() * ds8f.x(); + const float m8s8 = dm8f.y() * ds8f.y(); #endif // GGML_SYCL_F16 - // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it - return sumi*d8d8 + m8s8 / (QI8_1 / vdr); + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads + // adding it + return sumi * d8d8 + m8s8 / (QI8_1 / vdr); } #define VDR_Q2_K_Q8_1_MMVQ 1 -#define VDR_Q2_K_Q8_1_MMQ 2 +#define VDR_Q2_K_Q8_1_MMQ 2 // contiguous v/x values static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmvq( - const int &v, const int *__restrict__ u, const uint8_t *__restrict__ scales, - const sycl::half2 &dm2, const float *__restrict__ d8) { - - float sumf_d = 0.0f; - float sumf_m = 0.0f; + const int& v, + const int* __restrict__ u, + const uint8_t* __restrict__ scales, + const sycl::half2& dm2, + const float* __restrict__ d8) { + float sumf_d = 0.0f; + float sumf_m = 0.0f; #pragma unroll - for (int i = 0; i < QR2_K; ++i) { - const int sc = scales[2*i]; - - const int vi = (v >> (2*i)) & 0x03030303; - - sumf_d += - d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product - - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; - sumf_m += d8[i] * - dpct::dp4a( - m, u[i], - 0); // multiply constant q2_K part with sum of q8_1 values - } + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2 * i]; + + const int vi = (v >> (2 * i)) & 0x03030303; + + sumf_d += + d8[i] * (dpct::dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * + dpct::dp4a( + m, + u[i], + 0); // multiply constant q2_K part with sum of q8_1 values + } - const sycl::float2 dm2f = - dm2.convert(); + const sycl::float2 dm2f = + dm2.convert(); - return dm2f.x() * sumf_d - dm2f.y() * sumf_m; + return dm2f.x() * sumf_d - dm2f.y() * sumf_m; } // contiguous u/y values -static __dpct_inline__ float -vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u, - const uint8_t *__restrict__ scales, - const sycl::half2 &dm2, const float &d8) { - - int sumi_d = 0; - int sumi_m = 0; +static __dpct_inline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int* __restrict__ v, + const int* __restrict__ u, + const uint8_t* __restrict__ scales, + const sycl::half2& dm2, + const float& d8) { + int sumi_d = 0; + int sumi_m = 0; #pragma unroll - for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - int sumi_d_sc = 0; + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1 / 2) { + int sumi_d_sc = 0; - const int sc = scales[i0 / (QI8_1/2)]; + const int sc = scales[i0 / (QI8_1 / 2)]; - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; #pragma unroll - for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product - sumi_m = dpct::dp4a(m, u[i], - sumi_m); // multiply sum of q8_1 values with m - } - - sumi_d += sumi_d_sc * (sc & 0xF); + for (int i = i0; i < i0 + QI8_1 / 2; ++i) { + sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = dpct::dp4a( + m, + u[i], + sumi_m); // multiply sum of q8_1 values with m } - const sycl::float2 dm2f = - dm2.convert(); + sumi_d += sumi_d_sc * (sc & 0xF); + } - return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m); + const sycl::float2 dm2f = + dm2.convert(); + + return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m); } #define VDR_Q3_K_Q8_1_MMVQ 1 -#define VDR_Q3_K_Q8_1_MMQ 2 +#define VDR_Q3_K_Q8_1_MMQ 2 // contiguous v/x values static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmvq( - const int &vl, const int &vh, const int *__restrict__ u, - const uint8_t *__restrict__ scales, const int &scale_offset, - const float &d3, const float *__restrict__ d8) { - - float sumf = 0.0f; + const int& vl, + const int& vh, + const int* __restrict__ u, + const uint8_t* __restrict__ scales, + const int& scale_offset, + const float& d3, + const float* __restrict__ d8) { + float sumf = 0.0f; #pragma unroll - for (int i = 0; i < QR3_K; ++i) { - const int isc = scale_offset + 2*i; + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2 * i; - const int isc_low = isc % (QK_K/32); - const int sc_shift_low = 4 * (isc / (QK_K/32)); - const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + const int isc_low = isc % (QK_K / 32); + const int sc_shift_low = 4 * (isc / (QK_K / 32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; - const int isc_high = isc % (QK_K/64); - const int sc_shift_high = 2 * (isc / (QK_K/64)); - const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + const int isc_high = isc % (QK_K / 64); + const int sc_shift_high = 2 * (isc / (QK_K / 64)); + const int sc_high = ((scales[(QK_K / 32) + isc_high] >> sc_shift_high) & 3) + << 4; - const int sc = (sc_low | sc_high) - 32; + const int sc = (sc_low | sc_high) - 32; - const int vil = (vl >> (2*i)) & 0x03030303; + const int vil = (vl >> (2 * i)) & 0x03030303; - const int vih = ((vh >> i) << 2) & 0x04040404; + const int vih = ((vh >> i) << 2) & 0x04040404; - const int vi = - dpct::vectorized_binary(vil, vih, dpct::sub_sat()); + const int vi = + dpct::vectorized_binary(vil, vih, dpct::sub_sat()); - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } - return d3 * sumf; + return d3 * sumf; } // contiguous u/y values -static __dpct_inline__ float -vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u, - const int8_t *__restrict__ scales, const float &d3, - const float &d8) { - - int sumi = 0; +static __dpct_inline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int* __restrict__ v, + const int* __restrict__ u, + const int8_t* __restrict__ scales, + const float& d3, + const float& d8) { + int sumi = 0; #pragma unroll - for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { - int sumi_sc = 0; - - for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product - } + for (int i0 = 0; i0 < QR3_K * VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1 / 2) { + int sumi_sc = 0; - sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + for (int i = i0; i < i0 + QI8_1 / 2; ++i) { + sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product } - return d3*d8 * sumi; + sumi += sumi_sc * scales[i0 / (QI8_1 / 2)]; + } + + return d3 * d8 * sumi; } #define VDR_Q4_K_Q8_1_MMVQ 2 -#define VDR_Q4_K_Q8_1_MMQ 8 +#define VDR_Q4_K_Q8_1_MMQ 8 // contiguous v/x values static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_vmmq( - const int *__restrict__ v, const int *__restrict__ u, - const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m, - const sycl::half2 &dm4, const float *__restrict__ d8) { - - float sumf_d = 0.0f; - float sumf_m = 0.0f; + const int* __restrict__ v, + const int* __restrict__ u, + const uint8_t* __restrict__ sc, + const uint8_t* __restrict__ m, + const sycl::half2& dm4, + const float* __restrict__ d8) { + float sumf_d = 0.0f; + float sumf_m = 0.0f; #pragma unroll - for (int i = 0; i < QR4_K; ++i) { - const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; - const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; - - const int dot1 = - dpct::dp4a(v1i, u[2 * i + 1], - dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product - const int dot2 = - dpct::dp4a(0x01010101, u[2 * i + 1], - dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u - - sumf_d += d8[i] * (dot1 * sc[i]); - sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values - } - - const sycl::float2 dm4f = - dm4.convert(); - - return dm4f.x() * sumf_d - dm4f.y() * sumf_m; + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4 * i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4 * i)) & 0x0F0F0F0F; + + const int dot1 = dpct::dp4a( + v1i, + u[2 * i + 1], + dpct::dp4a(v0i, u[2 * i + 0], 0)); // SIMD dot product + const int dot2 = dpct::dp4a( + 0x01010101, + u[2 * i + 1], + dpct::dp4a(0x01010101, u[2 * i + 0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * + (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const sycl::float2 dm4f = + dm4.convert(); + + return dm4f.x() * sumf_d - dm4f.y() * sumf_m; } // contiguous u/y values static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq( - const int *__restrict__ v, const int *__restrict__ u, - const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m, - const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) { - - float sumf_d = 0.0f; - float sumf_m = 0.0f; + const int* __restrict__ v, + const int* __restrict__ u, + const uint8_t* __restrict__ sc, + const uint8_t* __restrict__ m, + const sycl::half2& dm4, + const sycl::half2* __restrict__ ds8) { + float sumf_d = 0.0f; + float sumf_m = 0.0f; #pragma unroll - for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { - int sumi_d = 0; + for (int i = 0; i < QR4_K * VDR_Q4_K_Q8_1_MMQ / QI8_1; ++i) { + int sumi_d = 0; #pragma unroll - for (int j = 0; j < QI8_1; ++j) { - sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F, - u[i * QI8_1 + j], sumi_d); // SIMD dot product - } + for (int j = 0; j < QI8_1; ++j) { + sumi_d = dpct::dp4a( + (v[j] >> (4 * i)) & 0x0F0F0F0F, + u[i * QI8_1 + j], + sumi_d); // SIMD dot product + } - const sycl::float2 ds8f = - ds8[i].convert(); + const sycl::float2 ds8f = + ds8[i].convert(); - sumf_d += ds8f.x() * (sc[i] * sumi_d); - sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val - } + sumf_d += ds8f.x() * (sc[i] * sumi_d); + sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val + } - const sycl::float2 dm4f = - dm4.convert(); + const sycl::float2 dm4f = + dm4.convert(); - return dm4f.x() * sumf_d - dm4f.y() * sumf_m; + return dm4f.x() * sumf_d - dm4f.y() * sumf_m; } #define VDR_Q5_K_Q8_1_MMVQ 2 -#define VDR_Q5_K_Q8_1_MMQ 8 +#define VDR_Q5_K_Q8_1_MMQ 8 // contiguous v/x values static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_vmmq( - const int *__restrict__ vl, const int *__restrict__ vh, - const int *__restrict__ u, const uint8_t *__restrict__ sc, - const uint8_t *__restrict__ m, const sycl::half2 &dm5, - const float *__restrict__ d8) { - - float sumf_d = 0.0f; - float sumf_m = 0.0f; + const int* __restrict__ vl, + const int* __restrict__ vh, + const int* __restrict__ u, + const uint8_t* __restrict__ sc, + const uint8_t* __restrict__ m, + const sycl::half2& dm5, + const float* __restrict__ d8) { + float sumf_d = 0.0f; + float sumf_m = 0.0f; #pragma unroll - for (int i = 0; i < QR5_K; ++i) { - const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; - const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4 * i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4 * i)) & 0x0F0F0F0F; - const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; - const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; - const int v0i = vl0i | vh0i; - const int v1i = vl1i | vh1i; + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; - const int dot1 = - dpct::dp4a(v0i, u[2 * i + 0], - dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product - const int dot2 = - dpct::dp4a(0x01010101, u[2 * i + 0], - dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u + const int dot1 = dpct::dp4a( + v0i, + u[2 * i + 0], + dpct::dp4a(v1i, u[2 * i + 1], 0)); // SIMD dot product + const int dot2 = dpct::dp4a( + 0x01010101, + u[2 * i + 0], + dpct::dp4a(0x01010101, u[2 * i + 1], 0)); // sum of u - sumf_d += d8[i] * (dot1 * sc[i]); - sumf_m += d8[i] * (dot2 * m[i]); + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + } - } + const sycl::float2 dm5f = + dm5.convert(); - const sycl::float2 dm5f = - dm5.convert(); - - return dm5f.x() * sumf_d - dm5f.y() * sumf_m; + return dm5f.x() * sumf_d - dm5f.y() * sumf_m; } // contiguous u/y values static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq( - const int *__restrict__ v, const int *__restrict__ u, - const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m, - const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) { - - float sumf_d = 0.0f; - float sumf_m = 0.0f; + const int* __restrict__ v, + const int* __restrict__ u, + const uint8_t* __restrict__ sc, + const uint8_t* __restrict__ m, + const sycl::half2& dm4, + const sycl::half2* __restrict__ ds8) { + float sumf_d = 0.0f; + float sumf_m = 0.0f; #pragma unroll - for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { - int sumi_d = 0; + for (int i = 0; i < QR5_K * VDR_Q5_K_Q8_1_MMQ / QI8_1; ++i) { + int sumi_d = 0; #pragma unroll - for (int j = 0; j < QI8_1; ++j) { - sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j], - sumi_d); // SIMD dot product - } + for (int j = 0; j < QI8_1; ++j) { + sumi_d = dpct::dp4a( + v[i * QI8_1 + j], + u[i * QI8_1 + j], + sumi_d); // SIMD dot product + } - const sycl::float2 ds8f = - ds8[i].convert(); + const sycl::float2 ds8f = + ds8[i].convert(); - sumf_d += ds8f.x() * (sc[i] * sumi_d); - sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val - } + sumf_d += ds8f.x() * (sc[i] * sumi_d); + sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val + } - const sycl::float2 dm4f = - dm4.convert(); + const sycl::float2 dm4f = + dm4.convert(); - return dm4f.x() * sumf_d - dm4f.y() * sumf_m; + return dm4f.x() * sumf_d - dm4f.y() * sumf_m; } #define VDR_Q6_K_Q8_1_MMVQ 1 -#define VDR_Q6_K_Q8_1_MMQ 8 +#define VDR_Q6_K_Q8_1_MMQ 8 // contiguous v/x values -static __dpct_inline__ float -vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, - const int *__restrict__ u, - const int8_t *__restrict__ scales, const float &d, - const float *__restrict__ d8) { - - float sumf = 0.0f; +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int& vl, + const int& vh, + const int* __restrict__ u, + const int8_t* __restrict__ scales, + const float& d, + const float* __restrict__ d8) { + float sumf = 0.0f; #pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4 * i]; - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; - const int vi = dpct::vectorized_binary( - (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 + const int vi = dpct::vectorized_binary( + (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } - return d*sumf; + return d * sumf; } // contiguous u/y values -static __dpct_inline__ float -vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u, - const int8_t *__restrict__ sc, const float &d6, - const float *__restrict__ d8) { - - float sumf_d = 0.0f; +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int* __restrict__ v, + const int* __restrict__ u, + const int8_t* __restrict__ sc, + const float& d6, + const float* __restrict__ d8) { + float sumf_d = 0.0f; #pragma unroll - for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { - sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale #pragma unroll - for (int i = i0; i < i0 + 2; ++i) { - sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0], - sumi_d.x()); // SIMD dot product - sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1], - sumi_d.x()); // SIMD dot product - - sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4], - sumi_d.y()); // SIMD dot product - sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5], - sumi_d.y()); // SIMD dot product - } - - sumf_d += d8[i0 / 4] * - (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y()); + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x() = dpct::dp4a( + v[2 * i + 0], + u[2 * i + 0], + sumi_d.x()); // SIMD dot product + sumi_d.x() = dpct::dp4a( + v[2 * i + 1], + u[2 * i + 1], + sumi_d.x()); // SIMD dot product + + sumi_d.y() = dpct::dp4a( + v[2 * i + 4], + u[2 * i + 4], + sumi_d.y()); // SIMD dot product + sumi_d.y() = dpct::dp4a( + v[2 * i + 5], + u[2 * i + 5], + sumi_d.y()); // SIMD dot product } - return d6 * sumf_d; -} - + sumf_d += d8[i0 / 4] * + (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y()); + } -static __dpct_inline__ float -vec_dot_q4_0_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { + return d6 * sumf_d; +} - const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; +static __dpct_inline__ float vec_dot_q4_0_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q4_0* bq4_0 = (const block_q4_0*)vbq; - int v[VDR_Q4_0_Q8_1_MMVQ]; - int u[2*VDR_Q4_0_Q8_1_MMVQ]; + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2 * VDR_Q4_0_Q8_1_MMVQ]; #pragma unroll - for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); - } + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + } - return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); + return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); } -static __dpct_inline__ float -vec_dot_q4_1_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - - const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; +static __dpct_inline__ float vec_dot_q4_1_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q4_1* bq4_1 = (const block_q4_1*)vbq; - int v[VDR_Q4_1_Q8_1_MMVQ]; - int u[2*VDR_Q4_1_Q8_1_MMVQ]; + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2 * VDR_Q4_1_Q8_1_MMVQ]; #pragma unroll - for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); - } + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + } - return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); + return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); } -static __dpct_inline__ float -vec_dot_q5_0_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - - const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; +static __dpct_inline__ float vec_dot_q5_0_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q5_0* bq5_0 = (const block_q5_0*)vbq; - int vl[VDR_Q5_0_Q8_1_MMVQ]; - int vh[VDR_Q5_0_Q8_1_MMVQ]; - int u[2*VDR_Q5_0_Q8_1_MMVQ]; + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2 * VDR_Q5_0_Q8_1_MMVQ]; #pragma unroll - for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { - vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); - vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); - } - - return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); + vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl( + vl, vh, u, bq5_0->d, bq8_1->ds); } -static __dpct_inline__ float -vec_dot_q5_1_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - - const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; +static __dpct_inline__ float vec_dot_q5_1_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q5_1* bq5_1 = (const block_q5_1*)vbq; - int vl[VDR_Q5_1_Q8_1_MMVQ]; - int vh[VDR_Q5_1_Q8_1_MMVQ]; - int u[2*VDR_Q5_1_Q8_1_MMVQ]; + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2 * VDR_Q5_1_Q8_1_MMVQ]; #pragma unroll - for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { - vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); - vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); - } - - return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); + vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl( + vl, vh, u, bq5_1->dm, bq8_1->ds); } -static __dpct_inline__ float -vec_dot_q8_0_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - - const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; +static __dpct_inline__ float vec_dot_q8_0_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q8_0* bq8_0 = (const block_q8_0*)vbq; - int v[VDR_Q8_0_Q8_1_MMVQ]; - int u[VDR_Q8_0_Q8_1_MMVQ]; + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; #pragma unroll - for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_int8(bq8_0->qs, iqs + i); - u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - } + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_from_int8(bq8_0->qs, iqs + i); + u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + } - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, - bq8_1->ds[0]); + return vec_dot_q8_0_q8_1_impl( + v, u, bq8_0->d, bq8_1->ds[0]); } -static __dpct_inline__ float -vec_dot_q2_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { +static __dpct_inline__ float vec_dot_q2_K_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q2_K* bq2_K = (const block_q2_K*)vbq; - const block_q2_K * bq2_K = (const block_q2_K *) vbq; + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2); - const int bq8_offset = QR2_K * (iqs / QI8_1); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + const uint8_t* scales = bq2_K->scales + scale_offset; - const uint8_t * scales = bq2_K->scales + scale_offset; - - const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); - int u[QR2_K]; - float d8[QR2_K]; + const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; #pragma unroll - for (int i = 0; i < QR2_K; ++ i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + i].ds[0]; - } + for (int i = 0; i < QR2_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds[0]; + } - return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); } -static __dpct_inline__ float -vec_dot_q3_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { +static __dpct_inline__ float vec_dot_q3_K_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q3_K* bq3_K = (const block_q3_K*)vbq; - const block_q3_K * bq3_K = (const block_q3_K *) vbq; + const int bq8_offset = QR3_K * (iqs / (QI3_K / 2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2); - const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); - const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + const float d = bq3_K->d; - const float d = bq3_K->d; + const int vl = get_int_from_uint8(bq3_K->qs, iqs); - const int vl = get_int_from_uint8(bq3_K->qs, iqs); + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = + ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K / 2)) >> bq8_offset; - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; - - int u[QR3_K]; - float d8[QR3_K]; + int u[QR3_K]; + float d8[QR3_K]; #pragma unroll - for (int i = 0; i < QR3_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + i].ds[0]; - } + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + i].ds[0]; + } - return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); + return vec_dot_q3_K_q8_1_impl_mmvq( + vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -static __dpct_inline__ float -vec_dot_q4_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - +static __dpct_inline__ float vec_dot_q4_K_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { #ifndef GGML_QKK_64 - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - int v[2]; - int u[2*QR4_K]; - float d8[QR4_K]; - - // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); - - // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 - // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 - // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 - // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 - - const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - v[0] = q4[0]; - v[1] = q4[4]; - - const uint16_t * scales = (const uint16_t *)bq4_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int i = 0; i < QR4_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds[0]; - - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } - - return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + const block_q4_K* bq4_K = (const block_q4_K*)vbq; + + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int* q4 = + (const int*)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t* scales = (const uint16_t*)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + const uint8_t* sc = (const uint8_t*)aux; + const uint8_t* m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1* bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; + + const int* q8 = (const int*)bq8i->qs + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); #else -#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics - const block_q4_K * bq4_K = (const block_q4_K *) vbq; +#if __SYCL_ARCH__ >= \ + VER_4VEC // lowest compute capability for integer intrinsics + const block_q4_K* bq4_K = (const block_q4_K*)vbq; - float sumf_d = 0.0f; - float sumf_m = 0.0f; + float sumf_d = 0.0f; + float sumf_m = 0.0f; - uint16_t aux16[2]; - const uint8_t * s = (const uint8_t *)aux16; + uint16_t aux16[2]; + const uint8_t* s = (const uint8_t*)aux16; - const uint16_t * a = (const uint16_t *)bq4_K->scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; + const uint16_t* a = (const uint16_t*)bq4_K->scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; - const float dall = bq4_K->dm[0]; - const float dmin = bq4_K->dm[1]; + const float dall = bq4_K->dm[0]; + const float dmin = bq4_K->dm[1]; - const float d8_1 = bq8_1[0].ds[0]; - const float d8_2 = bq8_1[1].ds[1]; + const float d8_1 = bq8_1[0].ds[0]; + const float d8_2 = bq8_1[1].ds[1]; - const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); - const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); - const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); - const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + const int ui1 = *((const int*)bq8_1[0].qs + (iqs / 2)); + const int ui2 = *((const int*)bq8_1[0].qs + (iqs / 2) + 4); + const int ui3 = *((const int*)bq8_1[1].qs + (iqs / 2)); + const int ui4 = *((const int*)bq8_1[1].qs + (iqs / 2) + 4); - const int * q4 = (const int *)bq4_K->qs + (iqs/2); - const int v1 = q4[0]; - const int v2 = q4[4]; + const int* q4 = (const int*)bq4_K->qs + (iqs / 2); + const int v1 = q4[0]; + const int v2 = q4[4]; - const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0)); - const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); - const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0)); - const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0)); + const int dot1 = + dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = dpct::dp4a( + ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0)); + const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0)); - sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); - sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); + sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); + sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); - return dall * sumf_d - dmin * sumf_m; + return dall * sumf_d - dmin * sumf_m; #else - bad_arch(); + bad_arch(); #endif // __SYCL_ARCH__ >= VER_4VEC #endif } -static __dpct_inline__ float -vec_dot_q5_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - +static __dpct_inline__ float vec_dot_q5_K_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { #ifndef GGML_QKK_64 - const block_q5_K * bq5_K = (const block_q5_K *) vbq; - - int vl[2]; - int vh[2]; - int u[2*QR5_K]; - float d8[QR5_K]; - - const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); - const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); - - vl[0] = ql[0]; - vl[1] = ql[4]; - - vh[0] = qh[0] >> bq8_offset; - vh[1] = qh[4] >> bq8_offset; - - const uint16_t * scales = (const uint16_t *)bq5_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; + const block_q5_K* bq5_K = (const block_q5_K*)vbq; + + int vl[2]; + int vh[2]; + int u[2 * QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2)); + const int* ql = + (const int*)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const int* qh = (const int*)(bq5_K->qh + 4 * ((iqs / 2) % 4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t* scales = (const uint16_t*)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + const uint8_t* sc = (const uint8_t*)aux; + const uint8_t* m = sc + 2; #pragma unroll - for (int i = 0; i < QR5_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds[0]; + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1* bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } + const int* q8 = (const int*)bq8i->qs + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } - return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); #else -#if __SYCL_ARCH__ >= VER_4VEC // lowest compute capability for integer intrinsics - const block_q5_K * bq5_K = (const block_q5_K *) vbq; +#if __SYCL_ARCH__ >= \ + VER_4VEC // lowest compute capability for integer intrinsics + const block_q5_K* bq5_K = (const block_q5_K*)vbq; - const int8_t * s = bq5_K->scales; + const int8_t* s = bq5_K->scales; - const float d = bq5_K->d; + const float d = bq5_K->d; - const float d8_1 = bq8_1[0].ds[0]; - const float d8_2 = bq8_1[1].ds[1]; + const float d8_1 = bq8_1[0].ds[0]; + const float d8_2 = bq8_1[1].ds[1]; - const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); - const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); - const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); - const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); + const int ui1 = *((const int*)bq8_1[0].qs + (iqs / 2)); + const int ui2 = *((const int*)bq8_1[0].qs + (iqs / 2) + 4); + const int ui3 = *((const int*)bq8_1[1].qs + (iqs / 2)); + const int ui4 = *((const int*)bq8_1[1].qs + (iqs / 2) + 4); - const int * ql = (const int *)bq5_K->qs + (iqs/2); - const int vl1 = ql[0]; - const int vl2 = ql[4]; + const int* ql = (const int*)bq5_K->qs + (iqs / 2); + const int vl1 = ql[0]; + const int vl2 = ql[4]; - const int step = 4 * (iqs/2); // 0, 4, 8, 12 - const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 - const int in = step%8; // 0, 4, 0, 4 - const int vh = (*((const int *)(bq5_K->qh + in))) >> im; + const int step = 4 * (iqs / 2); // 0, 4, 8, 12 + const int im = step / 8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 + const int in = step % 8; // 0, 4, 0, 4 + const int vh = (*((const int*)(bq5_K->qh + in))) >> im; - const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); - const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); - const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); - const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); + const int v1 = + (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); + const int v2 = + (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); + const int v3 = + (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); + const int v4 = + (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); - const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1]) - + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]); + const float sumf_d = + d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]); - return d * sumf_d; + return d * sumf_d; #else - bad_arch(); + bad_arch(); #endif // __SYCL_ARCH__ >= VER_4VEC #endif } -static __dpct_inline__ float -vec_dot_q6_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { +static __dpct_inline__ float vec_dot_q6_K_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs) { + const block_q6_K* bq6_K = (const block_q6_K*)vbq; - const block_q6_K * bq6_K = (const block_q6_K *) vbq; + const int bq8_offset = + 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4); + const int scale_offset = + (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8); + const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4)); - const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); - const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); - const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + const int vl = get_int_from_uint8(bq6_K->ql, iqs); + const int vh = + get_int_from_uint8( + bq6_K->qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> + vh_shift; - const int vl = get_int_from_uint8(bq6_K->ql, iqs); - const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + const int8_t* scales = bq6_K->scales + scale_offset; - const int8_t * scales = bq6_K->scales + scale_offset; - - int u[QR6_K]; - float d8[QR6_K]; + int u[QR6_K]; + float d8[QR6_K]; #pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; - } + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2 * i].qs, iqs % QI8_1); + d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; + } - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); } -static __dpct_inline__ float -vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs, - const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs, - const uint8_t *kmask_iq2xs) { +static __dpct_inline__ float vec_dot_iq2_xxs_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs, + const uint64_t* iq2xxs_grid, + const uint8_t* ksigns_iq2xs, + const uint8_t* kmask_iq2xs) { #if QK_K == 256 - const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + const block_iq2_xxs* bq2 = (const block_iq2_xxs*)vbq; #if QR2_XXS == 8 - const int ib32 = iqs; - const uint16_t * q2 = bq2->qs + 4*ib32; - const uint8_t * aux8 = (const uint8_t *)q2; - const int8_t * q8 = bq8_1[ib32].qs; - uint32_t aux32 = q2[2] | (q2[3] << 16); - int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); - const uint8_t signs = ksigns_iq2xs[aux32 & 127]; - for (int j = 0; j < 8; ++j) { - sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - aux32 >>= 7; - } - const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f; - return d * sumi; -#else - // iqs is 0...15 - const int ib32 = iqs/2; - const int il = iqs%2; - const uint16_t * q2 = bq2->qs + 4*ib32; - const uint8_t * aux8 = (const uint8_t *)q2; - const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]); - const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]); - const uint32_t aux32 = q2[2] | (q2[3] << 16); - const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f; - const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; - const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127]; - const int8_t * q8 = bq8_1[ib32].qs + 16*il; - int sumi1 = 0, sumi2 = 0; + const int ib32 = iqs; + const uint16_t* q2 = bq2->qs + 4 * ib32; + const uint8_t* aux8 = (const uint8_t*)q2; + const int8_t* q8 = bq8_1[ib32].qs; + uint32_t aux32 = q2[2] | (q2[3] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t* grid = (const uint8_t*)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[aux32 & 127]; for (int j = 0; j < 8; ++j) { - sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); - sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); + sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); } - return d * (sumi1 + sumi2); + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f; + return d * sumi; +#else + // iqs is 0...15 + const int ib32 = iqs / 2; + const int il = iqs % 2; + const uint16_t* q2 = bq2->qs + 4 * ib32; + const uint8_t* aux8 = (const uint8_t*)q2; + const uint8_t* grid1 = (const uint8_t*)(iq2xxs_grid + aux8[2 * il + 0]); + const uint8_t* grid2 = (const uint8_t*)(iq2xxs_grid + aux8[2 * il + 1]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = + (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f; + const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14 * il) & 127]; + const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14 * il + 7)) & 127]; + const int8_t* q8 = bq8_1[ib32].qs + 16 * il; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j + 0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); + sumi2 += q8[j + 8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); + } + return d * (sumi1 + sumi2); #endif #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif } - -static __dpct_inline__ float -vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs, - const uint64_t *iq2xs_grid, const uint64_t *ksigns64) { -#if DPCT_COMPATIBILITY_TEMP >= \ +static __dpct_inline__ float vec_dot_iq2_xs_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs, + const uint64_t* iq2xs_grid, + const uint64_t* ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ MIN_CC_DP4A // lowest compute capability for integer intrinsics #if QK_K == 256 - const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; - - const int ib32 = iqs; - const uint16_t * q2 = bq2->qs + 4*ib32; - const int8_t * q8 = bq8_1[ib32].qs; - const uint8_t ls1 = bq2->scales[ib32] & 0xf; - const uint8_t ls2 = bq2->scales[ib32] >> 4; - int sumi1 = 0; - for (int l = 0; l < 2; ++l) { - const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); - const int grid_l = dpct::vectorized_binary( - grid[0] ^ signs[0], signs[0], std::minus<>()); - const int grid_h = dpct::vectorized_binary( - grid[1] ^ signs[1], signs[1], std::minus<>()); - sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1); - sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1); - q8 += 8; - } - int sumi2 = 0; - for (int l = 2; l < 4; ++l) { - const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); - const int grid_l = dpct::vectorized_binary( - grid[0] ^ signs[0], signs[0], std::minus<>()); - const int grid_h = dpct::vectorized_binary( - grid[1] ^ signs[1], signs[1], std::minus<>()); - sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2); - sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2); - q8 += 8; - } - const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f; - return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); + const block_iq2_xs* bq2 = (const block_iq2_xs*)vbq; + + const int ib32 = iqs; + const uint16_t* q2 = bq2->qs + 4 * ib32; + const int8_t* q8 = bq8_1[ib32].qs; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint32_t* grid = (const uint32_t*)(iq2xs_grid + (q2[l] & 511)); + const uint32_t* signs = (const uint32_t*)(ksigns64 + (q2[l] >> 9)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi1 = dpct::dp4a(grid_l, *((const int*)q8 + 0), sumi1); + sumi1 = dpct::dp4a(grid_h, *((const int*)q8 + 1), sumi1); + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint32_t* grid = (const uint32_t*)(iq2xs_grid + (q2[l] & 511)); + const uint32_t* signs = (const uint32_t*)(ksigns64 + (q2[l] >> 9)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi2 = dpct::dp4a(grid_l, *((const int*)q8 + 0), sumi2); + sumi2 = dpct::dp4a(grid_h, *((const int*)q8 + 1), sumi2); + q8 += 8; + } + const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif } -static __dpct_inline__ float -vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs, - const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) { -#if DPCT_COMPATIBILITY_TEMP >= \ +static __dpct_inline__ float vec_dot_iq3_xxs_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs, + const uint32_t* iq3xxs_grid, + const uint64_t* ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ MIN_CC_DP4A // lowest compute capability for integer intrinsics #if QK_K == 256 - const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; - - const int ib32 = iqs; - const uint8_t * q3 = bq2->qs + 8*ib32; - const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; - const int8_t * q8 = bq8_1[ib32].qs; - uint32_t aux32 = gas[0] | (gas[1] << 16); - int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0]; - const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1]; - const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); - const int grid_l = dpct::vectorized_binary( - grid1[0] ^ signs[0], signs[0], std::minus<>()); - const int grid_h = dpct::vectorized_binary( - grid2[0] ^ signs[1], signs[1], std::minus<>()); - sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); - sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); - q8 += 8; - aux32 >>= 7; - } - const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f; - return d * sumi; + const block_iq3_xxs* bq2 = (const block_iq3_xxs*)vbq; + + const int ib32 = iqs; + const uint8_t* q3 = bq2->qs + 8 * ib32; + const uint16_t* gas = (const uint16_t*)(bq2->qs + QK_K / 4) + 2 * ib32; + const int8_t* q8 = bq8_1[ib32].qs; + uint32_t aux32 = gas[0] | (gas[1] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t* grid1 = iq3xxs_grid + q3[2 * l + 0]; + const uint32_t* grid2 = iq3xxs_grid + q3[2 * l + 1]; + const uint32_t* signs = (const uint32_t*)(ksigns64 + (aux32 & 127)); + const int grid_l = dpct::vectorized_binary( + grid1[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid2[0] ^ signs[1], signs[1], std::minus<>()); + sumi = dpct::dp4a(grid_l, *((int*)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((int*)q8 + 1), sumi); + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f; + return d * sumi; #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif } -static __dpct_inline__ float -vec_dot_iq3_s_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs, - const uint32_t *iq3s_grid, const uint64_t *ksigns64) { -#if DPCT_COMPATIBILITY_TEMP >= \ +static __dpct_inline__ float vec_dot_iq3_s_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs, + const uint32_t* iq3s_grid, + const uint64_t* ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ MIN_CC_DP4A // lowest compute capability for integer intrinsics #if QK_K == 256 - const block_iq3_s * bq2 = (const block_iq3_s *) vbq; - - const int ib32 = iqs; - const uint8_t * qs = bq2->qs + 8*ib32; - const int8_t * q8 = bq8_1[ib32].qs; - int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); - uint32_t signs0 = dpct::vectorized_binary( - ((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>()); - uint32_t signs1 = dpct::vectorized_binary( - ((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201, std::equal_to<>()); - const int grid_l = dpct::vectorized_binary( - grid1[0] ^ signs0, signs0, std::minus<>()); - const int grid_h = dpct::vectorized_binary( - grid2[0] ^ signs1, signs1, std::minus<>()); - sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); - sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); - q8 += 8; - } - const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * bq8_1[ib32].ds[0]; - return d * sumi; + const block_iq3_s* bq2 = (const block_iq3_s*)vbq; + + const int ib32 = iqs; + const uint8_t* qs = bq2->qs + 8 * ib32; + const int8_t* q8 = bq8_1[ib32].qs; + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t* grid1 = + iq3s_grid + (qs[2 * l + 0] | ((bq2->qh[ib32] << (8 - 2 * l)) & 256)); + const uint32_t* grid2 = + iq3s_grid + (qs[2 * l + 1] | ((bq2->qh[ib32] << (7 - 2 * l)) & 256)); + uint32_t signs0 = dpct::vectorized_binary( + ((bq2->signs[4 * ib32 + l] & 0xf) * 0x01010101) & 0x08040201, + 0x08040201, + std::equal_to<>()); + uint32_t signs1 = dpct::vectorized_binary( + ((bq2->signs[4 * ib32 + l] >> 4) * 0x01010101) & 0x08040201, + 0x08040201, + std::equal_to<>()); + const int grid_l = dpct::vectorized_binary( + grid1[0] ^ signs0, signs0, std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid2[0] ^ signs1, signs1, std::minus<>()); + sumi = dpct::dp4a(grid_l, *((int*)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((int*)q8 + 1), sumi); + q8 += 8; + } + const float d = (float)bq2->d * + (1 + 2 * ((bq2->scales[ib32 / 2] >> 4 * (ib32 % 2)) & 0xf)) * + bq8_1[ib32].ds[0]; + return d * sumi; #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif } -static __dpct_inline__ float -vec_dot_iq1_s_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs, - const uint32_t *iq1s_grid, const uint64_t *ksigns64) { +static __dpct_inline__ float vec_dot_iq1_s_q8_1( + const void* __restrict__ vbq, + const block_q8_1* __restrict__ bq8_1, + const int& iqs, + const uint32_t* iq1s_grid, + const uint64_t* ksigns64) { #if QK_K == 256 - const block_iq1_s * bq1 = (const block_iq1_s *) vbq; - - const int ib32 = iqs; - const uint8_t * qs = bq1->qs + 4*ib32; - const int8_t * q8 = bq8_1[ib32].qs; - int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8)); - const int grid_l = dpct::vectorized_binary( - grid[0] ^ signs[0], signs[0], std::minus<>()); - const int grid_h = dpct::vectorized_binary( - grid[1] ^ signs[1], signs[1], std::minus<>()); - sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); - sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); - q8 += 8; - } - const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f; - return d * sumi; + const block_iq1_s* bq1 = (const block_iq1_s*)vbq; + + const int ib32 = iqs; + const uint8_t* qs = bq1->qs + 4 * ib32; + const int8_t* q8 = bq8_1[ib32].qs; + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t* grid = (const uint32_t*)(iq1s_grid + qs[l]); + const uint32_t* signs = (const uint32_t*)(ksigns64 + (qs[l] >> 8)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi = dpct::dp4a(grid_l, *((int*)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((int*)q8 + 1), sumi); + q8 += 8; + } + const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f; + return d * sumi; #else - assert(false); - return 0.f; + assert(false); + return 0.f; #endif }