Skip to content

Commit

Permalink
[FIX] update checks estimators - fix compatibibility issues with scik…
Browse files Browse the repository at this point in the history
…it-learn >1.5.2 (nilearn#4724)

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Dimitri Papadopoulos Orfanos <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 19, 2024
1 parent cbd0ba2 commit 4bcd033
Show file tree
Hide file tree
Showing 18 changed files with 230 additions and 75 deletions.
52 changes: 30 additions & 22 deletions nilearn/_utils/class_inspect.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
"""Small utilities to inspect classes."""

from sklearn.utils.estimator_checks import (
check_decision_proba_consistency,
check_estimator_get_tags_default_keys,
check_estimators_partial_fit_n_features,
check_get_params_invariance,
check_non_transformer_estimators_n_iter,
check_set_params,
)
from sklearn import __version__ as sklearn_version
from sklearn.utils.estimator_checks import (
check_estimator as sklearn_check_estimator,
)

from nilearn._utils import compare_version

# List of sklearn estimators checks that are valid
# for all nilearn estimators.
VALID_CHECKS = [
x.__name__
for x in [
check_estimator_get_tags_default_keys,
check_estimators_partial_fit_n_features,
check_non_transformer_estimators_n_iter,
check_decision_proba_consistency,
check_get_params_invariance,
check_set_params,
]
"check_estimator_cloneable",
"check_estimators_partial_fit_n_features",
"check_estimator_repr",
"check_estimator_tags_renamed",
"check_mixin_order",
"check_non_transformer_estimators_n_iter",
"check_decision_proba_consistency",
"check_get_params_invariance",
"check_set_params",
]

if compare_version(sklearn_version, ">", "1.5.2"):
VALID_CHECKS.append("check_valid_tag_types")
else:
VALID_CHECKS.append("check_estimator_get_tags_default_keys")


# TODO
# remove when bumping to sklearn >= 1.3
try:
Expand All @@ -37,9 +40,11 @@


def check_estimator(estimator=None, valid=True, extra_valid_checks=None):
"""Check compatibility with scikit-learn estimators.
"""Return a valid or invalid scikit-learn estimators check.
As some of Nilearn estimators cannot fit Numpy arrays,
As some of Nilearn estimators do not comply
with sklearn recommendations
(cannot fit Numpy arrays, do input validation in the constructor...)
we cannot directly use
sklearn.utils.estimator_checks.check_estimator.
Expand All @@ -51,6 +56,10 @@ def check_estimator(estimator=None, valid=True, extra_valid_checks=None):
If new 'valid' checks are added to scikit-learn,
then tests marked as xfail will start passing.
See this section rolling-your-own-estimator in
the scikit-learn doc for more info:
https://scikit-learn.org/stable/developers/develop.html
Parameters
----------
estimator : estimator object or list of estimator object
Expand All @@ -62,9 +71,8 @@ def check_estimator(estimator=None, valid=True, extra_valid_checks=None):
extra_valid_checks : list of strings
Names of checks to be tested as valid for this estimator.
"""
if extra_valid_checks is None:
valid_checks = VALID_CHECKS
else:
valid_checks = VALID_CHECKS
if extra_valid_checks is not None:
valid_checks = VALID_CHECKS + extra_valid_checks

if not isinstance(estimator, list):
Expand Down
36 changes: 23 additions & 13 deletions nilearn/connectome/tests/test_group_sparse_cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
extra_valid_checks = [
"check_parameters_default_constructible",
"check_no_attributes_set_in_init",
"check_estimators_unfitted",
"check_do_not_raise_errors_in_init_or_set_params",
]


Expand All @@ -29,16 +31,16 @@ def test_check_estimator_group_sparse_covariance_cv(estimator, check, name): #
check(estimator)


@pytest.mark.xfail(reason="invalid checks should fail")
@pytest.mark.parametrize(
"estimator, check, name",
(
check_estimator(
estimator=[GroupSparseCovariance()],
extra_valid_checks=["check_no_attributes_set_in_init"],
)
check_estimator(
estimator=[GroupSparseCovarianceCV()],
valid=False,
extra_valid_checks=extra_valid_checks,
),
)
def test_check_estimator_group_sparse_covariance(
def test_check_estimator_invalid_group_sparse_covariance_cv(
estimator,
check,
name, # noqa: ARG001
Expand All @@ -47,16 +49,20 @@ def test_check_estimator_group_sparse_covariance(
check(estimator)


@pytest.mark.xfail(reason="invalid checks should fail")
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[GroupSparseCovarianceCV()],
valid=False,
extra_valid_checks=extra_valid_checks,
(
check_estimator(
estimator=[GroupSparseCovariance()],
extra_valid_checks=[
"check_no_attributes_set_in_init",
"check_estimators_unfitted",
"check_do_not_raise_errors_in_init_or_set_params",
],
)
),
)
def test_check_estimator_invalid_group_sparse_covariance_cv(
def test_check_estimator_group_sparse_covariance(
estimator,
check,
name, # noqa: ARG001
Expand All @@ -71,7 +77,11 @@ def test_check_estimator_invalid_group_sparse_covariance_cv(
check_estimator(
estimator=[GroupSparseCovariance()],
valid=False,
extra_valid_checks=["check_no_attributes_set_in_init"],
extra_valid_checks=[
"check_no_attributes_set_in_init",
"check_estimators_unfitted",
"check_do_not_raise_errors_in_init_or_set_params",
],
),
)
def test_check_estimator_invalid_group_sparse_covariance(
Expand Down
23 changes: 20 additions & 3 deletions nilearn/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,15 +1109,17 @@ def _predict_dummy(self, n_samples):
return scores.ravel() if scores.shape[1] == 1 else scores

def _more_tags(self):
return self.__sklearn_tags__()

def __sklearn_tags__(self):
# TODO
# rename method to '__sklearn_tags__'
# and get rid of if block
# get rid of if block
# bumping sklearn_version > 1.5
# see https://github.com/scikit-learn/scikit-learn/pull/29677
ver = parse(sklearn_version)
if ver.release[1] < 6:
return {"require_y": True}
tags = self.__sklearn_tags__()
tags = super().__sklearn_tags__()
tags.target_tags.required = True
return tags

Expand Down Expand Up @@ -1437,6 +1439,21 @@ def __init__(
n_jobs=n_jobs,
)

def _more_tags(self):
return self.__sklearn_tags__()

def __sklearn_tags__(self):
# TODO
# get rid of if block
# bumping sklearn_version > 1.5
# see https://github.com/scikit-learn/scikit-learn/pull/29677
ver = parse(sklearn_version)
if ver.release[1] < 6:
return {"multioutput": True}
tags = super().__sklearn_tags__()
tags.target_tags.required = True
return tags


@fill_doc
class FREMRegressor(_BaseDecoder):
Expand Down
4 changes: 2 additions & 2 deletions nilearn/decoding/searchlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import numpy as np
from joblib import Parallel, cpu_count, delayed
from sklearn import svm
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.exceptions import ConvergenceWarning
from sklearn.model_selection import KFold, cross_val_score

Expand Down Expand Up @@ -237,7 +237,7 @@ def _group_iter_search_light(
# Class for search_light #####################################################
##############################################################################
@fill_doc
class SearchLight(BaseEstimator):
class SearchLight(TransformerMixin, BaseEstimator):
"""Implement search_light analysis using an arbitrary type of classifier.
Parameters
Expand Down
79 changes: 68 additions & 11 deletions nilearn/decoding/tests/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,19 @@

ESTIMATOR_REGRESSION = ("ridge", "svr")

extra_valid_checks = [
"check_do_not_raise_errors_in_init_or_set_params",
]


@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[DecoderRegressor()],
extra_valid_checks=["check_parameters_default_constructible"],
extra_valid_checks=[
*extra_valid_checks,
"check_parameters_default_constructible",
],
),
)
def test_check_estimator_decoder_regressor(estimator, check, name): # noqa: ARG001
Expand All @@ -90,7 +97,10 @@ def test_check_estimator_decoder_regressor(estimator, check, name): # noqa: ARG
"estimator, check, name",
check_estimator(
estimator=[DecoderRegressor()],
extra_valid_checks=["check_parameters_default_constructible"],
extra_valid_checks=[
*extra_valid_checks,
"check_parameters_default_constructible",
],
valid=False,
),
)
Expand All @@ -103,6 +113,7 @@ def test_check_estimator_invalid_decoder_regressor(estimator, check, name): # n
"estimator, check, name",
check_estimator(
estimator=[FREMRegressor()],
extra_valid_checks=extra_valid_checks,
),
)
def test_check_estimator_frem_regressor(estimator, check, name): # noqa: ARG001
Expand All @@ -113,7 +124,11 @@ def test_check_estimator_frem_regressor(estimator, check, name): # noqa: ARG001
@pytest.mark.xfail(reason="invalid checks should fail")
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(estimator=[FREMRegressor()], valid=False),
check_estimator(
estimator=[FREMRegressor()],
valid=False,
extra_valid_checks=extra_valid_checks,
),
)
def test_check_estimator_invalid_frem_regressor(estimator, check, name): # noqa: ARG001
"""Check compliance with sklearn estimators."""
Expand All @@ -123,8 +138,9 @@ def test_check_estimator_invalid_frem_regressor(estimator, check, name): # noqa
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[_BaseDecoder(), Decoder()],
estimator=[Decoder()],
extra_valid_checks=[
*extra_valid_checks,
"check_no_attributes_set_in_init",
"check_parameters_default_constructible",
],
Expand All @@ -139,8 +155,9 @@ def test_check_estimator_decoder(estimator, check, name): # noqa: ARG001
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[_BaseDecoder(), Decoder()],
estimator=[Decoder()],
extra_valid_checks=[
*extra_valid_checks,
"check_no_attributes_set_in_init",
"check_parameters_default_constructible",
],
Expand All @@ -152,11 +169,48 @@ def test_check_estimator_invalid_decoder(estimator, check, name): # noqa: ARG00
check(estimator)


@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[_BaseDecoder()],
extra_valid_checks=[
*extra_valid_checks,
"check_no_attributes_set_in_init",
"check_parameters_default_constructible",
],
),
)
def test_check_estimator_base_decoder(estimator, check, name): # noqa: ARG001
"""Check compliance with sklearn estimators."""
check(estimator)


@pytest.mark.xfail(reason="invalid checks should fail")
@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[_BaseDecoder()],
extra_valid_checks=[
*extra_valid_checks,
"check_no_attributes_set_in_init",
"check_parameters_default_constructible",
],
valid=False,
),
)
def test_check_estimator_invalid_base_decoder(estimator, check, name): # noqa: ARG001
"""Check compliance with sklearn estimators."""
check(estimator)


@pytest.mark.parametrize(
"estimator, check, name",
check_estimator(
estimator=[FREMClassifier()],
extra_valid_checks=["check_no_attributes_set_in_init"],
extra_valid_checks=[
*extra_valid_checks,
"check_no_attributes_set_in_init",
],
),
)
def test_check_estimator_frem_classifier(estimator, check, name): # noqa: ARG001
Expand All @@ -169,7 +223,10 @@ def test_check_estimator_frem_classifier(estimator, check, name): # noqa: ARG00
"estimator, check, name",
check_estimator(
estimator=[FREMClassifier()],
extra_valid_checks=["check_no_attributes_set_in_init"],
extra_valid_checks=[
*extra_valid_checks,
"check_no_attributes_set_in_init",
],
valid=False,
),
)
Expand Down Expand Up @@ -1120,9 +1177,9 @@ def test_decoder_tags_classification():
# remove if block when bumping sklearn_version to > 1.5
ver = parse(sklearn_version)
if ver.release[1] < 6:
assert model._more_tags()["require_y"] is True
assert model.__sklearn_tags__()["require_y"] is True
else:
assert model._more_tags().target_tags.required is True
assert model.__sklearn_tags__().target_tags.required is True


def test_decoder_tags_regression():
Expand All @@ -1131,9 +1188,9 @@ def test_decoder_tags_regression():
# remove if block when bumping sklearn_version to > 1.5
ver = parse(sklearn_version)
if ver.release[1] < 6:
assert model._more_tags()["multioutput"] is True
assert model.__sklearn_tags__()["multioutput"] is True
else:
assert model._more_tags().target_tags.multi_output is True
assert model.__sklearn_tags__().target_tags.multi_output is True


def test_decoder_decision_function(binary_classification_data):
Expand Down
Loading

0 comments on commit 4bcd033

Please sign in to comment.