-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapt to scikit-learn 1.6 estimator tag changes #11021
base: master
Are you sure you want to change the base?
Changes from all commits
79ed32c
3106cf1
3af44be
a9e30b4
6a12576
52e6d83
816667a
abfc6a6
ef725c1
d845922
b7564a1
8364e92
a511848
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,8 @@ disable = [ | |
"import-error", | ||
"attribute-defined-outside-init", | ||
"import-outside-toplevel", | ||
"too-few-public-methods", | ||
"too-many-ancestors", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switching the placeholder classes in the scikit-learn-is-not-available branch of
It seems that there are already many other places in this codebase where those warnings are being suppressed with
I don't feel that strongly... if you'd prefer to keep suppressing individual cases of these, please let me know and I'll happily switch back to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good to me. @RAMitchell find the pylint checks helpful. I myself prefer mypy checks and think the pylint is not particularly suitable for ML libraries like XGBoost. In general, I don't have a strong opinion about these "structural" or naming warnings and care mostly about warnings like unused imports or use before definition. |
||
"too-many-nested-blocks", | ||
"unsubscriptable-object", | ||
"useless-object-inheritance" | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -43,32 +43,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool: | |||||||||||||
|
||||||||||||||
# sklearn | ||||||||||||||
try: | ||||||||||||||
from sklearn import __version__ as _sklearn_version | ||||||||||||||
from sklearn.base import BaseEstimator as XGBModelBase | ||||||||||||||
from sklearn.base import ClassifierMixin as XGBClassifierBase | ||||||||||||||
from sklearn.base import RegressorMixin as XGBRegressorBase | ||||||||||||||
from sklearn.preprocessing import LabelEncoder | ||||||||||||||
|
||||||||||||||
try: | ||||||||||||||
from sklearn.model_selection import KFold as XGBKFold | ||||||||||||||
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold | ||||||||||||||
except ImportError: | ||||||||||||||
from sklearn.cross_validation import KFold as XGBKFold | ||||||||||||||
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold | ||||||||||||||
|
||||||||||||||
# sklearn.utils Tags types can be imported unconditionally once | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can do that once the next sklearn is published. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should. That'd effectively raise Because it would result in xgboost/python-package/xgboost/compat.py Lines 60 to 61 in 5826b02
Which would make all the estimators unusable on those versions. xgboost/python-package/xgboost/sklearn.py Lines 754 to 757 in 5826b02
|
||||||||||||||
# xgboost's minimum scikit-learn version is 1.6 or higher | ||||||||||||||
try: | ||||||||||||||
from sklearn.utils import Tags as _sklearn_Tags | ||||||||||||||
except ImportError: | ||||||||||||||
_sklearn_Tags = object | ||||||||||||||
|
||||||||||||||
SKLEARN_INSTALLED = True | ||||||||||||||
|
||||||||||||||
except ImportError: | ||||||||||||||
SKLEARN_INSTALLED = False | ||||||||||||||
|
||||||||||||||
# used for compatibility without sklearn | ||||||||||||||
XGBModelBase = object | ||||||||||||||
XGBClassifierBase = object | ||||||||||||||
XGBRegressorBase = object | ||||||||||||||
LabelEncoder = object | ||||||||||||||
class XGBModelBase: # type: ignore[no-redef] | ||||||||||||||
"""Dummy class for sklearn.base.BaseEstimator.""" | ||||||||||||||
|
||||||||||||||
class XGBClassifierBase: # type: ignore[no-redef] | ||||||||||||||
"""Dummy class for sklearn.base.ClassifierMixin.""" | ||||||||||||||
|
||||||||||||||
class XGBRegressorBase: # type: ignore[no-redef] | ||||||||||||||
"""Dummy class for sklearn.base.RegressorMixin.""" | ||||||||||||||
|
||||||||||||||
XGBKFold = None | ||||||||||||||
XGBStratifiedKFold = None | ||||||||||||||
|
||||||||||||||
_sklearn_Tags = object | ||||||||||||||
_sklearn_version = object | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
_logger = logging.getLogger(__name__) | ||||||||||||||
|
||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -35,6 +35,8 @@ | |||
XGBClassifierBase, | ||||
XGBModelBase, | ||||
XGBRegressorBase, | ||||
_sklearn_Tags, | ||||
_sklearn_version, | ||||
import_cupy, | ||||
) | ||||
from .config import config_context | ||||
|
@@ -54,7 +56,7 @@ | |||
from .training import train | ||||
|
||||
|
||||
class XGBRankerMixIn: # pylint: disable=too-few-public-methods | ||||
class XGBRankerMixIn: | ||||
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn | ||||
base classes. | ||||
|
||||
|
@@ -79,7 +81,7 @@ def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool: | |||
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl | ||||
|
||||
|
||||
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods | ||||
class _SklObjWProto(Protocol): | ||||
def __call__( | ||||
self, | ||||
y_true: ArrayLike, | ||||
|
@@ -805,6 +807,41 @@ def _more_tags(self) -> Dict[str, bool]: | |||
tags["non_deterministic"] = True | ||||
return tags | ||||
|
||||
@staticmethod | ||||
def _update_sklearn_tags_from_dict( | ||||
*, | ||||
tags: _sklearn_Tags, | ||||
tags_dict: Dict[str, bool], | ||||
) -> _sklearn_Tags: | ||||
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes. | ||||
|
||||
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags. | ||||
ref: https://github.com/scikit-learn/scikit-learn/pull/29677 | ||||
|
||||
This method handles updating that instance based on the values in ``self._more_tags()``. | ||||
""" | ||||
tags.non_deterministic = tags_dict.get("non_deterministic", False) | ||||
tags.no_validation = tags_dict["no_validation"] | ||||
tags.input_tags.allow_nan = tags_dict["allow_nan"] | ||||
return tags | ||||
|
||||
def __sklearn_tags__(self) -> _sklearn_Tags: | ||||
# XGBModelBase.__sklearn_tags__() cannot be called unconditionally, | ||||
# because that method isn't defined for scikit-learn<1.6 | ||||
if not hasattr(XGBModelBase, "__sklearn_tags__"): | ||||
err_msg = ( | ||||
"__sklearn_tags__() should not be called when using scikit-learn<1.6. " | ||||
f"Detected version: {_sklearn_version}" | ||||
) | ||||
raise AttributeError(err_msg) | ||||
|
||||
# take whatever tags are provided by BaseEstimator, then modify | ||||
# them with XGBoost-specific values | ||||
return self._update_sklearn_tags_from_dict( | ||||
tags=super().__sklearn_tags__(), # pylint: disable=no-member | ||||
tags_dict=self._more_tags(), | ||||
) | ||||
|
||||
def __sklearn_is_fitted__(self) -> bool: | ||||
return hasattr(self, "_Booster") | ||||
|
||||
|
@@ -898,13 +935,30 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: | |||
"""Get parameters.""" | ||||
# Based on: https://stackoverflow.com/questions/59248211 | ||||
# The basic flow in `get_params` is: | ||||
# 0. Return parameters in subclass first, by using inspect. | ||||
# 1. Return parameters in `XGBModel` (the base class). | ||||
# 0. Return parameters in subclass (self.__class__) first, by using inspect. | ||||
# 1. Return parameters in all parent classes (especially `XGBModel`). | ||||
# 2. Return whatever in `**kwargs`. | ||||
# 3. Merge them. | ||||
# | ||||
# This needs to accommodate being called recursively in the following | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please help add a test for this? The hierarchy and the Python introspection are getting a bit confusing now. ;-( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! I just added one in a511848, let me know if there are other conditions you'd like to see tested. Between that and the existing test: xgboost/tests/python/test_with_sklearn.py Line 758 in 5826b02
I think this behavior should be well-covered. |
||||
# inheritance graphs (and similar for classification and ranking): | ||||
# | ||||
# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator | ||||
# XGBRegressor -> XGBModel -> BaseEstimator | ||||
# XGBModel -> BaseEstimator | ||||
# | ||||
params = super().get_params(deep) | ||||
cp = copy.copy(self) | ||||
cp.__class__ = cp.__class__.__bases__[0] | ||||
# if the immediate parent is a mixin, skip it (mixins don't define get_params()) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think it's more general to check for the |
||||
if cp.__class__.__bases__[0] in ( | ||||
XGBClassifierBase, | ||||
XGBRankerMixIn, | ||||
XGBRegressorBase, | ||||
): | ||||
cp.__class__ = cp.__class__.__bases__[1] | ||||
# otherwise, run get_params() from the immediate parent class | ||||
else: | ||||
cp.__class__ = cp.__class__.__bases__[0] | ||||
params.update(cp.__class__.get_params(cp, deep)) | ||||
# if kwargs is a dict, update params accordingly | ||||
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict): | ||||
|
@@ -1481,7 +1535,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> | |||
Number of boosting rounds. | ||||
""", | ||||
) | ||||
class XGBClassifier(XGBModel, XGBClassifierBase): | ||||
class XGBClassifier(XGBClassifierBase, XGBModel): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As of scikit-learn/scikit-learn#30234 (which will be in
That check is new, but it enforced behavior that's been documented in
That new check error led to these inheritance-order changes, which led to the |
||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes | ||||
@_deprecate_positional_args | ||||
def __init__( | ||||
|
@@ -1497,6 +1551,12 @@ def _more_tags(self) -> Dict[str, bool]: | |||
tags["multilabel"] = True | ||||
return tags | ||||
|
||||
def __sklearn_tags__(self) -> _sklearn_Tags: | ||||
tags = super().__sklearn_tags__() | ||||
tags_dict = self._more_tags() | ||||
tags.classifier_tags.multi_label = tags_dict["multilabel"] | ||||
return tags | ||||
|
||||
@_deprecate_positional_args | ||||
def fit( | ||||
self, | ||||
|
@@ -1769,7 +1829,7 @@ def fit( | |||
"Implementation of the scikit-learn API for XGBoost regression.", | ||||
["estimators", "model", "objective"], | ||||
) | ||||
class XGBRegressor(XGBModel, XGBRegressorBase): | ||||
class XGBRegressor(XGBRegressorBase, XGBModel): | ||||
# pylint: disable=missing-docstring | ||||
@_deprecate_positional_args | ||||
def __init__( | ||||
|
@@ -1783,6 +1843,13 @@ def _more_tags(self) -> Dict[str, bool]: | |||
tags["multioutput_only"] = False | ||||
return tags | ||||
|
||||
def __sklearn_tags__(self) -> _sklearn_Tags: | ||||
tags = super().__sklearn_tags__() | ||||
tags_dict = self._more_tags() | ||||
tags.target_tags.multi_output = tags_dict["multioutput"] | ||||
tags.target_tags.single_output = not tags_dict["multioutput_only"] | ||||
return tags | ||||
|
||||
|
||||
@xgboost_model_doc( | ||||
"scikit-learn API for XGBoost random forest regression.", | ||||
|
@@ -1910,7 +1977,7 @@ def _get_qid( | |||
`qid` can be a special column of input `X` instead of a separated parameter, see | ||||
:py:meth:`fit` for more info.""", | ||||
) | ||||
class XGBRanker(XGBModel, XGBRankerMixIn): | ||||
class XGBRanker(XGBRankerMixIn, XGBModel): | ||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name | ||||
@_deprecate_positional_args | ||||
def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any): | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noticed some model files left behind from running all the Python tests locally while developing this. These
.gitignore
rules prevent checking them into source control.