diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index e7d260bd4ebe3..106a7f217bc81 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -3307,6 +3307,7 @@ class sycl_gpu_mgr { void detect_sycl_gpu_list_with_max_cu() try { int device_count = dpct::dev_mgr::instance().device_count(); + sycl::platform platform; for (int id = 0; id < device_count; id++) { sycl::device device = dpct::dev_mgr::instance().get_device(id); @@ -3314,8 +3315,10 @@ class sycl_gpu_mgr { continue; dpct::device_info prop; dpct::get_device_info(prop, device); - if (max_compute_units < prop.get_max_compute_units()) + if (max_compute_units < prop.get_max_compute_units()) { max_compute_units = prop.get_max_compute_units(); + platform = device.get_platform(); + } } for (int id = 0; id < device_count; id++) { @@ -3325,7 +3328,7 @@ class sycl_gpu_mgr { dpct::device_info prop; dpct::get_device_info(prop, device); if (max_compute_units == prop.get_max_compute_units() && - is_ext_oneapi_device(device)) { + platform == device.get_platform()) { gpus.push_back(id); devices.push_back(device); work_group_size = prop.get_max_work_group_size(); @@ -3357,15 +3360,6 @@ class sycl_gpu_mgr { } GGML_ASSERT(false); } - - bool is_ext_oneapi_device(const sycl::device &dev) { - sycl::backend dev_backend = dev.get_backend(); - if (dev_backend == sycl::backend::ext_oneapi_level_zero || - dev_backend == sycl::backend::ext_oneapi_cuda || - dev_backend == sycl::backend::ext_oneapi_hip) - return true; - return false; - } }; static sycl_gpu_mgr *g_sycl_gpu_mgr = NULL; @@ -17400,6 +17394,7 @@ GGML_API GGML_CALL void ggml_backend_sycl_set_single_device_mode(int main_gpu_id g_sycl_gpu_mgr = new sycl_gpu_mgr(main_gpu_id); g_ggml_sycl_backend_gpu_mode = SYCL_SINGLE_GPU_MODE; ggml_init_by_gpus(g_sycl_gpu_mgr->get_gpu_count()); + ggml_sycl_set_main_device(0); g_ggml_backend_sycl_buffer_type_initialized = false; } @@ -17419,6 +17414,7 @@ GGML_API GGML_CALL void ggml_backend_sycl_set_mul_device_mode() { g_sycl_gpu_mgr = new sycl_gpu_mgr(); g_ggml_sycl_backend_gpu_mode = SYCL_MUL_GPU_MODE; ggml_init_by_gpus(g_sycl_gpu_mgr->get_gpu_count()); + ggml_sycl_set_main_device(0); g_ggml_backend_sycl_buffer_type_initialized = false; }