Skip to content

Commit

Permalink
Fix grouping by multiindex (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored May 17, 2022
1 parent 5b7edbe commit 227ce04
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
24 changes: 15 additions & 9 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
import xarray as xr
from packaging.version import Version

from .aggregations import Aggregation, _atleast_1d
from .core import (
Expand Down Expand Up @@ -345,12 +346,16 @@ def wrapper(array, *by, func, skipna, **kwargs):
expect = expect.to_numpy()
if isinstance(actual, xr.Dataset) and name in actual:
actual = actual.drop_vars(name)
actual[name] = expect

# if grouping by multi-indexed variable, then restore it
for name, index in ds.indexes.items():
if name in actual.indexes and isinstance(index, pd.MultiIndex):
actual[name] = index
# When grouping by MultiIndex, expect is an pd.Index wrapping
# an object array of tuples
if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex):
levelnames = ds.indexes[name].names
expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames)
actual[name] = expect
if Version(xr.__version__) > Version("2022.03.0"):
actual = actual.set_coords(levelnames)
else:
actual[name] = expect

if unindexed_dims:
actual = actual.drop_vars(unindexed_dims)
Expand All @@ -361,7 +366,8 @@ def wrapper(array, *by, func, skipna, **kwargs):
template = obj
else:
template = obj[var]
actual[var] = _restore_dim_order(actual[var], template, by[0])
if actual[var].ndim > 1:
actual[var] = _restore_dim_order(actual[var], template, by[0])

if missing_dim:
for k, v in missing_dim.items():
Expand All @@ -370,9 +376,9 @@ def wrapper(array, *by, func, skipna, **kwargs):
}
# The expand_dims is for backward compat with xarray's questionable behaviour
if missing_group_dims:
actual[k] = v.expand_dims(missing_group_dims)
actual[k] = v.expand_dims(missing_group_dims).variable
else:
actual[k] = v
actual[k] = v.variable

if isinstance(obj, xr.DataArray):
return obj._from_temp_dataset(actual)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def test_multi_index_groupby_sum(engine):
actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine)
assert_equal(expected, actual.unstack("space"))

actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine)
assert_equal(expected.foo, actual.unstack("space"))

ds = xr.Dataset(
dict(a=(("z",), np.ones(10))),
coords=dict(b=(("z"), np.arange(2).repeat(5)), c=(("z"), np.arange(5).repeat(2))),
).set_index(bc=["b", "c"])
expected = ds.groupby("bc").sum()
actual = xarray_reduce(ds, "bc", func="sum")
assert_equal(expected, actual)


@pytest.mark.parametrize("chunks", (None, 2))
def test_xarray_groupby_bins(chunks, engine):
Expand Down

0 comments on commit 227ce04

Please sign in to comment.