Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-24.06' into fix-type-li…
Browse files Browse the repository at this point in the history
…mits-errors
  • Loading branch information
bdice committed May 1, 2024
2 parents 3d5285a + 7458a6e commit d9c8c9e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 16 deletions.
24 changes: 24 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@

import cudf

_LEGACY_WORKAROUND = (
"To enable the 'legacy' dask-cudf API, set the "
"global 'dataframe.query-planning' config to "
"`False` before dask is imported. This can also "
"be done by setting an environment variable: "
"`DASK_DATAFRAME__QUERY_PLANNING=False` "
)


##
## Custom collection classes
##
Expand Down Expand Up @@ -88,6 +97,21 @@ def groupby(
f"`by` must be a column name or list of columns, got {by}."
)

if "as_index" in kwargs:
msg = (
"The `as_index` argument is now deprecated. All groupby "
"results will be consistent with `as_index=True`."
)

if kwargs.pop("as_index") is not True:
raise NotImplementedError(
f"{msg} Please reset the index after aggregating, or "
"use the legacy API if `as_index=False` is required.\n"
f"{_LEGACY_WORKAROUND}"
)
else:
warnings.warn(msg, FutureWarning)

return GroupBy(
self,
by,
Expand Down
54 changes: 54 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,55 @@
from dask_expr._groupby import (
GroupBy as DXGroupBy,
SeriesGroupBy as DXSeriesGroupBy,
SingleAggregation,
)
from dask_expr._util import is_scalar

from dask.dataframe.groupby import Aggregation

##
## Custom groupby classes
##


class Collect(SingleAggregation):
@staticmethod
def groupby_chunk(arg):
return arg.agg("collect")

@staticmethod
def groupby_aggregate(arg):
gb = arg.agg("collect")
if gb.ndim > 1:
for col in gb.columns:
gb[col] = gb[col].list.concat()
return gb
else:
return gb.list.concat()


collect_aggregation = Aggregation(
name="collect",
chunk=Collect.groupby_chunk,
agg=Collect.groupby_aggregate,
)


def _translate_arg(arg):
# Helper function to translate args so that
# they can be processed correctly by upstream
# dask & dask-expr. Right now, the only necessary
# translation is "collect" aggregations.
if isinstance(arg, dict):
return {k: _translate_arg(v) for k, v in arg.items()}
elif isinstance(arg, list):
return [_translate_arg(x) for x in arg]
elif arg in ("collect", "list", list):
return collect_aggregation
else:
return arg


# TODO: These classes are mostly a work-around for missing
# `observed=False` support.
# See: https://github.com/rapidsai/cudf/issues/15173
Expand Down Expand Up @@ -41,8 +83,20 @@ def __getitem__(self, key):
)
return g

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)


class SeriesGroupBy(DXSeriesGroupBy):
def __init__(self, *args, observed=None, **kwargs):
observed = observed if observed is not None else True
super().__init__(*args, observed=observed, **kwargs)

def collect(self, **kwargs):
return self._single_agg(Collect, **kwargs)

def aggregate(self, arg, **kwargs):
return super().aggregate(_translate_arg(arg), **kwargs)
39 changes: 23 additions & 16 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,6 @@
from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized
from dask_cudf.tests.utils import QUERY_PLANNING_ON, xfail_dask_expr

# XFAIL "collect" tests for now
agg_params = [agg for agg in OPTIMIZED_AGGS if agg != "collect"]
if QUERY_PLANNING_ON:
agg_params.append(
# TODO: "collect" not supported with dask-expr yet
pytest.param("collect", marks=pytest.mark.xfail)
)
else:
agg_params.append("collect")


def assert_cudf_groupby_layers(ddf):
for prefix in ("cudf-aggregate-chunk", "cudf-aggregate-agg"):
Expand Down Expand Up @@ -57,7 +47,7 @@ def pdf(request):
return pdf


@pytest.mark.parametrize("aggregation", agg_params)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize("series", [False, True])
def test_groupby_basic(series, aggregation, pdf):
gdf = cudf.DataFrame.from_pandas(pdf)
Expand Down Expand Up @@ -110,7 +100,7 @@ def test_groupby_cumulative(aggregation, pdf, series):
dd.assert_eq(a, b)


@pytest.mark.parametrize("aggregation", agg_params)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize(
"func",
[
Expand Down Expand Up @@ -579,8 +569,16 @@ def test_groupby_categorical_key():
dd.assert_eq(expect, got)


@xfail_dask_expr("as_index not supported in dask-expr")
@pytest.mark.parametrize("as_index", [True, False])
@pytest.mark.parametrize(
"as_index",
[
True,
pytest.param(
False,
marks=xfail_dask_expr("as_index not supported in dask-expr"),
),
],
)
@pytest.mark.parametrize("split_out", ["use_dask_default", 1, 2])
@pytest.mark.parametrize("split_every", [False, 4])
@pytest.mark.parametrize("npartitions", [1, 10])
Expand All @@ -603,10 +601,19 @@ def test_groupby_agg_params(npartitions, split_every, split_out, as_index):
if split_out == "use_dask_default":
split_kwargs.pop("split_out")

# Avoid using as_index when query-planning is enabled
if QUERY_PLANNING_ON:
with pytest.warns(FutureWarning, match="argument is now deprecated"):
# Should warn when `as_index` is used
ddf.groupby(["name", "a"], sort=False, as_index=as_index)
maybe_as_index = {"as_index": as_index} if as_index is False else {}
else:
maybe_as_index = {"as_index": as_index}

# Check `sort=True` behavior
if split_out == 1:
gf = (
ddf.groupby(["name", "a"], sort=True, as_index=as_index)
ddf.groupby(["name", "a"], sort=True, **maybe_as_index)
.aggregate(
agg_dict,
**split_kwargs,
Expand All @@ -628,7 +635,7 @@ def test_groupby_agg_params(npartitions, split_every, split_out, as_index):
)

# Full check (`sort=False`)
gr = ddf.groupby(["name", "a"], sort=False, as_index=as_index).aggregate(
gr = ddf.groupby(["name", "a"], sort=False, **maybe_as_index).aggregate(
agg_dict,
**split_kwargs,
)
Expand Down

0 comments on commit d9c8c9e

Please sign in to comment.