Skip to content

Commit

Permalink
feat: compat with new sklearn version
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 12, 2024
1 parent 85b4f40 commit 3d23dc0
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
_suggest_keywords,
)

try:
from sklearn.utils.validation import validate_data

OLD_SKLEARN = False
except ImportError:
OLD_SKLEARN = True

ALREADY_RAN = False


Expand Down Expand Up @@ -1604,11 +1611,17 @@ def _validate_and_set_fit_params(
)

def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
if OLD_SKLEARN:
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
else:
raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore
return cast(tuple[ndarray, ndarray], raw_out)

def _validate_data_X(self, X: Any) -> ndarray:
raw_out = self._validate_data(X=X, reset=False) # type: ignore
if OLD_SKLEARN:
raw_out = self._validate_data(X=X, reset=False) # type: ignore
else:
raw_out = validate_data(self, X=X, reset=False) # type: ignore
return cast(ndarray, raw_out)

def _get_precision_mapped_dtype(self, X: np.ndarray) -> type:
Expand Down

0 comments on commit 3d23dc0

Please sign in to comment.