diff --git a/python/dask_cudf/dask_cudf/expr/_collection.py b/python/dask_cudf/dask_cudf/expr/_collection.py index 605a81f0fcd..d50dfb24256 100644 --- a/python/dask_cudf/dask_cudf/expr/_collection.py +++ b/python/dask_cudf/dask_cudf/expr/_collection.py @@ -18,6 +18,15 @@ import cudf +_LEGACY_WORKAROUND = ( + "To enable the 'legacy' dask-cudf API, set the " + "global 'dataframe.query-planning' config to " + "`False` before dask is imported. This can also " + "be done by setting an environment variable: " + "`DASK_DATAFRAME__QUERY_PLANNING=False` " +) + + ## ## Custom collection classes ## @@ -88,6 +97,21 @@ def groupby( f"`by` must be a column name or list of columns, got {by}." ) + if "as_index" in kwargs: + msg = ( + "The `as_index` argument is now deprecated. All groupby " + "results will be consistent with `as_index=True`." + ) + + if kwargs.pop("as_index") is not True: + raise NotImplementedError( + f"{msg} Please reset the index after aggregating, or " + "use the legacy API if `as_index=False` is required.\n" + f"{_LEGACY_WORKAROUND}" + ) + else: + warnings.warn(msg, FutureWarning) + return GroupBy( self, by, diff --git a/python/dask_cudf/dask_cudf/expr/_groupby.py b/python/dask_cudf/dask_cudf/expr/_groupby.py index 7f275151f75..116893891e3 100644 --- a/python/dask_cudf/dask_cudf/expr/_groupby.py +++ b/python/dask_cudf/dask_cudf/expr/_groupby.py @@ -3,13 +3,55 @@ from dask_expr._groupby import ( GroupBy as DXGroupBy, SeriesGroupBy as DXSeriesGroupBy, + SingleAggregation, ) from dask_expr._util import is_scalar +from dask.dataframe.groupby import Aggregation + ## ## Custom groupby classes ## + +class Collect(SingleAggregation): + @staticmethod + def groupby_chunk(arg): + return arg.agg("collect") + + @staticmethod + def groupby_aggregate(arg): + gb = arg.agg("collect") + if gb.ndim > 1: + for col in gb.columns: + gb[col] = gb[col].list.concat() + return gb + else: + return gb.list.concat() + + +collect_aggregation = Aggregation( + name="collect", + chunk=Collect.groupby_chunk, + agg=Collect.groupby_aggregate, +) + + +def _translate_arg(arg): + # Helper function to translate args so that + # they can be processed correctly by upstream + # dask & dask-expr. Right now, the only necessary + # translation is "collect" aggregations. + if isinstance(arg, dict): + return {k: _translate_arg(v) for k, v in arg.items()} + elif isinstance(arg, list): + return [_translate_arg(x) for x in arg] + elif arg in ("collect", "list", list): + return collect_aggregation + else: + return arg + + # TODO: These classes are mostly a work-around for missing # `observed=False` support. # See: https://github.com/rapidsai/cudf/issues/15173 @@ -41,8 +83,20 @@ def __getitem__(self, key): ) return g + def collect(self, **kwargs): + return self._single_agg(Collect, **kwargs) + + def aggregate(self, arg, **kwargs): + return super().aggregate(_translate_arg(arg), **kwargs) + class SeriesGroupBy(DXSeriesGroupBy): def __init__(self, *args, observed=None, **kwargs): observed = observed if observed is not None else True super().__init__(*args, observed=observed, **kwargs) + + def collect(self, **kwargs): + return self._single_agg(Collect, **kwargs) + + def aggregate(self, arg, **kwargs): + return super().aggregate(_translate_arg(arg), **kwargs) diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 1e22dd95475..67fa045d3d0 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -14,16 +14,6 @@ from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized from dask_cudf.tests.utils import QUERY_PLANNING_ON, xfail_dask_expr -# XFAIL "collect" tests for now -agg_params = [agg for agg in OPTIMIZED_AGGS if agg != "collect"] -if QUERY_PLANNING_ON: - agg_params.append( - # TODO: "collect" not supported with dask-expr yet - pytest.param("collect", marks=pytest.mark.xfail) - ) -else: - agg_params.append("collect") - def assert_cudf_groupby_layers(ddf): for prefix in ("cudf-aggregate-chunk", "cudf-aggregate-agg"): @@ -57,7 +47,7 @@ def pdf(request): return pdf -@pytest.mark.parametrize("aggregation", agg_params) +@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS) @pytest.mark.parametrize("series", [False, True]) def test_groupby_basic(series, aggregation, pdf): gdf = cudf.DataFrame.from_pandas(pdf) @@ -110,7 +100,7 @@ def test_groupby_cumulative(aggregation, pdf, series): dd.assert_eq(a, b) -@pytest.mark.parametrize("aggregation", agg_params) +@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS) @pytest.mark.parametrize( "func", [ @@ -579,8 +569,16 @@ def test_groupby_categorical_key(): dd.assert_eq(expect, got) -@xfail_dask_expr("as_index not supported in dask-expr") -@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "as_index", + [ + True, + pytest.param( + False, + marks=xfail_dask_expr("as_index not supported in dask-expr"), + ), + ], +) @pytest.mark.parametrize("split_out", ["use_dask_default", 1, 2]) @pytest.mark.parametrize("split_every", [False, 4]) @pytest.mark.parametrize("npartitions", [1, 10]) @@ -603,10 +601,19 @@ def test_groupby_agg_params(npartitions, split_every, split_out, as_index): if split_out == "use_dask_default": split_kwargs.pop("split_out") + # Avoid using as_index when query-planning is enabled + if QUERY_PLANNING_ON: + with pytest.warns(FutureWarning, match="argument is now deprecated"): + # Should warn when `as_index` is used + ddf.groupby(["name", "a"], sort=False, as_index=as_index) + maybe_as_index = {"as_index": as_index} if as_index is False else {} + else: + maybe_as_index = {"as_index": as_index} + # Check `sort=True` behavior if split_out == 1: gf = ( - ddf.groupby(["name", "a"], sort=True, as_index=as_index) + ddf.groupby(["name", "a"], sort=True, **maybe_as_index) .aggregate( agg_dict, **split_kwargs, @@ -628,7 +635,7 @@ def test_groupby_agg_params(npartitions, split_every, split_out, as_index): ) # Full check (`sort=False`) - gr = ddf.groupby(["name", "a"], sort=False, as_index=as_index).aggregate( + gr = ddf.groupby(["name", "a"], sort=False, **maybe_as_index).aggregate( agg_dict, **split_kwargs, )