Skip to content

Commit

Permalink
Compatibility for xgboost>=1.6.0 (#167)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Oct 29, 2021
1 parent 03e0a34 commit a1089c6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
44 changes: 33 additions & 11 deletions xgboost_ray/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,17 @@ def _ray_predict(
**compat_predict_kwargs,
)

def _ray_get_wrap_evaluation_matrices_compat_kwargs(self) -> dict:
def _ray_get_wrap_evaluation_matrices_compat_kwargs(
self, label_transform=None) -> dict:
ret = {}
if "label_transform" in inspect.signature(
_wrap_evaluation_matrices).parameters:
# XGBoost < 1.6.0
identity_func = lambda x: x # noqa
ret["label_transform"] = label_transform or identity_func
if hasattr(self, "enable_categorical"):
return {"enable_categorical": self.enable_categorical}
return {}
ret["enable_categorical"] = self.enable_categorical
return ret

# copied from the file in the top comment
# provided here for compatibility with legacy xgboost versions
Expand Down Expand Up @@ -450,8 +457,13 @@ def fit(
else:
obj = None

model, feval, params = self._configure_fit(xgb_model, eval_metric,
params)
try:
model, feval, params = self._configure_fit(xgb_model, eval_metric,
params)
except TypeError:
# XGBoost >= 1.6.0
model, feval, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds)

# remove those as they will be set in RayXGBoostActor
params.pop("n_jobs", None)
Expand Down Expand Up @@ -638,8 +650,13 @@ def fit(
params["objective"] = "multi:softprob"
params["num_class"] = self.n_classes_

model, feval, params = self._configure_fit(xgb_model, eval_metric,
params)
try:
model, feval, params = self._configure_fit(xgb_model, eval_metric,
params)
except TypeError:
# XGBoost >= 1.6.0
model, feval, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds)

if train_dmatrix is None:
train_dmatrix, evals = _wrap_evaluation_matrices(
Expand All @@ -656,13 +673,13 @@ def fit(
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
label_transform=label_transform,
# changed in xgboost-ray:
create_dmatrix=lambda **kwargs: RayDMatrix(**{
**kwargs,
**ray_dmatrix_params
}),
**self._ray_get_wrap_evaluation_matrices_compat_kwargs())
**self._ray_get_wrap_evaluation_matrices_compat_kwargs(
label_transform=label_transform))

# remove those as they will be set in RayXGBoostActor
params.pop("n_jobs", None)
Expand Down Expand Up @@ -970,8 +987,13 @@ def fit(
evals_result = {}
params = self.get_xgb_params()

model, feval, params = self._configure_fit(xgb_model, eval_metric,
params)
try:
model, feval, params = self._configure_fit(xgb_model, eval_metric,
params)
except TypeError:
# XGBoost >= 1.6.0
model, feval, params, early_stopping_rounds = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds)
if callable(feval):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker.")
Expand Down
23 changes: 13 additions & 10 deletions xgboost_ray/tests/test_sklearn_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

from xgboost_ray.main import XGBOOST_VERSION_TUPLE

has_label_encoder = (XGBOOST_VERSION_TUPLE >= (1, 0, 0)
and XGBOOST_VERSION_TUPLE < (1, 6, 0))


class XGBoostRaySklearnMatrixTest(unittest.TestCase):
def setUp(self):
Expand All @@ -26,9 +29,9 @@ def _init_ray(self):
if not ray.is_initialized():
ray.init(num_cpus=4)

@unittest.skipIf(XGBOOST_VERSION_TUPLE < (1, 0, 0),
@unittest.skipIf(not has_label_encoder,
f"not supported in xgb version {xgb.__version__}")
def testClassifier(self, n_class=2):
def testClassifierLabelEncoder(self, n_class=2):
self._init_ray()

from sklearn.datasets import load_digits
Expand Down Expand Up @@ -74,14 +77,14 @@ def testClassifier(self, n_class=2):
clf.predict(test_matrix)
clf.predict_proba(test_matrix)

@unittest.skipIf(XGBOOST_VERSION_TUPLE < (1, 0, 0),
@unittest.skipIf(not has_label_encoder,
f"not supported in xgb version {xgb.__version__}")
def testClassifierMulticlass(self):
self.testClassifier(n_class=3)
def testClassifierMulticlassLabelEncoder(self):
self.testClassifierLabelEncoder(n_class=3)

@unittest.skipIf(XGBOOST_VERSION_TUPLE >= (1, 0, 0),
@unittest.skipIf(has_label_encoder,
f"not supported in xgb version {xgb.__version__}")
def testClassifierLegacy(self, n_class=2):
def testClassifierNoLabelEncoder(self, n_class=2):
self._init_ray()

from sklearn.datasets import load_digits
Expand Down Expand Up @@ -118,10 +121,10 @@ def testClassifierLegacy(self, n_class=2):
clf.predict(test_matrix)
clf.predict_proba(test_matrix)

@unittest.skipIf(XGBOOST_VERSION_TUPLE >= (1, 0, 0),
@unittest.skipIf(has_label_encoder,
f"not supported in xgb version {xgb.__version__}")
def testClassifierMulticlassLegacy(self):
self.testClassifierLegacy(n_class=3)
def testClassifierMulticlassNoLabelEncoder(self):
self.testClassifierNoLabelEncoder(n_class=3)

def testRegressor(self):
self._init_ray()
Expand Down

0 comments on commit a1089c6

Please sign in to comment.