From 85b4f40e52d929c6f80d93fa2a55ebe8a5e2113d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 11 Dec 2024 19:10:05 -0800 Subject: [PATCH 1/2] test: skip new sklearn checks --- pysr/test/test_main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 52772042..2f5f440a 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -876,8 +876,14 @@ def test_scikit_learn_compatibility(self): check_generator = check_estimator(model, generate_only=True) exception_messages = [] for _, check in check_generator: - if check.func.__name__ == "check_complex_data": - # We can use complex data, so avoid this check. + if check.func.__name__ in { + # We can use complex data, so avoid this check + "check_complex_data", + # We handle kwargs manually, so skip this check + "check_do_not_raise_errors_in_init_or_set_params", + # TODO: + "check_n_features_in_after_fitting", + }: continue try: with warnings.catch_warnings(): From 3d23dc079d765e41224bd3b24f031344858fc73c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 11 Dec 2024 19:27:23 -0800 Subject: [PATCH 2/2] feat: compat with new sklearn version --- pysr/sr.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/pysr/sr.py b/pysr/sr.py index 3cae1735..c23805b4 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -58,6 +58,13 @@ _suggest_keywords, ) +try: + from sklearn.utils.validation import validate_data + + OLD_SKLEARN = False +except ImportError: + OLD_SKLEARN = True + ALREADY_RAN = False @@ -1604,11 +1611,17 @@ def _validate_and_set_fit_params( ) def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]: - raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore + if OLD_SKLEARN: + raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore + else: + raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore return cast(tuple[ndarray, ndarray], raw_out) def _validate_data_X(self, X: Any) -> ndarray: - raw_out = self._validate_data(X=X, reset=False) # type: ignore + if OLD_SKLEARN: + raw_out = self._validate_data(X=X, reset=False) # type: ignore + else: + raw_out = validate_data(self, X=X, reset=False) # type: ignore return cast(ndarray, raw_out) def _get_precision_mapped_dtype(self, X: np.ndarray) -> type: