Skip to content

Commit

Permalink
Fix bug in remove_join_expressions (apache#11693)
Browse files Browse the repository at this point in the history
* Fix bug in `remove_join_expressions`

* Update datafusion/optimizer/src/eliminate_cross_join.rs

Co-authored-by: Andrew Lamb <[email protected]>

* fmt

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
jonahgao and alamb authored Jul 30, 2024
1 parent a591301 commit 2f5e73c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
34 changes: 28 additions & 6 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,7 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
None
}
// Fix for issue#78 join predicates from inside of OR expr also pulled up properly.
Expr::BinaryExpr(BinaryExpr { left, op, right })
if matches!(op, Operator::And | Operator::Or) =>
{
Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => {
let l = remove_join_expressions(*left, join_keys);
let r = remove_join_expressions(*right, join_keys);
match (l, r) {
Expand All @@ -402,7 +400,20 @@ fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> {
_ => None,
}
}

Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => {
let l = remove_join_expressions(*left, join_keys);
let r = remove_join_expressions(*right, join_keys);
match (l, r) {
(Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
Box::new(ll),
op,
Box::new(rr),
))),
// When either `left` or `right` is empty, it means they are `true`
// so OR'ing anything with them will also be true
_ => None,
}
}
_ => Some(expr),
}
}
Expand Down Expand Up @@ -995,6 +1006,7 @@ mod tests {
let t4 = test_table_scan_with_name("t4")?;

// could eliminate to inner join
// filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688)
let plan1 = LogicalPlanBuilder::from(t1)
.cross_join(t2)?
.filter(binary_expr(
Expand All @@ -1012,6 +1024,10 @@ mod tests {
let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;

// could eliminate to inner join
// filter:
// ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
// AND
// ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
Expand Down Expand Up @@ -1057,7 +1073,7 @@ mod tests {
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
Expand All @@ -1084,6 +1100,12 @@ mod tests {
let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?;

// could eliminate to inner join
// Filter:
// ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688))
// AND
// ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b))
// AND
// ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688))
let plan = LogicalPlanBuilder::from(plan1)
.cross_join(plan2)?
.filter(binary_expr(
Expand Down Expand Up @@ -1142,7 +1164,7 @@ mod tests {
.build()?;

let expected = vec![
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
Expand Down
15 changes: 14 additions & 1 deletion datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,6 @@ statement ok
DROP TABLE t3;


# Test issue: https://github.com/apache/datafusion/issues/11275
statement ok
CREATE TABLE t0 (v1 BOOLEAN) AS VALUES (false), (null);

Expand All @@ -1033,6 +1032,7 @@ CREATE TABLE t1 (v1 BOOLEAN) AS VALUES (false), (null), (false);
statement ok
CREATE TABLE t2 (v1 BOOLEAN) AS VALUES (false), (true);

# Test issue: https://github.com/apache/datafusion/issues/11275
query BB
SELECT t2.v1, t1.v1 FROM t0, t1, t2 WHERE t2.v1 IS DISTINCT FROM t0.v1 ORDER BY 1,2;
----
Expand All @@ -1046,6 +1046,19 @@ true false
true NULL
true NULL

# Test issue: https://github.com/apache/datafusion/issues/11621
query BB
SELECT * FROM t1 JOIN t2 ON t1.v1 = t2.v1 WHERE (t1.v1 == t2.v1) OR t1.v1;
----
false false
false false

query BB
SELECT * FROM t1 JOIN t2 ON t1.v1 = t2.v1 WHERE t1.v1 OR (t1.v1 == t2.v1);
----
false false
false false

statement ok
DROP TABLE t0;

Expand Down

0 comments on commit 2f5e73c

Please sign in to comment.