Skip to content

Commit

Permalink
Merge pull request #776 from MilesCranmer/fix-sklearn-tests
Browse files Browse the repository at this point in the history
test: skip new sklearn checks
  • Loading branch information
MilesCranmer authored Dec 12, 2024
2 parents 89b5a89 + 3d23dc0 commit 8c8695b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
17 changes: 15 additions & 2 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
_suggest_keywords,
)

try:
from sklearn.utils.validation import validate_data

OLD_SKLEARN = False
except ImportError:
OLD_SKLEARN = True

ALREADY_RAN = False


Expand Down Expand Up @@ -1604,11 +1611,17 @@ def _validate_and_set_fit_params(
)

def _validate_data_X_y(self, X: Any, y: Any) -> tuple[ndarray, ndarray]:
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
if OLD_SKLEARN:
raw_out = self._validate_data(X=X, y=y, reset=True, multi_output=True) # type: ignore
else:
raw_out = validate_data(self, X=X, y=y, reset=True, multi_output=True) # type: ignore
return cast(tuple[ndarray, ndarray], raw_out)

def _validate_data_X(self, X: Any) -> ndarray:
raw_out = self._validate_data(X=X, reset=False) # type: ignore
if OLD_SKLEARN:
raw_out = self._validate_data(X=X, reset=False) # type: ignore
else:
raw_out = validate_data(self, X=X, reset=False) # type: ignore
return cast(ndarray, raw_out)

def _get_precision_mapped_dtype(self, X: np.ndarray) -> type:
Expand Down
10 changes: 8 additions & 2 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,8 +876,14 @@ def test_scikit_learn_compatibility(self):
check_generator = check_estimator(model, generate_only=True)
exception_messages = []
for _, check in check_generator:
if check.func.__name__ == "check_complex_data":
# We can use complex data, so avoid this check.
if check.func.__name__ in {
# We can use complex data, so avoid this check
"check_complex_data",
# We handle kwargs manually, so skip this check
"check_do_not_raise_errors_in_init_or_set_params",
# TODO:
"check_n_features_in_after_fitting",
}:
continue
try:
with warnings.catch_warnings():
Expand Down

0 comments on commit 8c8695b

Please sign in to comment.