diff --git a/flox/xarray.py b/flox/xarray.py index 87fbd78d7..95edb843c 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import xarray as xr +from packaging.version import Version from .aggregations import Aggregation, _atleast_1d from .core import ( @@ -345,12 +346,16 @@ def wrapper(array, *by, func, skipna, **kwargs): expect = expect.to_numpy() if isinstance(actual, xr.Dataset) and name in actual: actual = actual.drop_vars(name) - actual[name] = expect - - # if grouping by multi-indexed variable, then restore it - for name, index in ds.indexes.items(): - if name in actual.indexes and isinstance(index, pd.MultiIndex): - actual[name] = index + # When grouping by MultiIndex, expect is an pd.Index wrapping + # an object array of tuples + if name in ds.indexes and isinstance(ds.indexes[name], pd.MultiIndex): + levelnames = ds.indexes[name].names + expect = pd.MultiIndex.from_tuples(expect.values, names=levelnames) + actual[name] = expect + if Version(xr.__version__) > Version("2022.03.0"): + actual = actual.set_coords(levelnames) + else: + actual[name] = expect if unindexed_dims: actual = actual.drop_vars(unindexed_dims) @@ -361,7 +366,8 @@ def wrapper(array, *by, func, skipna, **kwargs): template = obj else: template = obj[var] - actual[var] = _restore_dim_order(actual[var], template, by[0]) + if actual[var].ndim > 1: + actual[var] = _restore_dim_order(actual[var], template, by[0]) if missing_dim: for k, v in missing_dim.items(): @@ -370,9 +376,9 @@ def wrapper(array, *by, func, skipna, **kwargs): } # The expand_dims is for backward compat with xarray's questionable behaviour if missing_group_dims: - actual[k] = v.expand_dims(missing_group_dims) + actual[k] = v.expand_dims(missing_group_dims).variable else: - actual[k] = v + actual[k] = v.variable if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(actual) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 52a28e549..a25bb5559 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -321,6 +321,17 @@ def test_multi_index_groupby_sum(engine): actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine) assert_equal(expected, actual.unstack("space")) + actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine) + assert_equal(expected.foo, actual.unstack("space")) + + ds = xr.Dataset( + dict(a=(("z",), np.ones(10))), + coords=dict(b=(("z"), np.arange(2).repeat(5)), c=(("z"), np.arange(5).repeat(2))), + ).set_index(bc=["b", "c"]) + expected = ds.groupby("bc").sum() + actual = xarray_reduce(ds, "bc", func="sum") + assert_equal(expected, actual) + @pytest.mark.parametrize("chunks", (None, 2)) def test_xarray_groupby_bins(chunks, engine):