Skip to content

Commit

Permalink
Fix bug where we had extra groups in expected_groups. (#112)
Browse files Browse the repository at this point in the history
* Fix bug where we had extra groups in expected_groups.

This affected _factorize_multiple.

Closes #111

* Fix extra expected groups (#113)

* fix dask case

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Co-authored-by: LunarLanding <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 23, 2022
1 parent a0b9d1f commit 5e0b793
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
9 changes: 6 additions & 3 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,11 +1310,12 @@ def _lazy_factorize_wrapper(*by, **kwargs):
return group_idx


def _factorize_multiple(by, expected_groups, by_is_dask):
def _factorize_multiple(by, expected_groups, by_is_dask, reindex):
kwargs = dict(
expected_groups=expected_groups,
axis=None, # always None, we offset later if necessary.
fastpath=True,
reindex=reindex,
)
if by_is_dask:
import dask.array
Expand All @@ -1325,7 +1326,9 @@ def _factorize_multiple(by, expected_groups, by_is_dask):
meta=np.array((), dtype=np.int64),
**kwargs,
)
found_groups = tuple(None if is_duck_dask_array(b) else pd.unique(b) for b in by)
found_groups = tuple(
None if is_duck_dask_array(b) else pd.unique(np.array(b).reshape(-1)) for b in by
)
grp_shape = tuple(len(e) for e in expected_groups)
else:
group_idx, found_groups, grp_shape = factorize_(by, **kwargs)
Expand Down Expand Up @@ -1489,7 +1492,7 @@ def groupby_reduce(
)
if factorize_early:
by, final_groups, grp_shape = _factorize_multiple(
by, expected_groups, by_is_dask=by_is_dask
by, expected_groups, by_is_dask=by_is_dask, reindex=reindex
)
expected_groups = (pd.RangeIndex(np.prod(grp_shape)),)

Expand Down
34 changes: 34 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,37 @@ def test_groupby_bins_indexed_coordinate():
method="split-reduce",
)
xr.testing.assert_allclose(expected, actual)


@pytest.mark.parametrize("chunk", (True, False))
def test_mixed_grouping(chunk):
if not has_dask and chunk:
pytest.skip()
# regression test for https://github.com/dcherian/flox/pull/111
sa = 10
sb = 13
sc = 3

x = xr.Dataset(
{
"v0": xr.DataArray(
((np.arange(sa * sb * sc) / sa) % 1).reshape((sa, sb, sc)),
dims=("a", "b", "c"),
),
"v1": xr.DataArray((np.arange(sa * sb) % 3).reshape(sa, sb), dims=("a", "b")),
}
)
if chunk:
x["v0"] = x["v0"].chunk({"a": 5})

r = xarray_reduce(
x["v0"],
x["v1"],
x["v0"],
expected_groups=(np.arange(6), np.linspace(0, 1, num=5)),
isbin=[False, True],
func="count",
dim="b",
fill_value=0,
)
assert (r.sel(v1=[3, 4, 5]) == 0).all().data

0 comments on commit 5e0b793

Please sign in to comment.