From b39b852f9a448f07e2fe9ff5353486dc29bbd930 Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Wed, 11 Dec 2024 01:00:35 -0800 Subject: [PATCH] simplify SyclQueue --- onedal/_device_offload.py | 43 +++++++++++++++++++++-------------- onedal/common/_backend.py | 2 +- onedal/neighbors/neighbors.py | 10 ++++---- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/onedal/_device_offload.py b/onedal/_device_offload.py index 3d83f4bcd5..1ff5caa693 100644 --- a/onedal/_device_offload.py +++ b/onedal/_device_offload.py @@ -32,21 +32,19 @@ else: from onedal import _dpc_backend - SyclQueueImplementation = getattr(_dpc_backend, "SyclQueue", None) + SyclQueueImplementation = getattr(_dpc_backend, "SyclQueue", object) -class SyclQueue: +class SyclQueue(SyclQueueImplementation): def __init__(self, target=None): - if target and isinstance(target, SyclQueueImplementation): - self.implementation = target - elif target and SyclQueueImplementation is not None: - self.implementation = SyclQueueImplementation(target) + if target is None: + super().__init__() else: - self.implementation = None + super().__init__(target) @property def sycl_device(self): - return getattr(self.implementation, "sycl_device", None) + return getattr(super(), "sycl_device", None) class SyclQueueManager: @@ -67,7 +65,7 @@ def get_global_queue() -> Optional[SyclQueue]: if target == "auto": # queue will be created from the provided data to each function call - return SyclQueue(None) + return None if isinstance(target, (str, int)): q = SyclQueue(target) @@ -111,14 +109,19 @@ def from_data(*data) -> Optional[SyclQueue]: # no interface found - try next data object continue - # extract the queue, verify it aligns with the global queue + # extract the queue global_queue = SyclQueueManager.get_global_queue() - data_queue = SyclQueue(usm_iface["syclobj"]) + data_queue = usm_iface["syclobj"] + if not data_queue: + # no queue, i.e. host data, no more work to do + continue + + # update the global queue if not set if global_queue is None: SyclQueueManager.update_global_queue(data_queue) global_queue = data_queue - # if the data item is on device, assert it's compatible with device in global queue + # if either queue points to a device, assert it's always the same device data_dev = data_queue.sycl_device global_dev = global_queue.sycl_device if (data_dev and global_dev) is not None and data_dev != global_dev: @@ -260,14 +263,20 @@ def wrapper_impl(obj, *args, **kwargs): return _run_on_device(func, obj, *args, **kwargs) hostargs, hostkwargs = _get_host_inputs(*args, **kwargs) + if hostkwargs.get("queue") is None: + # no queue provided, get it from the data + data_queue = SyclQueueManager.from_data(*hostargs) + if queue_param: + # if queue_param requested, add it to the hostkwargs + hostkwargs["queue"] = data_queue + else: + # use the provided queue + data_queue = hostkwargs["queue"] + data = (*args, *kwargs.values()) - data_queue = SyclQueueManager.from_data(*data) - if queue_param and hostkwargs.get("queue") is None: - hostkwargs["queue"] = data_queue result = _run_on_device(func, obj, *hostargs, **hostkwargs) - usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None) - if usm_iface is not None: + if data_queue is not None: result = _copy_to_usm(data_queue, result) if dpnp_available and isinstance(data[0], dpnp.ndarray): result = _convert_to_dpnp(result) diff --git a/onedal/common/_backend.py b/onedal/common/_backend.py index 6fb5e9c84d..a81d4a31cf 100644 --- a/onedal/common/_backend.py +++ b/onedal/common/_backend.py @@ -97,7 +97,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.method(*args, **kwargs) # use globally configured queue (from `target_offload` configuration or provided data) - queue = getattr(SyclQueueManager.get_global_queue(), "implementation", None) + queue = SyclQueueManager.get_global_queue() if queue is not None and not (self.backend.is_dpc or self.backend.is_spmd): raise RuntimeError("Operations using queues require the DPC/SPMD backend") diff --git a/onedal/neighbors/neighbors.py b/onedal/neighbors/neighbors.py index 3d979b4001..1f7f6f8986 100755 --- a/onedal/neighbors/neighbors.py +++ b/onedal/neighbors/neighbors.py @@ -275,7 +275,7 @@ def _fit(self, X, y): _fit_y = None # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function - queue = getattr(SyclQueueManager.get_global_queue(), "implementation") + queue = SyclQueueManager.get_global_queue() gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if _is_classifier(self) or (_is_regressor(self) and gpu_device): @@ -446,7 +446,7 @@ def _get_daal_params(self, data): def _onedal_fit(self, X, y): # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function - queue = getattr(SyclQueueManager.get_global_queue(), "implementation") + queue = SyclQueueManager.get_global_queue() gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if self.effective_metric_ == "euclidean" and not gpu_device: params = self._get_daal_params(X) @@ -604,7 +604,7 @@ def _get_daal_params(self, data): def _onedal_fit(self, X, y): # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function - queue = getattr(SyclQueueManager.get_global_queue(), "implementation") + queue = SyclQueueManager.get_global_queue() gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if self.effective_metric_ == "euclidean" and not gpu_device: params = self._get_daal_params(X) @@ -632,7 +632,7 @@ def _onedal_predict(self, model, X, params): return bf_knn_classification_prediction(**params).compute(X, model) # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function - queue = getattr(SyclQueueManager.get_global_queue(), "implementation") + queue = SyclQueueManager.get_global_queue() gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) X = _convert_to_supported(X) @@ -754,7 +754,7 @@ def _get_daal_params(self, data): def _onedal_fit(self, X, y): # global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function - queue = getattr(SyclQueueManager.get_global_queue(), "implementation") + queue = SyclQueueManager.get_global_queue() gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) if self.effective_metric_ == "euclidean" and not gpu_device: params = self._get_daal_params(X)