Skip to content

Commit

Permalink
couple fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jan 10, 2024
1 parent f045f23 commit 331193d
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def dask(self):
# cause very unfortunate materializations. Even a mere hasattr(obj,
# "dask") check already triggers this since it's a property, not even a
# method.
return self.__dask_graph_factory__().materialize()
return self.__dask_graph_factory__().lower_completely().materialize()

def finalize_compute(self):
return new_collection(Repartition(self.expr, 1))
Expand Down
10 changes: 5 additions & 5 deletions dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,9 +864,9 @@ def test_persist(pdf, df):
b = a.persist()

assert_eq(a, b)
assert len(a.__dask_graph__()) > len(b.__dask_graph__())
assert len(a.dask) > len(b.dask)

assert len(b.__dask_graph__()) == b.npartitions
assert len(b.dask) == b.npartitions

assert_eq(b.y.sum(), (pdf + 2).y.sum())

Expand Down Expand Up @@ -1113,7 +1113,7 @@ def test_serialization(pdf, df):

part = df.partitions[0].compute()
assert (
len(pickle.dumps(df.__dask_graph__()))
len(pickle.dumps(df.__dask_graph_factory__()))
< 1000 + len(pickle.dumps(part)) * df.npartitions
)

Expand Down Expand Up @@ -1185,7 +1185,7 @@ def test_tree_repr(fuse):

def test_simple_graphs(df):
expr = (df + 1).expr
graph = expr.__dask_graph__()
graph = expr.materialize()

assert graph[(expr._name, 0)] == (operator.add, (df.expr._name, 0), 1)

Expand Down Expand Up @@ -1239,7 +1239,7 @@ def test_repartition_divisions(df, opt):
assert_eq((df + 1)["x"], df2)

# Check partitions
for p, part in enumerate(dask.compute(list(df2.index.partitions))[0]):
for p, part in enumerate(dask.compute(list(df2.index.partitions))):
if len(part):
assert part.min() >= df2.divisions[p]
assert part.max() < df2.divisions[p + 1]
Expand Down
2 changes: 1 addition & 1 deletion dask_expr/tests/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_fuse_broadcast_deps():
df3 = from_pandas(pdf3, npartitions=2)

query = df.merge(df2).merge(df3)
assert len(query.optimize().__dask_graph__()) == 2
assert len(query.optimize().materialize()) == 2
assert_eq(query, pdf.merge(pdf2).merge(pdf3))


Expand Down
4 changes: 1 addition & 3 deletions dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@ def test_merge_combine_similar(npartitions_left, npartitions_right):
query = df.merge(df2)
query["new"] = query.b + query.c
query = query.groupby(["a", "e", "x"]).new.sum()
assert (
len(query.optimize().__dask_graph__()) <= 25
) # 45 is the non-combined version
assert len(query.optimize().materialize()) <= 25 # 45 is the non-combined version

expected = pdf.merge(pdf2)
expected["new"] = expected.b + expected.c
Expand Down

0 comments on commit 331193d

Please sign in to comment.