diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 8fb767b7..43e18eb0 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -162,7 +162,7 @@ def _wrap_expr_op(self, other, op=None): elif ( expr.are_co_aligned(self.expr, other, allow_broadcast=False) or other.npartitions == self.npartitions == 1 - and (self.ndim > other.ndim or self.ndim == 0) + or min(self.ndim, other.ndim) == 0 ): return new_collection(getattr(self.expr, op)(other)) else: diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index d1d80e7c..85500046 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -1927,8 +1927,7 @@ def test_avoid_alignment(): assert_eq(a.x + b.y, da.x + db.y) assert not any(isinstance(ex, AlignPartitions) for ex in (db.y + db.z).walk()) - # TODO: We can potentially do better here - assert any(isinstance(ex, AlignPartitions) for ex in (da.x + db.y.sum()).walk()) + assert not any(isinstance(ex, AlignPartitions) for ex in (da.x + db.y.sum()).walk()) @pytest.mark.xfail(reason="can't hash HLG") diff --git a/dask_expr/tests/test_resample.py b/dask_expr/tests/test_resample.py index 4276c22d..716c6b04 100644 --- a/dask_expr/tests/test_resample.py +++ b/dask_expr/tests/test_resample.py @@ -15,7 +15,7 @@ def resample(df, freq, how="mean", **kwargs): @pytest.fixture def pdf(): - idx = pd.date_range("2000-01-01", periods=12, freq="T") + idx = pd.date_range("2000-01-01", periods=12, freq="min") pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx) pdf["bar"] = 1 yield pdf @@ -49,20 +49,20 @@ def df(pdf): ], ) def test_resample_apis(df, pdf, api, kwargs): - result = getattr(df.resample("2T", **kwargs), api)() - expected = getattr(pdf.resample("2T", **kwargs), api)() + result = getattr(df.resample("2min", **kwargs), api)() + expected = getattr(pdf.resample("2min", **kwargs), api)() assert_eq(result, expected) # No column output if api not in ("size",): - result = getattr(df.resample("2T"), api)()["foo"] - expected = getattr(pdf.resample("2T"), api)()["foo"] + result = getattr(df.resample("2min"), api)()["foo"] + expected = getattr(pdf.resample("2min"), api)()["foo"] assert_eq(result, expected) if api != "ohlc": # ohlc actually gives back a DataFrame, so this doesn't work q = result.simplify() - eq = getattr(df["foo"].resample("2T"), api)().simplify() + eq = getattr(df["foo"].resample("2min"), api)().simplify() assert q._name == eq._name @@ -73,7 +73,7 @@ def test_resample_apis(df, pdf, api, kwargs): ["series", "frame"], ["count", "mean", "ohlc"], [2, 5], - ["30min", "h", "d", "w"], + ["30min", "h", "d", "W"], ["right", "left"], ["right", "left"], ) @@ -104,17 +104,17 @@ def test_resample_agg(df, pdf): def my_sum(vals, foo=None, *, bar=None): return vals.sum() - result = df.resample("2T").agg(my_sum, "foo", bar="bar") - expected = pdf.resample("2T").agg(my_sum, "foo", bar="bar") + result = df.resample("2min").agg(my_sum, "foo", bar="bar") + expected = pdf.resample("2min").agg(my_sum, "foo", bar="bar") assert_eq(result, expected) - result = df.resample("2T").agg(my_sum)["foo"] - expected = pdf.resample("2T").agg(my_sum)["foo"] + result = df.resample("2min").agg(my_sum)["foo"] + expected = pdf.resample("2min").agg(my_sum)["foo"] assert_eq(result, expected) # simplify up disabled for `agg`, function may access other columns - q = df.resample("2T").agg(my_sum)["foo"].simplify() - eq = df["foo"].resample("2T").agg(my_sum).simplify() + q = df.resample("2min").agg(my_sum)["foo"].simplify() + eq = df["foo"].resample("2min").agg(my_sum).simplify() assert q._name != eq._name diff --git a/dask_expr/tests/test_rolling.py b/dask_expr/tests/test_rolling.py index fe4bcc50..b8091909 100644 --- a/dask_expr/tests/test_rolling.py +++ b/dask_expr/tests/test_rolling.py @@ -11,7 +11,7 @@ @pytest.fixture def pdf(): - idx = pd.date_range("2000-01-01", periods=12, freq="T") + idx = pd.date_range("2000-01-01", periods=12, freq="min") pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx) pdf["bar"] = 1 yield pdf diff --git a/dask_expr/tests/test_shuffle.py b/dask_expr/tests/test_shuffle.py index d581cb77..3c3ee42a 100644 --- a/dask_expr/tests/test_shuffle.py +++ b/dask_expr/tests/test_shuffle.py @@ -611,7 +611,7 @@ def test_index_nulls(null_value): ).compute() -@pytest.mark.parametrize("freq", ["16H", "-16H"]) +@pytest.mark.parametrize("freq", ["16h", "-16h"]) def test_set_index_with_dask_dt_index(freq): values = { "x": [1, 2, 3, 4] * 3,