diff --git a/.gitignore b/.gitignore index 082e85e2c67f..d53f3f1f255d 100644 --- a/.gitignore +++ b/.gitignore @@ -144,11 +144,13 @@ credentials.csv .bloop # python tests +*.bin demo/**/*.txt *.dmatrix .hypothesis __MACOSX/ model*.json +/tests/python/models/models/ # R tests *.htm diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 0420b2672e1e..7d79d6726eec 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -63,6 +63,8 @@ disable = [ "import-error", "attribute-defined-outside-init", "import-outside-toplevel", + "too-few-public-methods", + "too-many-ancestors", "too-many-nested-blocks", "unsubscriptable-object", "useless-object-inheritance" diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 26399f0da2f8..bcd3d8d4ee54 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -43,32 +43,43 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool: # sklearn try: + from sklearn import __version__ as _sklearn_version from sklearn.base import BaseEstimator as XGBModelBase from sklearn.base import ClassifierMixin as XGBClassifierBase from sklearn.base import RegressorMixin as XGBRegressorBase - from sklearn.preprocessing import LabelEncoder try: - from sklearn.model_selection import KFold as XGBKFold from sklearn.model_selection import StratifiedKFold as XGBStratifiedKFold except ImportError: - from sklearn.cross_validation import KFold as XGBKFold from sklearn.cross_validation import StratifiedKFold as XGBStratifiedKFold + # sklearn.utils Tags types can be imported unconditionally once + # xgboost's minimum scikit-learn version is 1.6 or higher + try: + from sklearn.utils import Tags as _sklearn_Tags + except ImportError: + _sklearn_Tags = object + SKLEARN_INSTALLED = True except ImportError: SKLEARN_INSTALLED = False # used for compatibility without sklearn - XGBModelBase = object - XGBClassifierBase = object - XGBRegressorBase = object - LabelEncoder = object + class XGBModelBase: # type: ignore[no-redef] + """Dummy class for sklearn.base.BaseEstimator.""" + + class XGBClassifierBase: # type: ignore[no-redef] + """Dummy class for sklearn.base.ClassifierMixin.""" + + class XGBRegressorBase: # type: ignore[no-redef] + """Dummy class for sklearn.base.RegressorMixin.""" - XGBKFold = None XGBStratifiedKFold = None + _sklearn_Tags = object + _sklearn_version = object + _logger = logging.getLogger(__name__) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index c2034652322d..99072c7a3c58 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -430,7 +430,7 @@ def c_array( def from_array_interface(interface: dict) -> NumpyOrCupy: """Convert array interface to numpy or cupy array""" - class Array: # pylint: disable=too-few-public-methods + class Array: """Wrapper type for communicating with numpy and cupy.""" _interface: Optional[dict] = None diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index 76fcc1a6ad92..e0221310bc51 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -1,7 +1,6 @@ # pylint: disable=too-many-arguments, too-many-locals # pylint: disable=missing-class-docstring, invalid-name # pylint: disable=too-many-lines -# pylint: disable=too-few-public-methods """ Dask extensions for distributed training ---------------------------------------- diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 25448657c8ad..c337505f7641 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -35,6 +35,8 @@ XGBClassifierBase, XGBModelBase, XGBRegressorBase, + _sklearn_Tags, + _sklearn_version, import_cupy, ) from .config import config_context @@ -54,7 +56,7 @@ from .training import train -class XGBRankerMixIn: # pylint: disable=too-few-public-methods +class XGBRankerMixIn: """MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base classes. @@ -79,7 +81,7 @@ def _can_use_qdm(tree_method: Optional[str], device: Optional[str]) -> bool: return tree_method in ("hist", "gpu_hist", None, "auto") and not_sycl -class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods +class _SklObjWProto(Protocol): def __call__( self, y_true: ArrayLike, @@ -805,6 +807,41 @@ def _more_tags(self) -> Dict[str, bool]: tags["non_deterministic"] = True return tags + @staticmethod + def _update_sklearn_tags_from_dict( + *, + tags: _sklearn_Tags, + tags_dict: Dict[str, bool], + ) -> _sklearn_Tags: + """Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes. + + ``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags. + ref: https://github.com/scikit-learn/scikit-learn/pull/29677 + + This method handles updating that instance based on the values in ``self._more_tags()``. + """ + tags.non_deterministic = tags_dict.get("non_deterministic", False) + tags.no_validation = tags_dict["no_validation"] + tags.input_tags.allow_nan = tags_dict["allow_nan"] + return tags + + def __sklearn_tags__(self) -> _sklearn_Tags: + # XGBModelBase.__sklearn_tags__() cannot be called unconditionally, + # because that method isn't defined for scikit-learn<1.6 + if not hasattr(XGBModelBase, "__sklearn_tags__"): + err_msg = ( + "__sklearn_tags__() should not be called when using scikit-learn<1.6. " + f"Detected version: {_sklearn_version}" + ) + raise AttributeError(err_msg) + + # take whatever tags are provided by BaseEstimator, then modify + # them with XGBoost-specific values + return self._update_sklearn_tags_from_dict( + tags=super().__sklearn_tags__(), # pylint: disable=no-member + tags_dict=self._more_tags(), + ) + def __sklearn_is_fitted__(self) -> bool: return hasattr(self, "_Booster") @@ -898,13 +935,27 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]: """Get parameters.""" # Based on: https://stackoverflow.com/questions/59248211 # The basic flow in `get_params` is: - # 0. Return parameters in subclass first, by using inspect. - # 1. Return parameters in `XGBModel` (the base class). + # 0. Return parameters in subclass (self.__class__) first, by using inspect. + # 1. Return parameters in all parent classes (especially `XGBModel`). # 2. Return whatever in `**kwargs`. # 3. Merge them. + # + # This needs to accommodate being called recursively in the following + # inheritance graphs (and similar for classification and ranking): + # + # XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator + # XGBRegressor -> XGBModel -> BaseEstimator + # XGBModel -> BaseEstimator + # params = super().get_params(deep) cp = copy.copy(self) - cp.__class__ = cp.__class__.__bases__[0] + # If the immediate parent defines get_params(), use that. + if callable(getattr(cp.__class__.__bases__[0], "get_params", None)): + cp.__class__ = cp.__class__.__bases__[0] + # Otherwise, skip it and assume the next class will have it. + # This is here primarily for cases where the first class in MRO is a scikit-learn mixin. + else: + cp.__class__ = cp.__class__.__bases__[1] params.update(cp.__class__.get_params(cp, deep)) # if kwargs is a dict, update params accordingly if hasattr(self, "kwargs") and isinstance(self.kwargs, dict): @@ -1481,7 +1532,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) -> Number of boosting rounds. """, ) -class XGBClassifier(XGBModel, XGBClassifierBase): +class XGBClassifier(XGBClassifierBase, XGBModel): # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes @_deprecate_positional_args def __init__( @@ -1497,6 +1548,12 @@ def _more_tags(self) -> Dict[str, bool]: tags["multilabel"] = True return tags + def __sklearn_tags__(self) -> _sklearn_Tags: + tags = super().__sklearn_tags__() + tags_dict = self._more_tags() + tags.classifier_tags.multi_label = tags_dict["multilabel"] + return tags + @_deprecate_positional_args def fit( self, @@ -1769,7 +1826,7 @@ def fit( "Implementation of the scikit-learn API for XGBoost regression.", ["estimators", "model", "objective"], ) -class XGBRegressor(XGBModel, XGBRegressorBase): +class XGBRegressor(XGBRegressorBase, XGBModel): # pylint: disable=missing-docstring @_deprecate_positional_args def __init__( @@ -1783,6 +1840,13 @@ def _more_tags(self) -> Dict[str, bool]: tags["multioutput_only"] = False return tags + def __sklearn_tags__(self) -> _sklearn_Tags: + tags = super().__sklearn_tags__() + tags_dict = self._more_tags() + tags.target_tags.multi_output = tags_dict["multioutput"] + tags.target_tags.single_output = not tags_dict["multioutput_only"] + return tags + @xgboost_model_doc( "scikit-learn API for XGBoost random forest regression.", @@ -1910,7 +1974,7 @@ def _get_qid( `qid` can be a special column of input `X` instead of a separated parameter, see :py:meth:`fit` for more info.""", ) -class XGBRanker(XGBModel, XGBRankerMixIn): +class XGBRanker(XGBRankerMixIn, XGBModel): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @_deprecate_positional_args def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any): diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 166acbe1764b..32d7c1e490c8 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -2,8 +2,8 @@ import base64 -# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name -# pylint: disable=too-few-public-methods, too-many-lines, too-many-branches +# pylint: disable=fixme, protected-access, no-member, invalid-name +# pylint: disable=too-many-lines, too-many-branches import json import logging import os diff --git a/python-package/xgboost/spark/estimator.py b/python-package/xgboost/spark/estimator.py index f53ef72eb99e..011f7ea0b715 100644 --- a/python-package/xgboost/spark/estimator.py +++ b/python-package/xgboost/spark/estimator.py @@ -1,7 +1,6 @@ """Xgboost pyspark integration submodule for estimator API.""" -# pylint: disable=too-many-ancestors -# pylint: disable=fixme, too-many-ancestors, protected-access, no-member, invalid-name +# pylint: disable=fixme, protected-access, no-member, invalid-name # pylint: disable=unused-argument, too-many-locals import warnings diff --git a/python-package/xgboost/spark/params.py b/python-package/xgboost/spark/params.py index a177c73fe413..f173d3301286 100644 --- a/python-package/xgboost/spark/params.py +++ b/python-package/xgboost/spark/params.py @@ -2,7 +2,6 @@ from typing import Dict -# pylint: disable=too-few-public-methods from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import Param, Params diff --git a/python-package/xgboost/spark/utils.py b/python-package/xgboost/spark/utils.py index c96ec284abe3..e0d3e094a805 100644 --- a/python-package/xgboost/spark/utils.py +++ b/python-package/xgboost/spark/utils.py @@ -47,7 +47,7 @@ def _get_default_params_from_func( return filtered_params_dict -class CommunicatorContext(CCtx): # pylint: disable=too-few-public-methods +class CommunicatorContext(CCtx): """Context with PySpark specific task ID.""" def __init__(self, context: BarrierTaskContext, **args: CollArgsVals) -> None: diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index d9a4c85af326..34f55c077a85 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -566,7 +566,7 @@ def is_binary(self) -> bool: return self.max_rel == 1 -class PBM: # pylint: disable=too-few-public-methods +class PBM: """Simulate click data with position bias model. There are other models available in `ULTRA `_ like the cascading model. diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 6c4540301432..937e59095863 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -2,6 +2,7 @@ import os import pickle import random +import re import tempfile import warnings from typing import Callable, Optional @@ -825,6 +826,32 @@ def get_tm(clf: xgb.XGBClassifier) -> str: assert clf.get_params()["tree_method"] is None +def test_get_params_works_as_expected(): + # XGBModel -> BaseEstimator + params = xgb.XGBModel(max_depth=2).get_params() + assert params["max_depth"] == 2 + # 'objective' defaults to None in the signature of XGBModel + assert params["objective"] is None + + # XGBRegressor -> XGBModel -> BaseEstimator + params = xgb.XGBRegressor(max_depth=3).get_params() + assert params["max_depth"] == 3 + # 'objective' defaults to 'reg:squarederror' in the signature of XGBRegressor + assert params["objective"] == "reg:squarederror" + # 'colsample_bynode' defaults to 'None' for XGBModel (which XGBRegressor inherits from), so it + # should be in get_params() output + assert params["colsample_bynode"] is None + + # XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator + params = xgb.XGBRFRegressor(max_depth=4, objective="reg:tweedie").get_params() + assert params["max_depth"] == 4 + # 'objective' is a keyword argument for XGBRegressor, so it should be in get_params() output + # ... but values passed through kwargs should override the default from the signature of XGBRegressor + assert params["objective"] == "reg:tweedie" + # 'colsample_bynode' defaults to 0.8 for XGBRFRegressor...that should be preferred to the None from XGBRegressor + assert params["colsample_bynode"] == 0.8 + + def test_kwargs_error(): params = {'updater': 'grow_gpu_hist', 'subsample': .5, 'n_jobs': -1} with pytest.raises(TypeError): @@ -1517,7 +1544,7 @@ def test_tags() -> None: assert tags["multioutput"] is True assert tags["multioutput_only"] is False - for clf in [xgb.XGBClassifier()]: + for clf in [xgb.XGBClassifier(), xgb.XGBRFClassifier()]: tags = clf._more_tags() assert "multioutput" not in tags assert tags["multilabel"] is True @@ -1526,6 +1553,58 @@ def test_tags() -> None: assert "multioutput" not in tags +# the try-excepts in this test should be removed once xgboost's +# minimum supported scikit-learn version is at least 1.6 +def test_sklearn_tags(): + + def _assert_has_xgbmodel_tags(tags): + # values set by XGBModel.__sklearn_tags__() + assert tags.non_deterministic is False + assert tags.no_validation is True + assert tags.input_tags.allow_nan is True + + for reg in [xgb.XGBRegressor(), xgb.XGBRFRegressor()]: + try: + # if no AttributeError was thrown, we must be using scikit-learn>=1.6, + # and so the actual effects of __sklearn_tags__() should be tested + tags = reg.__sklearn_tags__() + _assert_has_xgbmodel_tags(tags) + # regressor-specific values + assert tags.estimator_type == "regressor" + assert tags.regressor_tags is not None + assert tags.classifier_tags is None + assert tags.target_tags.multi_output is True + assert tags.target_tags.single_output is True + except AttributeError as err: + # only the exact error we expected to be raised should be raised + assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) + + for clf in [xgb.XGBClassifier(), xgb.XGBRFClassifier()]: + try: + # if no AttributeError was thrown, we must be using scikit-learn>=1.6, + # and so the actual effects of __sklearn_tags__() should be tested + tags = clf.__sklearn_tags__() + _assert_has_xgbmodel_tags(tags) + # classifier-specific values + assert tags.estimator_type == "classifier" + assert tags.regressor_tags is None + assert tags.classifier_tags is not None + assert tags.classifier_tags.multi_label is True + except AttributeError as err: + # only the exact error we expected to be raised should be raised + assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) + + for rnk in [xgb.XGBRanker(),]: + try: + # if no AttributeError was thrown, we must be using scikit-learn>=1.6, + # and so the actual effects of __sklearn_tags__() should be tested + tags = rnk.__sklearn_tags__() + _assert_has_xgbmodel_tags(tags) + except AttributeError as err: + # only the exact error we expected to be raised should be raised + assert bool(re.search(r"__sklearn_tags__.* should not be called", str(err))) + + def test_doc_link() -> None: for est in [ xgb.XGBRegressor(),