Skip to content

Commit

Permalink
test check.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Mar 1, 2023
1 parent 290027c commit aabe787
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
13 changes: 3 additions & 10 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,7 +1967,7 @@ def fit(
provide qid.
qid :
Query ID for each training sample. Should have the size of n_samples. If
this is set to None, then user must provide group.
this is set to None, then user must provide group or a special column in X.
sample_weight :
Query group weights
Expand All @@ -1988,7 +1988,8 @@ def fit(
query groups in the ``i``-th pair in **eval_set**.
eval_qid :
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
pair in **eval_set**.
pair in **eval_set**. The special column convention in `X` applies to
validation datasets as well.
eval_metric : str, list of str, optional
.. deprecated:: 1.6.0
Expand Down Expand Up @@ -2031,15 +2032,7 @@ def fit(
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
# check if group information is provided

with config_context(verbosity=self.verbosity):
if eval_set is not None:
if eval_group is None and eval_qid is None:
raise ValueError(
"eval_group or eval_qid is required if eval_set is not None"
)

train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
X=X,
Expand Down
8 changes: 8 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def test_ranking_qid_df():
s = ranker.score(df, y)
assert s > 0.7

# works with validation datasets as well
valid_df = df.copy()
valid_df.iloc[0, 0] = 3.0
ranker.fit(df, y, eval_set=[(valid_df, y)])

# same as passing qid directly
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg")
ranker.fit(X, y, qid=q)
Expand All @@ -238,6 +243,9 @@ def test_ranking_qid_df():
results = cross_val_score(ranker, df, y)
assert len(results) == 5

with pytest.raises(ValueError, match="Either `group` or `qid`."):
ranker.fit(df, y, eval_set=[(X, y)])


def test_stacking_regression():
from sklearn.datasets import load_diabetes
Expand Down

0 comments on commit aabe787

Please sign in to comment.