Skip to content

Commit

Permalink
Make split_out for categorical default smarter (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Aug 16, 2024
1 parent c0b2b9a commit 8d2c1be
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
7 changes: 5 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4235,8 +4235,11 @@ def value_counts(
):
if split_out is no_default:
if isinstance(self.dtype, CategoricalDtype):
# unobserved categories are a pain
split_out = 1
# unobserved or huge categories will lead to oom errors
if self.cat.known:
split_out = 1 + len(self.dtype.categories) // 100_000
else:
split_out = True
else:
split_out = True
if split_out == 1 and split_out is not True and sort is None:
Expand Down
4 changes: 2 additions & 2 deletions dask_expr/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,9 +1424,9 @@ def _meta(self):
def aggregate(cls, inputs, **kwargs):
func = cls.reduction_aggregate or cls.reduction_chunk
if is_scalar(inputs[-1]):
return func(_concat(inputs[:-1]), inputs[-1], **kwargs)
return func(_concat(inputs[:-1]), inputs[-1], observed=True, **kwargs)
else:
return func(_concat(inputs), **kwargs)
return func(_concat(inputs), observed=True, **kwargs)

@property
def split_by(self):
Expand Down
16 changes: 16 additions & 0 deletions dask_expr/tests/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,19 @@ def test_unique_numerical_columns(key):
assert_eq(
df[key].unique(), pd.Series(pdf[key].unique(), name=key), check_index=False
)


def test_cat_value_counts_large_unknown_categories():
pdf = pd.DataFrame({"x": np.random.randint(1, 1_000_000, (250_000,))})
df = from_pandas(pdf, npartitions=50)
df["x"] = df["x"].astype("category")
result = df.x.value_counts()
assert result.npartitions == 50 # unknown
pdf["x"] = pdf["x"].astype("category")
expected = pdf.x.value_counts()
assert_eq(result, expected, check_index=False, check_dtype=False)

df = from_pandas(pdf, npartitions=50)
result = df.x.value_counts()
assert result.npartitions == 3 # known but large
assert_eq(result, expected, check_index=False, check_dtype=False)

0 comments on commit 8d2c1be

Please sign in to comment.