Skip to content

Commit

Permalink
simplify SyclQueue
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Dec 11, 2024
1 parent 511b44f commit b39b852
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 23 deletions.
43 changes: 26 additions & 17 deletions onedal/_device_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,19 @@
else:
from onedal import _dpc_backend

SyclQueueImplementation = getattr(_dpc_backend, "SyclQueue", None)
SyclQueueImplementation = getattr(_dpc_backend, "SyclQueue", object)


class SyclQueue:
class SyclQueue(SyclQueueImplementation):
def __init__(self, target=None):
if target and isinstance(target, SyclQueueImplementation):
self.implementation = target
elif target and SyclQueueImplementation is not None:
self.implementation = SyclQueueImplementation(target)
if target is None:
super().__init__()
else:
self.implementation = None
super().__init__(target)

@property
def sycl_device(self):
return getattr(self.implementation, "sycl_device", None)
return getattr(super(), "sycl_device", None)


class SyclQueueManager:
Expand All @@ -67,7 +65,7 @@ def get_global_queue() -> Optional[SyclQueue]:

if target == "auto":
# queue will be created from the provided data to each function call
return SyclQueue(None)
return None

if isinstance(target, (str, int)):
q = SyclQueue(target)
Expand Down Expand Up @@ -111,14 +109,19 @@ def from_data(*data) -> Optional[SyclQueue]:
# no interface found - try next data object
continue

# extract the queue, verify it aligns with the global queue
# extract the queue
global_queue = SyclQueueManager.get_global_queue()
data_queue = SyclQueue(usm_iface["syclobj"])
data_queue = usm_iface["syclobj"]
if not data_queue:
# no queue, i.e. host data, no more work to do
continue

# update the global queue if not set
if global_queue is None:
SyclQueueManager.update_global_queue(data_queue)
global_queue = data_queue

# if the data item is on device, assert it's compatible with device in global queue
# if either queue points to a device, assert it's always the same device
data_dev = data_queue.sycl_device
global_dev = global_queue.sycl_device
if (data_dev and global_dev) is not None and data_dev != global_dev:
Expand Down Expand Up @@ -260,14 +263,20 @@ def wrapper_impl(obj, *args, **kwargs):
return _run_on_device(func, obj, *args, **kwargs)

hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
if hostkwargs.get("queue") is None:
# no queue provided, get it from the data
data_queue = SyclQueueManager.from_data(*hostargs)
if queue_param:
# if queue_param requested, add it to the hostkwargs
hostkwargs["queue"] = data_queue
else:
# use the provided queue
data_queue = hostkwargs["queue"]

data = (*args, *kwargs.values())
data_queue = SyclQueueManager.from_data(*data)
if queue_param and hostkwargs.get("queue") is None:
hostkwargs["queue"] = data_queue
result = _run_on_device(func, obj, *hostargs, **hostkwargs)

usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
if usm_iface is not None:
if data_queue is not None:
result = _copy_to_usm(data_queue, result)
if dpnp_available and isinstance(data[0], dpnp.ndarray):
result = _convert_to_dpnp(result)
Expand Down
2 changes: 1 addition & 1 deletion onedal/common/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.method(*args, **kwargs)

# use globally configured queue (from `target_offload` configuration or provided data)
queue = getattr(SyclQueueManager.get_global_queue(), "implementation", None)
queue = SyclQueueManager.get_global_queue()

if queue is not None and not (self.backend.is_dpc or self.backend.is_spmd):
raise RuntimeError("Operations using queues require the DPC/SPMD backend")
Expand Down
10 changes: 5 additions & 5 deletions onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _fit(self, X, y):

_fit_y = None
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = getattr(SyclQueueManager.get_global_queue(), "implementation")
queue = SyclQueueManager.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)

if _is_classifier(self) or (_is_regressor(self) and gpu_device):
Expand Down Expand Up @@ -446,7 +446,7 @@ def _get_daal_params(self, data):

def _onedal_fit(self, X, y):
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = getattr(SyclQueueManager.get_global_queue(), "implementation")
queue = SyclQueueManager.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
Expand Down Expand Up @@ -604,7 +604,7 @@ def _get_daal_params(self, data):

def _onedal_fit(self, X, y):
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = getattr(SyclQueueManager.get_global_queue(), "implementation")
queue = SyclQueueManager.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
Expand Down Expand Up @@ -632,7 +632,7 @@ def _onedal_predict(self, model, X, params):
return bf_knn_classification_prediction(**params).compute(X, model)

# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = getattr(SyclQueueManager.get_global_queue(), "implementation")
queue = SyclQueueManager.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
X = _convert_to_supported(X)

Expand Down Expand Up @@ -754,7 +754,7 @@ def _get_daal_params(self, data):

def _onedal_fit(self, X, y):
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = getattr(SyclQueueManager.get_global_queue(), "implementation")
queue = SyclQueueManager.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
Expand Down

0 comments on commit b39b852

Please sign in to comment.