Skip to content

Commit

Permalink
Make alignment conditions smarter (#820)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Jan 29, 2024
1 parent b2956e9 commit edee3b1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
2 changes: 1 addition & 1 deletion dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _wrap_expr_op(self, other, op=None):
elif (
expr.are_co_aligned(self.expr, other, allow_broadcast=False)
or other.npartitions == self.npartitions == 1
and (self.ndim > other.ndim or self.ndim == 0)
or min(self.ndim, other.ndim) == 0
):
return new_collection(getattr(self.expr, op)(other))
else:
Expand Down
3 changes: 1 addition & 2 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,8 +1927,7 @@ def test_avoid_alignment():
assert_eq(a.x + b.y, da.x + db.y)

assert not any(isinstance(ex, AlignPartitions) for ex in (db.y + db.z).walk())
# TODO: We can potentially do better here
assert any(isinstance(ex, AlignPartitions) for ex in (da.x + db.y.sum()).walk())
assert not any(isinstance(ex, AlignPartitions) for ex in (da.x + db.y.sum()).walk())


@pytest.mark.xfail(reason="can't hash HLG")
Expand Down
26 changes: 13 additions & 13 deletions dask_expr/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def resample(df, freq, how="mean", **kwargs):

@pytest.fixture
def pdf():
idx = pd.date_range("2000-01-01", periods=12, freq="T")
idx = pd.date_range("2000-01-01", periods=12, freq="min")
pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx)
pdf["bar"] = 1
yield pdf
Expand Down Expand Up @@ -49,20 +49,20 @@ def df(pdf):
],
)
def test_resample_apis(df, pdf, api, kwargs):
result = getattr(df.resample("2T", **kwargs), api)()
expected = getattr(pdf.resample("2T", **kwargs), api)()
result = getattr(df.resample("2min", **kwargs), api)()
expected = getattr(pdf.resample("2min", **kwargs), api)()
assert_eq(result, expected)

# No column output
if api not in ("size",):
result = getattr(df.resample("2T"), api)()["foo"]
expected = getattr(pdf.resample("2T"), api)()["foo"]
result = getattr(df.resample("2min"), api)()["foo"]
expected = getattr(pdf.resample("2min"), api)()["foo"]
assert_eq(result, expected)

if api != "ohlc":
# ohlc actually gives back a DataFrame, so this doesn't work
q = result.simplify()
eq = getattr(df["foo"].resample("2T"), api)().simplify()
eq = getattr(df["foo"].resample("2min"), api)().simplify()
assert q._name == eq._name


Expand All @@ -73,7 +73,7 @@ def test_resample_apis(df, pdf, api, kwargs):
["series", "frame"],
["count", "mean", "ohlc"],
[2, 5],
["30min", "h", "d", "w"],
["30min", "h", "d", "W"],
["right", "left"],
["right", "left"],
)
Expand Down Expand Up @@ -104,17 +104,17 @@ def test_resample_agg(df, pdf):
def my_sum(vals, foo=None, *, bar=None):
return vals.sum()

result = df.resample("2T").agg(my_sum, "foo", bar="bar")
expected = pdf.resample("2T").agg(my_sum, "foo", bar="bar")
result = df.resample("2min").agg(my_sum, "foo", bar="bar")
expected = pdf.resample("2min").agg(my_sum, "foo", bar="bar")
assert_eq(result, expected)

result = df.resample("2T").agg(my_sum)["foo"]
expected = pdf.resample("2T").agg(my_sum)["foo"]
result = df.resample("2min").agg(my_sum)["foo"]
expected = pdf.resample("2min").agg(my_sum)["foo"]
assert_eq(result, expected)

# simplify up disabled for `agg`, function may access other columns
q = df.resample("2T").agg(my_sum)["foo"].simplify()
eq = df["foo"].resample("2T").agg(my_sum).simplify()
q = df.resample("2min").agg(my_sum)["foo"].simplify()
eq = df["foo"].resample("2min").agg(my_sum).simplify()
assert q._name != eq._name


Expand Down
2 changes: 1 addition & 1 deletion dask_expr/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.fixture
def pdf():
idx = pd.date_range("2000-01-01", periods=12, freq="T")
idx = pd.date_range("2000-01-01", periods=12, freq="min")
pdf = pd.DataFrame({"foo": range(len(idx))}, index=idx)
pdf["bar"] = 1
yield pdf
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_index_nulls(null_value):
).compute()


@pytest.mark.parametrize("freq", ["16H", "-16H"])
@pytest.mark.parametrize("freq", ["16h", "-16h"])
def test_set_index_with_dask_dt_index(freq):
values = {
"x": [1, 2, 3, 4] * 3,
Expand Down

0 comments on commit edee3b1

Please sign in to comment.