Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Nov 19, 2024
1 parent 34d5f4b commit 0aa1669
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 51 deletions.
2 changes: 2 additions & 0 deletions onedal/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def predict(self, X, queue=None):
return self._predict(X, queue).ravel()


@bind_default_backend("decision_forest.classification", ["_get_policy", "train", "infer"])
class ExtraTreesClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):
def __init__(
self,
Expand Down Expand Up @@ -637,6 +638,7 @@ def predict_proba(self, X, queue=None):
)


@bind_default_backend("decision_forest.regression", ["_get_policy", "train", "infer"])
class ExtraTreesRegressor(RegressorMixin, BaseForest, metaclass=ABCMeta):
def __init__(
self,
Expand Down
69 changes: 35 additions & 34 deletions onedal/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np

from daal4py.sklearn._utils import daal_check_version, get_dtype, make2d
from onedal.common._backend import bind_default_backend

from ..common._estimator_checks import _check_is_fitted
from ..common._mixin import ClassifierMixin
Expand All @@ -34,6 +35,7 @@
)


@bind_default_backend("logistic_regression.regression", ["_get_policy"])
class BaseLogisticRegression(metaclass=ABCMeta):
@abstractmethod
def __init__(self, tol, C, fit_intercept, solver, max_iter, algorithm):
Expand All @@ -44,6 +46,19 @@ def __init__(self, tol, C, fit_intercept, solver, max_iter, algorithm):
self.max_iter = max_iter
self.algorithm = algorithm

@abstractmethod
def _get_policy(self, queue, *data): ...

@abstractmethod
def train(self, policy, params, X, y): ...

@abstractmethod
def infer(self, policy, params, X): ...

# direct access to the backend model constructor
@abstractmethod
def model(self): ...

def _get_onedal_params(self, is_csr, dtype=np.float32):
intercept = "intercept|" if self.fit_intercept else ""
return {
Expand All @@ -61,7 +76,7 @@ def _get_onedal_params(self, is_csr, dtype=np.float32):
),
}

def _fit(self, X, y, module, queue):
def _fit(self, X, y, queue):
sparsity_enabled = daal_check_version((2024, "P", 700))
X, y = _check_X_y(
X,
Expand All @@ -86,7 +101,7 @@ def _fit(self, X, y, module, queue):
params = self._get_onedal_params(is_csr, get_dtype(X))
X_table, y_table = to_table(X, y)

result = module.train(policy, params, X_table, y_table)
result = self.train(policy, params, X_table, y_table)

self._onedal_model = result.model
self.n_iter_ = np.array([result.iterations_count])
Expand All @@ -100,8 +115,8 @@ def _fit(self, X, y, module, queue):

return self

def _create_model(self, module, policy):
m = module.model()
def _create_model(self, policy):
m = self.model()

coefficients = self.coef_
dtype = get_dtype(coefficients)
Expand Down Expand Up @@ -151,7 +166,7 @@ def _create_model(self, module, policy):

return m

def _infer(self, X, module, queue):
def _infer(self, X, queue):
_check_is_fitted(self)
sparsity_enabled = daal_check_version((2024, "P", 700))

Expand All @@ -172,33 +187,36 @@ def _infer(self, X, module, queue):
if hasattr(self, "_onedal_model"):
model = self._onedal_model
else:
model = self._create_model(module, policy)
model = self._create_model(policy)

X = _convert_to_supported(policy, X)
params = self._get_onedal_params(is_csr, get_dtype(X))

X_table = to_table(X)
result = module.infer(policy, params, model, X_table)
result = self.infer(policy, params, model, X_table)
return result

def _predict(self, X, module, queue):
result = self._infer(X, module, queue)
def _predict(self, X, queue):
result = self._infer(X, queue)
y = from_table(result.responses)
y = np.take(self.classes_, y.ravel(), axis=0)
return y

def _predict_proba(self, X, module, queue):
result = self._infer(X, module, queue)
def _predict_proba(self, X, queue):
result = self._infer(X, queue)

y = from_table(result.probabilities)
y = y.reshape(-1, 1)
return np.hstack([1 - y, y])

def _predict_log_proba(self, X, module, queue):
y_proba = self._predict_proba(X, module, queue)
def _predict_log_proba(self, X, queue):
y_proba = self._predict_proba(X, queue)
return np.log(y_proba)


@bind_default_backend(
"logistic_regression.classification", ["_get_policy", "train", "infer", "model"]
)
class LogisticRegression(ClassifierMixin, BaseLogisticRegression):
"""
Logistic Regression oneDAL implementation.
Expand All @@ -225,33 +243,16 @@ def __init__(
)

def fit(self, X, y, queue=None):
return self._fit(
X,
y,
self._get_backend_component("logistic_regression", "classification", None),
queue,
)
return self._fit(X, y, queue)

def predict(self, X, queue=None):
y = self._predict(
X,
self._get_backend_component("logistic_regression", "classification", None),
queue,
)
y = self._predict(X, queue)
return y

def predict_proba(self, X, queue=None):
y = self._predict_proba(
X,
self._get_backend_component("logistic_regression", "classification", None),
queue,
)
y = self._predict_proba(X, queue)
return y

def predict_log_proba(self, X, queue=None):
y = self._predict_log_proba(
X,
self._get_backend_component("logistic_regression", "classification", None),
queue,
)
y = self._predict_log_proba(X, queue)
return y
39 changes: 22 additions & 17 deletions onedal/neighbors/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def train(self, *args): ...
@abstractmethod
def infer(self, *args): ...

@abstractmethod
def _onedal_fit(self, X, y, queue): ...

def _validate_data(
self, X, y=None, reset=True, validate_separately=None, **check_params
):
Expand Down Expand Up @@ -341,9 +344,9 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None
self._fit_method, self.n_samples_fit_, n_features
)

if (
type(self._onedal_model) is kdtree_knn_classification_model
or type(self._onedal_model) is bf_knn_classification_model
if type(self._onedal_model) in (
kdtree_knn_classification_model,
bf_knn_classification_model,
):
params = super()._get_daal_params(X, n_neighbors=n_neighbors)
prediction_results = self._onedal_predict(
Expand Down Expand Up @@ -445,9 +448,12 @@ def _onedal_fit(self, X, y, queue):
if self.effective_metric_ == "euclidean" and not gpu_device:
params = self._get_daal_params(X)
if self._fit_method == "brute":
return bf_knn_classification_training(**params).compute(X, y).model
train_alg = bf_knn_classification_training

else:
return kdtree_knn_classification_training(**params).compute(X, y).model
train_alg = kdtree_knn_classification_training

return train_alg(**params).compute(X, y).model
else:
policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)
Expand Down Expand Up @@ -585,10 +591,6 @@ def train_search(self, *args): ...
@abstractmethod
def infer_search(self, *args): ...

def _get_onedal_params(self, X, y=None):
params = self._get_onedal_params(X, y)
return params

def _get_daal_params(self, data):
params = super()._get_daal_params(data)
params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances"
Expand All @@ -601,7 +603,6 @@ def _onedal_fit(self, X, y, queue):
params = self._get_daal_params(X)
if self._fit_method == "brute":
train_alg = bf_knn_classification_training

else:
train_alg = kdtree_knn_classification_training

Expand All @@ -612,11 +613,13 @@ def _onedal_fit(self, X, y, queue):
params = self._get_onedal_params(X, y)

if gpu_device:
self.train(policy, params, *to_table(X, y)).model
return self.train(policy, params, *to_table(X, y)).model
else:
self.train_search(policy, params, to_table(X, y)).model
return self.train_search(policy, params, to_table(X, y)).model

def _onedal_predict(self, model, X, params, queue):
assert self._onedal_model is not None, "Model is not trained"

if type(model) is kdtree_knn_classification_model:
return kdtree_knn_classification_prediction(**params).compute(X, model)
elif type(model) is bf_knn_classification_model:
Expand Down Expand Up @@ -704,6 +707,7 @@ def predict(self, X, queue=None):
)


@bind_default_backend("neighbors.search", ["train", "infer"])
class NearestNeighbors(NeighborsBase):
def __init__(
self,
Expand All @@ -727,7 +731,7 @@ def __init__(
self.weights = weights

def _get_daal_params(self, data):
params = self._get_daal_params(data)
params = super()._get_daal_params(data)
params["resultsToCompute"] = "computeIndicesOfNeighbors|computeDistances"
params["resultsToEvaluate"] = (
"none" if getattr(self, "_y", None) is None else "computeClassLabels"
Expand All @@ -746,10 +750,11 @@ def _onedal_fit(self, X, y, queue):

return train_alg(**params).compute(X, y).model

policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)
params = self._get_onedal_params(X, y)
return self.train(policy, params, to_table(X)).model
else:
policy = self._get_policy(queue, X, y)
X, y = _convert_to_supported(policy, X, y)
params = self._get_onedal_params(X, y)
return self.train(policy, params, to_table(X)).model

def _onedal_predict(self, model, X, params, queue):
if type(self._onedal_model) is kdtree_knn_classification_model:
Expand Down

0 comments on commit 0aa1669

Please sign in to comment.