Skip to content

Commit

Permalink
Tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 30, 2023
1 parent 2ce97b5 commit ad34df1
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 25 deletions.
4 changes: 1 addition & 3 deletions python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,7 @@ class ClickFold:


class RelDataCV(NamedTuple):
"""Simple data struct for holding a train-test split of a learning to rank dataset.
"""
"""Simple data struct for holding a train-test split of a learning to rank dataset."""

train: RelData
test: RelData
Expand Down
54 changes: 53 additions & 1 deletion python-package/xgboost/testing/metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,61 @@
"""Tests for evaluation metrics."""
from typing import Dict
from typing import Dict, List

import numpy as np
import pytest

import xgboost as xgb
from xgboost.compat import concat
from xgboost.core import _parse_eval_str


def check_precision_score(tree_method: str) -> None:
"""Test for precision with ranking and classification."""
datasets = pytest.importorskip("sklearn.datasets")

X, y = datasets.make_classification(
n_samples=1024, n_features=4, n_classes=2, random_state=2023
)
qid = np.zeros(shape=y.shape) # same group

ltr = xgb.XGBRanker(n_estimators=2, tree_method=tree_method)
ltr.fit(X, y, qid=qid)

# re-generate so that XGBoost doesn't evaluate the result to 1.0
X, y = datasets.make_classification(
n_samples=512, n_features=4, n_classes=2, random_state=1994
)

ltr.set_params(eval_metric="pre@32")
result = _parse_eval_str(
ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y), "Xy")])
)
score_0 = result[1][1]

X_list = []
y_list = []
n_query_groups = 3
q_list: List[np.ndarray] = []
for i in range(n_query_groups):
# same for all groups
X, y = datasets.make_classification(
n_samples=512, n_features=4, n_classes=2, random_state=1994
)
X_list.append(X)
y_list.append(y)
q = np.full(shape=y.shape, fill_value=i, dtype=np.uint64)
q_list.append(q)

qid = concat(q_list)
X = concat(X_list)
y = concat(y_list)

result = _parse_eval_str(
ltr.get_booster().eval_set(evals=[(xgb.DMatrix(X, y, qid=qid), "Xy")])
)
assert result[1][0].endswith("pre@32")
score_1 = result[1][1]
assert score_1 == score_0


def check_quantile_error(tree_method: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def check_cmd_print_failure_assistance(cmd: List[str]) -> bool:

subprocess.run([cmd[0], "--version"])
msg = """
Please run the following command on your machine to address the formatting error:
Please run the following command on your machine to address the error:
"""
msg += " ".join(cmd)
Expand Down
5 changes: 4 additions & 1 deletion tests/python-gpu/test_gpu_eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import xgboost
from xgboost import testing as tm
from xgboost.testing.metrics import check_quantile_error
from xgboost.testing.metrics import check_precision_score, check_quantile_error

sys.path.append("tests/python")
import test_eval_metrics as test_em # noqa
Expand Down Expand Up @@ -59,6 +59,9 @@ def test_pr_auc_multi(self):
def test_pr_auc_ltr(self):
self.cpu_test.run_pr_auc_ltr("gpu_hist")

def test_precision_score(self):
check_precision_score("gpu_hist")

@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_error(self) -> None:
check_quantile_error("gpu_hist")
21 changes: 2 additions & 19 deletions tests/python/test_eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import xgboost as xgb
from xgboost import testing as tm
from xgboost.testing.metrics import check_quantile_error
from xgboost.testing.metrics import check_precision_score, check_quantile_error

rng = np.random.RandomState(1337)

Expand Down Expand Up @@ -318,24 +318,7 @@ def test_pr_auc_ltr(self):
self.run_pr_auc_ltr("hist")

def test_precision_score(self):
from sklearn.metrics import precision_score
from sklearn.datasets import make_classification

x, y = make_classification(n_samples=128, n_features=4, n_classes=2)
qid = np.zeros(shape=y.shape) # same group

ltr = xgb.XGBRanker()
ltr.fit(x, y, qid=qid)
p = ltr.predict(x)
sorted_idx = np.argsort(p)

score_0 = precision_score(y, y[sorted_idx])
print(score_0)

Xy = xgb.DMatrix(x, y)
ltr.set_params(eval_metric="pre")
e = ltr.get_booster().eval_set(evals=[(Xy, "Xy")])
print(e)
check_precision_score("hist")

@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_error(self) -> None:
Expand Down

0 comments on commit ad34df1

Please sign in to comment.