diff --git a/flox/core.py b/flox/core.py index 7613c1f24..b4e68c23f 100644 --- a/flox/core.py +++ b/flox/core.py @@ -462,7 +462,7 @@ def factorize_( group_idx = factorized[0] if fastpath: - return group_idx, found_groups, grp_shape + return group_idx.reshape(by[0].shape), found_groups, grp_shape if np.isscalar(axis) and groupvar.ndim > 1: # Not reducing along all dimensions of by diff --git a/tests/test_xarray.py b/tests/test_xarray.py index cb74617ea..5f779d0ad 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -430,3 +430,24 @@ def test_datetime_array_reduce(use_cftime, func): expected = getattr(time.resample(time="YS"), func)() actual = resample_reduce(time.resample(time="YS"), func=func, engine="flox") assert_equal(expected, actual) + + +@requires_dask +def test_groupby_bins_indexed_coordinate(): + ds = ( + xr.tutorial.open_dataset("air_temperature") + .isel(time=slice(100)) + .chunk({"time": 20, "lat": 5}) + ) + bins = [40, 50, 60, 70] + expected = ds.groupby_bins("lat", bins=bins).mean(keep_attrs=True, dim=...) + actual = xarray_reduce( + ds, + ds.lat, + dim=ds.air.dims, + expected_groups=([40, 50, 60, 70],), + isbin=(True,), + func="mean", + method="split-reduce", + ) + xr.testing.assert_allclose(expected, actual)