Skip to content

Commit

Permalink
ENH: groupby.nunique supports by series (#726)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengjieLi28 authored Oct 8, 2023
1 parent 9bff550 commit 882f1b9
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 22 deletions.
48 changes: 35 additions & 13 deletions python/xorbits/_mars/dataframe/groupby/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...core import ENTITY_TYPE, OutputType
from ...core.context import get_context
from ...core.custom_log import redirect_custom_log
from ...core.entity.utils import recursive_tile
from ...core.operand import OperandStage
from ...serialization.serializables import (
AnyField,
Expand Down Expand Up @@ -480,6 +481,7 @@ def _gen_map_chunks(
# force as_index=True for map phase
map_op.output_types = op.output_types
map_op.groupby_params = map_op.groupby_params.copy()
map_op.raw_groupby_params = map_op.raw_groupby_params.copy()
map_op.groupby_params["as_index"] = True
if isinstance(map_op.groupby_params["by"], list):
by = []
Expand All @@ -491,6 +493,7 @@ def _gen_map_chunks(
else:
by.append(v)
map_op.groupby_params["by"] = by
map_op.raw_groupby_params["by"] = by
map_op.stage = OperandStage.map
map_op.pre_funcs = func_infos.pre_funcs
map_op.agg_funcs = func_infos.agg_funcs
Expand Down Expand Up @@ -926,6 +929,20 @@ def tile(cls, op: "DataFrameGroupByAgg"):
in_df = build_concatenated_rows_frame(in_df)
out_df = op.outputs[0]

by = op.groupby_params["by"]
in_df_nsplits_settled: bool = all([not np.isnan(v) for v in in_df.nsplits[0]])
if isinstance(by, list):
for i, _by in enumerate(by):
if (
isinstance(_by, ENTITY_TYPE)
and all([not np.isnan(v) for v in _by.nsplits[0]])
and in_df_nsplits_settled
):
by[i] = yield from recursive_tile(
_by.rechunk({0: in_df.nsplits[0]})
)
yield by[i].chunks

func_infos = cls._compile_funcs(op, in_df)

if op.method == "auto":
Expand All @@ -943,6 +960,10 @@ def tile(cls, op: "DataFrameGroupByAgg"):
else: # pragma: no cover
raise NotImplementedError

@classmethod
def _get_new_by_data(cls, by: List, ctx: Dict):
return [ctx[v.key] if isinstance(v, ENTITY_TYPE) else v for v in by]

@classmethod
def _get_grouped(cls, op: "DataFrameGroupByAgg", df, ctx, copy=False, grouper=None):
if copy:
Expand All @@ -956,13 +977,7 @@ def _get_grouped(cls, op: "DataFrameGroupByAgg", df, ctx, copy=False, grouper=No
params["by"] = grouper
params.pop("level", None)
elif isinstance(params.get("by"), list):
new_by = []
for v in params["by"]:
if isinstance(v, ENTITY_TYPE):
new_by.append(ctx[v.key])
else:
new_by.append(v)
params["by"] = new_by
params["by"] = cls._get_new_by_data(params["by"], ctx)

grouped = df.groupby(**params)

Expand All @@ -984,16 +999,23 @@ def _pack_inputs(agg_funcs: List[ReductionAggStep], in_data):
pos += step.output_limit
return out_dict

@staticmethod
@classmethod
def _do_custom_agg(
func_name: str, op: "DataFrameGroupByAgg", in_data: pd.DataFrame
cls, func_name: str, op: "DataFrameGroupByAgg", in_data: pd.DataFrame, ctx: Dict
) -> Union[pd.Series, pd.DataFrame]:
# Must be tuple way, like x=('col', 'agg_func_name')
# See `is_funcs_aggregate` func,
# if not this way, the code doesn't go here or switch to transform execution.
if op.raw_func is None:
func_name = list(op.raw_func_kw.values())[0][1]

if (
func_name == "nunique"
and "by" in op.groupby_params
and isinstance(op.groupby_params["by"], list)
):
op.raw_groupby_params["by"] = cls._get_new_by_data(
op.groupby_params["by"], ctx
)
if op.stage == OperandStage.map:
return custom_agg_functions[func_name].execute_map(op, in_data)
elif op.stage == OperandStage.combine:
Expand Down Expand Up @@ -1111,7 +1133,7 @@ def _wrapped_func(col):
) in op.agg_funcs:
input_obj = ret_map_groupbys[input_key]
if map_func_name == "custom_reduction":
agg_dfs.append(cls._do_custom_agg(raw_func_name, op, in_data))
agg_dfs.append(cls._do_custom_agg(raw_func_name, op, in_data, ctx))
else:
single_func = map_func_name == op.raw_func
agg_dfs.append(
Expand Down Expand Up @@ -1159,7 +1181,7 @@ def _execute_combine(cls, ctx, op: "DataFrameGroupByAgg"):
) in zip(ctx[op.inputs[0].key], op.agg_funcs):
input_obj = in_data_dict[output_key]
if agg_func_name == "custom_reduction":
combines.append(cls._do_custom_agg(raw_func_name, op, raw_input))
combines.append(cls._do_custom_agg(raw_func_name, op, raw_input, ctx))
else:
combines.append(
cls._do_predefined_agg(input_obj, agg_func_name, gpu=op.gpu, **kwds)
Expand Down Expand Up @@ -1200,7 +1222,7 @@ def _execute_agg(cls, ctx, op: "DataFrameGroupByAgg"):
) in op.agg_funcs:
if agg_func_name == "custom_reduction":
in_data_dict[output_key] = cls._do_custom_agg(
raw_func_name, op, in_data_dict[output_key]
raw_func_name, op, in_data_dict[output_key], ctx
)
else:
input_obj = cls._get_grouped(op, in_data_dict[output_key], ctx)
Expand Down
61 changes: 52 additions & 9 deletions python/xorbits/_mars/dataframe/groupby/nunique.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

import pandas as pd

from ...core import OutputType
from ...core import ENTITY_TYPE, OutputType
from ...utils import implements
from ..utils import is_dataframe
from .aggregation import DataFrameGroupByAgg
from .custom_aggregation import (
DataFrameCustomGroupByAggMixin,
Expand Down Expand Up @@ -79,17 +80,46 @@ def _get_selection_columns(cls, op: DataFrameGroupByAgg) -> Union[None, List]:
selection = [selection]
return selection

@classmethod
def _drop_duplicates_by_series(cls, in_data: pd.DataFrame, origin_cols: List):
if isinstance(in_data.index, pd.MultiIndex):
origin_index_name = in_data.index.names
else:
origin_index_name = in_data.index.name
res = in_data.reset_index()
new_cols = list(res.columns)
index_cols = [v for v in new_cols if v not in origin_cols]
res = res.drop_duplicates().set_index(index_cols)
if isinstance(res.index, pd.MultiIndex):
res.index.names = origin_index_name
else:
res.index.name = origin_index_name
return res

@classmethod
def _get_execute_map_result(
cls, op: DataFrameGroupByAgg, in_data: pd.DataFrame
) -> Union[pd.DataFrame, pd.Series]:
selections = cls._get_selection_columns(op)
by_cols = op.raw_groupby_params["by"]
if by_cols is not None:
cols = (
[*selections, *by_cols] if selections is not None else in_data.columns
)
res = in_data[cols].drop_duplicates(subset=cols).set_index(by_cols)
# When `by` some series, the series will be used to determine the groups.
# We first need to set the index of the data to these series,
# and then `reset_index` to let these series become data columns.
# Next bring these columns for `drop_duplicates` and reset these columns to index.
if isinstance(by_cols, list) and any(
[isinstance(v, pd.Series) for v in by_cols]
):
origin_cols = list(in_data.columns)
res = in_data.set_index(by_cols)
res = cls._drop_duplicates_by_series(res, origin_cols)
else:
cols = (
[*selections, *by_cols]
if selections is not None
else in_data.columns
)
res = in_data[cols].drop_duplicates(subset=cols).set_index(by_cols)
else: # group by level
selections = selections if selections is not None else in_data.columns
level_indexes = cls._get_level_indexes(op, in_data)
Expand All @@ -111,9 +141,17 @@ def _get_execute_map_result(
def _get_execute_combine_result(
cls, op: DataFrameGroupByAgg, in_data: pd.DataFrame
) -> Union[pd.DataFrame, pd.Series]:
# in_data.index.names means MultiIndex (groupby on multi cols)
index_col = in_data.index.name or in_data.index.names
res = in_data.reset_index().drop_duplicates().set_index(index_col)
by = op.raw_groupby_params["by"]
if isinstance(by, list) and any([isinstance(v, ENTITY_TYPE) for v in by]):
# `in_data` may be series when there is index op after groupby
origin_cols = (
list(in_data.columns) if is_dataframe(in_data) else [in_data.name]
)
res = cls._drop_duplicates_by_series(in_data, origin_cols)
else:
# in_data.index.names means MultiIndex (groupby on multi cols)
index_col = in_data.index.name or in_data.index.names
res = in_data.reset_index().drop_duplicates().set_index(index_col)
if op.output_types[0] == OutputType.series:
res = res.squeeze()
return res
Expand All @@ -127,7 +165,12 @@ def _get_execute_agg_result(
by = op.raw_groupby_params["by"]

if by is not None:
if op.output_types[0] == OutputType.dataframe:
if isinstance(by, list) and any(
[isinstance(_by, ENTITY_TYPE) for _by in by]
):
# nothing to do here, just group by level is correct
pass
elif op.output_types[0] == OutputType.dataframe:
groupby_params.pop("level", None)
groupby_params["by"] = cols
in_data = in_data.reset_index()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,80 @@ def test_groupby_agg_nunique_with_tuple_kwargs(
e=("a", "nunique"), f=("c", "nunique")
)
pd.testing.assert_frame_equal(res.execute().fetch(), expected)


@pytest.mark.parametrize(
"chunk_size, as_index, sort",
itertools.product([None, 13], [True, False], [True, False]),
)
def test_groupby_nunique_by_series(setup, gen_data2, chunk_size, as_index, sort):
df = gen_data2
mdf = md.DataFrame(df, chunk_size=chunk_size)

by1 = pd.Series([i + 100 for i in range(100)])
mby1 = md.Series(by1)

by2 = pd.Series([i + 200 for i in range(100)])
mby2 = md.Series(by2)

res = mdf.groupby(mby1, as_index=as_index, sort=sort).nunique()
expected = df.groupby(by1, as_index=as_index, sort=sort).nunique()
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

res = mdf.groupby([mby1, mby2], as_index=as_index, sort=sort).nunique()
expected = df.groupby([by1, by2], as_index=as_index, sort=sort).nunique()
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

res = mdf.groupby([mby1, mby2], as_index=as_index, sort=sort).agg(
e=("a", "nunique"), f=("c", "nunique")
)
expected = df.groupby([by1, by2], as_index=as_index, sort=sort).agg(
e=("a", "nunique"), f=("c", "nunique")
)
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

# test by with duplicates
rs = np.random.RandomState(0)
by3 = pd.Series(rs.choice([i for i in range(1, 6)], size=(100,)))
mby3 = md.Series(by3)

res = mdf.groupby(mby3, as_index=as_index, sort=sort).nunique()
expected = df.groupby(by3, as_index=as_index, sort=sort).nunique()
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

# test by other chunk size
by4 = pd.Series(rs.choice([i for i in range(10)], size=(100,)))
mby4 = md.Series(by4, chunk_size=21)

res = mdf.groupby(mby4, as_index=as_index, sort=sort).nunique()
expected = df.groupby(by4, as_index=as_index, sort=sort).nunique()
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

# test index after groupby
res = mdf.groupby(mby3, as_index=as_index, sort=sort)[["a", "b"]].nunique()
expected = df.groupby(by3, as_index=as_index, sort=sort)[["a", "b"]].nunique()
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

res = mdf.groupby(mby3, as_index=as_index, sort=sort)[["a"]].nunique()
expected = df.groupby(by3, as_index=as_index, sort=sort)[["a"]].nunique()
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

res = mdf.groupby(mby3, as_index=as_index, sort=sort)["a"].nunique()
expected = df.groupby(by3, as_index=as_index, sort=sort)["a"].nunique()
if as_index:
pd.testing.assert_series_equal(res.execute().fetch(), expected)
else:
pd.testing.assert_frame_equal(res.execute().fetch(), expected)

# test different methods
for method in ["auto", "tree", "shuffle"]:
res = mdf.groupby(mby3, as_index=as_index, sort=sort).nunique(method=method)
expected = df.groupby(by3, as_index=as_index, sort=sort).nunique()
real = res.execute().fetch()
if method == "shuffle":
pd.testing.assert_frame_equal(
real.sort_values(["a", "b", "c", "d"]).reset_index(drop=True),
expected.sort_values(["a", "b", "c", "d"]).reset_index(drop=True),
)
else:
pd.testing.assert_frame_equal(real, expected)

0 comments on commit 882f1b9

Please sign in to comment.