Skip to content

Commit

Permalink
Support median in Groupby.aggregate (#766)
Browse files Browse the repository at this point in the history
Co-authored-by: crusaderky <[email protected]>
  • Loading branch information
hendrikmakait and crusaderky authored Jan 29, 2024
1 parent 7ad58b0 commit 71e5237
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
# Wipe cache every 24 hours or whenever environment.yml changes. This means it
# may take up to a day before changes to unpinned packages are picked up.
# To force a cache refresh, change the hardcoded numerical suffix below.
cache-environment-key: environment-${{ steps.date.outputs.date }}-0
cache-environment-key: environment-${{ steps.date.outputs.date }}-1

- name: Install dask-expr
run: python -m pip install -e . --no-deps
Expand Down
188 changes: 138 additions & 50 deletions dask_expr/_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
_cumcount_aggregate,
_determine_levels,
_groupby_aggregate,
_groupby_aggregate_spec,
_groupby_apply_funcs,
_groupby_get_group,
_groupby_slice_apply,
_groupby_slice_shift,
_groupby_slice_transform,
_head_aggregate,
_head_chunk,
_non_agg_chunk,
_normalize_spec,
_nunique_df_chunk,
_nunique_df_combine,
Expand Down Expand Up @@ -249,10 +251,10 @@ def _simplify_up(self, parent, dependents):
return groupby_projection(self, parent, dependents)


class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase):
"""General groupby aggregation
class GroupbyAggregationBase(GroupByApplyConcatApply, GroupByBase):
"""Base class for groupby aggregation
This class can be used directly to perform a general
This class can be subclassed to perform a general
groupby aggregation by passing in a `str`, `list` or
`dict`-based specification using the `arg` operand.
Expand All @@ -269,10 +271,6 @@ class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase):
Passed through to dataframe backend.
dropna:
Whether rows with NA values should be dropped.
chunk_kwargs:
Key-word arguments to pass to `groupby_chunk`.
aggregate_kwargs:
Key-word arguments to pass to `aggregate_chunk`.
"""

_parameters = [
Expand All @@ -295,7 +293,19 @@ class GroupbyAggregation(GroupByApplyConcatApply, GroupByBase):
"shuffle_method": None,
"_slice": None,
}
chunk = staticmethod(_groupby_apply_funcs)

@functools.cached_property
def _meta(self):
meta = meta_nonempty(self.frame._meta)
meta = meta.groupby(
self._by_meta,
**_as_dict("observed", self.observed),
**_as_dict("dropna", self.dropna),
)
if self._slice is not None:
meta = meta[self._slice]
meta = meta.aggregate(self.arg)
return make_meta(meta)

@functools.cached_property
def spec(self):
Expand Down Expand Up @@ -329,13 +339,121 @@ def spec(self):
else:
raise ValueError(f"aggregate on unknown object {self.frame._meta}")

# Median not supported yet
has_median = any(s[1] in ("median", np.median) for s in spec)
if has_median:
raise NotImplementedError("median not yet supported")
return spec

@functools.cached_property
def agg_args(self):
keys = ["chunk_funcs", "aggregate_funcs", "finalizers"]
return dict(zip(keys, _build_agg_args(spec)))
return dict(zip(keys, _build_agg_args(self.spec)))

def _simplify_down(self):
if not isinstance(self.arg, dict):
return

# Use agg-spec information to add column projection
required_columns = (
set(self._by_columns)
.union(self.arg.keys())
.intersection(self.frame.columns)
)
column_projection = [
column for column in self.frame.columns if column in required_columns
]
if column_projection != self.frame.columns:
return type(self)(self.frame[column_projection], *self.operands[1:])


class GroupbyAggregation(GroupbyAggregationBase):
"""Logical groupby aggregation class
This class lowers itself to concrete implementations for decomposable
or holistic aggregations.
"""

@functools.cached_property
def _is_decomposable(self):
return not any(s[1] in ("median", np.median) for s in self.spec)

def _lower(self):
cls = (
DecomposableGroupbyAggregation
if self._is_decomposable
else HolisticGroupbyAggregation
)
return cls(
self.frame,
self.arg,
self.observed,
self.dropna,
self.split_every,
self.split_out,
self.sort,
self.shuffle_method,
self._slice,
*self.by,
)


class HolisticGroupbyAggregation(GroupbyAggregationBase):
"""Groupby aggregation for both decomposable and non-decomposable aggregates
This class always calculates the aggregates by first collecting all the data for
the groups and then aggregating at once.
"""

chunk = staticmethod(_non_agg_chunk)

@property
def should_shuffle(self):
return True

@classmethod
def chunk(cls, df, *by, **kwargs):
return _non_agg_chunk(df, *by, **kwargs)

@classmethod
def combine(cls, inputs, **kwargs):
return _groupby_aggregate_spec(_concat(inputs), **kwargs)

@classmethod
def aggregate(cls, inputs, **kwargs):
return _groupby_aggregate_spec(_concat(inputs), **kwargs)

@property
def chunk_kwargs(self) -> dict:
return {
"by": self._by_columns,
"key": [col for col in self.frame.columns if col not in self._by_columns],
**_as_dict("observed", self.observed),
**_as_dict("dropna", self.dropna),
}

@property
def combine_kwargs(self) -> dict:
return {
"spec": self.arg,
"levels": _determine_levels(self.by),
**_as_dict("observed", self.observed),
**_as_dict("dropna", self.dropna),
}

@property
def aggregate_kwargs(self) -> dict:
return {
"spec": self.arg,
"levels": _determine_levels(self.by),
**_as_dict("observed", self.observed),
**_as_dict("dropna", self.dropna),
}


class DecomposableGroupbyAggregation(GroupbyAggregationBase):
"""Groupby aggregation for decomposable aggregates
The results may be calculated via tree or shuffle reduction.
"""

chunk = staticmethod(_groupby_apply_funcs)

@classmethod
def combine(cls, inputs, **kwargs):
Expand All @@ -348,7 +466,7 @@ def aggregate(cls, inputs, **kwargs):
@property
def chunk_kwargs(self) -> dict:
return {
"funcs": self.spec["chunk_funcs"],
"funcs": self.agg_args["chunk_funcs"],
"sort": self.sort,
**_as_dict("observed", self.observed),
**_as_dict("dropna", self.dropna),
Expand All @@ -357,7 +475,7 @@ def chunk_kwargs(self) -> dict:
@property
def combine_kwargs(self) -> dict:
return {
"funcs": self.spec["aggregate_funcs"],
"funcs": self.agg_args["aggregate_funcs"],
"level": self.levels,
"sort": self.sort,
**_as_dict("observed", self.observed),
Expand All @@ -367,26 +485,17 @@ def combine_kwargs(self) -> dict:
@property
def aggregate_kwargs(self) -> dict:
return {
"aggregate_funcs": self.spec["aggregate_funcs"],
"finalize_funcs": self.spec["finalizers"],
"aggregate_funcs": self.agg_args["aggregate_funcs"],
"arg": self.arg,
"columns": self._slice,
"finalize_funcs": self.agg_args["finalizers"],
"is_series": self._meta.ndim == 1,
"level": self.levels,
"sort": self.sort,
**_as_dict("observed", self.observed),
**_as_dict("dropna", self.dropna),
}

def _simplify_down(self):
# Use agg-spec information to add column projection
column_projection = None
if isinstance(self.arg, dict):
column_projection = (
set(self._by_columns)
.union(self.arg.keys())
.intersection(self.frame.columns)
)
if column_projection and column_projection < set(self.frame.columns):
return type(self)(self.frame[list(column_projection)], *self.operands[1:])


class Sum(SingleAggregation):
groupby_chunk = M.sum
Expand Down Expand Up @@ -1781,27 +1890,6 @@ def __init__(
obj, by=by, slice=slice, observed=observed, dropna=dropna, sort=sort
)

def aggregate(self, arg=None, split_every=8, split_out=1, **kwargs):
result = super().aggregate(
arg=arg, split_every=split_every, split_out=split_out
)
if self._slice:
try:
result = result[self._slice]
except KeyError:
pass

if (
arg is not None
and not isinstance(arg, (list, dict))
and is_dataframe_like(result._meta)
):
result = result[result.columns[0]]

return result

agg = aggregate

def idxmin(
self, split_every=None, split_out=1, skipna=True, numeric_only=False, **kwargs
):
Expand Down
12 changes: 9 additions & 3 deletions dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,13 @@ def _divisions(self):
def _chunk_cls_args(self):
return []

@property
def should_shuffle(self):
sort = getattr(self, "sort", False)
return not (
not isinstance(self.split_out, bool) and self.split_out == 1 or sort
)

def _lower(self):
# Normalize functions in case not all are defined
chunk = self.chunk
Expand All @@ -465,12 +472,11 @@ def _lower(self):
combine = aggregate
combine_kwargs = aggregate_kwargs

sort = getattr(self, "sort", False)
split_every = getattr(self, "split_every", None)
chunked = self._chunk_cls(
self.frame, type(self), chunk, chunk_kwargs, *self._chunk_cls_args
)
if not isinstance(self.split_out, bool) and self.split_out == 1 or sort:
if not self.should_shuffle:
# Lower into TreeReduce(Chunk)
return TreeReduce(
chunked,
Expand All @@ -496,7 +502,7 @@ def _lower(self):
split_by=self.split_by,
split_out=self.split_out,
split_every=split_every,
sort=sort,
sort=getattr(self, "sort", False),
shuffle_by_index=getattr(self, "shuffle_by_index", None),
shuffle_method=getattr(self, "shuffle_method", None),
ignore_index=getattr(self, "ignore_index", True),
Expand Down
1 change: 1 addition & 0 deletions dask_expr/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_dataframe_aggregations_multilevel(df, pdf):
{"x": ["sum", "mean"]},
["min", "mean"],
"sum",
"median",
],
)
def test_groupby_agg(pdf, df, spec):
Expand Down

0 comments on commit 71e5237

Please sign in to comment.