Skip to content

Commit

Permalink
check result and fix divisions
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Oct 17, 2024
1 parent e277a13 commit 7bf60d2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
22 changes: 17 additions & 5 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,21 @@ def _divisions(self):
meta_index_names = set(self._meta.index.names)
if (
self.broadcast_side == "left"
and self.right_index
and set(self.right._meta.index.names) == meta_index_names
):
return self._bcast_right.divisions
if self.right_index:
return self._bcast_right.divisions
_npartitions = self._bcast_right.npartitions
elif (
self.broadcast_side == "right"
and set(self.left._meta.index.names) == meta_index_names
):
return self._bcast_left.divisions
_npartitions = max(self.left.npartitions, self.right.npartitions)
if self.left_index:
return self._bcast_left.divisions
_npartitions = self._bcast_left.npartitions
else:
_npartitions = max(self.left.npartitions, self.right.npartitions)

else:
_npartitions = self._npartitions
Expand Down Expand Up @@ -718,8 +724,14 @@ class BroadcastJoin(Merge, PartitionsFiltered):

def _divisions(self):
if self.broadcast_side == "left":
return self.right.divisions
return self.left.divisions
if self.right_index:
return self.right.divisions
npartitions = self.right.npartitions
else:
if self.left_index:
return self.left.divisions
npartitions = self.left.npartitions
return (None,) * (npartitions + 1)

def _simplify_up(self, parent, dependents):
return
Expand Down
21 changes: 12 additions & 9 deletions dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def test_broadcast_merge(how, npartitions):

# Check result with/without fusion
expect = pdf1.merge(pdf2, on="x", how=how)
# TODO: This is incorrect, but consistent with dask/dask
assert_eq(df3, expect, check_index=False, check_divisions=False)
assert_eq(df3.optimize(), expect, check_index=False, check_divisions=False)
assert_eq(df3, expect, check_index=False)
assert_eq(df3.optimize(), expect, check_index=False)


def test_merge_column_projection():
Expand Down Expand Up @@ -436,8 +435,7 @@ def test_recursive_join():

ddf_pairwise = ddf_pairwise.join(dfs_to_merge, how="left")

# TODO: divisions is None for recursive join for now
assert_eq(ddf_pairwise, ddf_loop, check_divisions=False)
assert_eq(ddf_pairwise, ddf_loop)


def test_merge_repartition():
Expand Down Expand Up @@ -1086,14 +1084,19 @@ def test_merge_tuple_left_on():


def test_merged_partitions_filtered():
a = from_dict({"x": range(1000), "y": [1, 2, 3, 4] * 250}, npartitions=10)
a = from_dict(
{"x": range(1000), "y": [1, 2, 3, 4] * 250}, npartitions=10
).partitions[:5]
b = from_dict({"xx": range(100), "yy": [1, 2] * 50}, npartitions=3)
result = a.merge(b, left_on=["y"], right_on=["yy"], how="inner", broadcast=True)

result = a.partitions[:5].merge(
b, left_on=["y"], right_on=["yy"], how="inner", broadcast=True
)
# Check expression properties
expr = result.optimize(fuse=False).expr
assert not expr._filtered
assert expr.left._filtered
assert expr.divisions == expr._divisions()
assert len(expr.divisions) == 6

# Check result
expect = a.compute().merge(b.compute(), left_on=["y"], right_on=["yy"], how="inner")
assert_eq(result, expect, check_index=False)

0 comments on commit 7bf60d2

Please sign in to comment.