Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
d916d13
update: update neighbors.py to use pybind11 instead of daal4py
yuejiaointel Jan 28, 2025
2478b7f
test: revert changes and only change ondal fit for classifier
yuejiaointel Jan 28, 2025
736183d
test: revert everything
yuejiaointel Jan 28, 2025
e8df51b
fix: add failing validate data test to DESIGN_RULE_VIOLATIONS to avoi…
yuejiaointel Jan 30, 2025
6ce1b87
fix: format changed files to pass ci
yuejiaointel Jan 30, 2025
8476aac
fix: format fix
yuejiaointel Jan 30, 2025
03eea09
fix: move validate data to onedal fit call
yuejiaointel Jan 31, 2025
e463bf6
fix: remove uncessary changes
yuejiaointel Jan 31, 2025
31c2d3c
fix: remove uncessary changes
yuejiaointel Jan 31, 2025
cb7ea08
Merge branch 'main' into replace_daal4py_with_pybind11_obj_knn
yuejiaointel Mar 12, 2025
fa95c87
Merge branch 'main' into replace_daal4py_with_pybind11_obj_knn
yuejiaointel Jun 25, 2025
cabfcd4
fix: add changes
yuejiaointel Jun 25, 2025
fd9fbc2
fix: format
yuejiaointel Jun 25, 2025
4766304
fix: add import
yuejiaointel Jun 26, 2025
58c8811
fix: fix check feature names
yuejiaointel Jun 28, 2025
3a7cf19
Merge branch 'main' into replace_daal4py_with_pybind11_obj_knn
yuejiaointel Jun 30, 2025
067d508
wip: getting errors about feature anmes
yuejiaointel Jul 2, 2025
eece097
buggy: when remove functions from _onedal_fit of NearestNeighbors
yuejiaointel Jul 2, 2025
11480e0
fix: got rid of all daal4py functions for knn regression, classificai…
yuejiaointel Jul 2, 2025
4625190
fix: format
yuejiaointel Jul 2, 2025
e8ae01f
fix: remove some violoations from desgin rule
yuejiaointel Jul 2, 2025
1a9de0e
Merge branch 'main' into replace_daal4py_with_pybind11_obj_knn
yuejiaointel Jul 9, 2025
b4f1fad
fix: add some validate data violoations to rules
yuejiaointel Jul 9, 2025
2f70196
fix: fix the kneightbors calls
yuejiaointel Jul 9, 2025
2143bf3
fix: remove addtional tests
yuejiaointel Jul 10, 2025
da76c1d
fix: format
yuejiaointel Jul 10, 2025
152f74d
fix: try use array api xp
yuejiaointel Jul 10, 2025
e082648
fix: add if check for nearest neighbors
yuejiaointel Jul 10, 2025
239b94e
test: test np for knn regression predict
yuejiaointel Jul 10, 2025
fdd9d6a
fix: format
yuejiaointel Jul 10, 2025
13626f8
fix: format
yuejiaointel Jul 10, 2025
b6cf1fd
fix: revert previous
yuejiaointel Jul 11, 2025
f4a25a9
fix: fix validate data
yuejiaointel Jul 11, 2025
d7967b8
fix: fix score
yuejiaointel Jul 14, 2025
73f7352
fix: fix predict
yuejiaointel Jul 14, 2025
b17ef98
fix: fix score
yuejiaointel Jul 14, 2025
57532ba
fix: should only have 1 error now
yuejiaointel Jul 15, 2025
c308d52
fix: fix error not raised error
yuejiaointel Jul 15, 2025
805197e
fix: fix predict
yuejiaointel Jul 15, 2025
5fb9c8e
fix: test fix
yuejiaointel Jul 15, 2025
053d903
fix: add flag to fit avoid type change
yuejiaointel Jul 17, 2025
96549fa
Merge branch 'main' into replace_daal4py_with_pybind11_obj_knn
yuejiaointel Jul 17, 2025
ef59b38
fix: remove ensure finit
yuejiaointel Jul 17, 2025
202e945
fix: format
yuejiaointel Jul 17, 2025
171cff6
test: tst ensure all finite = false
yuejiaointel Jul 26, 2025
2a6d5df
test: fix
yuejiaointel Aug 3, 2025
fbc3cab
Merge remote-tracking branch 'upstream/main' into replace_daal4py_wit…
Alexsandruss Aug 20, 2025
ec6a49f
fix: try use explict convert to numpy
yuejiaointel Sep 19, 2025
2ee8122
fix: format
yuejiaointel Sep 19, 2025
2fc3988
fix: fix as numpy
yuejiaointel Sep 19, 2025
ae26c77
fix: revert changes in dataframe support
yuejiaointel Sep 19, 2025
2b1fddb
Merge branch 'main' into replace_daal4py_with_pybind11_obj_knn
yuejiaointel Sep 19, 2025
827270a
fix: try fix as numbpy
yuejiaointel Sep 22, 2025
2f7d52e
fix: try fix as numbpy
yuejiaointel Sep 22, 2025
9d2ef72
fix: format
yuejiaointel Sep 22, 2025
4e1f6c0
fix: try fix as numbpy
yuejiaointel Sep 22, 2025
248092d
fix: don't change as numbpy
yuejiaointel Sep 22, 2025
b59aeba
fix: try without as numpy
yuejiaointel Sep 23, 2025
c99b761
fix: try don't use xp
yuejiaointel Sep 23, 2025
080707a
fix: try comment out xp again
yuejiaointel Sep 23, 2025
7213e70
fix: try comment out xp again
yuejiaointel Sep 23, 2025
fccd174
fix: comment out array api import
yuejiaointel Sep 23, 2025
6ddb7b9
fix: as numpy in lof
yuejiaointel Sep 24, 2025
9294c77
fix: fresh start and try step by step again
yuejiaointel Sep 24, 2025
7e9eba6
fix: just get rid of the daal4py functions
yuejiaointel Sep 24, 2025
480dd6f
fix: remove ck feature names
yuejiaointel Sep 24, 2025
055136e
fix: format
yuejiaointel Sep 24, 2025
4ca908b
fix: add valudate tests to violation array
yuejiaointel Sep 24, 2025
bb1d9da
fix: dpn't delete check featuer names
yuejiaointel Sep 25, 2025
c9e97db
fix: remove daal functions from onedal
yuejiaointel Oct 2, 2025
4f76b78
fix: format
yuejiaointel Oct 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 19 additions & 130 deletions onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@

import numpy as np

from daal4py import (
bf_knn_classification_model,
bf_knn_classification_prediction,
bf_knn_classification_training,
kdtree_knn_classification_model,
kdtree_knn_classification_prediction,
kdtree_knn_classification_training,
)
from onedal._device_offload import supports_queue
from onedal.common._backend import bind_default_backend
from onedal.utils import _sycl_queue_manager as QM
Expand Down Expand Up @@ -166,25 +158,6 @@ def _get_onedal_params(self, X, y=None, n_neighbors=None):
"result_option": "indices|distances" if y is None else "responses",
}

def _get_daal_params(self, data, n_neighbors=None):
class_count = 0 if self.classes_ is None else len(self.classes_)
weights = getattr(self, "weights", "uniform")
params = {
"fptype": "float" if data.dtype == np.float32 else "double",
"method": "defaultDense",
"k": self.n_neighbors if n_neighbors is None else n_neighbors,
"voteWeights": "voteUniform" if weights == "uniform" else "voteDistance",
"resultsToCompute": "computeIndicesOfNeighbors|computeDistances",
"resultsToEvaluate": (
"none"
if getattr(self, "_y", None) is None or _is_regressor(self)
else "computeClassLabels"
),
}
if class_count != 0:
params["nClasses"] = class_count
return params


class NeighborsBase(NeighborsCommonBase, metaclass=ABCMeta):
def __init__(
Expand Down Expand Up @@ -348,19 +321,10 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
self._fit_method, self.n_samples_fit_, n_features
)

if type(self._onedal_model) in (
kdtree_knn_classification_model,
bf_knn_classification_model,
):
params = super()._get_daal_params(X, n_neighbors=n_neighbors)
prediction_results = self._onedal_predict(self._onedal_model, X, params)
distances = prediction_results.distances
indices = prediction_results.indices
else:
params = super()._get_onedal_params(X, n_neighbors=n_neighbors)
prediction_results = self._onedal_predict(self._onedal_model, X, params)
distances = from_table(prediction_results.distances)
indices = from_table(prediction_results.indices)
params = super()._get_onedal_params(X, n_neighbors=n_neighbors)
prediction_results = self._onedal_predict(self._onedal_model, X, params)
distances = from_table(prediction_results.distances)
indices = from_table(prediction_results.indices)

if method == "kd_tree":
for i in range(distances.shape[0]):
Expand Down Expand Up @@ -443,43 +407,21 @@ def train(self, *args, **kwargs): ...
@bind_default_backend("neighbors.classification")
def infer(self, *args, **kwargs): ...

def _get_daal_params(self, data):
params = super()._get_daal_params(data)
params["resultsToEvaluate"] = "computeClassLabels"
params["resultsToCompute"] = ""
return params

def _onedal_fit(self, X, y):
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = QM.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
if self._fit_method == "brute":
train_alg = bf_knn_classification_training

else:
train_alg = kdtree_knn_classification_training

return train_alg(**params).compute(X, y).model
else:
params = self._get_onedal_params(X, y)
X_table, y_table = to_table(X, y, queue=queue)
return self.train(params, X_table, y_table).model
params = self._get_onedal_params(X, y)
X_table, y_table = to_table(X, y, queue=queue)
return self.train(params, X_table, y_table).model

def _onedal_predict(self, model, X, params):
if type(self._onedal_model) is kdtree_knn_classification_model:
return kdtree_knn_classification_prediction(**params).compute(X, model)
elif type(self._onedal_model) is bf_knn_classification_model:
return bf_knn_classification_prediction(**params).compute(X, model)
else:
X = to_table(X, queue=QM.get_global_queue())
if "responses" not in params["result_option"]:
params["result_option"] += "|responses"
params["fptype"] = X.dtype
result = self.infer(params, model, X)
X = to_table(X, queue=QM.get_global_queue())
if "responses" not in params["result_option"]:
params["result_option"] += "|responses"
params["fptype"] = X.dtype
result = self.infer(params, model, X)

return result
return result

@supports_queue
def fit(self, X, y, queue=None):
Expand Down Expand Up @@ -511,17 +453,9 @@ def predict(self, X, queue=None):

self._validate_n_classes()

if (
type(onedal_model) is kdtree_knn_classification_model
or type(onedal_model) is bf_knn_classification_model
):
params = self._get_daal_params(X)
prediction_result = self._onedal_predict(onedal_model, X, params)
responses = prediction_result.prediction
else:
params = self._get_onedal_params(X)
prediction_result = self._onedal_predict(onedal_model, X, params)
responses = from_table(prediction_result.responses)
params = self._get_onedal_params(X)
prediction_result = self._onedal_predict(onedal_model, X, params)
responses = from_table(prediction_result.responses)

result = self.classes_.take(np.asarray(responses.ravel(), dtype=np.intp))
return result
Expand Down Expand Up @@ -603,25 +537,10 @@ def train(self, *args, **kwargs): ...
@bind_default_backend("neighbors.regression")
def infer(self, *args, **kwargs): ...

def _get_daal_params(self, data):
params = super()._get_daal_params(data)
params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances"
params["resultsToEvaluate"] = "none"
return params

def _onedal_fit(self, X, y):
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = QM.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
if self._fit_method == "brute":
train_alg = bf_knn_classification_training
else:
train_alg = kdtree_knn_classification_training

return train_alg(**params).compute(X, y).model

X_table, y_table = to_table(X, y, queue=queue)
params = self._get_onedal_params(X_table, y)

Expand All @@ -633,11 +552,6 @@ def _onedal_fit(self, X, y):
def _onedal_predict(self, model, X, params):
assert self._onedal_model is not None, "Model is not trained"

if type(model) is kdtree_knn_classification_model:
return kdtree_knn_classification_prediction(**params).compute(X, model)
elif type(model) is bf_knn_classification_model:
return bf_knn_classification_prediction(**params).compute(X, model)

# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = QM.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
Expand Down Expand Up @@ -753,39 +667,14 @@ def train(self, *args, **kwargs): ...
@bind_default_backend("neighbors.search")
def infer(self, *arg, **kwargs): ...

def _get_daal_params(self, data):
params = super()._get_daal_params(data)
params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances"
params["resultsToEvaluate"] = (
"none" if getattr(self, "_y", None) is None else "computeClassLabels"
)
return params

def _onedal_fit(self, X, y):
# global queue is set as per user configuration (`target_offload`) or from data prior to calling this internal function
queue = QM.get_global_queue()
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
if self._fit_method == "brute":
train_alg = bf_knn_classification_training

else:
train_alg = kdtree_knn_classification_training

return train_alg(**params).compute(X, y).model

else:
params = self._get_onedal_params(X, y)
X, y = to_table(X, y, queue=queue)
return self.train(params, X).model
params = self._get_onedal_params(X, y)
X, y = to_table(X, y, queue=queue)
return self.train(params, X).model

def _onedal_predict(self, model, X, params):
if type(self._onedal_model) is kdtree_knn_classification_model:
return kdtree_knn_classification_prediction(**params).compute(X, model)
elif type(self._onedal_model) is bf_knn_classification_model:
return bf_knn_classification_prediction(**params).compute(X, model)

X = to_table(X, queue=QM.get_global_queue())

params["fptype"] = X.dtype
Expand Down
35 changes: 35 additions & 0 deletions sklearnex/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,41 @@
"LogisticRegression(solver='newton-cg')-predict-n_jobs_check": "uses daal4py for cpu in sklearnex",
"LogisticRegression(solver='newton-cg')-predict_log_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
"LogisticRegression(solver='newton-cg')-predict_proba-n_jobs_check": "uses daal4py for cpu in sklearnex",
# KNeighborsClassifier validate_data issues - will be fixed later
"KNeighborsClassifier-fit-call_validate_data": "validate_data implementation needs fixing",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuejiaointel Could rebase this off from the unmerged PR that has the fix, remove these deselections, and verify that they pass?

Copy link
Contributor Author

@yuejiaointel yuejiaointel Sep 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The followup PR that I am doing next will fix this by using the array api structure to add validate_data properly, but this PR needs to be merged first

"KNeighborsClassifier-predict_proba-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier-score-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier-predict-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor-fit-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor-score-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor-predict-call_validate_data": "validate_data implementation needs fixing",
"NearestNeighbors-fit-call_validate_data": "validate_data implementation needs fixing",
"NearestNeighbors-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"NearestNeighbors-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"LocalOutlierFactor-fit-call_validate_data": "validate_data implementation needs fixing",
"LocalOutlierFactor-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"LocalOutlierFactor-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"LocalOutlierFactor(novelty=True)-fit-call_validate_data": "validate_data implementation needs fixing",
"LocalOutlierFactor(novelty=True)-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"LocalOutlierFactor(novelty=True)-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier(algorithm='brute')-fit-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier(algorithm='brute')-predict_proba-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier(algorithm='brute')-score-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier(algorithm='brute')-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier(algorithm='brute')-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsClassifier(algorithm='brute')-predict-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor(algorithm='brute')-fit-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor(algorithm='brute')-score-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor(algorithm='brute')-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor(algorithm='brute')-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
"KNeighborsRegressor(algorithm='brute')-predict-call_validate_data": "validate_data implementation needs fixing",
"NearestNeighbors(algorithm='brute')-fit-call_validate_data": "validate_data implementation needs fixing",
"NearestNeighbors(algorithm='brute')-kneighbors-call_validate_data": "validate_data implementation needs fixing",
"NearestNeighbors(algorithm='brute')-kneighbors_graph-call_validate_data": "validate_data implementation needs fixing",
}


Expand Down
Loading