Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt to scikit-learn 1.6 estimator tag changes #11021

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,13 @@ credentials.csv
.bloop

# python tests
*.bin
demo/**/*.txt
*.dmatrix
.hypothesis
__MACOSX/
model*.json
/tests/python/models/models/
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed some model files left behind from running all the Python tests locally while developing this. These .gitignore rules prevent checking them into source control.


# R tests
*.htm
Expand Down
2 changes: 2 additions & 0 deletions python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ disable = [
"import-error",
"attribute-defined-outside-init",
"import-outside-toplevel",
"too-few-public-methods",
"too-many-ancestors",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching the placeholder classes in the scikit-learn-is-not-available branch of compat.py to be individual, differently-named, empty classes (instead of all using object), led to these several of these pylint errors, in sklearn.py and dask.py:

R0901: Too many ancestors (9/7) (too-many-ancestors)
R0903: Too few public methods (0/2) (too-few-public-methods)

(build link)

It seems that there are already many other places in this codebase where those warnings are being suppressed with # pylint: disable comments. So instead of adding more such comments (some of which would have to share a line with # type: ignore comments for mypy), I'm proposing:

  • just globally ignore these pylint warnings for the whole project
  • remove any existing # pylint: disable comments about them

I don't feel that strongly... if you'd prefer to keep suppressing individual cases of these, please let me know and I'll happily switch back to #pylint: disable comments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@RAMitchell find the pylint checks helpful. I myself prefer mypy checks and think the pylint is not particularly suitable for ML libraries like XGBoost. In general, I don't have a strong opinion about these "structural" or naming warnings and care mostly about warnings like unused imports or use before definition.

"too-many-nested-blocks",
"unsubscriptable-object",
"useless-object-inheritance"
Expand Down
27 changes: 19 additions & 8 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool:

# sklearn
try:
from sklearn import __version__ as _sklearn_version
from sklearn.base import BaseEstimator as XGBModelBase
from sklearn.base import ClassifierMixin as XGBClassifierBase
from sklearn.base import RegressorMixin as XGBRegressorBase
from sklearn.preprocessing import LabelEncoder

try:
from sklearn.model_selection import KFold as XGBKFold
from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold
except ImportError:
from sklearn.cross_validation import KFold as XGBKFold
from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold

# sklearn.utils Tags types can be imported unconditionally once
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that once the next sklearn is published.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should.

That'd effectively raise xgboost's requirement all the way to scikit-learn>=1.6.

Because it would result in compat.SKLEARN_INSTALLED being False for scikit-learn < 1.6:

except ImportError:
SKLEARN_INSTALLED = False

Which would make all the estimators unusable on those versions.

if not SKLEARN_INSTALLED:
raise ImportError(
"sklearn needs to be installed in order to use this module"
)

# xgboost's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = object

SKLEARN_INSTALLED = True

except ImportError:
SKLEARN_INSTALLED = False

# used for compatibility without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object
LabelEncoder = object
class XGBModelBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.BaseEstimator."""

class XGBClassifierBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.ClassifierMixin."""

class XGBRegressorBase: # type: ignore[no-redef]
"""Dummy class for sklearn.base.RegressorMixin."""

XGBKFold = None
XGBStratifiedKFold = None

_sklearn_Tags = object
_sklearn_version = object


_logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def c_array(
def from_array_interface(interface: dict) -> NumpyOrCupy:
"""Convert array interface to numpy or cupy array"""

class Array: # pylint: disable=too-few-public-methods
class Array:
"""Wrapper type for communicating with numpy and cupy."""

_interface: Optional[dict] = None
Expand Down
1 change: 0 additions & 1 deletion python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# pylint: disable=too-many-arguments, too-many-locals
# pylint: disable=missing-class-docstring, invalid-name
# pylint: disable=too-many-lines
# pylint: disable=too-few-public-methods
"""
Dask extensions for distributed training
----------------------------------------
Expand Down
83 changes: 75 additions & 8 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
XGBClassifierBase,
XGBModelBase,
XGBRegressorBase,
_sklearn_Tags,
_sklearn_version,
import_cupy,
)
from .config import config_context
Expand All @@ -54,7 +56,7 @@
from .training import train


class XGBRankerMixIn: # pylint: disable=too-few-public-methods
class XGBRankerMixIn:
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
base classes.

Expand All @@ -79,7 +81,7 @@ def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl


class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
class _SklObjWProto(Protocol):
def __call__(
self,
y_true: ArrayLike,
Expand Down Expand Up @@ -805,6 +807,41 @@ def _more_tags(self) -> Dict[str, bool]:
tags["non_deterministic"] = True
return tags

@staticmethod
def _update_sklearn_tags_from_dict(
*,
tags: _sklearn_Tags,
tags_dict: Dict[str, bool],
) -> _sklearn_Tags:
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes.

``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
ref: https://github.com/scikit-learn/scikit-learn/pull/29677

This method handles updating that instance based on the values in ``self._more_tags()``.
"""
tags.non_deterministic = tags_dict.get("non_deterministic", False)
tags.no_validation = tags_dict["no_validation"]
tags.input_tags.allow_nan = tags_dict["allow_nan"]
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
# XGBModelBase.__sklearn_tags__() cannot be called unconditionally,
# because that method isn't defined for scikit-learn<1.6
if not hasattr(XGBModelBase, "__sklearn_tags__"):
err_msg = (
"__sklearn_tags__() should not be called when using scikit-learn<1.6. "
f"Detected version: {_sklearn_version}"
)
raise AttributeError(err_msg)

# take whatever tags are provided by BaseEstimator, then modify
# them with XGBoost-specific values
return self._update_sklearn_tags_from_dict(
tags=super().__sklearn_tags__(), # pylint: disable=no-member
tags_dict=self._more_tags(),
)

def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")

Expand Down Expand Up @@ -898,13 +935,30 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
"""Get parameters."""
# Based on: https://stackoverflow.com/questions/59248211
# The basic flow in `get_params` is:
# 0. Return parameters in subclass first, by using inspect.
# 1. Return parameters in `XGBModel` (the base class).
# 0. Return parameters in subclass (self.__class__) first, by using inspect.
# 1. Return parameters in all parent classes (especially `XGBModel`).
# 2. Return whatever in `**kwargs`.
# 3. Merge them.
#
# This needs to accommodate being called recursively in the following
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please help add a test for this? The hierarchy and the Python introspection are getting a bit confusing now. ;-(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I just added one in a511848, let me know if there are other conditions you'd like to see tested.

Between that and the existing test:

def test_parameters_access():

I think this behavior should be well-covered.

# inheritance graphs (and similar for classification and ranking):
#
# XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator
# XGBRegressor -> XGBModel -> BaseEstimator
# XGBModel -> BaseEstimator
#
params = super().get_params(deep)
cp = copy.copy(self)
cp.__class__ = cp.__class__.__bases__[0]
# if the immediate parent is a mixin, skip it (mixins don't define get_params())
Copy link
Member

@trivialfis trivialfis Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's more general to check for the get_params attribute instead of checking hardcoded mixins? The current check seems to defeat the purpose of having a polymorphic structure (inheritance).

if cp.__class__.__bases__[0] in (
XGBClassifierBase,
XGBRankerMixIn,
XGBRegressorBase,
):
cp.__class__ = cp.__class__.__bases__[1]
# otherwise, run get_params() from the immediate parent class
else:
cp.__class__ = cp.__class__.__bases__[0]
params.update(cp.__class__.get_params(cp, deep))
# if kwargs is a dict, update params accordingly
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
Expand Down Expand Up @@ -1481,7 +1535,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
Number of boosting rounds.
""",
)
class XGBClassifier(XGBModel, XGBClassifierBase):
class XGBClassifier(XGBClassifierBase, XGBModel):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of scikit-learn/scikit-learn#30234 (which will be in scikit-learn 1.6), the estimator checks raise an error like the following:

XGBRegressor is inheriting from mixins in the wrong order. In general, in mixin inheritance, more specialized mixins must come before more general ones. This means, for instance, BaseEstimator should be on the right side of most other mixins. You need to change the order...

That check is new, but it enforced behavior that's been documented in scikit-learn's estimator development docs for a long time. See the "BaseEstimator and mixins" section in https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator.

It is particularly important to notice that mixins should be “on the left” while the BaseEstimator should be “on the right” in the inheritance list for proper MRO.

That new check error led to these inheritance-order changes, which led to the XGBModel.get_params() changes.

# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
@_deprecate_positional_args
def __init__(
Expand All @@ -1497,6 +1551,12 @@ def _more_tags(self) -> Dict[str, bool]:
tags["multilabel"] = True
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
tags = super().__sklearn_tags__()
tags_dict = self._more_tags()
tags.classifier_tags.multi_label = tags_dict["multilabel"]
return tags

@_deprecate_positional_args
def fit(
self,
Expand Down Expand Up @@ -1769,7 +1829,7 @@ def fit(
"Implementation of the scikit-learn API for XGBoost regression.",
["estimators", "model", "objective"],
)
class XGBRegressor(XGBModel, XGBRegressorBase):
class XGBRegressor(XGBRegressorBase, XGBModel):
# pylint: disable=missing-docstring
@_deprecate_positional_args
def __init__(
Expand All @@ -1783,6 +1843,13 @@ def _more_tags(self) -> Dict[str, bool]:
tags["multioutput_only"] = False
return tags

def __sklearn_tags__(self) -> _sklearn_Tags:
tags = super().__sklearn_tags__()
tags_dict = self._more_tags()
tags.target_tags.multi_output = tags_dict["multioutput"]
tags.target_tags.single_output = not tags_dict["multioutput_only"]
return tags


@xgboost_model_doc(
"scikit-learn API for XGBoost random forest regression.",
Expand Down Expand Up @@ -1910,7 +1977,7 @@ def _get_qid(
`qid` can be a special column of input `X` instead of a separated parameter, see
:py:meth:`fit` for more info.""",
)
class XGBRanker(XGBModel, XGBRankerMixIn):
class XGBRanker(XGBRankerMixIn, XGBModel):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any):
Expand Down
4 changes: 2 additions & 2 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import base64

# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches
# pylint: disable=fixme, protected-access, no-member, invalid-name
# pylint: disable=too-many-lines, too-many-branches
import json
import logging
import os
Expand Down
3 changes: 1 addition & 2 deletions python-package/xgboost/spark/estimator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Xgboost pyspark integration submodule for estimator API."""

# pylint: disable=too-many-ancestors
# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name
# pylint: disable=fixme, protected-access, no-member, invalid-name
# pylint: disable=unused-argument, too-many-locals

import warnings
Expand Down
1 change: 0 additions & 1 deletion python-package/xgboost/spark/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Dict

# pylint: disable=too-few-public-methods
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import Param, Params

Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_default_params_from_func(
return filtered_params_dict


class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods
class CommunicatorContext(CCtx):
"""Context with PySpark specific task ID."""

def __init__(self, context: BarrierTaskContext, **args: CollArgsVals) -> None:
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def is_binary(self) -> bool:
return self.max_rel == 1


class PBM: # pylint: disable=too-few-public-methods
class PBM:
"""Simulate click data with position bias model. There are other models available in
`ULTRA <https://github.com/ULTR-Community/ULTRA.git>`_ like the cascading model.

Expand Down
Loading
Loading