Skip to content

Commit

Permalink
fix(rust,python,sql): rework SQL join constraint processing to proper…
Browse files Browse the repository at this point in the history
…ly account for all `USING` columns (#11518)
  • Loading branch information
alexander-beedie authored Oct 5, 2023
1 parent b657f8d commit dd88d2b
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 109 deletions.
41 changes: 13 additions & 28 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry};
use crate::sql_expr::{parse_sql_expr, process_join_constraint};
use crate::sql_expr::{parse_sql_expr, process_join};
use crate::table_functions::PolarsTableFunctions;

/// The SQLContext is the main entry point for executing SQL queries.
Expand Down Expand Up @@ -266,50 +266,36 @@ impl SQLContext {

/// execute the 'FROM' part of the query
fn execute_from_statement(&mut self, tbl_expr: &TableWithJoins) -> PolarsResult<LazyFrame> {
let (tbl_name, mut lf) = self.get_table(&tbl_expr.relation)?;
let (l_name, mut lf) = self.get_table(&tbl_expr.relation)?;
if !tbl_expr.joins.is_empty() {
for tbl in &tbl_expr.joins {
let (join_tbl_name, join_tbl) = self.get_table(&tbl.relation)?;
let (r_name, rf) = self.get_table(&tbl.relation)?;
lf = match &tbl.join_operator {
JoinOperator::CrossJoin => lf.cross_join(join_tbl),
JoinOperator::CrossJoin => lf.cross_join(rf),
JoinOperator::FullOuter(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
lf.outer_join(join_tbl, left_on, right_on)
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Outer)?
},
JoinOperator::Inner(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
lf.inner_join(join_tbl, left_on, right_on)
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)?
},
JoinOperator::LeftOuter(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftAnti(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
lf.anti_join(join_tbl, left_on, right_on)
},
JoinOperator::LeftOuter(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
lf.left_join(join_tbl, left_on, right_on)
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftSemi(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
lf.semi_join(join_tbl, left_on, right_on)
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::RightAnti(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
join_tbl.anti_join(lf, right_on, left_on)
process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::RightSemi(constraint) => {
let (left_on, right_on) =
process_join_constraint(constraint, &tbl_name, &join_tbl_name)?;
join_tbl.semi_join(lf, right_on, left_on)
process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)?
},
join_type => {
polars_bail!(
Expand All @@ -320,7 +306,6 @@ impl SQLContext {
}
}
};

Ok(lf)
}

Expand Down
32 changes: 26 additions & 6 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,30 @@ pub(crate) fn parse_sql_expr(expr: &SqlExpr, ctx: &mut SQLContext) -> PolarsResu
visitor.visit_expr(expr)
}

pub(super) fn process_join(
left_tbl: LazyFrame,
right_tbl: LazyFrame,
constraint: &JoinConstraint,
tbl_name: &str,
join_tbl_name: &str,
join_type: JoinType,
) -> PolarsResult<LazyFrame> {
let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?;

Ok(left_tbl
.join_builder()
.with(right_tbl)
.left_on(left_on)
.right_on(right_on)
.how(join_type)
.finish())
}

pub(super) fn process_join_constraint(
constraint: &JoinConstraint,
left_name: &str,
right_name: &str,
) -> PolarsResult<(Expr, Expr)> {
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint {
match (left.as_ref(), right.as_ref()) {
(SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) => {
Expand All @@ -639,23 +658,24 @@ pub(super) fn process_join_constraint(

if let BinaryOperator::Eq = op {
if left_name == tbl_a && right_name == tbl_b {
return Ok((col(col_a), col(col_b)));
return Ok((vec![col(col_a)], vec![col(col_b)]));
} else if left_name == tbl_b && right_name == tbl_a {
return Ok((col(col_b), col(col_a)));
return Ok((vec![col(col_b)], vec![col(col_a)]));
}
}
}
},
(SqlExpr::Identifier(left), SqlExpr::Identifier(right)) => {
return Ok((col(&left.value), col(&right.value)))
return Ok((vec![col(&left.value)], vec![col(&right.value)]))
},
_ => {},
}
}
if let JoinConstraint::Using(idents) = constraint {
if !idents.is_empty() {
let cols = &idents[0].value;
return Ok((col(cols), col(cols)));
let mut using = Vec::with_capacity(idents.len());
using.extend(idents.iter().map(|id| col(&id.value)));
return Ok((using.clone(), using.clone()));
}
}
polars_bail!(InvalidOperation: "SQL join constraint {:?} is not yet supported", constraint);
Expand Down
142 changes: 67 additions & 75 deletions py-polars/tests/unit/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,92 +540,84 @@ def test_sql_limit_offset() -> None:
assert len(out) == min(limit, n_values - offset)


def test_sql_join_anti_semi() -> None:
@pytest.mark.parametrize(
("sql", "expected"),
[
(
"SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (a,c)",
pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}),
),
(
"SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (a)",
pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}),
),
(
"SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b USING (a)",
pl.DataFrame(schema={"a": pl.Int64, "b": pl.Int64, "c": pl.Utf8}),
),
(
"SELECT * FROM tbl_a LEFT SEMI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)",
pl.DataFrame({"a": [1, 3], "b": [4, 6], "c": ["w", "z"]}),
),
(
"SELECT * FROM tbl_a LEFT ANTI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)",
pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}),
),
(
"SELECT * FROM tbl_a RIGHT ANTI JOIN tbl_b USING (b) LEFT SEMI JOIN tbl_c USING (c)",
pl.DataFrame({"a": [2], "b": [5], "c": ["y"]}),
),
(
"SELECT * FROM tbl_a RIGHT SEMI JOIN tbl_b USING (b) RIGHT SEMI JOIN tbl_c USING (c)",
pl.DataFrame({"c": ["z"], "d": [25.5]}),
),
(
"SELECT * FROM tbl_a RIGHT SEMI JOIN tbl_b USING (b) RIGHT ANTI JOIN tbl_c USING (c)",
pl.DataFrame({"c": ["w", "y"], "d": [10.5, -50.0]}),
),
],
)
def test_sql_join_anti_semi(sql: str, expected: pl.DataFrame) -> None:
frames = {
"tbl_a": pl.DataFrame({"a": [1, 2, 3], "b": [4, 0, 6], "c": ["w", "y", "z"]}),
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
}
c = pl.SQLContext(frames, eager_execution=True)
assert_frame_equal(expected, c.execute(sql))

out = c.execute(
"""
SELECT *
FROM tbl_a
LEFT SEMI JOIN tbl_b USING (b)
LEFT SEMI JOIN tbl_c USING (c)
"""
)
assert_frame_equal(pl.DataFrame({"a": [1, 3], "b": [4, 6], "c": ["w", "z"]}), out)

out = c.execute(
"""
SELECT *
FROM tbl_a
LEFT ANTI JOIN tbl_b USING (b)
LEFT SEMI JOIN tbl_c USING (c)
"""
)
assert_frame_equal(pl.DataFrame({"a": [2], "b": [0], "c": ["y"]}), out)

out = c.execute(
"""
SELECT *
FROM tbl_a
RIGHT ANTI JOIN tbl_b USING (b)
LEFT SEMI JOIN tbl_c USING (c)
"""
)
assert_frame_equal(pl.DataFrame({"a": [2], "b": [5], "c": ["y"]}), out)
@pytest.mark.parametrize(
"join_clause",
[
"ON foods1.category = foods2.category",
"ON foods2.category = foods1.category",
"USING (category)",
],
)
def test_sql_join_inner(foods_ipc_path: Path, join_clause: str) -> None:
lf = pl.scan_ipc(foods_ipc_path)

out = c.execute(
"""
SELECT *
FROM tbl_a
RIGHT SEMI JOIN tbl_b USING (b)
RIGHT SEMI JOIN tbl_c USING (c)
"""
)
assert_frame_equal(pl.DataFrame({"c": ["z"], "d": [25.5]}), out)
c = pl.SQLContext()
c.register_many(foods1=lf, foods2=lf)

out = c.execute(
"""
f"""
SELECT *
FROM tbl_a
RIGHT SEMI JOIN tbl_b USING (b)
RIGHT ANTI JOIN tbl_c USING (c)
FROM foods1
INNER JOIN foods2 {join_clause}
LIMIT 2
"""
)
assert_frame_equal(pl.DataFrame({"c": ["w", "y"], "d": [10.5, -50.0]}), out)


def test_sql_join_inner(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)

c = pl.SQLContext()
c.register_many(foods1=lf, foods2=lf)

for join_clause in (
"ON foods1.category = foods2.category",
"USING (category)",
):
out = c.execute(
f"""
SELECT *
FROM foods1
INNER JOIN foods2 {join_clause}
LIMIT 2
"""
)
assert out.collect().to_dict(False) == {
"category": ["vegetables", "vegetables"],
"calories": [45, 20],
"fats_g": [0.5, 0.0],
"sugars_g": [2, 2],
"calories_right": [45, 45],
"fats_g_right": [0.5, 0.5],
"sugars_g_right": [2, 2],
}
assert out.collect().to_dict(False) == {
"category": ["vegetables", "vegetables"],
"calories": [45, 20],
"fats_g": [0.5, 0.0],
"sugars_g": [2, 2],
"calories_right": [45, 45],
"fats_g_right": [0.5, 0.5],
"sugars_g_right": [2, 2],
}


def test_sql_join_left() -> None:
Expand All @@ -641,13 +633,13 @@ def test_sql_join_left() -> None:
FROM tbl_a
LEFT JOIN tbl_b USING (a,b)
LEFT JOIN tbl_c USING (c)
ORDER BY c DESC
ORDER BY a DESC
"""
)
assert out.collect().rows() == [
(1, 4, "z", 25.5),
(2, None, "y", -50.0),
(3, 6, "x", None),
(2, None, None, None),
(1, 4, "z", 25.5),
]
assert c.tables() == ["tbl_a", "tbl_b", "tbl_c"]

Expand Down

0 comments on commit dd88d2b

Please sign in to comment.