diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 7a7d07fef04a..230c1dd89453 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -648,20 +648,21 @@ pub(super) fn process_join_constraint( right_name: &str, ) -> PolarsResult<(Vec, Vec)> { if let JoinConstraint::On(SqlExpr::BinaryOp { left, op, right }) = constraint { + if op != &BinaryOperator::Eq { + polars_bail!(InvalidOperation: + "SQL interface (currently) only supports basic equi-join \ + constraints; found '{:?}' op in\n{:?}", op, constraint) + } match (left.as_ref(), right.as_ref()) { (SqlExpr::CompoundIdentifier(left), SqlExpr::CompoundIdentifier(right)) => { if left.len() == 2 && right.len() == 2 { - let tbl_a = &left[0].value; - let col_a = &left[1].value; - let tbl_b = &right[0].value; - let col_b = &right[1].value; - - if let BinaryOperator::Eq = op { - if left_name == tbl_a && right_name == tbl_b { - return Ok((vec![col(col_a)], vec![col(col_b)])); - } else if left_name == tbl_b && right_name == tbl_a { - return Ok((vec![col(col_b)], vec![col(col_a)])); - } + let (tbl_a, col_a) = (&left[0].value, &left[1].value); + let (tbl_b, col_b) = (&right[0].value, &right[1].value); + + if left_name == tbl_a && right_name == tbl_b { + return Ok((vec![col(col_a)], vec![col(col_b)])); + } else if left_name == tbl_b && right_name == tbl_a { + return Ok((vec![col(col_b)], vec![col(col_a)])); } } }, @@ -678,7 +679,7 @@ pub(super) fn process_join_constraint( return Ok((using.clone(), using.clone())); } } - polars_bail!(InvalidOperation: "SQL join constraint {:?} is not yet supported", constraint); + polars_bail!(InvalidOperation: "Unsupported SQL join constraint:\n{:?}", constraint); } /// parse a SQL expression to a polars expression diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index c4d597c29061..f0bef2312644 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -119,14 +119,14 @@ def test_sql_distinct() -> None: "b": [1, 2, 3, 4, 5, 6], } ) - c = pl.SQLContext(register_globals=True, eager_execution=True) - res1 = c.execute("SELECT DISTINCT a FROM df ORDER BY a DESC") + ctx = pl.SQLContext(register_globals=True, eager_execution=True) + res1 = ctx.execute("SELECT DISTINCT a FROM df ORDER BY a DESC") assert_frame_equal( left=df.select("a").unique().sort(by="a", descending=True), right=res1, ) - res2 = c.execute( + res2 = ctx.execute( """ SELECT DISTINCT a*2 AS two_a, @@ -141,9 +141,9 @@ def test_sql_distinct() -> None: } # test unregistration - c.unregister("df") + ctx.unregister("df") with pytest.raises(pl.ComputeError, match=".*'df'.*not found"): - c.execute("SELECT * FROM df") + ctx.execute("SELECT * FROM df") def test_sql_div() -> None: @@ -242,8 +242,8 @@ def test_sql_trig() -> None: } ) - c = pl.SQLContext(df=df) - res = c.execute( + ctx = pl.SQLContext(df=df) + res = ctx.execute( """ SELECT asin(1.0)/a as "pi values", @@ -457,10 +457,10 @@ def test_sql_trig() -> None: def test_sql_group_by(foods_ipc_path: Path) -> None: lf = pl.scan_ipc(foods_ipc_path) - c = pl.SQLContext(eager_execution=True) - c.register("foods", lf) + ctx = pl.SQLContext(eager_execution=True) + ctx.register("foods", lf) - out = c.execute( + out = ctx.execute( """ SELECT category, @@ -486,12 +486,12 @@ def test_sql_group_by(foods_ipc_path: Path) -> None: "att": ["x", "y", "x", "y", "y"], } ) - assert c.tables() == ["foods"] + assert ctx.tables() == ["foods"] - c.register("test", lf) - assert c.tables() == ["foods", "test"] + ctx.register("test", lf) + assert ctx.tables() == ["foods", "test"] - out = c.execute( + out = ctx.execute( """ SELECT grp, @@ -527,15 +527,17 @@ def test_sql_left() -> None: def test_sql_limit_offset() -> None: n_values = 11 lf = pl.LazyFrame({"a": range(n_values), "b": reversed(range(n_values))}) - c = pl.SQLContext(tbl=lf) + ctx = pl.SQLContext(tbl=lf) - assert c.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [ + assert ctx.execute("SELECT * FROM tbl LIMIT 3 OFFSET 4", eager=True).rows() == [ (4, 6), (5, 5), (6, 4), ] for offset, limit in [(0, 3), (1, n_values), (2, 3), (5, 3), (8, 5), (n_values, 1)]: - out = c.execute(f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True) + out = ctx.execute( + f"SELECT * FROM tbl LIMIT {limit} OFFSET {offset}", eager=True + ) assert_frame_equal(out, lf.slice(offset, limit).collect()) assert len(out) == min(limit, n_values - offset) @@ -583,8 +585,8 @@ def test_sql_join_anti_semi(sql: str, expected: pl.DataFrame) -> None: "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), "tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}), } - c = pl.SQLContext(frames, eager_execution=True) - assert_frame_equal(expected, c.execute(sql)) + ctx = pl.SQLContext(frames, eager_execution=True) + assert_frame_equal(expected, ctx.execute(sql)) @pytest.mark.parametrize( @@ -598,10 +600,10 @@ def test_sql_join_anti_semi(sql: str, expected: pl.DataFrame) -> None: def test_sql_join_inner(foods_ipc_path: Path, join_clause: str) -> None: lf = pl.scan_ipc(foods_ipc_path) - c = pl.SQLContext() - c.register_many(foods1=lf, foods2=lf) + ctx = pl.SQLContext() + ctx.register_many(foods1=lf, foods2=lf) - out = c.execute( + out = ctx.execute( f""" SELECT * FROM foods1 @@ -626,8 +628,8 @@ def test_sql_join_left() -> None: "tbl_b": pl.DataFrame({"a": [3, 2, 1], "b": [6, 5, 4], "c": ["x", "y", "z"]}), "tbl_c": pl.DataFrame({"c": ["w", "y", "z"], "d": [10.5, -50.0, 25.5]}), } - c = pl.SQLContext(frames) - out = c.execute( + ctx = pl.SQLContext(frames) + out = ctx.execute( """ SELECT a, b, c, d FROM tbl_a @@ -641,14 +643,32 @@ def test_sql_join_left() -> None: (2, None, None, None), (1, 4, "z", 25.5), ] - assert c.tables() == ["tbl_a", "tbl_b", "tbl_c"] + assert ctx.tables() == ["tbl_a", "tbl_b", "tbl_c"] + + +@pytest.mark.parametrize( + "constraint", ["tbl.a != tbl.b", "tbl.a > tbl.b", "a >= b", "a < b", "b <= a"] +) +def test_sql_non_equi_joins(constraint: str) -> None: + # no support (yet) for non equi-joins in polars joins + with pytest.raises( + pl.InvalidOperationError, + match=r"SQL interface \(currently\) only supports basic equi-join constraints", + ), pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx: + ctx.execute( + f""" + SELECT * + FROM tbl + LEFT JOIN tbl ON {constraint} -- not an equi-join + """ + ) def test_sql_is_between(foods_ipc_path: Path) -> None: lf = pl.scan_ipc(foods_ipc_path) - c = pl.SQLContext(foods1=lf, eager_execution=True) - out = c.execute( + ctx = pl.SQLContext(foods1=lf, eager_execution=True) + out = ctx.execute( """ SELECT * FROM foods1 @@ -665,8 +685,7 @@ def test_sql_is_between(foods_ipc_path: Path) -> None: ("vegetables", 25, 0.0, 2), ("vegetables", 22, 0.0, 3), ] - - out = c.execute( + out = ctx.execute( """ SELECT * FROM foods1