Skip to content

Commit

Permalink
[bug] fix issues with sklearn conformance for preview Ridge for 2024.6 (
Browse files Browse the repository at this point in the history
uxlfoundation#1958)

* Update _ridge.py

* Update ridge.py

* formatting

* Update ridge.py

* fix __init__
  • Loading branch information
icfaust authored Jul 29, 2024
1 parent e6cb66c commit 53159f9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
4 changes: 2 additions & 2 deletions daal4py/sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _fit_ridge(self, _X, _y, sample_weight=None):
if not _dal_ready:
if hasattr(self, "daal_model_"):
del self.daal_model_
return super(Ridge, self).fit(_X, _y, sample_weight=sample_weight)
return Ridge_original.fit(self, _X, _y, sample_weight=sample_weight)
self.n_iter_ = None
res = _daal4py_fit(self, X, y)
if res is None:
Expand All @@ -188,7 +188,7 @@ def _fit_ridge(self, _X, _y, sample_weight=None):
)
if hasattr(self, "daal_model_"):
del self.daal_model_
return super(Ridge, self).fit(_X, _y, sample_weight=sample_weight)
return Ridge_original.fit(self, _X, _y, sample_weight=sample_weight)
return res


Expand Down
56 changes: 51 additions & 5 deletions sklearnex/preview/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

if sklearn_check_version("1.0") and not sklearn_check_version("1.2"):
from sklearn.linear_model._base import _deprecate_normalize
if sklearn_check_version("1.1") and not sklearn_check_version("1.2"):
from sklearn.utils import check_scalar

from onedal.linear_model import Ridge as onedal_Ridge
from onedal.utils import _num_features, _num_samples
Expand Down Expand Up @@ -78,16 +80,16 @@ def __init__(
random_state=random_state,
)

else:
elif sklearn_check_version("1.0"):

def __init__(
self,
alpha=1.0,
fit_intercept=True,
normalize="deprecated" if sklearn_check_version("1.0") else False,
normalize="deprecated",
copy_X=True,
max_iter=None,
tol=1e-4,
tol=1e-3,
solver="auto",
positive=False,
random_state=None,
Expand All @@ -104,6 +106,30 @@ def __init__(
random_state=random_state,
)

else:

def __init__(
self,
alpha=1.0,
fit_intercept=True,
normalize=False,
copy_X=True,
max_iter=None,
tol=1e-3,
solver="auto",
random_state=None,
):
super().__init__(
alpha=alpha,
fit_intercept=fit_intercept,
normalize=normalize,
copy_X=copy_X,
max_iter=max_iter,
tol=tol,
solver=solver,
random_state=random_state,
)

def fit(self, X, y, sample_weight=None):
# It is necessary to properly update coefs for predict if we
# fallback to sklearn in dispatch
Expand Down Expand Up @@ -274,6 +300,27 @@ def _onedal_fit(self, X, y, sample_weight, queue=None):
# `Sample weight` is not supported. Expected to be None value.
assert sample_weight is None

if sklearn_check_version("1.2"):
self._validate_params()
elif sklearn_check_version("1.1"):
if self.max_iter is not None:
self.max_iter = check_scalar(
self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1
)
self.tol = check_scalar(
self.tol, "tol", target_type=numbers.Real, min_val=0.0
)
if self.alpha is not None and not isinstance(
self.alpha, (np.ndarray, tuple)
):
self.alpha = check_scalar(
self.alpha,
"alpha",
target_type=numbers.Real,
min_val=0.0,
include_boundaries="left",
)

check_params = {
"X": X,
"y": y,
Expand All @@ -282,9 +329,8 @@ def _onedal_fit(self, X, y, sample_weight, queue=None):
"y_numeric": True,
"multi_output": True,
}
if sklearn_check_version("1.2"):
if sklearn_check_version("1.0"):
X, y = self._validate_data(**check_params)
self._validate_params()
else:
X, y = check_X_y(**check_params)

Expand Down

0 comments on commit 53159f9

Please sign in to comment.