diff --git a/demo/dask/dask_learning_to_rank.py b/demo/dask/dask_learning_to_rank.py
new file mode 100644
index 000000000000..c08450fec56e
--- /dev/null
+++ b/demo/dask/dask_learning_to_rank.py
@@ -0,0 +1,201 @@
+"""
+Learning to rank with the Dask Interface
+========================================
+
+ .. versionadded:: 3.0.0
+
+This is a demonstration of using XGBoost for learning to rank tasks using the
+MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
+`description page `_.
+
+See :ref:`ltr-dist` for a general description for distributed learning to rank and
+:ref:`ltr-dask` for Dask-specific features.
+
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+from contextlib import contextmanager
+from typing import Generator
+
+import dask
+import numpy as np
+from dask import dataframe as dd
+from distributed import Client, LocalCluster, wait
+from sklearn.datasets import load_svmlight_file
+
+from xgboost import dask as dxgb
+
+
+def load_mslr_10k(
+ device: str, data_path: str, cache_path: str
+) -> tuple[dd.DataFrame, dd.DataFrame, dd.DataFrame]:
+ """Load the MSLR10k dataset from data_path and save parquet files in the cache_path."""
+ root_path = os.path.expanduser(args.data)
+ cache_path = os.path.expanduser(args.cache)
+
+ # Use only the Fold1 for demo:
+ # Train, Valid, Test
+ # {S1,S2,S3}, S4, S5
+ fold = 1
+
+ if not os.path.exists(cache_path):
+ os.mkdir(cache_path)
+ fold_path = os.path.join(root_path, f"Fold{fold}")
+ train_path = os.path.join(fold_path, "train.txt")
+ valid_path = os.path.join(fold_path, "vali.txt")
+ test_path = os.path.join(fold_path, "test.txt")
+
+ X_train, y_train, qid_train = load_svmlight_file(
+ train_path, query_id=True, dtype=np.float32
+ )
+ columns = [f"f{i}" for i in range(X_train.shape[1])]
+ X_train = dd.from_array(X_train.toarray(), columns=columns)
+ y_train = y_train.astype(np.int32)
+ qid_train = qid_train.astype(np.int32)
+
+ X_train["y"] = dd.from_array(y_train)
+ X_train["qid"] = dd.from_array(qid_train)
+ X_train.to_parquet(os.path.join(cache_path, "train"), engine="pyarrow")
+
+ X_valid, y_valid, qid_valid = load_svmlight_file(
+ valid_path, query_id=True, dtype=np.float32
+ )
+ X_valid = dd.from_array(X_valid.toarray(), columns=columns)
+ y_valid = y_valid.astype(np.int32)
+ qid_valid = qid_valid.astype(np.int32)
+
+ X_valid["y"] = dd.from_array(y_valid)
+ X_valid["qid"] = dd.from_array(qid_valid)
+ X_valid.to_parquet(os.path.join(cache_path, "valid"), engine="pyarrow")
+
+ X_test, y_test, qid_test = load_svmlight_file(
+ test_path, query_id=True, dtype=np.float32
+ )
+
+ X_test = dd.from_array(X_test.toarray(), columns=columns)
+ y_test = y_test.astype(np.int32)
+ qid_test = qid_test.astype(np.int32)
+
+ X_test["y"] = dd.from_array(y_test)
+ X_test["qid"] = dd.from_array(qid_test)
+ X_test.to_parquet(os.path.join(cache_path, "test"), engine="pyarrow")
+
+ df_train = dd.read_parquet(
+ os.path.join(cache_path, "train"), calculate_divisions=True
+ )
+ df_valid = dd.read_parquet(
+ os.path.join(cache_path, "valid"), calculate_divisions=True
+ )
+ df_test = dd.read_parquet(
+ os.path.join(cache_path, "test"), calculate_divisions=True
+ )
+
+ return df_train, df_valid, df_test
+
+
+def ranking_demo(client: Client, args: argparse.Namespace) -> None:
+ """Learning to rank with data sorted locally."""
+ df_tr, df_va, _ = load_mslr_10k(args.device, args.data, args.cache)
+
+ X_train: dd.DataFrame = df_tr[df_tr.columns.difference(["y", "qid"])]
+ y_train = df_tr[["y", "qid"]]
+ Xy_train = dxgb.DaskQuantileDMatrix(client, X_train, y_train.y, qid=y_train.qid)
+
+ X_valid: dd.DataFrame = df_va[df_va.columns.difference(["y", "qid"])]
+ y_valid = df_va[["y", "qid"]]
+ Xy_valid = dxgb.DaskQuantileDMatrix(
+ client, X_valid, y_valid.y, qid=y_valid.qid, ref=Xy_train
+ )
+ # Upon training, you will see a performance warning about sorting data based on
+ # query groups.
+ dxgb.train(
+ client,
+ {"objective": "rank:ndcg", "device": args.device},
+ Xy_train,
+ evals=[(Xy_train, "Train"), (Xy_valid, "Valid")],
+ num_boost_round=100,
+ )
+
+
+def ranking_wo_split_demo(client: Client, args: argparse.Namespace) -> None:
+ """Learning to rank with data partitioned according to query groups."""
+ df_tr, df_va, df_te = load_mslr_10k(args.device, args.data, args.cache)
+
+ X_tr = df_tr[df_tr.columns.difference(["y", "qid"])]
+ X_va = df_va[df_va.columns.difference(["y", "qid"])]
+
+ # `allow_group_split=False` makes sure data is partitioned according to the query
+ # groups.
+ ltr = dxgb.DaskXGBRanker(allow_group_split=False, device=args.device)
+ ltr.client = client
+ ltr = ltr.fit(
+ X_tr,
+ df_tr.y,
+ qid=df_tr.qid,
+ eval_set=[(X_tr, df_tr.y), (X_va, df_va.y)],
+ eval_qid=[df_tr.qid, df_va.qid],
+ verbose=True,
+ )
+
+ df_te = df_te.persist()
+ wait([df_te])
+
+ X_te = df_te[df_te.columns.difference(["y", "qid"])]
+ predt = ltr.predict(X_te)
+ y = client.compute(df_te.y)
+ wait([predt, y])
+
+
+@contextmanager
+def gen_client(device: str) -> Generator[Client, None, None]:
+ match device:
+ case "cuda":
+ from dask_cuda import LocalCUDACluster
+
+ with LocalCUDACluster() as cluster:
+ with Client(cluster) as client:
+ with dask.config.set(
+ {
+ "array.backend": "cupy",
+ "dataframe.backend": "cudf",
+ }
+ ):
+ yield client
+ case "cpu":
+ with LocalCluster() as cluster:
+ with Client(cluster) as client:
+ yield client
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Demonstration of learning to rank using XGBoost."
+ )
+ parser.add_argument(
+ "--data",
+ type=str,
+ help="Root directory of the MSLR-WEB10K data.",
+ required=True,
+ )
+ parser.add_argument(
+ "--cache",
+ type=str,
+ help="Directory for caching processed data.",
+ required=True,
+ )
+ parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu")
+ parser.add_argument(
+ "--no-split",
+ action="store_true",
+ help="Flag to indicate query groups should not be split.",
+ )
+ args = parser.parse_args()
+
+ with gen_client(args.device) as client:
+ if args.no_split:
+ ranking_wo_split_demo(client, args)
+ else:
+ ranking_demo(client, args)
diff --git a/demo/guide-python/learning_to_rank.py b/demo/guide-python/learning_to_rank.py
index b131b31f76f6..fbc1f44baf50 100644
--- a/demo/guide-python/learning_to_rank.py
+++ b/demo/guide-python/learning_to_rank.py
@@ -12,8 +12,8 @@
train on relevance degree, and the second part simulates click data and enable the
position debiasing training.
-For an overview of learning to rank in XGBoost, please see
-:doc:`Learning to Rank `.
+For an overview of learning to rank in XGBoost, please see :doc:`Learning to Rank
+`.
"""
from __future__ import annotations
@@ -31,7 +31,7 @@
from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples
-def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
+def load_mslr_10k(data_path: str, cache_path: str) -> RelDataCV:
"""Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
Returns
@@ -89,7 +89,7 @@ def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
def ranking_demo(args: argparse.Namespace) -> None:
"""Demonstration for learning to rank with relevance degree."""
- data = load_mlsr_10k(args.data, args.cache)
+ data = load_mslr_10k(args.data, args.cache)
# Sort data according to query index
X_train, y_train, qid_train = data.train
@@ -123,7 +123,7 @@ def ranking_demo(args: argparse.Namespace) -> None:
def click_data_demo(args: argparse.Namespace) -> None:
"""Demonstration for learning to rank with click data."""
- data = load_mlsr_10k(args.data, args.cache)
+ data = load_mslr_10k(args.data, args.cache)
train, test = simulate_clicks(data)
assert test is not None
diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst
index 6e68d83a0083..036b1e725d47 100644
--- a/doc/tutorials/dask.rst
+++ b/doc/tutorials/dask.rst
@@ -355,15 +355,18 @@ Working with asyncio
.. versionadded:: 1.2.0
-XGBoost's dask interface supports the new ``asyncio`` in Python and can be integrated into
-asynchronous workflows. For using dask with asynchronous operations, please refer to
-`this dask example `_ and document in
-`distributed `_. To use XGBoost's
-dask interface asynchronously, the ``client`` which is passed as an argument for training and
-prediction must be operating in asynchronous mode by specifying ``asynchronous=True`` when the
-``client`` is created (example below). All functions (including ``DaskDMatrix``) provided
-by the functional interface will then return coroutines which can then be awaited to retrieve
-their result.
+XGBoost's dask interface supports the new :py:mod:`asyncio` in Python and can be
+integrated into asynchronous workflows. For using dask with asynchronous operations,
+please refer to `this dask example
+`_ and document in `distributed
+`_. To use XGBoost's Dask
+interface asynchronously, the ``client`` which is passed as an argument for training and
+prediction must be operating in asynchronous mode by specifying ``asynchronous=True`` when
+the ``client`` is created (example below). All functions (including ``DaskDMatrix``)
+provided by the functional interface will then return coroutines which can then be awaited
+to retrieve their result. Please note that XGBoost is a compute-bounded application, where
+parallelism is more important than concurrency. The support for `asyncio` is more about
+compatibility instead of performance gain.
Functional interface:
@@ -526,6 +529,47 @@ See https://github.com/coiled/dask-xgboost-nyctaxi for a set of examples of usin
with dask and optuna.
+.. _ltr-dask:
+
+****************
+Learning to Rank
+****************
+
+ .. versionadded:: 3.0.0
+
+ .. note::
+
+ Position debiasing is not yet supported.
+
+There are two operation modes in the Dask learning to rank for performance reasons. The
+difference is whether a distributed global sort is needed. Please see :ref:`ltr-dist` for
+how ranking works with distributed training in general. Below we will discuss some of the
+Dask-specific features.
+
+First, if you use the :py:class:`~xgboost.dask.DaskQuantileDMatrix` interface or the
+:py:class:`~xgboost.dask.DaskXGBRanker` with ``allow_group_split`` set to ``True``,
+XGBoost will try to sort and group the samples for each worker based on the query ID. This
+mode tries to skip the global sort and sort only worker-local data, and hence no
+inter-worker data shuffle. Please note that even worker-local sort is costly, particularly
+in terms of memory usage as there's no spilling when
+:py:meth:`~pandas.DataFrame.sort_values` is used, and we need to concatenate the
+data. XGBoost first checks whether the QID is already sorted before actually performing
+the sorting operation. One can choose this if the query groups are relatively consecutive,
+meaning most of the samples within a query group are close to each other and are likely to
+be resided to the same worker. Don't use this if you have performed a random shuffle on
+your data.
+
+If the input data is random, then there's no way we can guarantee most of data within the
+same group being in the same worker. For large query groups, this might not be an
+issue. But for small query groups, it's possible that each worker gets only one or two
+samples from their group for all groups, which can lead to disastrous performance. In that
+case, we can partition the data according to query group, which is the default behavior of
+the :py:class:`~xgboost.dask.DaskXGBRanker` unless the ``allow_group_split`` is set to
+``True``. This mode performs a sort and a groupby on the entire dataset in addition to an
+encoding operation for the query group IDs. Along with partition fragmentation, this
+option can lead to slow performance. See
+:ref:`sphx_glr_python_dask-examples_dask_learning_to_rank.py` for a worked example.
+
.. _tracker-ip:
***************
diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst
index 4d2cbad4aa47..8743a672d219 100644
--- a/doc/tutorials/learning_to_rank.rst
+++ b/doc/tutorials/learning_to_rank.rst
@@ -165,10 +165,26 @@ On the other hand, if you have comparatively small amount of training data:
For any method chosen, you can modify ``lambdarank_num_pair_per_sample`` to control the amount of pairs generated.
+.. _ltr-dist:
+
********************
Distributed Training
********************
-XGBoost implements distributed learning-to-rank with integration of multiple frameworks including Dask, Spark, and PySpark. The interface is similar to the single-node counterpart. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used. As a result, users don't need to partition the data based on query groups. As long as each data partition is correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly.
+
+XGBoost implements distributed learning-to-rank with integration of multiple frameworks
+including :doc:`Dask `, :doc:`Spark `, and
+:doc:`PySpark `. The interface is similar to the single-node
+counterpart. Please refer to document of the respective XGBoost interface for details.
+
+.. warning::
+
+ Position-debiasing is not yet supported for existing distributed interfaces.
+
+XGBoost works with collective operations, which means data is scattered to multiple workers. We can divide the data partitions by query group and ensure no query group is split among workers. However, this requires a costly sort and groupby operation and might only be necessary for selected use cases. Splitting and scattering a query group to multiple workers is theoretically sound but can affect the model's accuracy. If there are only a small number of groups sitting at the boundaries of workers, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used.
+
+For a longer explanation, assuming the pairwise ranking method is used, we calculate the gradient based on relevance degree by constructing pairs within a query group. If a single query group is split among workers and we use worker-local data for gradient calculation, then we are simply sampling pairs from a smaller group for each worker to calculate the gradient and the evaluation metric. The comparison between each pair doesn't change because a group is split into sub-groups, what changes is the number of total and effective pairs and normalizers like `IDCG`. One can generate more pairs from a large group than it's from two smaller subgroups. As a result, the obtained gradient is still valid from a theoretical standpoint but might not be optimal. As long as each data partitions within a worker are correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly. And both the (Py)Spark interface and the Dask interface can sort the data according to query ID, please see respected tutorials for more information.
+
+However, it's possible that a distributed framework shuffles the data during map reduce and splits every query group into multiple workers. In that case, the performance would be disastrous. As a result, it depends on the data and the framework for whether a sorted groupby is needed.
*******************
Reproducible Result
diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py
index b21cf80aea56..07924623955d 100644
--- a/python-package/xgboost/core.py
+++ b/python-package/xgboost/core.py
@@ -3215,11 +3215,7 @@ def trees_to_dataframe(self, fmap: PathLike = "") -> DataFrame:
}
)
- if callable(getattr(df, "sort_values", None)):
- # pylint: disable=no-member
- return df.sort_values(["Tree", "Node"]).reset_index(drop=True)
- # pylint: disable=no-member
- return df.sort(["Tree", "Node"]).reset_index(drop=True)
+ return df.sort_values(["Tree", "Node"]).reset_index(drop=True)
def _assign_dmatrix_features(self, data: DMatrix) -> None:
if data.num_row() == 0:
diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py
index e0221310bc51..6c92e9205dc9 100644
--- a/python-package/xgboost/dask/__init__.py
+++ b/python-package/xgboost/dask/__init__.py
@@ -72,6 +72,7 @@
Tuple,
TypeAlias,
TypedDict,
+ TypeGuard,
TypeVar,
Union,
)
@@ -117,7 +118,7 @@
)
from ..tracker import RabitTracker
from ..training import train as worker_train
-from .data import _create_dmatrix, _create_quantile_dmatrix
+from .data import _create_dmatrix, _create_quantile_dmatrix, no_group_split
from .utils import get_address_from_user, get_n_threads
_DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series]
@@ -1898,10 +1899,21 @@ def _argmax(x: Any) -> Any:
""",
["estimators", "model"],
+ extra_parameters="""
+ allow_group_split :
+
+ .. versionadded:: 3.0.0
+
+ Whether a query group can be split among multiple workers. When set to `False`,
+ inputs must be Dask dataframes or series. If you have many small query groups,
+ this can significantly increase the fragmentation of the data, and the internal
+ DMatrix construction can take longer.
+
+""",
end_note="""
.. note::
- For dask implementation, group is not supported, use qid instead.
+ For the dask implementation, group is not supported, use qid instead.
""",
)
class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@@ -1910,36 +1922,36 @@ def __init__(
self,
*,
objective: str = "rank:pairwise",
+ allow_group_split: bool = False,
coll_cfg: Optional[CollConfig] = None,
**kwargs: Any,
) -> None:
if callable(objective):
raise ValueError("Custom objective function not supported by XGBRanker.")
+ self.allow_group_split = allow_group_split
super().__init__(objective=objective, coll_cfg=coll_cfg, **kwargs)
+ def _wrapper_params(self) -> Set[str]:
+ params = super()._wrapper_params()
+ params.add("allow_group_split")
+ return params
+
async def _fit_async(
self,
X: _DataT,
y: _DaskCollection,
*,
- group: Optional[_DaskCollection],
qid: Optional[_DaskCollection],
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]],
sample_weight_eval_set: Optional[Sequence[_DaskCollection]],
base_margin_eval_set: Optional[Sequence[_DaskCollection]],
- eval_group: Optional[Sequence[_DaskCollection]],
eval_qid: Optional[Sequence[_DaskCollection]],
verbose: Union[int, bool],
xgb_model: Optional[Union[XGBModel, Booster]],
feature_weights: Optional[_DaskCollection],
) -> "DaskXGBRanker":
- msg = "Use the `qid` instead of the `group` with the dask interface."
- if not (group is None and eval_group is None):
- raise ValueError(msg)
- if qid is None:
- raise ValueError("`qid` is required for ranking.")
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
@@ -2006,8 +2018,108 @@ def fit(
base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
) -> "DaskXGBRanker":
- args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
- return self._client_sync(self._fit_async, **args)
+ msg = "Use the `qid` instead of the `group` with the dask interface."
+ if not (group is None and eval_group is None):
+ raise ValueError(msg)
+ if qid is None:
+ raise ValueError("`qid` is required for ranking.")
+
+ def check_df(X: _DaskCollection) -> TypeGuard[dd.DataFrame]:
+ if not isinstance(X, dd.DataFrame):
+ raise TypeError(
+ "When `allow_group_split` is set to False, X is required to be"
+ " a dataframe."
+ )
+ return True
+
+ def check_ser(
+ qid: Optional[_DaskCollection], name: str
+ ) -> TypeGuard[Optional[dd.Series]]:
+ if not isinstance(qid, dd.Series) and qid is not None:
+ raise TypeError(
+ f"When `allow_group_split` is set to False, {name} is required to be"
+ " a series."
+ )
+ return True
+
+ if not self.allow_group_split:
+ assert (
+ check_df(X)
+ and check_ser(qid, "qid")
+ and check_ser(y, "y")
+ and check_ser(sample_weight, "sample_weight")
+ and check_ser(base_margin, "base_margin")
+ )
+ assert qid is not None and y is not None
+ X_id = id(X)
+ X, qid, y, sample_weight, base_margin = no_group_split(
+ self.device,
+ X,
+ qid,
+ y=y,
+ sample_weight=sample_weight,
+ base_margin=base_margin,
+ )
+
+ if eval_set is not None:
+ new_eval_set = []
+ new_eval_qid = []
+ new_sample_weight_eval_set = []
+ new_base_margin_eval_set = []
+ assert eval_qid
+ for i, (Xe, ye) in enumerate(eval_set):
+ we = sample_weight_eval_set[i] if sample_weight_eval_set else None
+ be = base_margin_eval_set[i] if base_margin_eval_set else None
+ assert check_df(Xe)
+ assert eval_qid
+ qe = eval_qid[i]
+ assert (
+ eval_qid
+ and check_ser(qe, "qid")
+ and check_ser(ye, "y")
+ and check_ser(we, "sample_weight")
+ and check_ser(be, "base_margin")
+ )
+ assert qe is not None and ye is not None
+ if id(Xe) != X_id:
+ Xe, qe, ye, we, be = no_group_split(
+ self.device, Xe, qe, ye, we, be
+ )
+ else:
+ Xe, qe, ye, we, be = X, qid, y, sample_weight, base_margin
+
+ new_eval_set.append((Xe, ye))
+ new_eval_qid.append(qe)
+
+ if we is not None:
+ new_sample_weight_eval_set.append(we)
+ if be is not None:
+ new_base_margin_eval_set.append(be)
+
+ eval_set = new_eval_set
+ eval_qid = new_eval_qid
+ sample_weight_eval_set = (
+ new_sample_weight_eval_set if new_sample_weight_eval_set else None
+ )
+ base_margin_eval_set = (
+ new_base_margin_eval_set if new_base_margin_eval_set else None
+ )
+
+ return self._client_sync(
+ self._fit_async,
+ X=X,
+ y=y,
+ qid=qid,
+ sample_weight=sample_weight,
+ base_margin=base_margin,
+ eval_set=eval_set,
+ eval_qid=eval_qid,
+ verbose=verbose,
+ xgb_model=xgb_model,
+ sample_weight_eval_set=sample_weight_eval_set,
+ base_margin_eval_set=base_margin_eval_set,
+ feature_weights=feature_weights,
+ )
# FIXME(trivialfis): arguments differ due to additional parameters like group and
# qid.
diff --git a/python-package/xgboost/dask/data.py b/python-package/xgboost/dask/data.py
index c4f0f138b298..f92f1666499f 100644
--- a/python-package/xgboost/dask/data.py
+++ b/python-package/xgboost/dask/data.py
@@ -3,15 +3,30 @@
import logging
from collections.abc import Sequence
-from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
-
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+import dask
import distributed
import numpy as np
+import pandas as pd
from dask import dataframe as dd
+from .. import collective as coll
from .._typing import _T, FeatureNames
-from ..compat import concat
+from ..compat import concat, import_cupy
from ..core import DataIter, DMatrix, QuantileDMatrix
+from ..data import is_on_cuda
LOGGER = logging.getLogger("[xgboost.dask]")
@@ -96,6 +111,153 @@ def next(self, input_data: Callable) -> bool:
return True
+@overload
+def _add_column(df: dd.DataFrame, col: dd.Series) -> Tuple[dd.DataFrame, str]: ...
+
+
+@overload
+def _add_column(df: dd.DataFrame, col: None) -> Tuple[dd.DataFrame, None]: ...
+
+
+def _add_column(
+ df: dd.DataFrame, col: Optional[dd.Series]
+) -> Tuple[dd.DataFrame, Optional[str]]:
+ if col is None:
+ return df, col
+
+ trails = 0
+ uid = f"{col.name}_{trails}"
+ while uid in df.columns:
+ trails += 1
+ uid = f"{col.name}_{trails}"
+
+ df = df.assign(**{uid: col})
+ return df, uid
+
+
+def no_group_split( # pylint: disable=too-many-positional-arguments
+ device: str | None,
+ df: dd.DataFrame,
+ qid: dd.Series,
+ y: dd.Series,
+ sample_weight: Optional[dd.Series],
+ base_margin: Optional[dd.Series],
+) -> Tuple[
+ dd.DataFrame, dd.Series, dd.Series, Optional[dd.Series], Optional[dd.Series]
+]:
+ """A function to prevent query group from being scattered to different
+ workers. Please see the tutorial in the document for the implication for not having
+ partition boundary based on query groups.
+
+ """
+
+ df, qid_uid = _add_column(df, qid)
+ df, y_uid = _add_column(df, y)
+ df, w_uid = _add_column(df, sample_weight)
+ df, bm_uid = _add_column(df, base_margin)
+
+ # `tasks` shuffle is required as of rapids 24.12
+ shuffle = "p2p" if device is None or device == "cpu" else "tasks"
+ with dask.config.set({"dataframe.shuffle.method": shuffle}):
+ df = df.persist()
+ # Encode the QID to make it dense.
+ df[qid_uid] = df[qid_uid].astype("category").cat.as_known().cat.codes
+ # The shuffle here is costly.
+ df = df.sort_values(by=qid_uid)
+ cnt = df.groupby(qid_uid)[qid_uid].count()
+ div = cnt.index.compute().values.tolist()
+ div = sorted(div)
+ div = tuple(div + [div[-1] + 1])
+
+ df = df.set_index(
+ qid_uid,
+ drop=False,
+ divisions=div,
+ ).persist()
+
+ qid = df[qid_uid]
+ y = df[y_uid]
+ sample_weight, base_margin = (
+ cast(dd.Series, df[uid]) if uid is not None else None for uid in (w_uid, bm_uid)
+ )
+
+ uids = [uid for uid in [qid_uid, y_uid, w_uid, bm_uid] if uid is not None]
+ df = df.drop(uids, axis=1).persist()
+ return df, qid, y, sample_weight, base_margin
+
+
+def sort_data_by_qid(**kwargs: List[Any]) -> Dict[str, List[Any]]:
+ """Sort worker-local data by query ID for learning to rank tasks."""
+ data_parts = kwargs.get("data")
+ assert data_parts is not None
+ n_parts = len(data_parts)
+
+ if is_on_cuda(data_parts[0]):
+ from cudf import DataFrame
+ else:
+ from pandas import DataFrame
+
+ def get_dict(i: int) -> Dict[str, list]:
+ """Return a dictionary containing all the meta info and all partitions."""
+
+ def _get(attr: Optional[List[Any]]) -> Optional[list]:
+ if attr is not None:
+ return attr[i]
+ return None
+
+ data_opt = {name: _get(kwargs.get(name, None)) for name in meta}
+ # Filter out None values.
+ data = {k: v for k, v in data_opt.items() if v is not None}
+ return data
+
+ def map_fn(i: int) -> pd.DataFrame:
+ data = get_dict(i)
+ return DataFrame(data)
+
+ meta_parts = [map_fn(i) for i in range(n_parts)]
+ dfq = concat(meta_parts)
+ if dfq.qid.is_monotonic_increasing:
+ return kwargs
+
+ LOGGER.warning(
+ "[r%d]: Sorting data with %d partitions for ranking. "
+ "This is a costly operation and will increase the memory usage significantly. "
+ "To avoid this warning, sort the data based on qid before passing it into "
+ "XGBoost. Alternatively, you can use set the `allow_group_split` to False.",
+ coll.get_rank(),
+ n_parts,
+ )
+ # I tried to construct a new dask DF to perform the sort, but it's quite difficult
+ # to get the partition alignment right. Along with the still maturing shuffle
+ # implementation and GPU compatibility, a simple concat is used.
+ #
+ # In case it might become useful one day, I managed to get a CPU version working,
+ # albeit qutie slow (much slower than concatenated sort). The implementation merges
+ # everything into a single Dask DF and runs `DF.sort_values`, then retrieve the
+ # individual X,y,qid, ... from calculated partition values `client.compute([p for p
+ # in df.partitions])`. It was to avoid creating mismatched partitions.
+ dfx = concat(data_parts)
+
+ if is_on_cuda(dfq):
+ cp = import_cupy()
+ sorted_idx = cp.argsort(dfq.qid)
+ else:
+ sorted_idx = np.argsort(dfq.qid)
+ dfq = dfq.iloc[sorted_idx, :]
+
+ if hasattr(dfx, "iloc"):
+ dfx = dfx.iloc[sorted_idx, :]
+ else:
+ dfx = dfx[sorted_idx, :]
+
+ kwargs.update({"data": [dfx]})
+ for i, c in enumerate(dfq.columns):
+ assert c in kwargs
+ kwargs.update({c: [dfq[c]]})
+
+ return kwargs
+
+
def _get_worker_parts(list_of_parts: _DataParts) -> Dict[str, List[Any]]:
assert isinstance(list_of_parts, list)
result: Dict[str, List[Any]] = {}
@@ -115,6 +277,9 @@ def append(i: int, name: str) -> None:
for k in meta:
append(i, k)
+ qid = result.get("qid", None)
+ if qid is not None:
+ result = sort_data_by_qid(**result)
return result
diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py
index 5fbafd6ec58f..80e0ad2db1f5 100644
--- a/python-package/xgboost/testing/__init__.py
+++ b/python-package/xgboost/testing/__init__.py
@@ -457,7 +457,11 @@ def make_categorical(
def make_ltr(
- n_samples: int, n_features: int, n_query_groups: int, max_rel: int
+ n_samples: int,
+ n_features: int,
+ n_query_groups: int,
+ max_rel: int,
+ sort_qid: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Make a dataset for testing LTR."""
rng = np.random.default_rng(1994)
@@ -470,7 +474,8 @@ def make_ltr(
w = rng.normal(0, 1.0, size=n_query_groups)
w -= np.min(w)
w /= np.max(w)
- qid = np.sort(qid)
+ if sort_qid:
+ qid = np.sort(qid)
return X, y, qid, w
@@ -637,6 +642,10 @@ def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool:
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
+def non_decreasing(L: Sequence[float], tolerance: float = 1e-4) -> bool:
+ return all((y - x) >= -tolerance for x, y in zip(L, L[1:]))
+
+
def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool:
"""Assert whether two DMatrices contain the same predictors."""
lcsr = lhs.get_data()
diff --git a/python-package/xgboost/testing/dask.py b/python-package/xgboost/testing/dask.py
index 541009a73c85..af0fc8bf0397 100644
--- a/python-package/xgboost/testing/dask.py
+++ b/python-package/xgboost/testing/dask.py
@@ -1,6 +1,6 @@
"""Tests for dask shared by different test modules."""
-from typing import Any, List, Literal, cast
+from typing import Any, List, Literal, Tuple, cast
import numpy as np
import pandas as pd
@@ -175,7 +175,82 @@ def get_rabit_args(client: Client, n_workers: int) -> Any:
return client.sync(_get_rabit_args, client, n_workers)
-def get_client_workers(client: Any) -> List[str]:
+def get_client_workers(client: Client) -> List[str]:
"Get workers from a dask client."
workers = client.scheduler_info()["workers"]
return list(workers.keys())
+
+
+def make_ltr( # pylint: disable=too-many-locals,too-many-arguments
+ client: Client,
+ n_samples: int,
+ n_features: int,
+ *,
+ n_query_groups: int,
+ max_rel: int,
+ device: str,
+) -> Tuple[dd.DataFrame, dd.Series, dd.Series]:
+ """Synthetic dataset for learning to rank."""
+ workers = get_client_workers(client)
+ n_samples_per_worker = n_samples // len(workers)
+
+ if device == "cpu":
+ from pandas import DataFrame as DF
+ else:
+ from cudf import DataFrame as DF
+
+ def make(n: int, seed: int) -> pd.DataFrame:
+ rng = np.random.default_rng(seed)
+ X, y = make_classification(
+ n, n_features, n_informative=n_features, n_redundant=0, n_classes=max_rel
+ )
+ qid = rng.integers(size=(n,), low=0, high=n_query_groups)
+ df = DF(X, columns=[f"f{i}" for i in range(n_features)])
+ df["qid"] = qid
+ df["y"] = y
+ return df
+
+ futures = []
+ i = 0
+ for k in range(0, n_samples, n_samples_per_worker):
+ fut = client.submit(
+ make, n=n_samples_per_worker, seed=k, workers=[workers[i % len(workers)]]
+ )
+ futures.append(fut)
+ i += 1
+
+ last = n_samples - (n_samples_per_worker * len(workers))
+ if last != 0:
+ fut = client.submit(make, n=last, seed=n_samples_per_worker * len(workers))
+ futures.append(fut)
+
+ meta = make(1, 0)
+ df = dd.from_delayed(futures, meta=meta)
+ assert isinstance(df, dd.DataFrame)
+ return df.drop(["qid", "y"], axis=1), df.y, df.qid
+
+
+def check_no_group_split(client: Client, device: str) -> None:
+ """Test for the allow_group_split parameter."""
+ X_tr, q_tr, y_tr = make_ltr(
+ client, 4096, 128, n_query_groups=4, max_rel=5, device=device
+ )
+ X_va, q_va, y_va = make_ltr(
+ client, 1024, 128, n_query_groups=4, max_rel=5, device=device
+ )
+
+ ltr = dxgb.DaskXGBRanker(allow_group_split=False, n_estimators=32, device=device)
+ ltr.fit(
+ X_tr,
+ y_tr,
+ qid=q_tr,
+ eval_set=[(X_tr, y_tr), (X_va, y_va)],
+ eval_qid=[q_tr, q_va],
+ verbose=True,
+ )
+
+ assert ltr.n_features_in_ == 128
+ assert X_tr.shape[1] == ltr.n_features_in_ # no change
+ ndcg = ltr.evals_result()["validation_0"]["ndcg@32"]
+ assert tm.non_decreasing(ndcg[:16], tolerance=1e-2), ndcg
+ np.testing.assert_allclose(ndcg[-1], 1.0, rtol=1e-2)
diff --git a/src/data/data.cc b/src/data/data.cc
index 47836bb5134b..713ad4a1a514 100644
--- a/src/data/data.cc
+++ b/src/data/data.cc
@@ -539,7 +539,9 @@ void MetaInfo::SetInfoFromHost(Context const* ctx, StringView key, Json arr) {
} else if (key == "label") {
CopyTensorInfoImpl(ctx, arr, &this->labels);
if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) {
- CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels.";
+ CHECK_EQ(this->labels.Size() % this->num_row_, 0)
+ << "Incorrect size for labels: (" << this->labels.Shape(0) << "," << this->labels.Shape(1)
+ << ") v.s. " << this->num_row_;
size_t n_targets = this->labels.Size() / this->num_row_;
this->labels.Reshape(this->num_row_, n_targets);
}
diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc
index 94acf5a238d9..c50a55b3a17c 100644
--- a/src/objective/lambdarank_obj.cc
+++ b/src/objective/lambdarank_obj.cc
@@ -1,5 +1,5 @@
/**
- * Copyright (c) 2023, XGBoost contributors
+ * Copyright 2023-2024, XGBoost contributors
*/
#include "lambdarank_obj.h"
@@ -23,7 +23,6 @@
#include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
#include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
#include "../common/threading_utils.h" // for ParallelFor, Sched
-#include "../common/transform_iterator.h" // for IndexTransformIter
#include "init_estimation.h" // for FitIntercept
#include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
#include "xgboost/context.h" // for Context
diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh
index 2e5724f7f1fd..e1a78f905434 100644
--- a/src/objective/lambdarank_obj.cuh
+++ b/src/objective/lambdarank_obj.cuh
@@ -1,5 +1,5 @@
/**
- * Copyright 2023 XGBoost contributors
+ * Copyright 2023-2024, XGBoost contributors
*/
#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_
#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_
@@ -71,13 +71,13 @@ struct KernelInputs {
std::int32_t iter;
};
/**
- * \brief Functor for generating pairs
+ * @brief Functor for generating pairs
*/
template
struct MakePairsOp {
KernelInputs args;
/**
- * \brief Make pair for the topk pair method.
+ * @brief Make pair for the topk pair method.
*/
[[nodiscard]] XGBOOST_DEVICE std::tuple WithTruncation(
std::size_t idx, bst_group_t g) const {
@@ -86,9 +86,6 @@ struct MakePairsOp {
auto data_group_begin = static_cast(args.d_group_ptr[g]);
std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin;
- // obtain group segment data.
- auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0);
- auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data);
std::size_t i = 0, j = 0;
common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j);
@@ -97,7 +94,7 @@ struct MakePairsOp {
return std::make_tuple(rank_high, rank_low);
}
/**
- * \brief Make pair for the mean pair method
+ * @brief Make pair for the mean pair method
*/
XGBOOST_DEVICE std::tuple WithSampling(std::size_t idx,
bst_group_t g) const {
diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py
index 76860d9d1e35..dfa67e757059 100644
--- a/tests/ci_build/lint_python.py
+++ b/tests/ci_build/lint_python.py
@@ -111,8 +111,7 @@ class LintersPaths:
"tests/test_distributed/test_with_dask/test_external_memory.py",
"tests/test_distributed/test_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_spark/test_data.py",
- "tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
- "tests/test_distributed/test_gpu_with_dask/test_gpu_external_memory.py",
+ "tests/test_distributed/test_gpu_with_dask/",
# demo
"demo/dask/",
"demo/json-model/json_parser.py",
diff --git a/tests/test_distributed/test_gpu_with_dask/conftest.py b/tests/test_distributed/test_gpu_with_dask/conftest.py
index 0332dd945651..a066461303d3 100644
--- a/tests/test_distributed/test_gpu_with_dask/conftest.py
+++ b/tests/test_distributed/test_gpu_with_dask/conftest.py
@@ -1,4 +1,4 @@
-from typing import Generator, Sequence
+from typing import Any, Generator, Sequence
import pytest
@@ -6,12 +6,12 @@
@pytest.fixture(scope="session", autouse=True)
-def setup_rmm_pool(request, pytestconfig: pytest.Config) -> None:
+def setup_rmm_pool(request: Any, pytestconfig: pytest.Config) -> None:
tm.setup_rmm_pool(request, pytestconfig)
@pytest.fixture(scope="class")
-def local_cuda_client(request, pytestconfig: pytest.Config) -> Generator:
+def local_cuda_client(request: Any, pytestconfig: pytest.Config) -> Generator:
kwargs = {}
if hasattr(request, "param"):
kwargs.update(request.param)
diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_demos.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_demos.py
index 553b8746f0d0..848321ae4613 100644
--- a/tests/test_distributed/test_gpu_with_dask/test_gpu_demos.py
+++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_demos.py
@@ -14,14 +14,14 @@
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
-def test_dask_training():
+def test_dask_training() -> None:
script = os.path.join(tm.demo_dir(__file__), "dask", "gpu_training.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@pytest.mark.mgpu
-def test_dask_sklearn_demo():
+def test_dask_sklearn_demo() -> None:
script = os.path.join(tm.demo_dir(__file__), "dask", "sklearn_gpu_training.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@@ -29,7 +29,7 @@ def test_dask_sklearn_demo():
@pytest.mark.mgpu
@pytest.mark.skipif(**tm.no_cupy())
-def test_forward_logging_demo():
+def test_forward_logging_demo() -> None:
script = os.path.join(tm.demo_dir(__file__), "dask", "forward_logging.py")
cmd = ["python", script]
subprocess.check_call(cmd)
diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_ranking.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_ranking.py
new file mode 100644
index 000000000000..f8f586e39746
--- /dev/null
+++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_ranking.py
@@ -0,0 +1,18 @@
+"""Copyright 2024, XGBoost contributors"""
+
+import dask
+import pytest
+from distributed import Client
+
+from xgboost.testing import dask as dtm
+
+
+@pytest.mark.filterwarnings("error")
+def test_no_group_split(local_cuda_client: Client) -> None:
+ with dask.config.set(
+ {
+ "array.backend": "cupy",
+ "dataframe.backend": "cudf",
+ }
+ ):
+ dtm.check_no_group_split(local_cuda_client, "cuda")
diff --git a/tests/test_distributed/test_with_dask/test_ranking.py b/tests/test_distributed/test_with_dask/test_ranking.py
index 0b2ea404fde1..f806d61d2592 100644
--- a/tests/test_distributed/test_with_dask/test_ranking.py
+++ b/tests/test_distributed/test_with_dask/test_ranking.py
@@ -11,6 +11,7 @@
from xgboost import dask as dxgb
from xgboost import testing as tm
+from xgboost.testing import dask as dtm
@pytest.fixture(scope="module")
@@ -59,7 +60,10 @@ def test_dask_ranking(client: Client) -> None:
qid_test = qid_test.astype(np.uint32)
rank = dxgb.DaskXGBRanker(
- n_estimators=2500, eval_metric=["ndcg"], early_stopping_rounds=10
+ n_estimators=2500,
+ eval_metric=["ndcg"],
+ early_stopping_rounds=10,
+ allow_group_split=True,
)
rank.fit(
x_train,
@@ -71,3 +75,8 @@ def test_dask_ranking(client: Client) -> None:
)
assert rank.n_features_in_ == 46
assert rank.best_score > 0.98
+
+
+@pytest.mark.filterwarnings("error")
+def test_no_group_split(client: Client) -> None:
+ dtm.check_no_group_split(client, "cpu")