Skip to content

Commit

Permalink
Fix tests with the latest scikit-learn. (#11086)
Browse files Browse the repository at this point in the history
* Fix tests with the latest scikit-learn.

* dask.

* Remove scikit-learn pin

---------

Co-authored-by: Philip Hyunsu Cho <[email protected]>
  • Loading branch information
trivialfis and hcho3 authored Dec 11, 2024
1 parent aceb193 commit f4f3bd4
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion ops/conda_env/macos_cpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- numpy
- scipy
- llvm-openmp
- scikit-learn>=1.4.1,<=1.5.2
- scikit-learn>=1.4.1
- pandas
- matplotlib
- dask<=2024.10.0
Expand Down
2 changes: 1 addition & 1 deletion ops/conda_env/win64_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- numpy
- scipy
- matplotlib
- scikit-learn<=1.5.2
- scikit-learn
- pandas
- pytest
- boto3
Expand Down
6 changes: 3 additions & 3 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,7 +1616,7 @@ def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
)
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
class DaskXGBRegressor(XGBRegressorBase, DaskScikitLearnBase):
"""dummy doc string to workaround pylint, replaced by the decorator."""

async def _fit_async(
Expand Down Expand Up @@ -1707,7 +1707,7 @@ def fit(
"Implementation of the scikit-learn API for XGBoost classification.",
["estimators", "model"],
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
class DaskXGBClassifier(XGBClassifierBase, DaskScikitLearnBase):
# pylint: disable=missing-class-docstring
async def _fit_async(
self,
Expand Down Expand Up @@ -1911,7 +1911,7 @@ def _argmax(x: Any) -> Any:
For the dask implementation, group is not supported, use qid instead.
""",
)
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
class DaskXGBRanker(XGBRankerMixIn, DaskScikitLearnBase):
@_deprecate_positional_args
def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions tests/python-gpu/test_gpu_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def test_predict_leaf_dart(self, param: dict, dataset: tm.TestDataset) -> None:
)
@settings(deadline=None, max_examples=20, print_blob=True)
def test_predict_categorical_split(self, df):
from sklearn.metrics import mean_squared_error
from sklearn.metrics import root_mean_squared_error

df = df.astype("category")
x0, x1 = df["x0"].to_numpy(), df["x1"].to_numpy()
Expand All @@ -504,7 +504,7 @@ def test_predict_categorical_split(self, df):
)
bst.set_param({"device": "cuda:0"})
pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
rmse = root_mean_squared_error(y_true=y, y_pred=pred)
np.testing.assert_almost_equal(
rmse, eval_history["train"]["rmse"][-1], decimal=5
)
Expand Down

0 comments on commit f4f3bd4

Please sign in to comment.