Skip to content

Commit

Permalink
fix(rust,python,cli): catch use of non equi-joins in SQL interface an…
Browse files Browse the repository at this point in the history
…d raise appropriate error (#11526)
  • Loading branch information
alexander-beedie authored Oct 5, 2023
1 parent db45f5d commit 9b632bd
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 41 deletions.
25 changes: 13 additions & 12 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -648,20 +648,21 @@ pub(super) fn process_join_constraint(
right_name: &str,
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint {
if op != &BinaryOperator::Eq {
polars_bail!(InvalidOperation:
"SQL interface (currently) only supports basic equi-join \
constraints; found '{:?}' op in\n{:?}", op, constraint)
}
match (left.as_ref(), right.as_ref()) {
(SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) => {
if left.len() == 2 && right.len() == 2 {
let tbl_a = &left[0].value;
let col_a = &left[1].value;
let tbl_b = &right[0].value;
let col_b = &right[1].value;

if let BinaryOperator::Eq = op {
if left_name == tbl_a && right_name == tbl_b {
return Ok((vec![col(col_a)], vec![col(col_b)]));
} else if left_name == tbl_b && right_name == tbl_a {
return Ok((vec![col(col_b)], vec![col(col_a)]));
}
let (tbl_a, col_a) = (&left[0].value, &left[1].value);
let (tbl_b, col_b) = (&right[0].value, &right[1].value);

if left_name == tbl_a && right_name == tbl_b {
return Ok((vec![col(col_a)], vec![col(col_b)]));
} else if left_name == tbl_b && right_name == tbl_a {
return Ok((vec![col(col_b)], vec![col(col_a)]));
}
}
},
Expand All @@ -678,7 +679,7 @@ pub(super) fn process_join_constraint(
return Ok((using.clone(), using.clone()));
}
}
polars_bail!(InvalidOperation: "SQL join constraint {:?} is not yet supported", constraint);
polars_bail!(InvalidOperation: "Unsupported SQL join constraint:\n{:?}", constraint);
}

/// parse a SQL expression to a polars expression
Expand Down
77 changes: 48 additions & 29 deletions py-polars/tests/unit/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def test_sql_distinct() -> None:
"b": [1, 2, 3, 4, 5, 6],
}
)
c = pl.SQLContext(register_globals=True, eager_execution=True)
res1 = c.execute("SELECT DISTINCT a FROM df ORDER BY a DESC")
ctx = pl.SQLContext(register_globals=True, eager_execution=True)
res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC")
assert_frame_equal(
left=df.select("a").unique().sort(by="a", descending=True),
right=res1,
)

res2 = c.execute(
res2 = ctx.execute(
"""
SELECT DISTINCT
a*2 AS two_a,
Expand All @@ -141,9 +141,9 @@ def test_sql_distinct() -> None:
}

# test unregistration
c.unregister("df")
ctx.unregister("df")
with pytest.raises(pl.ComputeError, match=".*'df'.*not found"):
c.execute("SELECT * FROM df")
ctx.execute("SELECT * FROM df")


def test_sql_div() -> None:
Expand Down Expand Up @@ -242,8 +242,8 @@ def test_sql_trig() -> None:
}
)

c = pl.SQLContext(df=df)
res = c.execute(
ctx = pl.SQLContext(df=df)
res = ctx.execute(
"""
SELECT
asin(1.0)/a as "pi values",
Expand Down Expand Up @@ -457,10 +457,10 @@ def test_sql_trig() -> None:
def test_sql_group_by(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)

c = pl.SQLContext(eager_execution=True)
c.register("foods", lf)
ctx = pl.SQLContext(eager_execution=True)
ctx.register("foods", lf)

out = c.execute(
out = ctx.execute(
"""
SELECT
category,
Expand All @@ -486,12 +486,12 @@ def test_sql_group_by(foods_ipc_path: Path) -> None:
"att": ["x", "y", "x", "y", "y"],
}
)
assert c.tables() == ["foods"]
assert ctx.tables() == ["foods"]

c.register("test", lf)
assert c.tables() == ["foods", "test"]
ctx.register("test", lf)
assert ctx.tables() == ["foods", "test"]

out = c.execute(
out = ctx.execute(
"""
SELECT
grp,
Expand Down Expand Up @@ -527,15 +527,17 @@ def test_sql_left() -> None:
def test_sql_limit_offset() -> None:
n_values = 11
lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))})
c = pl.SQLContext(tbl=lf)
ctx = pl.SQLContext(tbl=lf)

assert c.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [
assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [
(4, 6),
(5, 5),
(6, 4),
]
for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]:
out = c.execute(f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True)
out = ctx.execute(
f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True
)
assert_frame_equal(out, lf.slice(offset, limit).collect())
assert len(out) == min(limit, n_values - offset)

Expand Down Expand Up @@ -583,8 +585,8 @@ def test_sql_join_anti_semi(sql: str, expected: pl.DataFrame) -> None:
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
}
c = pl.SQLContext(frames, eager_execution=True)
assert_frame_equal(expected, c.execute(sql))
ctx = pl.SQLContext(frames, eager_execution=True)
assert_frame_equal(expected, ctx.execute(sql))


@pytest.mark.parametrize(
Expand All @@ -598,10 +600,10 @@ def test_sql_join_anti_semi(sql: str, expected: pl.DataFrame) -> None:
def test_sql_join_inner(foods_ipc_path: Path, join_clause: str) -> None:
lf = pl.scan_ipc(foods_ipc_path)

c = pl.SQLContext()
c.register_many(foods1=lf, foods2=lf)
ctx = pl.SQLContext()
ctx.register_many(foods1=lf, foods2=lf)

out = c.execute(
out = ctx.execute(
f"""
SELECT *
FROM foods1
Expand All @@ -626,8 +628,8 @@ def test_sql_join_left() -> None:
"tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}),
"tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}),
}
c = pl.SQLContext(frames)
out = c.execute(
ctx = pl.SQLContext(frames)
out = ctx.execute(
"""
SELECT a, b, c, d
FROM tbl_a
Expand All @@ -641,14 +643,32 @@ def test_sql_join_left() -> None:
(2, None, None, None),
(1, 4, "z", 25.5),
]
assert c.tables() == ["tbl_a", "tbl_b", "tbl_c"]
assert ctx.tables() == ["tbl_a", "tbl_b", "tbl_c"]


@pytest.mark.parametrize(
"constraint", ["tbl.a != tbl.b", "tbl.a > tbl.b", "a >= b", "a < b", "b <= a"]
)
def test_sql_non_equi_joins(constraint: str) -> None:
# no support (yet) for non equi-joins in polars joins
with pytest.raises(
pl.InvalidOperationError,
match=r"SQL interface \(currently\) only supports basic equi-join constraints",
), pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx:
ctx.execute(
f"""
SELECT *
FROM tbl
LEFT JOIN tbl ON {constraint} -- not an equi-join
"""
)


def test_sql_is_between(foods_ipc_path: Path) -> None:
lf = pl.scan_ipc(foods_ipc_path)

c = pl.SQLContext(foods1=lf, eager_execution=True)
out = c.execute(
ctx = pl.SQLContext(foods1=lf, eager_execution=True)
out = ctx.execute(
"""
SELECT *
FROM foods1
Expand All @@ -665,8 +685,7 @@ def test_sql_is_between(foods_ipc_path: Path) -> None:
("vegetables", 25, 0.0, 2),
("vegetables", 22, 0.0, 3),
]

out = c.execute(
out = ctx.execute(
"""
SELECT *
FROM foods1
Expand Down

0 comments on commit 9b632bd

Please sign in to comment.