Skip to content

Commit

Permalink
Use the estimator type tags as required by scikit-learn 1.6.0 (nilear…
Browse files Browse the repository at this point in the history
  • Loading branch information
man-shu authored Nov 20, 2024
1 parent 94fd3b3 commit 722514f
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions nilearn/decoding/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from packaging.version import parse
from sklearn import __version__ as sklearn_version
from sklearn import clone
from sklearn.base import BaseEstimator, MultiOutputMixin
from sklearn.base import (
BaseEstimator,
ClassifierMixin,
MultiOutputMixin,
RegressorMixin,
)
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.linear_model import (
LassoCV,
Expand Down Expand Up @@ -1125,7 +1130,7 @@ def __sklearn_tags__(self):


@fill_doc
class Decoder(_BaseDecoder):
class Decoder(ClassifierMixin, _BaseDecoder):
"""A wrapper for popular classification strategies in neuroimaging.
The `Decoder` object supports classification methods.
Expand Down Expand Up @@ -1281,9 +1286,16 @@ def __init__(
n_jobs=n_jobs,
)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
ver = parse(sklearn_version)
if ver.release[1] >= 6:
tags.estimator_type = "classifier"
return tags


@fill_doc
class DecoderRegressor(MultiOutputMixin, _BaseDecoder):
class DecoderRegressor(MultiOutputMixin, RegressorMixin, _BaseDecoder):
"""A wrapper for popular regression strategies in neuroimaging.
The `DecoderRegressor` object supports regression methods.
Expand Down Expand Up @@ -1451,6 +1463,8 @@ def __sklearn_tags__(self):
if ver.release[1] < 6:
return {"multioutput": True}
tags = super().__sklearn_tags__()
if ver.release[1] >= 6:
tags.estimator_type = "regressor"
tags.target_tags.required = True
return tags

Expand Down

0 comments on commit 722514f

Please sign in to comment.