diff --git a/CMakeLists.txt b/CMakeLists.txt index 58a1805ba10fd..1998a7afdcbc8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -620,7 +620,8 @@ if (LLAMA_SYCL) endif() set(GGML_HEADERS_SYCL ggml-sycl.h) - set(GGML_SOURCES_SYCL ggml-sycl.cpp) + file(GLOB GGML_SOURCES_SYCL "ggml-sycl.cpp") + # list(APPEND GGML_SOURCES_SYCL "ggml-sycl/*.cpp") if (WIN32) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -fsycl sycl7 OpenCL mkl_sycl_blas_dll.lib mkl_intel_ilp64_dll.lib mkl_sequential_dll.lib mkl_core_dll.lib) @@ -1255,7 +1256,7 @@ install(FILES ${CMAKE_CURRENT_BINARY_DIR}/LlamaConfig.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/Llama) set(GGML_PUBLIC_HEADERS "ggml.h" "ggml-alloc.h" "ggml-backend.h" - "${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}" + "${GGML_HEADERS_CUDA}" "${GGML_HEADERS_OPENCL}" "${GGML_HEADERS_SYCL}" "${GGML_HEADERS_METAL}" "${GGML_HEADERS_MPI}" "${GGML_HEADERS_EXTRA}") set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a9b310243f04f..3e88c2e230515 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -36,2915 +36,7 @@ #include "ggml.h" #include "ggml-backend-impl.h" -/* -Following definition copied from DPCT head files, which are used by ggml-sycl.cpp -*/ -// COPY from DPCT head files -#include -#include -#include - -#if defined(__linux__) -#include -#elif defined(_WIN64) -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#else -#error "Only support Windows and Linux." -#endif - -#if defined(__linux__) -#include -#include -#endif -#if defined(_WIN64) -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#endif - -#define DPCT_COMPATIBILITY_TEMP (900) - -#if defined(_MSC_VER) -#define __dpct_align__(n) __declspec(align(n)) -#define __dpct_inline__ __forceinline -#else -#define __dpct_align__(n) __attribute__((aligned(n))) -#define __dpct_inline__ __inline__ __attribute__((always_inline)) -#endif - -#if defined(_MSC_VER) -#define __dpct_noinline__ __declspec(noinline) -#else -#define __dpct_noinline__ __attribute__((noinline)) -#endif - - -std::string get_device_type_name(const sycl::device &Device) { - auto DeviceType = Device.get_info(); - switch (DeviceType) { - case sycl::info::device_type::cpu: - return "cpu"; - case sycl::info::device_type::gpu: - return "gpu"; - case sycl::info::device_type::host: - return "host"; - case sycl::info::device_type::accelerator: - return "acc"; - default: - return "unknown"; - } -} - -std::string get_device_backend_and_type(const sycl::device &device) { - std::stringstream device_type; - sycl::backend backend = device.get_backend(); - device_type << backend << ":" << get_device_type_name(device); - return device_type.str(); -} - -namespace dpct -{ - typedef sycl::queue *queue_ptr; - typedef sycl::event *event_ptr; - typedef char *device_ptr; - typedef uint8_t byte_t; - typedef sycl::buffer buffer_t; - - /// SYCL default exception handler - inline auto exception_handler = [](sycl::exception_list exceptions) - { - for (std::exception_ptr const &e : exceptions) - { - try - { - std::rethrow_exception(e); - } - catch (sycl::exception const &e) - { - std::cerr << "Caught asynchronous SYCL exception:" << std::endl - << e.what() << std::endl - << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - } - } - }; - - enum error_code - { - success = 0, - default_error = 999 - }; - - enum memcpy_direction - { - host_to_host, - host_to_device, - device_to_host, - device_to_device, - automatic - }; - - enum memory_region - { - global = 0, // device global memory - constant, // device constant memory - local, // device local memory - shared, // memory which can be accessed by host and device - }; - - enum class library_data_t : unsigned char - { - real_float = 0, - complex_float, - real_double, - complex_double, - real_half, - complex_half, - real_bfloat16, - complex_bfloat16, - real_int4, - complex_int4, - real_uint4, - complex_uint4, - real_int8, - complex_int8, - real_uint8, - complex_uint8, - real_int16, - complex_int16, - real_uint16, - complex_uint16, - real_int32, - complex_int32, - real_uint32, - complex_uint32, - real_int64, - complex_int64, - real_uint64, - complex_uint64, - real_int8_4, - real_int8_32, - real_uint8_4, - library_data_t_size - }; - - template - struct DataType - { - using T2 = T; - }; - template - struct DataType> - { - using T2 = std::complex; - }; - - static void destroy_event(event_ptr event) - { - delete event; - } - - static inline unsigned int get_tid() - { -#if defined(__linux__) - return syscall(SYS_gettid); -#elif defined(_WIN64) - return GetCurrentThreadId(); -#else -#error "Only support Windows and Linux." -#endif - } - - namespace detail - { - static void get_version(const sycl::device &dev, int &major, int &minor) - { - // Version string has the following format: - // a. OpenCL - // b. - // c. e.g gfx1030 - std::string ver; - ver = dev.get_info(); - std::string::size_type i = 0; - while (i < ver.size()) { - if (isdigit(ver[i])) - break; - i++; - } - major = std::stoi(&(ver[i])); - while (i < ver.size()) { - if (ver[i] == '.') - break; - i++; - } - if (i < ver.size()) { - // a. and b. - i++; - minor = std::stoi(&(ver[i])); - } else { - // c. - minor = 0; - } - } - - template - class generic_error_type - { - public: - generic_error_type() = default; - generic_error_type(T value) : value{value} {} - operator T() const { return value; } - - private: - T value; - }; - - } // namespace detail - - /// Pitched 2D/3D memory data. - class pitched_data - { - public: - pitched_data() : pitched_data(nullptr, 0, 0, 0) {} - pitched_data(void *data, size_t pitch, size_t x, size_t y) - : _data(data), _pitch(pitch), _x(x), _y(y) {} - - void *get_data_ptr() { return _data; } - void set_data_ptr(void *data) { _data = data; } - - size_t get_pitch() { return _pitch; } - void set_pitch(size_t pitch) { _pitch = pitch; } - - size_t get_x() { return _x; } - void set_x(size_t x) { _x = x; }; - - size_t get_y() { return _y; } - void set_y(size_t y) { _y = y; } - - private: - void *_data; - size_t _pitch, _x, _y; - }; - - class device_info - { - public: - // get interface - const char *get_name() const { return _name; } - char *get_name() { return _name; } - template , - std::enable_if_t> || - std::is_same_v, - int> = 0> - auto get_max_work_item_sizes() const - { - if constexpr (std::is_same_v>) - return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); - else - { - return _max_work_item_sizes_i; - } - } - template , - std::enable_if_t> || - std::is_same_v, - int> = 0> - auto get_max_work_item_sizes() - { - if constexpr (std::is_same_v>) - return sycl::range<3>(_max_work_item_sizes_i[0], - _max_work_item_sizes_i[1], - _max_work_item_sizes_i[2]); - else - { - return _max_work_item_sizes_i; - } - } - bool get_host_unified_memory() const { return _host_unified_memory; } - int get_major_version() const { return _major; } - int get_minor_version() const { return _minor; } - int get_integrated() const { return _integrated; } - int get_max_clock_frequency() const { return _frequency; } - int get_max_compute_units() const { return _max_compute_units; } - int get_max_work_group_size() const { return _max_work_group_size; } - int get_max_sub_group_size() const { return _max_sub_group_size; } - int get_max_work_items_per_compute_unit() const - { - return _max_work_items_per_compute_unit; - } - int get_max_register_size_per_work_group() const - { - return _max_register_size_per_work_group; - } - template || - std::is_same_v, - int> = 0> - auto get_max_nd_range_size() const - { - if constexpr (std::is_same_v) - return _max_nd_range_size; - else - return _max_nd_range_size_i; - } - template || - std::is_same_v, - int> = 0> - auto get_max_nd_range_size() - { - if constexpr (std::is_same_v) - return _max_nd_range_size; - else - return _max_nd_range_size_i; - } - size_t get_global_mem_size() const { return _global_mem_size; } - size_t get_local_mem_size() const { return _local_mem_size; } - size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; } - /// Returns the maximum clock rate of device's global memory in kHz. If - /// compiler does not support this API then returns default value 3200000 kHz. - unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } - /// Returns the maximum bus width between device and memory in bits. If - /// compiler does not support this API then returns default value 64 bits. - unsigned int get_memory_bus_width() const { return _memory_bus_width; } - uint32_t get_device_id() const { return _device_id; } - std::array get_uuid() const { return _uuid; } - /// Returns global memory cache size in bytes. - unsigned int get_global_mem_cache_size() const - { - return _global_mem_cache_size; - } - - // set interface - void set_name(const char *name) - { - size_t length = strlen(name); - if (length < 256) - { - std::memcpy(_name, name, length + 1); - } - else - { - std::memcpy(_name, name, 255); - _name[255] = '\0'; - } - } - void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) - { - for (int i = 0; i < 3; ++i) - _max_work_item_sizes_i[i] = max_work_item_sizes[i]; - } - [[deprecated]] void - set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) - { - for (int i = 0; i < 3; ++i) - { - _max_work_item_sizes_i[i] = max_work_item_sizes[i]; - } - } - void set_host_unified_memory(bool host_unified_memory) - { - _host_unified_memory = host_unified_memory; - } - void set_major_version(int major) { _major = major; } - void set_minor_version(int minor) { _minor = minor; } - void set_integrated(int integrated) { _integrated = integrated; } - void set_max_clock_frequency(int frequency) { _frequency = frequency; } - void set_max_compute_units(int max_compute_units) - { - _max_compute_units = max_compute_units; - } - void set_global_mem_size(size_t global_mem_size) - { - _global_mem_size = global_mem_size; - } - void set_local_mem_size(size_t local_mem_size) - { - _local_mem_size = local_mem_size; - } - void set_max_mem_alloc_size(size_t max_mem_alloc_size) - { - _max_mem_alloc_size = max_mem_alloc_size; - } - void set_max_work_group_size(int max_work_group_size) - { - _max_work_group_size = max_work_group_size; - } - void set_max_sub_group_size(int max_sub_group_size) - { - _max_sub_group_size = max_sub_group_size; - } - void - set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) - { - _max_work_items_per_compute_unit = max_work_items_per_compute_unit; - } - void set_max_nd_range_size(int max_nd_range_size[]) - { - for (int i = 0; i < 3; i++) - { - _max_nd_range_size[i] = max_nd_range_size[i]; - _max_nd_range_size_i[i] = max_nd_range_size[i]; - } - } - void set_memory_clock_rate(unsigned int memory_clock_rate) - { - _memory_clock_rate = memory_clock_rate; - } - void set_memory_bus_width(unsigned int memory_bus_width) - { - _memory_bus_width = memory_bus_width; - } - void - set_max_register_size_per_work_group(int max_register_size_per_work_group) - { - _max_register_size_per_work_group = max_register_size_per_work_group; - } - void set_device_id(uint32_t device_id) - { - _device_id = device_id; - } - void set_uuid(std::array uuid) - { - _uuid = std::move(uuid); - } - void set_global_mem_cache_size(unsigned int global_mem_cache_size) - { - _global_mem_cache_size = global_mem_cache_size; - } - - private: - char _name[256]; - int _max_work_item_sizes_i[3]; - bool _host_unified_memory = false; - int _major; - int _minor; - int _integrated = 0; - int _frequency; - // Set estimated value 3200000 kHz as default value. - unsigned int _memory_clock_rate = 3200000; - // Set estimated value 64 bits as default value. - unsigned int _memory_bus_width = 64; - unsigned int _global_mem_cache_size; - int _max_compute_units; - int _max_work_group_size; - int _max_sub_group_size; - int _max_work_items_per_compute_unit; - int _max_register_size_per_work_group; - size_t _global_mem_size; - size_t _local_mem_size; - size_t _max_mem_alloc_size; - size_t _max_nd_range_size[3]; - int _max_nd_range_size_i[3]; - uint32_t _device_id; - std::array _uuid; - }; - - static int get_major_version(const sycl::device &dev) - { - int major, minor; - detail::get_version(dev, major, minor); - return major; - } - - static int get_minor_version(const sycl::device &dev) - { - int major, minor; - detail::get_version(dev, major, minor); - return minor; - } - - static void get_device_info(device_info &out, const sycl::device &dev) - { - device_info prop; - prop.set_name(dev.get_info().c_str()); - - int major, minor; - detail::get_version(dev, major, minor); - prop.set_major_version(major); - prop.set_minor_version(minor); - - prop.set_max_work_item_sizes( -#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) - // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes - // is an enum class element - dev.get_info()); -#else - // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by - // an int - dev.get_info>()); -#endif - prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); - - prop.set_max_clock_frequency( - dev.get_info() * 1000); - - prop.set_max_compute_units( - dev.get_info()); - prop.set_max_work_group_size( - dev.get_info()); - prop.set_global_mem_size(dev.get_info()); - prop.set_local_mem_size(dev.get_info()); - prop.set_max_mem_alloc_size(dev.get_info()); - -#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) - if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) - { - unsigned int tmp = - dev.get_info(); - if (tmp != 0) - prop.set_memory_clock_rate(1000 * tmp); - } - if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) - { - prop.set_memory_bus_width( - dev.get_info()); - } - if (dev.has(sycl::aspect::ext_intel_device_id)) - { - prop.set_device_id( - dev.get_info()); - } - if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) - { - prop.set_uuid(dev.get_info()); - } -#elif defined(_MSC_VER) && !defined(__clang__) -#pragma message("get_device_info: querying memory_clock_rate and \ - memory_bus_width are not supported by the compiler used. \ - Use 3200000 kHz as memory_clock_rate default value. \ - Use 64 bits as memory_bus_width default value.") -#else -#warning "get_device_info: querying memory_clock_rate and \ - memory_bus_width are not supported by the compiler used. \ - Use 3200000 kHz as memory_clock_rate default value. \ - Use 64 bits as memory_bus_width default value." -#endif - - size_t max_sub_group_size = 1; - std::vector sub_group_sizes = - dev.get_info(); - - for (const auto &sub_group_size : sub_group_sizes) - { - if (max_sub_group_size < sub_group_size) - max_sub_group_size = sub_group_size; - } - - prop.set_max_sub_group_size(max_sub_group_size); - - prop.set_max_work_items_per_compute_unit( - dev.get_info()); - int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; - prop.set_max_nd_range_size(max_nd_range_size); - - // Estimates max register size per work group, feel free to update the value - // according to device properties. - prop.set_max_register_size_per_work_group(65536); - - prop.set_global_mem_cache_size( - dev.get_info()); - out = prop; - } - - /// dpct device extension - class device_ext : public sycl::device - { - typedef std::mutex mutex_type; - - public: - device_ext() : sycl::device(), _ctx(*this) {} - ~device_ext() - { - std::lock_guard lock(m_mutex); - clear_queues(); - } - device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) - { - std::lock_guard lock(m_mutex); - init_queues(); - } - - int is_native_atomic_supported() { return 0; } - int get_major_version() const - { - return dpct::get_major_version(*this); - } - - int get_minor_version() const - { - return dpct::get_minor_version(*this); - } - - int get_max_compute_units() const - { - return get_device_info().get_max_compute_units(); - } - - /// Return the maximum clock frequency of this device in KHz. - int get_max_clock_frequency() const - { - return get_device_info().get_max_clock_frequency(); - } - - int get_integrated() const { return get_device_info().get_integrated(); } - - int get_max_sub_group_size() const - { - return get_device_info().get_max_sub_group_size(); - } - - int get_max_register_size_per_work_group() const - { - return get_device_info().get_max_register_size_per_work_group(); - } - - int get_max_work_group_size() const - { - return get_device_info().get_max_work_group_size(); - } - - int get_mem_base_addr_align() const - { - return get_info(); - } - - size_t get_global_mem_size() const - { - return get_device_info().get_global_mem_size(); - } - - size_t get_max_mem_alloc_size() const - { - return get_device_info().get_max_mem_alloc_size(); - } - - /// Get the number of bytes of free and total memory on the SYCL device. - /// \param [out] free_memory The number of bytes of free memory on the SYCL device. - /// \param [out] total_memory The number of bytes of total memory on the SYCL device. - void get_memory_info(size_t &free_memory, size_t &total_memory) - { - total_memory = get_device_info().get_global_mem_size(); - const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " - "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " - "use total memory as free memory"; -#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) - if (!has(sycl::aspect::ext_intel_free_memory)) - { - std::cerr << warning_info << std::endl; - free_memory = total_memory; - } - else - { - free_memory = get_info(); - } -#else - std::cerr << warning_info << std::endl; - free_memory = total_memory; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma message("Querying the number of bytes of free memory is not supported") -#else -#warning "Querying the number of bytes of free memory is not supported" -#endif -#endif - } - - void get_device_info(device_info &out) const - { - dpct::get_device_info(out, *this); - } - - device_info get_device_info() const - { - device_info prop; - dpct::get_device_info(prop, *this); - return prop; - } - - void reset() - { - std::lock_guard lock(m_mutex); - clear_queues(); - init_queues(); - } - - sycl::queue &in_order_queue() { return *_q_in_order; } - - sycl::queue &out_of_order_queue() { return *_q_out_of_order; } - - sycl::queue &default_queue() - { - return in_order_queue(); - } - - void queues_wait_and_throw() - { - std::unique_lock lock(m_mutex); - std::vector> current_queues( - _queues); - lock.unlock(); - for (const auto &q : current_queues) - { - q->wait_and_throw(); - } - // Guard the destruct of current_queues to make sure the ref count is safe. - lock.lock(); - } - - sycl::queue *create_queue(bool enable_exception_handler = false) - { - return create_in_order_queue(enable_exception_handler); - } - - sycl::queue *create_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { - return create_in_order_queue(context, device, enable_exception_handler); - } - - sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(enable_exception_handler, - sycl::property::queue::in_order()); - } - - sycl::queue *create_in_order_queue(sycl::context context, sycl::device device, - bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(context, device, enable_exception_handler, - sycl::property::queue::in_order()); - } - - sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { - std::lock_guard lock(m_mutex); - return create_queue_impl(enable_exception_handler); - } - - void destroy_queue(sycl::queue *&queue) - { - std::lock_guard lock(m_mutex); - _queues.erase(std::remove_if(_queues.begin(), _queues.end(), - [=](const std::shared_ptr &q) -> bool - { - return q.get() == queue; - }), - _queues.end()); - queue = nullptr; - } - void set_saved_queue(sycl::queue *q) - { - std::lock_guard lock(m_mutex); - _saved_queue = q; - } - sycl::queue *get_saved_queue() const - { - std::lock_guard lock(m_mutex); - return _saved_queue; - } - sycl::context get_context() const { return _ctx; } - - private: - void clear_queues() - { - _queues.clear(); - _q_in_order = _q_out_of_order = _saved_queue = nullptr; - } - - void init_queues() - { - _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); - _q_out_of_order = create_queue_impl(true); - _saved_queue = &default_queue(); - } - - /// Caller should acquire resource \p m_mutex before calling this function. - template - sycl::queue *create_queue_impl(bool enable_exception_handler, - Properties... properties) - { - sycl::async_handler eh = {}; - if (enable_exception_handler) - { - eh = exception_handler; - } - _queues.push_back(std::make_shared( - _ctx, *this, eh, - sycl::property_list( -#ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), -#endif - properties...))); - - return _queues.back().get(); - } - - template - sycl::queue *create_queue_impl(sycl::context context, sycl::device device, - bool enable_exception_handler, - Properties... properties) { - sycl::async_handler eh = {}; - if (enable_exception_handler) { - eh = exception_handler; - } - _queues.push_back(std::make_shared( - context, device, eh, - sycl::property_list( - #ifdef DPCT_PROFILING_ENABLED - sycl::property::queue::enable_profiling(), - #endif - properties...))); - - return _queues.back().get(); - } - - void get_version(int &major, int &minor) const - { - detail::get_version(*this, major, minor); - } - sycl::queue *_q_in_order, *_q_out_of_order; - sycl::queue *_saved_queue; - sycl::context _ctx; - std::vector> _queues; - mutable mutex_type m_mutex; - }; - - /// device manager - class dev_mgr - { - public: - device_ext ¤t_device() - { - unsigned int dev_id = current_device_id(); - check_id(dev_id); - return *_devs[dev_id]; - } - device_ext &cpu_device() const - { - std::lock_guard lock(m_mutex); - if (_cpu_device == -1) - { - throw std::runtime_error("no valid cpu device"); - } - else - { - return *_devs[_cpu_device]; - } - } - device_ext &get_device(unsigned int id) const - { - std::lock_guard lock(m_mutex); - check_id(id); - return *_devs[id]; - } - unsigned int current_device_id() const - { - std::lock_guard lock(m_mutex); - auto it = _thread2dev_map.find(get_tid()); - if (it != _thread2dev_map.end()) - return it->second; - return DEFAULT_DEVICE_ID; - } - - /// Select device with a device ID. - /// \param [in] id The id of the device which can - /// be obtained through get_device_id(const sycl::device). - void select_device(unsigned int id) - { - std::lock_guard lock(m_mutex); - check_id(id); - _thread2dev_map[get_tid()] = id; - } - unsigned int device_count() { return _devs.size(); } - - unsigned int get_device_id(const sycl::device &dev) - { - unsigned int id = 0; - for (auto dev_item : _devs) - { - if (*dev_item == dev) - { - break; - } - id++; - } - return id; - } - - template - std::enable_if_t< - std::is_invocable_r_v> - select_device(const DeviceSelector &selector = sycl::gpu_selector_v) - { - sycl::device selected_device = sycl::device(selector); - unsigned int selected_device_id = get_device_id(selected_device); - select_device(selected_device_id); - } - - /// Returns the instance of device manager singleton. - static dev_mgr &instance() - { - static dev_mgr d_m; - return d_m; - } - dev_mgr(const dev_mgr &) = delete; - dev_mgr &operator=(const dev_mgr &) = delete; - dev_mgr(dev_mgr &&) = delete; - dev_mgr &operator=(dev_mgr &&) = delete; - - private: - mutable std::recursive_mutex m_mutex; - static bool compare_dev(sycl::device &device1, sycl::device &device2) - { - dpct::device_info prop1; - dpct::get_device_info(prop1, device1); - dpct::device_info prop2; - dpct::get_device_info(prop2, device2); - return prop1.get_max_compute_units() > prop2.get_max_compute_units(); - } - static int convert_backend_index(std::string & backend) { - if (backend == "ext_oneapi_level_zero:gpu") return 0; - if (backend == "opencl:gpu") return 1; - if (backend == "ext_oneapi_cuda:gpu") return 2; - if (backend == "ext_oneapi_hip:gpu") return 3; - if (backend == "opencl:cpu") return 4; - if (backend == "opencl:acc") return 5; - printf("convert_backend_index: can't handle backend=%s\n", backend.c_str()); - GGML_ASSERT(false); - } - static bool compare_backend(std::string &backend1, std::string &backend2) { - return convert_backend_index(backend1) < convert_backend_index(backend2); - } - dev_mgr() - { - sycl::device default_device = - sycl::device(sycl::default_selector_v); - _devs.push_back(std::make_shared(default_device)); - - std::vector sycl_all_devs; - // Collect other devices except for the default device. - if (default_device.is_cpu()) - _cpu_device = 0; - - auto Platforms = sycl::platform::get_platforms(); - // Keep track of the number of devices per backend - std::map DeviceNums; - std::map> backend_devices; - - while (!Platforms.empty()) { - auto Platform = Platforms.back(); - Platforms.pop_back(); - auto devices = Platform.get_devices(); - std::string backend_type = get_device_backend_and_type(devices[0]); - for (const auto &device : devices) { - backend_devices[backend_type].push_back(device); - } - } - - std::vector keys; - for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { - keys.push_back(it->first); - } - std::sort(keys.begin(), keys.end(), compare_backend); - - for (auto &key : keys) { - std::vector devs = backend_devices[key]; - std::sort(devs.begin(), devs.end(), compare_dev); - for (const auto &dev : devs) { - sycl_all_devs.push_back(dev); - } - } - - for (auto &dev : sycl_all_devs) - { - if (dev == default_device) - { - continue; - } - _devs.push_back(std::make_shared(dev)); - if (_cpu_device == -1 && dev.is_cpu()) - { - _cpu_device = _devs.size() - 1; - } - } - } - void check_id(unsigned int id) const - { - if (id >= _devs.size()) - { - throw std::runtime_error("invalid device id"); - } - } - std::vector> _devs; - /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current - /// thread id in _thread2dev_map, which means default device should be used - /// for the current thread. - const unsigned int DEFAULT_DEVICE_ID = 0; - /// thread-id to device-id map. - std::map _thread2dev_map; - int _cpu_device = -1; - }; - - static inline sycl::queue &get_default_queue() - { - return dev_mgr::instance().current_device().default_queue(); - } - - namespace detail - { - enum class pointer_access_attribute - { - host_only = 0, - device_only, - host_device, - end - }; - - static pointer_access_attribute get_pointer_attribute(sycl::queue &q, - const void *ptr) - { - switch (sycl::get_pointer_type(ptr, q.get_context())) - { - case sycl::usm::alloc::unknown: - return pointer_access_attribute::host_only; - case sycl::usm::alloc::device: - return pointer_access_attribute::device_only; - case sycl::usm::alloc::shared: - case sycl::usm::alloc::host: - return pointer_access_attribute::host_device; - } - } - - template - inline constexpr std::uint64_t get_type_combination_id(ArgT Val) - { - static_assert((unsigned char)library_data_t::library_data_t_size <= - std::numeric_limits::max() && - "library_data_t size exceeds limit."); - static_assert(std::is_same_v, "Unsupported ArgT"); - return (std::uint64_t)Val; - } - - template - inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, - RestT... RestVal) - { - static_assert((std::uint8_t)library_data_t::library_data_t_size <= - std::numeric_limits::max() && - "library_data_t size exceeds limit."); - static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); - static_assert(std::is_same_v, "Unsupported FirstT"); - return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); - } - - class mem_mgr - { - mem_mgr() - { - // Reserved address space, no real memory allocation happens here. -#if defined(__linux__) - mapped_address_space = - (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, - MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); -#elif defined(_WIN64) - mapped_address_space = (byte_t *)VirtualAlloc( - NULL, // NULL specified as the base address parameter - mapped_region_size, // Size of allocation - MEM_RESERVE, // Allocate reserved pages - PAGE_NOACCESS); // Protection = no access -#else -#error "Only support Windows and Linux." -#endif - next_free = mapped_address_space; - }; - - public: - using buffer_id_t = int; - - struct allocation - { - buffer_t buffer; - byte_t *alloc_ptr; - size_t size; - }; - - ~mem_mgr() - { -#if defined(__linux__) - munmap(mapped_address_space, mapped_region_size); -#elif defined(_WIN64) - VirtualFree(mapped_address_space, 0, MEM_RELEASE); -#else -#error "Only support Windows and Linux." -#endif - }; - - mem_mgr(const mem_mgr &) = delete; - mem_mgr &operator=(const mem_mgr &) = delete; - mem_mgr(mem_mgr &&) = delete; - mem_mgr &operator=(mem_mgr &&) = delete; - - /// Allocate - void *mem_alloc(size_t size) - { - if (!size) - return nullptr; - std::lock_guard lock(m_mutex); - if (next_free + size > mapped_address_space + mapped_region_size) - { - throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool"); - } - // Allocation - sycl::range<1> r(size); - buffer_t buf(r); - allocation A{buf, next_free, size}; - // Map allocation to device pointer - void *result = next_free; - m_map.emplace(next_free + size, A); - // Update pointer to the next free space. - next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); - - return result; - } - - /// Deallocate - void mem_free(const void *ptr) - { - if (!ptr) - return; - std::lock_guard lock(m_mutex); - auto it = get_map_iterator(ptr); - m_map.erase(it); - } - - /// map: device pointer -> allocation(buffer, alloc_ptr, size) - allocation translate_ptr(const void *ptr) - { - std::lock_guard lock(m_mutex); - auto it = get_map_iterator(ptr); - return it->second; - } - - /// Check if the pointer represents device pointer or not. - bool is_device_ptr(const void *ptr) const - { - std::lock_guard lock(m_mutex); - return (mapped_address_space <= ptr) && - (ptr < mapped_address_space + mapped_region_size); - } - - /// Returns the instance of memory manager singleton. - static mem_mgr &instance() - { - static mem_mgr m; - return m; - } - - private: - std::map m_map; - mutable std::mutex m_mutex; - byte_t *mapped_address_space; - byte_t *next_free; - const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; - const size_t alignment = 256; - /// This padding may be defined to some positive value to debug - /// out of bound accesses. - const size_t extra_padding = 0; - - std::map::iterator get_map_iterator(const void *ptr) - { - auto it = m_map.upper_bound((byte_t *)ptr); - if (it == m_map.end()) - { - // Not a virtual pointer. - throw std::runtime_error("can not get buffer from non-virtual pointer"); - } - const allocation &alloc = it->second; - if (ptr < alloc.alloc_ptr) - { - // Out of bound. - // This may happen if there's a gap between allocations due to alignment - // or extra padding and pointer points to this gap. - throw std::runtime_error("invalid virtual pointer"); - } - return it; - } - }; - - template - class accessor; - template - class memory_traits - { - public: - static constexpr sycl::access::target target = - sycl::access::target::device; - static constexpr sycl::access_mode mode = - (Memory == constant) ? sycl::access_mode::read - : sycl::access_mode::read_write; - static constexpr size_t type_size = sizeof(T); - using element_t = - typename std::conditional::type; - using value_t = typename std::remove_cv::type; - template - using accessor_t = typename std::conditional< - Memory == local, sycl::local_accessor, - sycl::accessor>::type; - using pointer_t = T *; - }; - - static inline void *dpct_malloc(size_t size, sycl::queue &q) - { - return sycl::malloc_device(size, q.get_device(), q.get_context()); - } - -#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) - static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, - sycl::queue &q) - { - pitch = PITCH_DEFAULT_ALIGN(x); - return dpct_malloc(pitch * y * z, q); - } - - /** - * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] dev_ptr Pointer to the virtual device memory address. - * @param [in] value The value to be set. - * @param [in] size Number of elements to be set to the value. - * @return An event representing the memset operation. - */ - template - static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, - valueT value, size_t size) - { - return q.fill(dev_ptr, value, size); - } - - /** - * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] data Pointer to the pitched device memory region. - * @param [in] value The value to be set. - * @param [in] size 3D memory region by number of elements. - * @return An event list representing the memset operations. - */ - template - static inline std::vector - dpct_memset(sycl::queue &q, pitched_data data, valueT value, - sycl::range<3> size) - { - std::vector event_list; - size_t slice = data.get_pitch() * data.get_y(); - unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); - for (size_t z = 0; z < size.get(2); ++z) - { - unsigned char *data_ptr = data_surface; - for (size_t y = 0; y < size.get(1); ++y) - { - event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); - data_ptr += data.get_pitch(); - } - data_surface += slice; - } - return event_list; - } - - /** - * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. - * @tparam valueT The type of the element to be set. - * @param [in] q The queue in which the operation is done. - * @param [in] ptr Pointer to the virtual device memory. - * @param [in] pitch The pitch size by number of elements, including padding. - * @param [in] val The value to be set. - * @param [in] x The width of memory region by number of elements. - * @param [in] y The height of memory region by number of elements. - * @return An event list representing the memset operations. - */ - template - static inline std::vector - dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, - size_t y) - { - return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, - sycl::range<3>(x, y, 1)); - } - - static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, - const void *from_ptr, - memcpy_direction dir) - { - switch (dir) - { - case memcpy_direction::host_to_host: - case memcpy_direction::host_to_device: - case memcpy_direction::device_to_host: - case memcpy_direction::device_to_device: - return dir; - case memcpy_direction::automatic: - { - // table[to_attribute][from_attribute] - static const memcpy_direction - direction_table[static_cast(pointer_access_attribute::end)] - [static_cast(pointer_access_attribute::end)] = - {{memcpy_direction::host_to_host, - memcpy_direction::device_to_host, - memcpy_direction::host_to_host}, - {memcpy_direction::host_to_device, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}, - {memcpy_direction::host_to_host, - memcpy_direction::device_to_device, - memcpy_direction::device_to_device}}; - return direction_table[static_cast(get_pointer_attribute( - q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; - } - default: - throw std::runtime_error("dpct_memcpy: invalid direction value"); - } - } - - static sycl::event - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction, - const std::vector &dep_events = {}) - { - if (!size) - return sycl::event{}; - return q.memcpy(to_ptr, from_ptr, size, dep_events); - GGML_UNUSED(direction); - } - - // Get actual copy range and make sure it will not exceed range. - static inline size_t get_copy_range(sycl::range<3> size, size_t slice, - size_t pitch) - { - return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); - } - - static inline size_t get_offset(sycl::id<3> id, size_t slice, - size_t pitch) - { - return slice * id.get(2) + pitch * id.get(1) + id.get(0); - } - - /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr - /// and \p from_range to another specified by \p to_ptr and \p to_range. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - sycl::range<3> to_range, sycl::range<3> from_range, - sycl::id<3> to_id, sycl::id<3> from_id, - sycl::range<3> size, memcpy_direction direction, - const std::vector &dep_events = {}) - { - // RAII for host pointer - class host_buffer - { - void *_buf; - size_t _size; - sycl::queue &_q; - const std::vector &_deps; // free operation depends - - public: - host_buffer(size_t size, sycl::queue &q, - const std::vector &deps) - : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} - void *get_ptr() const { return _buf; } - size_t get_size() const { return _size; } - ~host_buffer() - { - if (_buf) - { - _q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(_deps); - cgh.host_task([buf = _buf] { std::free(buf); }); }); - } - } - }; - std::vector event_list; - - size_t to_slice = to_range.get(1) * to_range.get(0), - from_slice = from_range.get(1) * from_range.get(0); - unsigned char *to_surface = - (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); - const unsigned char *from_surface = - (const unsigned char *)from_ptr + - get_offset(from_id, from_slice, from_range.get(0)); - - if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) - { - return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events)}; - } - direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); - size_t size_slice = size.get(1) * size.get(0); - switch (direction) - { - case host_to_host: - for (size_t z = 0; z < size.get(2); ++z) - { - unsigned char *to_ptr = to_surface; - const unsigned char *from_ptr = from_surface; - if (to_range.get(0) == from_range.get(0) && - to_range.get(0) == size.get(0)) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, - direction, dep_events)); - } - else - { - for (size_t y = 0; y < size.get(1); ++y) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), - direction, dep_events)); - to_ptr += to_range.get(0); - from_ptr += from_range.get(0); - } - } - to_surface += to_slice; - from_surface += from_slice; - } - break; - case host_to_device: - { - host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, - event_list); - std::vector host_events; - if (to_slice == size_slice) - { - // Copy host data to a temp host buffer with the shape of target. - host_events = - dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, - host_to_host, dep_events); - } - else - { - // Copy host data to a temp host buffer with the shape of target. - host_events = dpct_memcpy( - q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, - // If has padding data, not sure whether it is useless. So fill temp - // buffer with it. - std::vector{ - dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), - device_to_host, dep_events)}); - } - // Copy from temp host buffer to device with only one submit. - event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), - buf.get_size(), host_to_device, - host_events)); - break; - } - case device_to_host: - { - host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, - event_list); - // Copy from host temp buffer to host target with reshaping. - event_list = dpct_memcpy( - q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), - sycl::id<3>(0, 0, 0), size, host_to_host, - // Copy from device to temp host buffer with only one submit. - std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, - buf.get_size(), - device_to_host, dep_events)}); - break; - } - case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh){ - cgh.depends_on(dep_events); - cgh.parallel_for( - size, - [=](sycl::id<3> id) { - to_surface[get_offset(id, to_slice, to_range.get(0))] = - from_surface[get_offset(id, from_slice, from_range.get(0))]; - }); })); - break; - default: - throw std::runtime_error("dpct_memcpy: invalid direction value"); - } - return event_list; - } - - /// memcpy 2D/3D matrix specified by pitched_data. - static inline std::vector - dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, - pitched_data from, sycl::id<3> from_id, sycl::range<3> size, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), - sycl::range<3>(to.get_pitch(), to.get_y(), 1), - sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, - size, direction); - } - - /// memcpy 2D matrix with pitch. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - size_t to_pitch, size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), - sycl::range<3>(from_pitch, y, 1), - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), - sycl::range<3>(x, y, 1), direction); - } - - namespace deprecated - { - - template - class usm_allocator - { - private: - using Alloc = sycl::usm_allocator; - Alloc _impl; - - public: - using value_type = typename std::allocator_traits::value_type; - using pointer = typename std::allocator_traits::pointer; - using const_pointer = typename std::allocator_traits::const_pointer; - using void_pointer = typename std::allocator_traits::void_pointer; - using const_void_pointer = - typename std::allocator_traits::const_void_pointer; - using reference = typename std::allocator_traits::value_type &; - using const_reference = - const typename std::allocator_traits::value_type &; - using difference_type = - typename std::allocator_traits::difference_type; - using size_type = typename std::allocator_traits::size_type; - using propagate_on_container_copy_assignment = typename std::allocator_traits< - Alloc>::propagate_on_container_copy_assignment; - using propagate_on_container_move_assignment = typename std::allocator_traits< - Alloc>::propagate_on_container_move_assignment; - using propagate_on_container_swap = - typename std::allocator_traits::propagate_on_container_swap; - using is_always_equal = - typename std::allocator_traits::is_always_equal; - - template - struct rebind - { - typedef usm_allocator other; - }; - - usm_allocator() : _impl(dpct::get_default_queue()) {} - ~usm_allocator() {} - usm_allocator(const usm_allocator &other) : _impl(other._impl) {} - usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} - pointer address(reference r) { return &r; } - const_pointer address(const_reference r) { return &r; } - pointer allocate(size_type cnt, const_void_pointer hint = nullptr) - { - return std::allocator_traits::allocate(_impl, cnt, hint); - } - void deallocate(pointer p, size_type cnt) - { - std::allocator_traits::deallocate(_impl, p, cnt); - } - size_type max_size() const - { - return std::allocator_traits::max_size(_impl); - } - bool operator==(const usm_allocator &other) const { return _impl == other._impl; } - bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } - }; - - } // namespace deprecated - - inline void dpct_free(void *ptr, - const sycl::queue &q) - { - if (ptr) - { - sycl::free(ptr, q.get_context()); - } - } - - template - inline auto get_memory(const void *x) - { - T *new_x = reinterpret_cast(const_cast(x)); - return new_x; - } - - template - inline typename DataType::T2 get_value(const T *s, sycl::queue &q) - { - using Ty = typename DataType::T2; - Ty s_h; - if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) - detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host) - .wait(); - else - s_h = *reinterpret_cast(s); - return s_h; - } - - } // namespace detail - - template - inline auto get_value(const T *s, sycl::queue &q) - { - return detail::get_value(s, q); - } - - namespace detail - { - template - inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, int lda, const void *b, - int ldb, const void *beta, void *c, int ldc) - { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - data_b, ldb, beta_value, data_c, ldc); - } - - template - class vectorized_binary - { - public: - inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) - { - VecT v4; - for (size_t i = 0; i < v4.size(); ++i) - { - v4[i] = binary_op(a[i], b[i]); - } - return v4; - } - }; - - template - class vectorized_binary< - VecT, BinaryOperation, - std::void_t>> - { - public: - inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) - { - return binary_op(a, b).template as(); - } - }; - - template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) - { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); - matrix_info->transpose_info[0] = a_trans; - matrix_info->transpose_info[1] = b_trans; - matrix_info->value_info[0] = alpha_value; - matrix_info->value_info[1] = beta_value; - matrix_info->size_info[0] = m; - matrix_info->size_info[1] = n; - matrix_info->size_info[2] = k; - matrix_info->ld_info[0] = lda; - matrix_info->ld_info[1] = ldb; - matrix_info->ld_info[2] = ldc; - matrix_info->groupsize_info = batch_size; - - sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); - - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); - } - - template - inline void - gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, - int k, const void *alpha, const void *a, int lda, - long long int stride_a, const void *b, int ldb, - long long int stride_b, const void *beta, void *c, - int ldc, long long int stride_c, int batch_size) - { - Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); - Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - auto data_a = get_memory(a); - auto data_b = get_memory(b); - auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, - data_c, ldc, stride_c, batch_size); - } - - } // namespace detail - - template - inline unsigned vectorized_binary(unsigned a, unsigned b, - const BinaryOperation binary_op) - { - sycl::vec v0{a}, v1{b}; - auto v2 = v0.as(); - auto v3 = v1.as(); - auto v4 = - detail::vectorized_binary()(v2, v3, binary_op); - v0 = v4.template as>(); - return v0; - } - - static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction = automatic, - sycl::queue &q = dpct::get_default_queue()) - { - detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); - } - - static inline unsigned int select_device(unsigned int id) - { - dev_mgr::instance().select_device(id); - return id; - } - - template - T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, - unsigned int logical_sub_group_size = 32) - { - unsigned int id = g.get_local_linear_id(); - unsigned int start_index = - id / logical_sub_group_size * logical_sub_group_size; - unsigned int target_offset = (id % logical_sub_group_size) ^ mask; - return sycl::select_from_group(g, x, - target_offset < logical_sub_group_size - ? start_index + target_offset - : id); - } - - template - sycl::vec extract_and_sign_or_zero_extend4(T val) - { - return sycl::vec(val) - .template as, int8_t, uint8_t>, 4>>() - .template convert(); - } - - template - using dot_product_acc_t = - std::conditional_t && std::is_unsigned_v, - uint32_t, int32_t>; - - template - inline auto dp4a(T1 a, T2 b, T3 c) - { - dot_product_acc_t res = c; - auto va = extract_and_sign_or_zero_extend4(a); - auto vb = extract_and_sign_or_zero_extend4(b); - res += va[0] * vb[0]; - res += va[1] * vb[1]; - res += va[2] * vb[2]; - res += va[3] * vb[3]; - return res; - } - - struct sub_sat - { - template - auto operator()(const T x, const T y) const - { - return sycl::sub_sat(x, y); - } - }; - - template - inline T vectorized_min(T a, T b) - { - sycl::vec v0{a}, v1{b}; - auto v2 = v0.template as(); - auto v3 = v1.template as(); - auto v4 = sycl::min(v2, v3); - v0 = v4.template as>(); - return v0; - } - - inline float pow(const float a, const int b) { return sycl::pown(a, b); } - inline double pow(const double a, const int b) { return sycl::pown(a, b); } - inline float pow(const float a, const float b) { return sycl::pow(a, b); } - inline double pow(const double a, const double b) { return sycl::pow(a, b); } - template - inline typename std::enable_if_t, T> - pow(const T a, const U b) - { - return sycl::pow(a, static_cast(b)); - } - template - inline typename std::enable_if_t, double> - pow(const T a, const U b) - { - return sycl::pow(static_cast(a), static_cast(b)); - } - - inline double min(const double a, const float b) - { - return sycl::fmin(a, static_cast(b)); - } - inline double min(const float a, const double b) - { - return sycl::fmin(static_cast(a), b); - } - inline float min(const float a, const float b) { return sycl::fmin(a, b); } - inline double min(const double a, const double b) { return sycl::fmin(a, b); } - inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) - { - return sycl::min(static_cast(a), b); - } - inline std::int32_t min(const std::int32_t a, const std::int32_t b) - { - return sycl::min(a, b); - } - inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) - { - return sycl::min(a, b); - } - inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) - { - return sycl::min(static_cast(a), b); - } - inline std::int64_t min(const std::int64_t a, const std::int64_t b) - { - return sycl::min(a, b); - } - inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) - { - return sycl::min(a, b); - } - inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) - { - return sycl::min(static_cast(a), b); - } - inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) - { - return sycl::min(a, static_cast(b)); - } - inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) - { - return sycl::min(static_cast(a), b); - } - // max function overloads. - // For floating-point types, `float` or `double` arguments are acceptable. - // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or - // `std::int64_t` type arguments are acceptable. - inline double max(const double a, const float b) - { - return sycl::fmax(a, static_cast(b)); - } - inline double max(const float a, const double b) - { - return sycl::fmax(static_cast(a), b); - } - inline float max(const float a, const float b) { return sycl::fmax(a, b); } - inline double max(const double a, const double b) { return sycl::fmax(a, b); } - inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) - { - return sycl::max(static_cast(a), b); - } - inline std::int32_t max(const std::int32_t a, const std::int32_t b) - { - return sycl::max(a, b); - } - inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) - { - return sycl::max(a, b); - } - inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) - { - return sycl::max(static_cast(a), b); - } - inline std::int64_t max(const std::int64_t a, const std::int64_t b) - { - return sycl::max(a, b); - } - inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) - { - return sycl::max(a, b); - } - inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) - { - return sycl::max(static_cast(a), b); - } - inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) - { - return sycl::max(a, static_cast(b)); - } - inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) - { - return sycl::max(static_cast(a), b); - } - - inline void - has_capability_or_fail(const sycl::device &dev, - const std::initializer_list &props) - { - for (const auto &it : props) - { - if (dev.has(it)) - continue; - switch (it) - { - case sycl::aspect::fp64: - throw std::runtime_error("'double' is not supported in '" + - dev.get_info() + - "' device"); - break; - case sycl::aspect::fp16: - throw std::runtime_error("'half' is not supported in '" + - dev.get_info() + - "' device"); - break; - default: -#define __SYCL_ASPECT(ASPECT, ID) \ - case sycl::aspect::ASPECT: \ - return #ASPECT; -#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) -#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) - auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string - { - switch (AspectNum) - { -#include -#include - default: - return "unknown aspect"; - } - }; -#undef __SYCL_ASPECT_DEPRECATED_ALIAS -#undef __SYCL_ASPECT_DEPRECATED -#undef __SYCL_ASPECT - throw std::runtime_error( - "'" + getAspectNameStr(it) + "' is not supported in '" + - dev.get_info() + "' device"); - } - break; - } - } - - static inline unsigned int get_current_device_id() - { - return dev_mgr::instance().current_device_id(); - } - - static inline device_ext &get_current_device() - { - return dev_mgr::instance().current_device(); - } - - static inline sycl::queue &get_in_order_queue() - { - return dev_mgr::instance().current_device().in_order_queue(); - } - - static sycl::event - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, - memcpy_direction direction, - const std::vector &dep_events = {}) - { - if (!size) - return sycl::event{}; - return q.memcpy(to_ptr, from_ptr, size, dep_events); - GGML_UNUSED(direction); - } - - // Get actual copy range and make sure it will not exceed range. - static inline size_t get_copy_range(sycl::range<3> size, size_t slice, - size_t pitch) - { - return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); - } - - static inline size_t get_offset(sycl::id<3> id, size_t slice, - size_t pitch) - { - return slice * id.get(2) + pitch * id.get(1) + id.get(0); - } - - /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr - /// and \p from_range to another specified by \p to_ptr and \p to_range. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - sycl::range<3> to_range, sycl::range<3> from_range, - sycl::id<3> to_id, sycl::id<3> from_id, - sycl::range<3> size, memcpy_direction direction, - const std::vector &dep_events = {}) - { - // RAII for host pointer - class host_buffer - { - void *_buf; - size_t _size; - sycl::queue &_q; - const std::vector &_deps; // free operation depends - - public: - host_buffer(size_t size, sycl::queue &q, - const std::vector &deps) - : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} - void *get_ptr() const { return _buf; } - size_t get_size() const { return _size; } - ~host_buffer() - { - if (_buf) - { - _q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(_deps); - cgh.host_task([buf = _buf] { std::free(buf); }); }); - } - } - }; - std::vector event_list; - - size_t to_slice = to_range.get(1) * to_range.get(0), - from_slice = from_range.get(1) * from_range.get(0); - unsigned char *to_surface = - (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); - const unsigned char *from_surface = - (const unsigned char *)from_ptr + - get_offset(from_id, from_slice, from_range.get(0)); - - if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) - { - return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), - direction, dep_events)}; - } - direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); - size_t size_slice = size.get(1) * size.get(0); - switch (direction) - { - case host_to_host: - for (size_t z = 0; z < size.get(2); ++z) - { - unsigned char *to_ptr = to_surface; - const unsigned char *from_ptr = from_surface; - if (to_range.get(0) == from_range.get(0) && - to_range.get(0) == size.get(0)) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, - direction, dep_events)); - } - else - { - for (size_t y = 0; y < size.get(1); ++y) - { - event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), - direction, dep_events)); - to_ptr += to_range.get(0); - from_ptr += from_range.get(0); - } - } - to_surface += to_slice; - from_surface += from_slice; - } - break; - case host_to_device: - { - host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, - event_list); - std::vector host_events; - if (to_slice == size_slice) - { - // Copy host data to a temp host buffer with the shape of target. - host_events = - dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, - host_to_host, dep_events); - } - else - { - // Copy host data to a temp host buffer with the shape of target. - host_events = dpct_memcpy( - q, buf.get_ptr(), from_surface, to_range, from_range, - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, - // If has padding data, not sure whether it is useless. So fill temp - // buffer with it. - std::vector{ - dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), - device_to_host, dep_events)}); - } - // Copy from temp host buffer to device with only one submit. - event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), - buf.get_size(), host_to_device, - host_events)); - break; - } - case device_to_host: - { - host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, - event_list); - // Copy from host temp buffer to host target with reshaping. - event_list = dpct_memcpy( - q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), - sycl::id<3>(0, 0, 0), size, host_to_host, - // Copy from device to temp host buffer with only one submit. - std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, - buf.get_size(), - device_to_host, dep_events)}); - break; - } - case device_to_device: - event_list.push_back(q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(dep_events); - cgh.parallel_for( - size, - [=](sycl::id<3> id) { - to_surface[get_offset(id, to_slice, to_range.get(0))] = - from_surface[get_offset(id, from_slice, from_range.get(0))]; - }); })); - break; - default: - throw std::runtime_error("dpct_memcpy: invalid direction value"); - } - return event_list; - } - - /// memcpy 2D/3D matrix specified by pitched_data. - static inline std::vector - dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, - pitched_data from, sycl::id<3> from_id, sycl::range<3> size, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), - sycl::range<3>(to.get_pitch(), to.get_y(), 1), - sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, - size, direction); - } - - /// memcpy 2D matrix with pitch. - static inline std::vector - dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, - size_t to_pitch, size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic) - { - return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), - sycl::range<3>(from_pitch, y, 1), - sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), - sycl::range<3>(x, y, 1), direction); - } - - inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, const void *b, library_data_t b_type, int ldb, - const void *beta, void *c, library_data_t c_type, int ldc, - library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) - { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, - lda, b, ldb, beta, c, ldc); - break; - } -#ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, - a, lda, b, ldb, &beta_half, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_impl( - q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); - break; - } -#endif // __INTEL_MKL__ - default: - throw std::runtime_error("the combination of data type is unsupported"); - } - } // gemm() - - /// Computes a batch of matrix-matrix product with general matrices. - /// \param [in] q The queue where the routine should be executed. - /// \param [in] a_trans Specifies the operation applied to A. - /// \param [in] b_trans Specifies the operation applied to B. - /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. - /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. - /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). - /// \param [in] alpha Scaling factor for the matrix-matrix product. - /// \param [in] a Input matrix A. - /// \param [in] a_type Data type of the matrix A. - /// \param [in] lda Leading dimension of A. - /// \param [in] b Input matrix B. - /// \param [in] b_type Data type of the matrix B. - /// \param [in] ldb Leading dimension of B. - /// \param [in] beta Scaling factor for matrix C. - /// \param [in, out] c Input/Output matrix C. - /// \param [in] c_type Data type of the matrix C. - /// \param [in] ldc Leading dimension of C. - /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. - /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) - { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } -#ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - float alpha_float = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_float = - dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } -#endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } - } - - /// Computes a batch of matrix-matrix product with general matrices. - /// \param [in] q The queue where the routine should be executed. - /// \param [in] a_trans Specifies the operation applied to A. - /// \param [in] b_trans Specifies the operation applied to B. - /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. - /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. - /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). - /// \param [in] alpha Scaling factor for the matrix-matrix product. - /// \param [in] a Input matrix A. - /// \param [in] a_type Data type of the matrix A. - /// \param [in] lda Leading dimension of A. - /// \param [in] stride_a Stride between the different A matrices. - /// \param [in] b Input matrix B. - /// \param [in] b_type Data type of the matrix B. - /// \param [in] ldb Leading dimension of B. - /// \param [in] stride_b Stride between the different B matrices. - /// \param [in] beta Scaling factor for matrix C. - /// \param [in, out] c Input/Output matrix C. - /// \param [in] c_type Data type of the matrix C. - /// \param [in] ldc Leading dimension of C. - /// \param [in] stride_c Stride between the different C matrices. - /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. - /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a, library_data_t a_type, - int lda, long long int stride_a, const void *b, - library_data_t b_type, int ldb, long long int stride_b, - const void *beta, void *c, library_data_t c_type, - int ldc, long long int stride_c, int batch_size, - library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - - std::uint64_t key = - detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); - switch (key) - { - case detail::get_type_combination_id( - library_data_t::real_float, library_data_t::real_float, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_double, library_data_t::real_double, - library_data_t::real_double, library_data_t::real_double): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_half): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } -#ifdef __INTEL_MKL__ - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_bfloat16, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_bfloat16, library_data_t::real_bfloat16, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - stride_a, b, ldb, stride_b, beta, c, ldc, - stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_int32, library_data_t::real_int32): - { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_int8, library_data_t::real_int8, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_float, library_data_t::real_float): - { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, - beta, c, ldc, stride_c, batch_size); - break; - } -#endif - case detail::get_type_combination_id( - library_data_t::real_half, library_data_t::real_half, - library_data_t::real_half, library_data_t::real_float): - { - float alpha_value = - dpct::get_value(reinterpret_cast(alpha), q); - float beta_value = - dpct::get_value(reinterpret_cast(beta), q); - sycl::half alpha_half(alpha_value); - sycl::half beta_half(beta_value); - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, - &beta_half, c, ldc, stride_c, batch_size); - break; - } - default: - throw std::runtime_error("the combination of data type is unsupported"); - } - } - - static inline void - async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, - size_t from_pitch, size_t x, size_t y, - memcpy_direction direction = automatic, - sycl::queue &q = get_default_queue()) - { - detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, - direction); - } - - using err0 = detail::generic_error_type; - using err1 = detail::generic_error_type; - - static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) { - detail::dpct_free(ptr, q); - } - - /// dpct accessor used as device function parameter. - template class accessor; - template class accessor { - public: - using memory_t = detail::memory_traits; - using element_t = typename memory_t::element_t; - using pointer_t = typename memory_t::pointer_t; - using accessor_t = typename memory_t::template accessor_t<3>; - accessor(pointer_t data, const sycl::range<3> &in_range) - : _data(data), _range(in_range) {} - template - accessor(typename std::enable_if::type &acc) - : accessor(acc, acc.get_range()) {} - accessor(const accessor_t &acc, const sycl::range<3> &in_range) - : accessor(acc.get_pointer(), in_range) {} - accessor operator[](size_t index) const { - sycl::range<2> sub(_range.get(1), _range.get(2)); - return accessor(_data + index * sub.size(), sub); - } - - pointer_t get_ptr() const { return _data; } - - private: - pointer_t _data; - sycl::range<3> _range; - }; - template class accessor { - public: - using memory_t = detail::memory_traits; - using element_t = typename memory_t::element_t; - using pointer_t = typename memory_t::pointer_t; - using accessor_t = typename memory_t::template accessor_t<2>; - accessor(pointer_t data, const sycl::range<2> &in_range) - : _data(data), _range(in_range) {} - template - accessor(typename std::enable_if::type &acc) - : accessor(acc, acc.get_range()) {} - accessor(const accessor_t &acc, const sycl::range<2> &in_range) - : accessor(acc.get_pointer(), in_range) {} - - pointer_t operator[](size_t index) const { - return _data + _range.get(1) * index; - } - - pointer_t get_ptr() const { return _data; } - - private: - pointer_t _data; - sycl::range<2> _range; - }; - - namespace detail { - /// Device variable with address space of shared, global or constant. - template class device_memory { - public: - using accessor_t = - typename detail::memory_traits::template accessor_t; - using value_t = typename detail::memory_traits::value_t; - using dpct_accessor_t = dpct::accessor; - - device_memory() : device_memory(sycl::range(1)) {} - - /// Constructor of 1-D array with initializer list - device_memory(const sycl::range &in_range, - std::initializer_list &&init_list) - : device_memory(in_range) { - assert(init_list.size() <= in_range.size()); - _host_ptr = (value_t *)std::malloc(_size); - std::memset(_host_ptr, 0, _size); - std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); - } - - /// Constructor of 2-D array with initializer list - template - device_memory( - const typename std::enable_if>::type &in_range, - std::initializer_list> &&init_list) - : device_memory(in_range) { - assert(init_list.size() <= in_range[0]); - _host_ptr = (value_t *)std::malloc(_size); - std::memset(_host_ptr, 0, _size); - auto tmp_data = _host_ptr; - for (auto sub_list : init_list) { - assert(sub_list.size() <= in_range[1]); - std::memcpy(tmp_data, sub_list.begin(), - sub_list.size() * sizeof(T)); - tmp_data += in_range[1]; - } - } - - /// Constructor with range - device_memory(const sycl::range &range_in) - : _size(range_in.size() * sizeof(T)), _range(range_in), - _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) { - static_assert( - (Memory == global) || (Memory == constant) || (Memory == shared), - "device memory region should be global, constant or shared"); - // Make sure that singleton class mem_mgr and dev_mgr will destruct - // later than this. - detail::mem_mgr::instance(); - dev_mgr::instance(); - } - - /// Constructor with range - template - device_memory(Args... Arguments) - : device_memory(sycl::range(Arguments...)) {} - - ~device_memory() { - if (_device_ptr && !_reference) - dpct::dpct_free(_device_ptr); - if (_host_ptr) - std::free(_host_ptr); - } - - /// Allocate memory with default queue, and init memory if has initial - /// value. - void init() { init(dpct::get_default_queue()); } - /// Allocate memory with specified queue, and init memory if has initial - /// value. - void init(sycl::queue &q) { - if (_device_ptr) - return; - if (!_size) - return; - allocate_device(q); - if (_host_ptr) - detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, - host_to_device); - } - - /// The variable is assigned to a device pointer. - void assign(value_t *src, size_t size) { - this->~device_memory(); - new (this) device_memory(src, size); - } - - /// Get memory pointer of the memory object, which is virtual pointer when - /// usm is not used, and device pointer when usm is used. - value_t *get_ptr() { return get_ptr(get_default_queue()); } - /// Get memory pointer of the memory object, which is virtual pointer when - /// usm is not used, and device pointer when usm is used. - value_t *get_ptr(sycl::queue &q) { - init(q); - return _device_ptr; - } - - /// Get the device memory object size in bytes. - size_t get_size() { return _size; } - - template - typename std::enable_if::type &operator[](size_t index) { - init(); - return _device_ptr[index]; - } - - /// Get dpct::accessor with dimension info for the device memory object - /// when usm is used and dimension is greater than 1. - template - typename std::enable_if::type - get_access(sycl::handler &cgh) { - return dpct_accessor_t((T *)_device_ptr, _range); - } - - private: - device_memory(value_t *memory_ptr, size_t size) - : _size(size), _range(size / sizeof(T)), _reference(true), - _device_ptr(memory_ptr) {} - - void allocate_device(sycl::queue &q) { - #ifndef DPCT_USM_LEVEL_NONE - if (Memory == shared) { - _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), - q.get_context()); - return; - } - #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY - if (Memory == constant) { - _device_ptr = (value_t *)sycl::malloc_device( - _size, q.get_device(), q.get_context(), - sycl::ext::oneapi::property::usm::device_read_only()); - return; - } - #endif - #endif - _device_ptr = (value_t *)detail::dpct_malloc(_size, q); - } - - size_t _size; - sycl::range _range; - bool _reference; - value_t *_host_ptr; - value_t *_device_ptr; - }; - template - class device_memory : public device_memory { - public: - using base = device_memory; - using value_t = typename base::value_t; - using accessor_t = - typename detail::memory_traits::template accessor_t<0>; - - /// Constructor with initial value. - device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} - - /// Default constructor - device_memory() : base(1) {} - }; - } // namespace detail - - template - using global_memory = detail::device_memory; - template - using constant_memory = detail::device_memory; - template - using shared_memory = detail::device_memory; - - -} // COPY from DPCT head files +#include "ggml-sycl/dpct/helper.hpp" #define GGML_COMMON_DECL_SYCL #define GGML_COMMON_IMPL_SYCL diff --git a/ggml-sycl/dpct/helper.hpp b/ggml-sycl/dpct/helper.hpp new file mode 100644 index 0000000000000..788f9724efeb9 --- /dev/null +++ b/ggml-sycl/dpct/helper.hpp @@ -0,0 +1,2937 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + + +#include +#include +#include +#include + +#include "ggml.h" + +#if defined(__linux__) +#include +#elif defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#error "Only support Windows and Linux." +#endif + +#if defined(__linux__) +#include +#include +#endif +#if defined(_WIN64) +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif + +#define DPCT_COMPATIBILITY_TEMP (900) + +#if defined(_MSC_VER) +#define __dpct_align__(n) __declspec(align(n)) +#define __dpct_inline__ __forceinline +#else +#define __dpct_align__(n) __attribute__((aligned(n))) +#define __dpct_inline__ __inline__ __attribute__((always_inline)) +#endif + +#if defined(_MSC_VER) +#define __dpct_noinline__ __declspec(noinline) +#else +#define __dpct_noinline__ __attribute__((noinline)) +#endif + +std::string get_device_type_name(const sycl::device &Device) { + auto DeviceType = Device.get_info(); + switch (DeviceType) { + case sycl::info::device_type::cpu: + return "cpu"; + case sycl::info::device_type::gpu: + return "gpu"; + case sycl::info::device_type::host: + return "host"; + case sycl::info::device_type::accelerator: + return "acc"; + default: + return "unknown"; + } +} + +std::string get_device_backend_and_type(const sycl::device &device) { + std::stringstream device_type; + sycl::backend backend = device.get_backend(); + device_type << backend << ":" << get_device_type_name(device); + return device_type.str(); +} + +namespace dpct +{ + typedef sycl::queue *queue_ptr; + typedef sycl::event *event_ptr; + typedef char *device_ptr; + typedef uint8_t byte_t; + typedef sycl::buffer buffer_t; + + /// SYCL default exception handler + inline auto exception_handler = [](sycl::exception_list exceptions) + { + for (std::exception_ptr const &e : exceptions) + { + try + { + std::rethrow_exception(e); + } + catch (sycl::exception const &e) + { + std::cerr << "Caught asynchronous SYCL exception:" << std::endl + << e.what() << std::endl + << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + } + } + }; + + enum error_code + { + success = 0, + default_error = 999 + }; + + enum memcpy_direction + { + host_to_host, + host_to_device, + device_to_host, + device_to_device, + automatic + }; + + enum memory_region + { + global = 0, // device global memory + constant, // device constant memory + local, // device local memory + shared, // memory which can be accessed by host and device + }; + + enum class library_data_t : unsigned char + { + real_float = 0, + complex_float, + real_double, + complex_double, + real_half, + complex_half, + real_bfloat16, + complex_bfloat16, + real_int4, + complex_int4, + real_uint4, + complex_uint4, + real_int8, + complex_int8, + real_uint8, + complex_uint8, + real_int16, + complex_int16, + real_uint16, + complex_uint16, + real_int32, + complex_int32, + real_uint32, + complex_uint32, + real_int64, + complex_int64, + real_uint64, + complex_uint64, + real_int8_4, + real_int8_32, + real_uint8_4, + library_data_t_size + }; + + template + struct DataType + { + using T2 = T; + }; + template + struct DataType> + { + using T2 = std::complex; + }; + + static void destroy_event(event_ptr event) + { + delete event; + } + + static inline unsigned int get_tid() + { +#if defined(__linux__) + return syscall(SYS_gettid); +#elif defined(_WIN64) + return GetCurrentThreadId(); +#else +#error "Only support Windows and Linux." +#endif + } + + namespace detail + { + static void get_version(const sycl::device &dev, int &major, int &minor) + { + // Version string has the following format: + // a. OpenCL + // b. + // c. e.g gfx1030 + std::string ver; + ver = dev.get_info(); + std::string::size_type i = 0; + while (i < ver.size()) { + if (isdigit(ver[i])) + break; + i++; + } + major = std::stoi(&(ver[i])); + while (i < ver.size()) { + if (ver[i] == '.') + break; + i++; + } + if (i < ver.size()) { + // a. and b. + i++; + minor = std::stoi(&(ver[i])); + } else { + // c. + minor = 0; + } + } + + template + class generic_error_type + { + public: + generic_error_type() = default; + generic_error_type(T value) : value{value} {} + operator T() const { return value; } + + private: + T value; + }; + + } // namespace detail + + /// Pitched 2D/3D memory data. + class pitched_data + { + public: + pitched_data() : pitched_data(nullptr, 0, 0, 0) {} + pitched_data(void *data, size_t pitch, size_t x, size_t y) + : _data(data), _pitch(pitch), _x(x), _y(y) {} + + void *get_data_ptr() { return _data; } + void set_data_ptr(void *data) { _data = data; } + + size_t get_pitch() { return _pitch; } + void set_pitch(size_t pitch) { _pitch = pitch; } + + size_t get_x() { return _x; } + void set_x(size_t x) { _x = x; }; + + size_t get_y() { return _y; } + void set_y(size_t y) { _y = y; } + + private: + void *_data; + size_t _pitch, _x, _y; + }; + + class device_info + { + public: + // get interface + const char *get_name() const { return _name; } + char *get_name() { return _name; } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() const + { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else + { + return _max_work_item_sizes_i; + } + } + template , + std::enable_if_t> || + std::is_same_v, + int> = 0> + auto get_max_work_item_sizes() + { + if constexpr (std::is_same_v>) + return sycl::range<3>(_max_work_item_sizes_i[0], + _max_work_item_sizes_i[1], + _max_work_item_sizes_i[2]); + else + { + return _max_work_item_sizes_i; + } + } + bool get_host_unified_memory() const { return _host_unified_memory; } + int get_major_version() const { return _major; } + int get_minor_version() const { return _minor; } + int get_integrated() const { return _integrated; } + int get_max_clock_frequency() const { return _frequency; } + int get_max_compute_units() const { return _max_compute_units; } + int get_max_work_group_size() const { return _max_work_group_size; } + int get_max_sub_group_size() const { return _max_sub_group_size; } + int get_max_work_items_per_compute_unit() const + { + return _max_work_items_per_compute_unit; + } + int get_max_register_size_per_work_group() const + { + return _max_register_size_per_work_group; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() const + { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + template || + std::is_same_v, + int> = 0> + auto get_max_nd_range_size() + { + if constexpr (std::is_same_v) + return _max_nd_range_size; + else + return _max_nd_range_size_i; + } + size_t get_global_mem_size() const { return _global_mem_size; } + size_t get_local_mem_size() const { return _local_mem_size; } + size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; } + /// Returns the maximum clock rate of device's global memory in kHz. If + /// compiler does not support this API then returns default value 3200000 kHz. + unsigned int get_memory_clock_rate() const { return _memory_clock_rate; } + /// Returns the maximum bus width between device and memory in bits. If + /// compiler does not support this API then returns default value 64 bits. + unsigned int get_memory_bus_width() const { return _memory_bus_width; } + uint32_t get_device_id() const { return _device_id; } + std::array get_uuid() const { return _uuid; } + /// Returns global memory cache size in bytes. + unsigned int get_global_mem_cache_size() const + { + return _global_mem_cache_size; + } + + // set interface + void set_name(const char *name) + { + size_t length = strlen(name); + if (length < 256) + { + std::memcpy(_name, name, length + 1); + } + else + { + std::memcpy(_name, name, 255); + _name[255] = '\0'; + } + } + void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes) + { + for (int i = 0; i < 3; ++i) + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + [[deprecated]] void + set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes) + { + for (int i = 0; i < 3; ++i) + { + _max_work_item_sizes_i[i] = max_work_item_sizes[i]; + } + } + void set_host_unified_memory(bool host_unified_memory) + { + _host_unified_memory = host_unified_memory; + } + void set_major_version(int major) { _major = major; } + void set_minor_version(int minor) { _minor = minor; } + void set_integrated(int integrated) { _integrated = integrated; } + void set_max_clock_frequency(int frequency) { _frequency = frequency; } + void set_max_compute_units(int max_compute_units) + { + _max_compute_units = max_compute_units; + } + void set_global_mem_size(size_t global_mem_size) + { + _global_mem_size = global_mem_size; + } + void set_local_mem_size(size_t local_mem_size) + { + _local_mem_size = local_mem_size; + } + void set_max_mem_alloc_size(size_t max_mem_alloc_size) + { + _max_mem_alloc_size = max_mem_alloc_size; + } + void set_max_work_group_size(int max_work_group_size) + { + _max_work_group_size = max_work_group_size; + } + void set_max_sub_group_size(int max_sub_group_size) + { + _max_sub_group_size = max_sub_group_size; + } + void + set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit) + { + _max_work_items_per_compute_unit = max_work_items_per_compute_unit; + } + void set_max_nd_range_size(int max_nd_range_size[]) + { + for (int i = 0; i < 3; i++) + { + _max_nd_range_size[i] = max_nd_range_size[i]; + _max_nd_range_size_i[i] = max_nd_range_size[i]; + } + } + void set_memory_clock_rate(unsigned int memory_clock_rate) + { + _memory_clock_rate = memory_clock_rate; + } + void set_memory_bus_width(unsigned int memory_bus_width) + { + _memory_bus_width = memory_bus_width; + } + void + set_max_register_size_per_work_group(int max_register_size_per_work_group) + { + _max_register_size_per_work_group = max_register_size_per_work_group; + } + void set_device_id(uint32_t device_id) + { + _device_id = device_id; + } + void set_uuid(std::array uuid) + { + _uuid = std::move(uuid); + } + void set_global_mem_cache_size(unsigned int global_mem_cache_size) + { + _global_mem_cache_size = global_mem_cache_size; + } + + private: + char _name[256]; + int _max_work_item_sizes_i[3]; + bool _host_unified_memory = false; + int _major; + int _minor; + int _integrated = 0; + int _frequency; + // Set estimated value 3200000 kHz as default value. + unsigned int _memory_clock_rate = 3200000; + // Set estimated value 64 bits as default value. + unsigned int _memory_bus_width = 64; + unsigned int _global_mem_cache_size; + int _max_compute_units; + int _max_work_group_size; + int _max_sub_group_size; + int _max_work_items_per_compute_unit; + int _max_register_size_per_work_group; + size_t _global_mem_size; + size_t _local_mem_size; + size_t _max_mem_alloc_size; + size_t _max_nd_range_size[3]; + int _max_nd_range_size_i[3]; + uint32_t _device_id; + std::array _uuid; + }; + + static int get_major_version(const sycl::device &dev) + { + int major, minor; + detail::get_version(dev, major, minor); + return major; + } + + static int get_minor_version(const sycl::device &dev) + { + int major, minor; + detail::get_version(dev, major, minor); + return minor; + } + + static void get_device_info(device_info &out, const sycl::device &dev) + { + device_info prop; + prop.set_name(dev.get_info().c_str()); + + int major, minor; + detail::get_version(dev, major, minor); + prop.set_major_version(major); + prop.set_minor_version(minor); + + prop.set_max_work_item_sizes( +#if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902) + // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes + // is an enum class element + dev.get_info()); +#else + // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by + // an int + dev.get_info>()); +#endif + prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations)); + + prop.set_max_clock_frequency( + dev.get_info() * 1000); + + prop.set_max_compute_units( + dev.get_info()); + prop.set_max_work_group_size( + dev.get_info()); + prop.set_global_mem_size(dev.get_info()); + prop.set_local_mem_size(dev.get_info()); + prop.set_max_mem_alloc_size(dev.get_info()); + +#if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6) + if (dev.has(sycl::aspect::ext_intel_memory_clock_rate)) + { + unsigned int tmp = + dev.get_info(); + if (tmp != 0) + prop.set_memory_clock_rate(1000 * tmp); + } + if (dev.has(sycl::aspect::ext_intel_memory_bus_width)) + { + prop.set_memory_bus_width( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_id)) + { + prop.set_device_id( + dev.get_info()); + } + if (dev.has(sycl::aspect::ext_intel_device_info_uuid)) + { + prop.set_uuid(dev.get_info()); + } +#elif defined(_MSC_VER) && !defined(__clang__) +#pragma message("get_device_info: querying memory_clock_rate and \ + memory_bus_width are not supported by the compiler used. \ + Use 3200000 kHz as memory_clock_rate default value. \ + Use 64 bits as memory_bus_width default value.") +#else +#warning "get_device_info: querying memory_clock_rate and \ + memory_bus_width are not supported by the compiler used. \ + Use 3200000 kHz as memory_clock_rate default value. \ + Use 64 bits as memory_bus_width default value." +#endif + + size_t max_sub_group_size = 1; + std::vector sub_group_sizes = + dev.get_info(); + + for (const auto &sub_group_size : sub_group_sizes) + { + if (max_sub_group_size < sub_group_size) + max_sub_group_size = sub_group_size; + } + + prop.set_max_sub_group_size(max_sub_group_size); + + prop.set_max_work_items_per_compute_unit( + dev.get_info()); + int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF}; + prop.set_max_nd_range_size(max_nd_range_size); + + // Estimates max register size per work group, feel free to update the value + // according to device properties. + prop.set_max_register_size_per_work_group(65536); + + prop.set_global_mem_cache_size( + dev.get_info()); + out = prop; + } + + /// dpct device extension + class device_ext : public sycl::device + { + typedef std::mutex mutex_type; + + public: + device_ext() : sycl::device(), _ctx(*this) {} + ~device_ext() + { + std::lock_guard lock(m_mutex); + clear_queues(); + } + device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this) + { + std::lock_guard lock(m_mutex); + init_queues(); + } + + int is_native_atomic_supported() { return 0; } + int get_major_version() const + { + return dpct::get_major_version(*this); + } + + int get_minor_version() const + { + return dpct::get_minor_version(*this); + } + + int get_max_compute_units() const + { + return get_device_info().get_max_compute_units(); + } + + /// Return the maximum clock frequency of this device in KHz. + int get_max_clock_frequency() const + { + return get_device_info().get_max_clock_frequency(); + } + + int get_integrated() const { return get_device_info().get_integrated(); } + + int get_max_sub_group_size() const + { + return get_device_info().get_max_sub_group_size(); + } + + int get_max_register_size_per_work_group() const + { + return get_device_info().get_max_register_size_per_work_group(); + } + + int get_max_work_group_size() const + { + return get_device_info().get_max_work_group_size(); + } + + int get_mem_base_addr_align() const + { + return get_info(); + } + + size_t get_global_mem_size() const + { + return get_device_info().get_global_mem_size(); + } + + size_t get_max_mem_alloc_size() const + { + return get_device_info().get_max_mem_alloc_size(); + } + + /// Get the number of bytes of free and total memory on the SYCL device. + /// \param [out] free_memory The number of bytes of free memory on the SYCL device. + /// \param [out] total_memory The number of bytes of total memory on the SYCL device. + void get_memory_info(size_t &free_memory, size_t &total_memory) + { + total_memory = get_device_info().get_global_mem_size(); + const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not " + "supported (export/set ZES_ENABLE_SYSMAN=1 to support), " + "use total memory as free memory"; +#if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105) + if (!has(sycl::aspect::ext_intel_free_memory)) + { + std::cerr << warning_info << std::endl; + free_memory = total_memory; + } + else + { + free_memory = get_info(); + } +#else + std::cerr << warning_info << std::endl; + free_memory = total_memory; +#if defined(_MSC_VER) && !defined(__clang__) +#pragma message("Querying the number of bytes of free memory is not supported") +#else +#warning "Querying the number of bytes of free memory is not supported" +#endif +#endif + } + + void get_device_info(device_info &out) const + { + dpct::get_device_info(out, *this); + } + + device_info get_device_info() const + { + device_info prop; + dpct::get_device_info(prop, *this); + return prop; + } + + void reset() + { + std::lock_guard lock(m_mutex); + clear_queues(); + init_queues(); + } + + sycl::queue &in_order_queue() { return *_q_in_order; } + + sycl::queue &out_of_order_queue() { return *_q_out_of_order; } + + sycl::queue &default_queue() + { + return in_order_queue(); + } + + void queues_wait_and_throw() + { + std::unique_lock lock(m_mutex); + std::vector> current_queues( + _queues); + lock.unlock(); + for (const auto &q : current_queues) + { + q->wait_and_throw(); + } + // Guard the destruct of current_queues to make sure the ref count is safe. + lock.lock(); + } + + sycl::queue *create_queue(bool enable_exception_handler = false) + { + return create_in_order_queue(enable_exception_handler); + } + + sycl::queue *create_queue(sycl::context context, sycl::device device, + bool enable_exception_handler = false) { + return create_in_order_queue(context, device, enable_exception_handler); + } + + sycl::queue *create_in_order_queue(bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue *create_in_order_queue(sycl::context context, sycl::device device, + bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(context, device, enable_exception_handler, + sycl::property::queue::in_order()); + } + + sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) { + std::lock_guard lock(m_mutex); + return create_queue_impl(enable_exception_handler); + } + + void destroy_queue(sycl::queue *&queue) + { + std::lock_guard lock(m_mutex); + _queues.erase(std::remove_if(_queues.begin(), _queues.end(), + [=](const std::shared_ptr &q) -> bool + { + return q.get() == queue; + }), + _queues.end()); + queue = nullptr; + } + void set_saved_queue(sycl::queue *q) + { + std::lock_guard lock(m_mutex); + _saved_queue = q; + } + sycl::queue *get_saved_queue() const + { + std::lock_guard lock(m_mutex); + return _saved_queue; + } + sycl::context get_context() const { return _ctx; } + + private: + void clear_queues() + { + _queues.clear(); + _q_in_order = _q_out_of_order = _saved_queue = nullptr; + } + + void init_queues() + { + _q_in_order = create_queue_impl(true, sycl::property::queue::in_order()); + _q_out_of_order = create_queue_impl(true); + _saved_queue = &default_queue(); + } + + /// Caller should acquire resource \p m_mutex before calling this function. + template + sycl::queue *create_queue_impl(bool enable_exception_handler, + Properties... properties) + { + sycl::async_handler eh = {}; + if (enable_exception_handler) + { + eh = exception_handler; + } + _queues.push_back(std::make_shared( + _ctx, *this, eh, + sycl::property_list( +#ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), +#endif + properties...))); + + return _queues.back().get(); + } + + template + sycl::queue *create_queue_impl(sycl::context context, sycl::device device, + bool enable_exception_handler, + Properties... properties) { + sycl::async_handler eh = {}; + if (enable_exception_handler) { + eh = exception_handler; + } + _queues.push_back(std::make_shared( + context, device, eh, + sycl::property_list( + #ifdef DPCT_PROFILING_ENABLED + sycl::property::queue::enable_profiling(), + #endif + properties...))); + + return _queues.back().get(); + } + + void get_version(int &major, int &minor) const + { + detail::get_version(*this, major, minor); + } + sycl::queue *_q_in_order, *_q_out_of_order; + sycl::queue *_saved_queue; + sycl::context _ctx; + std::vector> _queues; + mutable mutex_type m_mutex; + }; + + /// device manager + class dev_mgr + { + public: + device_ext ¤t_device() + { + unsigned int dev_id = current_device_id(); + check_id(dev_id); + return *_devs[dev_id]; + } + device_ext &cpu_device() const + { + std::lock_guard lock(m_mutex); + if (_cpu_device == -1) + { + throw std::runtime_error("no valid cpu device"); + } + else + { + return *_devs[_cpu_device]; + } + } + device_ext &get_device(unsigned int id) const + { + std::lock_guard lock(m_mutex); + check_id(id); + return *_devs[id]; + } + unsigned int current_device_id() const + { + std::lock_guard lock(m_mutex); + auto it = _thread2dev_map.find(get_tid()); + if (it != _thread2dev_map.end()) + return it->second; + return DEFAULT_DEVICE_ID; + } + + /// Select device with a device ID. + /// \param [in] id The id of the device which can + /// be obtained through get_device_id(const sycl::device). + void select_device(unsigned int id) + { + std::lock_guard lock(m_mutex); + check_id(id); + _thread2dev_map[get_tid()] = id; + } + unsigned int device_count() { return _devs.size(); } + + unsigned int get_device_id(const sycl::device &dev) + { + unsigned int id = 0; + for (auto dev_item : _devs) + { + if (*dev_item == dev) + { + break; + } + id++; + } + return id; + } + + template + std::enable_if_t< + std::is_invocable_r_v> + select_device(const DeviceSelector &selector = sycl::gpu_selector_v) + { + sycl::device selected_device = sycl::device(selector); + unsigned int selected_device_id = get_device_id(selected_device); + select_device(selected_device_id); + } + + /// Returns the instance of device manager singleton. + static dev_mgr &instance() + { + static dev_mgr d_m; + return d_m; + } + dev_mgr(const dev_mgr &) = delete; + dev_mgr &operator=(const dev_mgr &) = delete; + dev_mgr(dev_mgr &&) = delete; + dev_mgr &operator=(dev_mgr &&) = delete; + + private: + mutable std::recursive_mutex m_mutex; + static bool compare_dev(sycl::device &device1, sycl::device &device2) + { + dpct::device_info prop1; + dpct::get_device_info(prop1, device1); + dpct::device_info prop2; + dpct::get_device_info(prop2, device2); + return prop1.get_max_compute_units() > prop2.get_max_compute_units(); + } + static int convert_backend_index(std::string & backend) { + if (backend == "ext_oneapi_level_zero:gpu") return 0; + if (backend == "opencl:gpu") return 1; + if (backend == "ext_oneapi_cuda:gpu") return 2; + if (backend == "ext_oneapi_hip:gpu") return 3; + if (backend == "opencl:cpu") return 4; + if (backend == "opencl:acc") return 5; + printf("convert_backend_index: can't handle backend=%s\n", backend.c_str()); + GGML_ASSERT(false); + } + static bool compare_backend(std::string &backend1, std::string &backend2) { + return convert_backend_index(backend1) < convert_backend_index(backend2); + } + dev_mgr() + { + sycl::device default_device = + sycl::device(sycl::default_selector_v); + _devs.push_back(std::make_shared(default_device)); + + std::vector sycl_all_devs; + // Collect other devices except for the default device. + if (default_device.is_cpu()) + _cpu_device = 0; + + auto Platforms = sycl::platform::get_platforms(); + // Keep track of the number of devices per backend + std::map DeviceNums; + std::map> backend_devices; + + while (!Platforms.empty()) { + auto Platform = Platforms.back(); + Platforms.pop_back(); + auto devices = Platform.get_devices(); + std::string backend_type = get_device_backend_and_type(devices[0]); + for (const auto &device : devices) { + backend_devices[backend_type].push_back(device); + } + } + + std::vector keys; + for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) { + keys.push_back(it->first); + } + std::sort(keys.begin(), keys.end(), compare_backend); + + for (auto &key : keys) { + std::vector devs = backend_devices[key]; + std::sort(devs.begin(), devs.end(), compare_dev); + for (const auto &dev : devs) { + sycl_all_devs.push_back(dev); + } + } + + for (auto &dev : sycl_all_devs) + { + if (dev == default_device) + { + continue; + } + _devs.push_back(std::make_shared(dev)); + if (_cpu_device == -1 && dev.is_cpu()) + { + _cpu_device = _devs.size() - 1; + } + } + } + void check_id(unsigned int id) const + { + if (id >= _devs.size()) + { + throw std::runtime_error("invalid device id"); + } + } + std::vector> _devs; + /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current + /// thread id in _thread2dev_map, which means default device should be used + /// for the current thread. + const unsigned int DEFAULT_DEVICE_ID = 0; + /// thread-id to device-id map. + std::map _thread2dev_map; + int _cpu_device = -1; + }; + + static inline sycl::queue &get_default_queue() + { + return dev_mgr::instance().current_device().default_queue(); + } + + namespace detail + { + enum class pointer_access_attribute + { + host_only = 0, + device_only, + host_device, + end + }; + + static pointer_access_attribute get_pointer_attribute(sycl::queue &q, + const void *ptr) + { + switch (sycl::get_pointer_type(ptr, q.get_context())) + { + case sycl::usm::alloc::unknown: + return pointer_access_attribute::host_only; + case sycl::usm::alloc::device: + return pointer_access_attribute::device_only; + case sycl::usm::alloc::shared: + case sycl::usm::alloc::host: + return pointer_access_attribute::host_device; + } + } + + template + inline constexpr std::uint64_t get_type_combination_id(ArgT Val) + { + static_assert((unsigned char)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(std::is_same_v, "Unsupported ArgT"); + return (std::uint64_t)Val; + } + + template + inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal, + RestT... RestVal) + { + static_assert((std::uint8_t)library_data_t::library_data_t_size <= + std::numeric_limits::max() && + "library_data_t size exceeds limit."); + static_assert(sizeof...(RestT) <= 8 && "Too many parameters"); + static_assert(std::is_same_v, "Unsupported FirstT"); + return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal); + } + + class mem_mgr + { + mem_mgr() + { + // Reserved address space, no real memory allocation happens here. +#if defined(__linux__) + mapped_address_space = + (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); +#elif defined(_WIN64) + mapped_address_space = (byte_t *)VirtualAlloc( + NULL, // NULL specified as the base address parameter + mapped_region_size, // Size of allocation + MEM_RESERVE, // Allocate reserved pages + PAGE_NOACCESS); // Protection = no access +#else +#error "Only support Windows and Linux." +#endif + next_free = mapped_address_space; + }; + + public: + using buffer_id_t = int; + + struct allocation + { + buffer_t buffer; + byte_t *alloc_ptr; + size_t size; + }; + + ~mem_mgr() + { +#if defined(__linux__) + munmap(mapped_address_space, mapped_region_size); +#elif defined(_WIN64) + VirtualFree(mapped_address_space, 0, MEM_RELEASE); +#else +#error "Only support Windows and Linux." +#endif + }; + + mem_mgr(const mem_mgr &) = delete; + mem_mgr &operator=(const mem_mgr &) = delete; + mem_mgr(mem_mgr &&) = delete; + mem_mgr &operator=(mem_mgr &&) = delete; + + /// Allocate + void *mem_alloc(size_t size) + { + if (!size) + return nullptr; + std::lock_guard lock(m_mutex); + if (next_free + size > mapped_address_space + mapped_region_size) + { + throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool"); + } + // Allocation + sycl::range<1> r(size); + buffer_t buf(r); + allocation A{buf, next_free, size}; + // Map allocation to device pointer + void *result = next_free; + m_map.emplace(next_free + size, A); + // Update pointer to the next free space. + next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1); + + return result; + } + + /// Deallocate + void mem_free(const void *ptr) + { + if (!ptr) + return; + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + m_map.erase(it); + } + + /// map: device pointer -> allocation(buffer, alloc_ptr, size) + allocation translate_ptr(const void *ptr) + { + std::lock_guard lock(m_mutex); + auto it = get_map_iterator(ptr); + return it->second; + } + + /// Check if the pointer represents device pointer or not. + bool is_device_ptr(const void *ptr) const + { + std::lock_guard lock(m_mutex); + return (mapped_address_space <= ptr) && + (ptr < mapped_address_space + mapped_region_size); + } + + /// Returns the instance of memory manager singleton. + static mem_mgr &instance() + { + static mem_mgr m; + return m; + } + + private: + std::map m_map; + mutable std::mutex m_mutex; + byte_t *mapped_address_space; + byte_t *next_free; + const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024; + const size_t alignment = 256; + /// This padding may be defined to some positive value to debug + /// out of bound accesses. + const size_t extra_padding = 0; + + std::map::iterator get_map_iterator(const void *ptr) + { + auto it = m_map.upper_bound((byte_t *)ptr); + if (it == m_map.end()) + { + // Not a virtual pointer. + throw std::runtime_error("can not get buffer from non-virtual pointer"); + } + const allocation &alloc = it->second; + if (ptr < alloc.alloc_ptr) + { + // Out of bound. + // This may happen if there's a gap between allocations due to alignment + // or extra padding and pointer points to this gap. + throw std::runtime_error("invalid virtual pointer"); + } + return it; + } + }; + + template + class accessor; + template + class memory_traits + { + public: + static constexpr sycl::access::target target = + sycl::access::target::device; + static constexpr sycl::access_mode mode = + (Memory == constant) ? sycl::access_mode::read + : sycl::access_mode::read_write; + static constexpr size_t type_size = sizeof(T); + using element_t = + typename std::conditional::type; + using value_t = typename std::remove_cv::type; + template + using accessor_t = typename std::conditional< + Memory == local, sycl::local_accessor, + sycl::accessor>::type; + using pointer_t = T *; + }; + + static inline void *dpct_malloc(size_t size, sycl::queue &q) + { + return sycl::malloc_device(size, q.get_device(), q.get_context()); + } + +#define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F)) + static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z, + sycl::queue &q) + { + pitch = PITCH_DEFAULT_ALIGN(x); + return dpct_malloc(pitch * y * z, q); + } + + /** + * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] dev_ptr Pointer to the virtual device memory address. + * @param [in] value The value to be set. + * @param [in] size Number of elements to be set to the value. + * @return An event representing the memset operation. + */ + template + static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr, + valueT value, size_t size) + { + return q.fill(dev_ptr, value, size); + } + + /** + * @brief Sets \p value to the 3D memory region pointed by \p data in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] data Pointer to the pitched device memory region. + * @param [in] value The value to be set. + * @param [in] size 3D memory region by number of elements. + * @return An event list representing the memset operations. + */ + template + static inline std::vector + dpct_memset(sycl::queue &q, pitched_data data, valueT value, + sycl::range<3> size) + { + std::vector event_list; + size_t slice = data.get_pitch() * data.get_y(); + unsigned char *data_surface = (unsigned char *)data.get_data_ptr(); + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *data_ptr = data_surface; + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0))); + data_ptr += data.get_pitch(); + } + data_surface += slice; + } + return event_list; + } + + /** + * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q. + * @tparam valueT The type of the element to be set. + * @param [in] q The queue in which the operation is done. + * @param [in] ptr Pointer to the virtual device memory. + * @param [in] pitch The pitch size by number of elements, including padding. + * @param [in] val The value to be set. + * @param [in] x The width of memory region by number of elements. + * @param [in] y The height of memory region by number of elements. + * @return An event list representing the memset operations. + */ + template + static inline std::vector + dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x, + size_t y) + { + return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val, + sycl::range<3>(x, y, 1)); + } + + static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr, + const void *from_ptr, + memcpy_direction dir) + { + switch (dir) + { + case memcpy_direction::host_to_host: + case memcpy_direction::host_to_device: + case memcpy_direction::device_to_host: + case memcpy_direction::device_to_device: + return dir; + case memcpy_direction::automatic: + { + // table[to_attribute][from_attribute] + static const memcpy_direction + direction_table[static_cast(pointer_access_attribute::end)] + [static_cast(pointer_access_attribute::end)] = + {{memcpy_direction::host_to_host, + memcpy_direction::device_to_host, + memcpy_direction::host_to_host}, + {memcpy_direction::host_to_device, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}, + {memcpy_direction::host_to_host, + memcpy_direction::device_to_device, + memcpy_direction::device_to_device}}; + return direction_table[static_cast(get_pointer_attribute( + q, to_ptr))][static_cast(get_pointer_attribute(q, from_ptr))]; + } + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + } + + static sycl::event + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) + { + if (!size) + return sycl::event{}; + return q.memcpy(to_ptr, from_ptr, size, dep_events); + GGML_UNUSED(direction); + } + + // Get actual copy range and make sure it will not exceed range. + static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) + { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); + } + + static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) + { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); + } + + /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr + /// and \p from_range to another specified by \p to_ptr and \p to_range. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) + { + // RAII for host pointer + class host_buffer + { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() + { + if (_buf) + { + _q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) + { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) + { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } + else + { + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: + { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) + { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } + else + { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: + { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: + event_list.push_back(q.submit([&](sycl::handler &cgh){ + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; + } + + /// memcpy 2D/3D matrix specified by pitched_data. + static inline std::vector + dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); + } + + /// memcpy 2D matrix with pitch. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); + } + + namespace deprecated + { + + template + class usm_allocator + { + private: + using Alloc = sycl::usm_allocator; + Alloc _impl; + + public: + using value_type = typename std::allocator_traits::value_type; + using pointer = typename std::allocator_traits::pointer; + using const_pointer = typename std::allocator_traits::const_pointer; + using void_pointer = typename std::allocator_traits::void_pointer; + using const_void_pointer = + typename std::allocator_traits::const_void_pointer; + using reference = typename std::allocator_traits::value_type &; + using const_reference = + const typename std::allocator_traits::value_type &; + using difference_type = + typename std::allocator_traits::difference_type; + using size_type = typename std::allocator_traits::size_type; + using propagate_on_container_copy_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_copy_assignment; + using propagate_on_container_move_assignment = typename std::allocator_traits< + Alloc>::propagate_on_container_move_assignment; + using propagate_on_container_swap = + typename std::allocator_traits::propagate_on_container_swap; + using is_always_equal = + typename std::allocator_traits::is_always_equal; + + template + struct rebind + { + typedef usm_allocator other; + }; + + usm_allocator() : _impl(dpct::get_default_queue()) {} + ~usm_allocator() {} + usm_allocator(const usm_allocator &other) : _impl(other._impl) {} + usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {} + pointer address(reference r) { return &r; } + const_pointer address(const_reference r) { return &r; } + pointer allocate(size_type cnt, const_void_pointer hint = nullptr) + { + return std::allocator_traits::allocate(_impl, cnt, hint); + } + void deallocate(pointer p, size_type cnt) + { + std::allocator_traits::deallocate(_impl, p, cnt); + } + size_type max_size() const + { + return std::allocator_traits::max_size(_impl); + } + bool operator==(const usm_allocator &other) const { return _impl == other._impl; } + bool operator!=(const usm_allocator &other) const { return _impl != other._impl; } + }; + + } // namespace deprecated + + inline void dpct_free(void *ptr, + const sycl::queue &q) + { + if (ptr) + { + sycl::free(ptr, q.get_context()); + } + } + + template + inline auto get_memory(const void *x) + { + T *new_x = reinterpret_cast(const_cast(x)); + return new_x; + } + + template + inline typename DataType::T2 get_value(const T *s, sycl::queue &q) + { + using Ty = typename DataType::T2; + Ty s_h; + if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only) + detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host) + .wait(); + else + s_h = *reinterpret_cast(s); + return s_h; + } + + } // namespace detail + + template + inline auto get_value(const T *s, sycl::queue &q) + { + return detail::get_value(s, q); + } + + namespace detail + { + template + inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, int lda, const void *b, + int ldb, const void *beta, void *c, int ldc) + { +#ifndef __INTEL_MKL__ + GGML_UNUSED(q); + GGML_UNUSED(a_trans); + GGML_UNUSED(b_trans); + GGML_UNUSED(m); + GGML_UNUSED(n); + GGML_UNUSED(k); + GGML_UNUSED(alpha); + GGML_UNUSED(a); + GGML_UNUSED(lda); + GGML_UNUSED(b); + GGML_UNUSED(ldb); + GGML_UNUSED(beta); + GGML_UNUSED(c); + GGML_UNUSED(ldc); + throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces " + "Project does not support this API."); +#else + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm( + q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + data_b, ldb, beta_value, data_c, ldc); +#endif + } + + template + class vectorized_binary + { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) + { + VecT v4; + for (size_t i = 0; i < v4.size(); ++i) + { + v4[i] = binary_op(a[i], b[i]); + } + return v4; + } + }; + + template + class vectorized_binary< + VecT, BinaryOperation, + std::void_t>> + { + public: + inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op) + { + return binary_op(a, b).template as(); + } + }; + + template + inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void **a, int lda, + const void **b, int ldb, const void *beta, void **c, + int ldc, int batch_size) + { + struct matrix_info_t + { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; + }; + + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + + matrix_info_t *matrix_info = + (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); + matrix_info->transpose_info[0] = a_trans; + matrix_info->transpose_info[1] = b_trans; + matrix_info->value_info[0] = alpha_value; + matrix_info->value_info[1] = beta_value; + matrix_info->size_info[0] = m; + matrix_info->size_info[1] = n; + matrix_info->size_info[2] = k; + matrix_info->ld_info[0] = lda; + matrix_info->ld_info[1] = ldb; + matrix_info->ld_info[2] = ldc; + matrix_info->groupsize_info = batch_size; + + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, + matrix_info->size_info, matrix_info->size_info + 1, + matrix_info->size_info + 2, matrix_info->value_info, + reinterpret_cast(a), matrix_info->ld_info, + reinterpret_cast(b), matrix_info->ld_info + 1, + matrix_info->value_info + 1, reinterpret_cast(c), + matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); + + q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(e); + cgh.host_task([=] { std::free(matrix_info); }); }); + } + + template + inline void + gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, + int k, const void *alpha, const void *a, int lda, + long long int stride_a, const void *b, int ldb, + long long int stride_b, const void *beta, void *c, + int ldc, long long int stride_c, int batch_size) + { + Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); + Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); + auto data_a = get_memory(a); + auto data_b = get_memory(b); + auto data_c = get_memory(c); + oneapi::mkl::blas::column_major::gemm_batch( + q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + stride_a, data_b, ldb, stride_b, beta_value, + data_c, ldc, stride_c, batch_size); + } + + } // namespace detail + + template + inline unsigned vectorized_binary(unsigned a, unsigned b, + const BinaryOperation binary_op) + { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.as(); + auto v3 = v1.as(); + auto v4 = + detail::vectorized_binary()(v2, v3, binary_op); + v0 = v4.template as>(); + return v0; + } + + static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction = automatic, + sycl::queue &q = dpct::get_default_queue()) + { + detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction); + } + + static inline unsigned int select_device(unsigned int id) + { + dev_mgr::instance().select_device(id); + return id; + } + + template + T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask, + unsigned int logical_sub_group_size = 32) + { + unsigned int id = g.get_local_linear_id(); + unsigned int start_index = + id / logical_sub_group_size * logical_sub_group_size; + unsigned int target_offset = (id % logical_sub_group_size) ^ mask; + return sycl::select_from_group(g, x, + target_offset < logical_sub_group_size + ? start_index + target_offset + : id); + } + + template + sycl::vec extract_and_sign_or_zero_extend4(T val) + { + return sycl::vec(val) + .template as, int8_t, uint8_t>, 4>>() + .template convert(); + } + + template + using dot_product_acc_t = + std::conditional_t && std::is_unsigned_v, + uint32_t, int32_t>; + + template + inline auto dp4a(T1 a, T2 b, T3 c) + { + dot_product_acc_t res = c; + auto va = extract_and_sign_or_zero_extend4(a); + auto vb = extract_and_sign_or_zero_extend4(b); + res += va[0] * vb[0]; + res += va[1] * vb[1]; + res += va[2] * vb[2]; + res += va[3] * vb[3]; + return res; + } + + struct sub_sat + { + template + auto operator()(const T x, const T y) const + { + return sycl::sub_sat(x, y); + } + }; + + template + inline T vectorized_min(T a, T b) + { + sycl::vec v0{a}, v1{b}; + auto v2 = v0.template as(); + auto v3 = v1.template as(); + auto v4 = sycl::min(v2, v3); + v0 = v4.template as>(); + return v0; + } + + inline float pow(const float a, const int b) { return sycl::pown(a, b); } + inline double pow(const double a, const int b) { return sycl::pown(a, b); } + inline float pow(const float a, const float b) { return sycl::pow(a, b); } + inline double pow(const double a, const double b) { return sycl::pow(a, b); } + template + inline typename std::enable_if_t, T> + pow(const T a, const U b) + { + return sycl::pow(a, static_cast(b)); + } + template + inline typename std::enable_if_t, double> + pow(const T a, const U b) + { + return sycl::pow(static_cast(a), static_cast(b)); + } + + inline double min(const double a, const float b) + { + return sycl::fmin(a, static_cast(b)); + } + inline double min(const float a, const double b) + { + return sycl::fmin(static_cast(a), b); + } + inline float min(const float a, const float b) { return sycl::fmin(a, b); } + inline double min(const double a, const double b) { return sycl::fmin(a, b); } + inline std::uint32_t min(const std::uint32_t a, const std::int32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint32_t min(const std::int32_t a, const std::uint32_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::int32_t min(const std::int32_t a, const std::int32_t b) + { + return sycl::min(a, b); + } + inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::int64_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::int64_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::int64_t min(const std::int64_t a, const std::int64_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b) + { + return sycl::min(a, b); + } + inline std::uint64_t min(const std::uint64_t a, const std::int32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::int32_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b) + { + return sycl::min(a, static_cast(b)); + } + inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b) + { + return sycl::min(static_cast(a), b); + } + // max function overloads. + // For floating-point types, `float` or `double` arguments are acceptable. + // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or + // `std::int64_t` type arguments are acceptable. + inline double max(const double a, const float b) + { + return sycl::fmax(a, static_cast(b)); + } + inline double max(const float a, const double b) + { + return sycl::fmax(static_cast(a), b); + } + inline float max(const float a, const float b) { return sycl::fmax(a, b); } + inline double max(const double a, const double b) { return sycl::fmax(a, b); } + inline std::uint32_t max(const std::uint32_t a, const std::int32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint32_t max(const std::int32_t a, const std::uint32_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::int32_t max(const std::int32_t a, const std::int32_t b) + { + return sycl::max(a, b); + } + inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::int64_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::int64_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::int64_t max(const std::int64_t a, const std::int64_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b) + { + return sycl::max(a, b); + } + inline std::uint64_t max(const std::uint64_t a, const std::int32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::int32_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b) + { + return sycl::max(a, static_cast(b)); + } + inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b) + { + return sycl::max(static_cast(a), b); + } + + inline void + has_capability_or_fail(const sycl::device &dev, + const std::initializer_list &props) + { + for (const auto &it : props) + { + if (dev.has(it)) + continue; + switch (it) + { + case sycl::aspect::fp64: + throw std::runtime_error("'double' is not supported in '" + + dev.get_info() + + "' device"); + break; + case sycl::aspect::fp16: + throw std::runtime_error("'half' is not supported in '" + + dev.get_info() + + "' device"); + break; + default: +#define __SYCL_ASPECT(ASPECT, ID) \ + case sycl::aspect::ASPECT: \ + return #ASPECT; +#define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID) +#define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE) + auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string + { + switch (AspectNum) + { +#include +#include + default: + return "unknown aspect"; + } + }; +#undef __SYCL_ASPECT_DEPRECATED_ALIAS +#undef __SYCL_ASPECT_DEPRECATED +#undef __SYCL_ASPECT + throw std::runtime_error( + "'" + getAspectNameStr(it) + "' is not supported in '" + + dev.get_info() + "' device"); + } + break; + } + } + + static inline unsigned int get_current_device_id() + { + return dev_mgr::instance().current_device_id(); + } + + static inline device_ext &get_current_device() + { + return dev_mgr::instance().current_device(); + } + + static inline sycl::queue &get_in_order_queue() + { + return dev_mgr::instance().current_device().in_order_queue(); + } + + static sycl::event + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size, + memcpy_direction direction, + const std::vector &dep_events = {}) + { + if (!size) + return sycl::event{}; + return q.memcpy(to_ptr, from_ptr, size, dep_events); + GGML_UNUSED(direction); + } + + // Get actual copy range and make sure it will not exceed range. + static inline size_t get_copy_range(sycl::range<3> size, size_t slice, + size_t pitch) + { + return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0); + } + + static inline size_t get_offset(sycl::id<3> id, size_t slice, + size_t pitch) + { + return slice * id.get(2) + pitch * id.get(1) + id.get(0); + } + + /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr + /// and \p from_range to another specified by \p to_ptr and \p to_range. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + sycl::range<3> to_range, sycl::range<3> from_range, + sycl::id<3> to_id, sycl::id<3> from_id, + sycl::range<3> size, memcpy_direction direction, + const std::vector &dep_events = {}) + { + // RAII for host pointer + class host_buffer + { + void *_buf; + size_t _size; + sycl::queue &_q; + const std::vector &_deps; // free operation depends + + public: + host_buffer(size_t size, sycl::queue &q, + const std::vector &deps) + : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {} + void *get_ptr() const { return _buf; } + size_t get_size() const { return _size; } + ~host_buffer() + { + if (_buf) + { + _q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(_deps); + cgh.host_task([buf = _buf] { std::free(buf); }); }); + } + } + }; + std::vector event_list; + + size_t to_slice = to_range.get(1) * to_range.get(0), + from_slice = from_range.get(1) * from_range.get(0); + unsigned char *to_surface = + (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0)); + const unsigned char *from_surface = + (const unsigned char *)from_ptr + + get_offset(from_id, from_slice, from_range.get(0)); + + if (to_slice == from_slice && to_slice == size.get(1) * size.get(0)) + { + return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2), + direction, dep_events)}; + } + direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction); + size_t size_slice = size.get(1) * size.get(0); + switch (direction) + { + case host_to_host: + for (size_t z = 0; z < size.get(2); ++z) + { + unsigned char *to_ptr = to_surface; + const unsigned char *from_ptr = from_surface; + if (to_range.get(0) == from_range.get(0) && + to_range.get(0) == size.get(0)) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice, + direction, dep_events)); + } + else + { + for (size_t y = 0; y < size.get(1); ++y) + { + event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0), + direction, dep_events)); + to_ptr += to_range.get(0); + from_ptr += from_range.get(0); + } + } + to_surface += to_slice; + from_surface += from_slice; + } + break; + case host_to_device: + { + host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q, + event_list); + std::vector host_events; + if (to_slice == size_slice) + { + // Copy host data to a temp host buffer with the shape of target. + host_events = + dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, + host_to_host, dep_events); + } + else + { + // Copy host data to a temp host buffer with the shape of target. + host_events = dpct_memcpy( + q, buf.get_ptr(), from_surface, to_range, from_range, + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host, + // If has padding data, not sure whether it is useless. So fill temp + // buffer with it. + std::vector{ + dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(), + device_to_host, dep_events)}); + } + // Copy from temp host buffer to device with only one submit. + event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(), + buf.get_size(), host_to_device, + host_events)); + break; + } + case device_to_host: + { + host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q, + event_list); + // Copy from host temp buffer to host target with reshaping. + event_list = dpct_memcpy( + q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0), + sycl::id<3>(0, 0, 0), size, host_to_host, + // Copy from device to temp host buffer with only one submit. + std::vector{dpct_memcpy(q, buf.get_ptr(), from_surface, + buf.get_size(), + device_to_host, dep_events)}); + break; + } + case device_to_device: + event_list.push_back(q.submit([&](sycl::handler &cgh) + { + cgh.depends_on(dep_events); + cgh.parallel_for( + size, + [=](sycl::id<3> id) { + to_surface[get_offset(id, to_slice, to_range.get(0))] = + from_surface[get_offset(id, from_slice, from_range.get(0))]; + }); })); + break; + default: + throw std::runtime_error("dpct_memcpy: invalid direction value"); + } + return event_list; + } + + /// memcpy 2D/3D matrix specified by pitched_data. + static inline std::vector + dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id, + pitched_data from, sycl::id<3> from_id, sycl::range<3> size, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(), + sycl::range<3>(to.get_pitch(), to.get_y(), 1), + sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id, + size, direction); + } + + /// memcpy 2D matrix with pitch. + static inline std::vector + dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, + size_t to_pitch, size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic) + { + return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1), + sycl::range<3>(from_pitch, y, 1), + sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), + sycl::range<3>(x, y, 1), direction); + } + + inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, library_data_t a_type, + int lda, const void *b, library_data_t b_type, int ldb, + const void *beta, void *c, library_data_t c_type, int ldc, + library_data_t scaling_type) + { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_impl(q, a_trans, b_trans, m, n, k, &alpha_half, + a, lda, b, ldb, &beta_half, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } // gemm() + + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. + inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a[], + library_data_t a_type, int lda, const void *b[], + library_data_t b_type, int ldb, const void *beta, + void *c[], library_data_t c_type, int ldc, + int batch_size, library_data_t scaling_type) + { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + b, ldb, beta, c, ldc, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + float alpha_float = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_float = + dpct::get_value(reinterpret_cast(beta), q); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, + a, lda, b, ldb, &beta_float, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, + batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } + + /// Computes a batch of matrix-matrix product with general matrices. + /// \param [in] q The queue where the routine should be executed. + /// \param [in] a_trans Specifies the operation applied to A. + /// \param [in] b_trans Specifies the operation applied to B. + /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C. + /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C. + /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B). + /// \param [in] alpha Scaling factor for the matrix-matrix product. + /// \param [in] a Input matrix A. + /// \param [in] a_type Data type of the matrix A. + /// \param [in] lda Leading dimension of A. + /// \param [in] stride_a Stride between the different A matrices. + /// \param [in] b Input matrix B. + /// \param [in] b_type Data type of the matrix B. + /// \param [in] ldb Leading dimension of B. + /// \param [in] stride_b Stride between the different B matrices. + /// \param [in] beta Scaling factor for matrix C. + /// \param [in, out] c Input/Output matrix C. + /// \param [in] c_type Data type of the matrix C. + /// \param [in] ldc Leading dimension of C. + /// \param [in] stride_c Stride between the different C matrices. + /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. + /// \param [in] scaling_type Data type of the scaling factors. + inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, + oneapi::mkl::transpose b_trans, int m, int n, int k, + const void *alpha, const void *a, library_data_t a_type, + int lda, long long int stride_a, const void *b, + library_data_t b_type, int ldb, long long int stride_b, + const void *beta, void *c, library_data_t c_type, + int ldc, long long int stride_c, int batch_size, + library_data_t scaling_type) + { + if (scaling_type == library_data_t::real_float && + c_type == library_data_t::complex_float) + { + scaling_type = library_data_t::complex_float; + } + else if (scaling_type == library_data_t::real_double && + c_type == library_data_t::complex_double) + { + scaling_type = library_data_t::complex_double; + } + + std::uint64_t key = + detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); + switch (key) + { + case detail::get_type_combination_id( + library_data_t::real_float, library_data_t::real_float, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_double, library_data_t::real_double, + library_data_t::real_double, library_data_t::real_double): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_float, library_data_t::complex_float, + library_data_t::complex_float, library_data_t::complex_float): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::complex_double, library_data_t::complex_double, + library_data_t::complex_double, library_data_t::complex_double): + { + detail::gemm_batch_impl, std::complex, + std::complex, std::complex>( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_half): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#ifdef __INTEL_MKL__ + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_bfloat16, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_bfloat16, library_data_t::real_bfloat16, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_int32, library_data_t::real_int32): + { + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, + a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_int8, library_data_t::real_int8, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_float, library_data_t::real_float): + { + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size); + break; + } +#endif + case detail::get_type_combination_id( + library_data_t::real_half, library_data_t::real_half, + library_data_t::real_half, library_data_t::real_float): + { + float alpha_value = + dpct::get_value(reinterpret_cast(alpha), q); + float beta_value = + dpct::get_value(reinterpret_cast(beta), q); + sycl::half alpha_half(alpha_value); + sycl::half beta_half(beta_value); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b, + &beta_half, c, ldc, stride_c, batch_size); + break; + } + default: + throw std::runtime_error("the combination of data type is unsupported"); + } + } + + static inline void + async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr, + size_t from_pitch, size_t x, size_t y, + memcpy_direction direction = automatic, + sycl::queue &q = get_default_queue()) + { + detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y, + direction); + } + + using err0 = detail::generic_error_type; + using err1 = detail::generic_error_type; + + static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) { + detail::dpct_free(ptr, q); + } + + /// dpct accessor used as device function parameter. + template class accessor; + template class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<3>; + accessor(pointer_t data, const sycl::range<3> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<3> &in_range) + : accessor(acc.get_pointer(), in_range) {} + accessor operator[](size_t index) const { + sycl::range<2> sub(_range.get(1), _range.get(2)); + return accessor(_data + index * sub.size(), sub); + } + + pointer_t get_ptr() const { return _data; } + + private: + pointer_t _data; + sycl::range<3> _range; + }; + template class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<2>; + accessor(pointer_t data, const sycl::range<2> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<2> &in_range) + : accessor(acc.get_pointer(), in_range) {} + + pointer_t operator[](size_t index) const { + return _data + _range.get(1) * index; + } + + pointer_t get_ptr() const { return _data; } + + private: + pointer_t _data; + sycl::range<2> _range; + }; + + namespace detail { + /// Device variable with address space of shared, global or constant. + template class device_memory { + public: + using accessor_t = + typename detail::memory_traits::template accessor_t; + using value_t = typename detail::memory_traits::value_t; + using dpct_accessor_t = dpct::accessor; + + device_memory() : device_memory(sycl::range(1)) {} + + /// Constructor of 1-D array with initializer list + device_memory(const sycl::range &in_range, + std::initializer_list &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range.size()); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); + } + + /// Constructor of 2-D array with initializer list + template + device_memory( + const typename std::enable_if>::type &in_range, + std::initializer_list> &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range[0]); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + auto tmp_data = _host_ptr; + for (auto sub_list : init_list) { + assert(sub_list.size() <= in_range[1]); + std::memcpy(tmp_data, sub_list.begin(), + sub_list.size() * sizeof(T)); + tmp_data += in_range[1]; + } + } + + /// Constructor with range + device_memory(const sycl::range &range_in) + : _size(range_in.size() * sizeof(T)), _range(range_in), + _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) { + static_assert( + (Memory == global) || (Memory == constant) || (Memory == shared), + "device memory region should be global, constant or shared"); + // Make sure that singleton class mem_mgr and dev_mgr will destruct + // later than this. + detail::mem_mgr::instance(); + dev_mgr::instance(); + } + + /// Constructor with range + template + device_memory(Args... Arguments) + : device_memory(sycl::range(Arguments...)) {} + + ~device_memory() { + if (_device_ptr && !_reference) + dpct::dpct_free(_device_ptr); + if (_host_ptr) + std::free(_host_ptr); + } + + /// Allocate memory with default queue, and init memory if has initial + /// value. + void init() { init(dpct::get_default_queue()); } + /// Allocate memory with specified queue, and init memory if has initial + /// value. + void init(sycl::queue &q) { + if (_device_ptr) + return; + if (!_size) + return; + allocate_device(q); + if (_host_ptr) + detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, + host_to_device); + } + + /// The variable is assigned to a device pointer. + void assign(value_t *src, size_t size) { + this->~device_memory(); + new (this) device_memory(src, size); + } + + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr() { return get_ptr(get_default_queue()); } + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr(sycl::queue &q) { + init(q); + return _device_ptr; + } + + /// Get the device memory object size in bytes. + size_t get_size() { return _size; } + + template + typename std::enable_if::type &operator[](size_t index) { + init(); + return _device_ptr[index]; + } + + /// Get dpct::accessor with dimension info for the device memory object + /// when usm is used and dimension is greater than 1. + template + typename std::enable_if::type + get_access(sycl::handler &cgh) { + return dpct_accessor_t((T *)_device_ptr, _range); + } + + private: + device_memory(value_t *memory_ptr, size_t size) + : _size(size), _range(size / sizeof(T)), _reference(true), + _device_ptr(memory_ptr) {} + + void allocate_device(sycl::queue &q) { + #ifndef DPCT_USM_LEVEL_NONE + if (Memory == shared) { + _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), + q.get_context()); + return; + } + #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY + if (Memory == constant) { + _device_ptr = (value_t *)sycl::malloc_device( + _size, q.get_device(), q.get_context(), + sycl::ext::oneapi::property::usm::device_read_only()); + return; + } + #endif + #endif + _device_ptr = (value_t *)detail::dpct_malloc(_size, q); + } + + size_t _size; + sycl::range _range; + bool _reference; + value_t *_host_ptr; + value_t *_device_ptr; + }; + template + class device_memory : public device_memory { + public: + using base = device_memory; + using value_t = typename base::value_t; + using accessor_t = + typename detail::memory_traits::template accessor_t<0>; + + /// Constructor with initial value. + device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} + + /// Default constructor + device_memory() : base(1) {} + }; + } // namespace detail + + template + using global_memory = detail::device_memory; + template + using constant_memory = detail::device_memory; + template + using shared_memory = detail::device_memory; + + +} // COPY from DPCT head files \ No newline at end of file