From 4d0e54ae59e07400f4e44ed6e16c1f17a097ef42 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sat, 7 Oct 2023 09:42:16 +0400 Subject: [PATCH] fix(rust,python,cli): handle unary operators applied to numbers used in SQL `IN` clauses (#11574) --- crates/polars-sql/src/sql_expr.rs | 35 +++++++++++++++++++++++----- py-polars/tests/unit/sql/test_sql.py | 20 ++++++++++++++++ 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 230c1dd89453..b566ceba0a12 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -372,17 +372,32 @@ impl SqlExprVisitor<'_> { }) } - // similar to visit_literal, but returns an AnyValue instead of Expr - fn visit_anyvalue(&self, value: &SqlValue) -> PolarsResult { + /// Visit a SQL literal (like [visit_literal]), but return AnyValue instead of Expr + fn visit_anyvalue( + &self, + value: &SqlValue, + op: Option<&UnaryOperator>, + ) -> PolarsResult { Ok(match value { SqlValue::Boolean(b) => AnyValue::Boolean(*b), SqlValue::Null => AnyValue::Null, SqlValue::Number(s, _) => { + let negate = match op { + Some(UnaryOperator::Minus) => true, + Some(UnaryOperator::Plus) => false, + _ => { + polars_bail!(ComputeError: "Unary op {:?} not supported for numeric SQL value", op) + }, + }; // Check for existence of decimal separator dot if s.contains('.') { - s.parse::().map(AnyValue::Float64).map_err(|_| ()) + s.parse::() + .map(|n: f64| AnyValue::Float64(if negate { -n } else { n })) + .map_err(|_| ()) } else { - s.parse::().map(AnyValue::Int64).map_err(|_| ()) + s.parse::() + .map(|n: i64| AnyValue::Int64(if negate { -n } else { n })) + .map_err(|_| ()) } .map_err(|_| polars_err!(ComputeError: "cannot parse literal: {s:?}"))? }, @@ -483,9 +498,17 @@ impl SqlExprVisitor<'_> { .iter() .map(|e| { if let SqlExpr::Value(v) = e { - let av = self.visit_anyvalue(v)?; + let av = self.visit_anyvalue(v, None)?; Ok(av) - } else { + } else if let SqlExpr::UnaryOp {op, expr} = e { + match expr.as_ref() { + SqlExpr::Value(v) => { + let av = self.visit_anyvalue(v, Some(op))?; + Ok(av) + }, + _ => Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)) + } + }else{ Err(polars_err!(ComputeError: "SQL expression {:?} is not yet supported", e)) } }) diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index cf94aff147d4..15f30b8ae6aa 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -1184,3 +1184,23 @@ def test_sql_expr() -> None: pl.InvalidOperationError, match=r"Unable to parse 'xyz\.\*' as Expr" ): pl.sql_expr("xyz.*") + + +@pytest.mark.parametrize("match_float", [False, True]) +def test_sql_unary_ops_8890(match_float: bool) -> None: + with pl.SQLContext( + df=pl.DataFrame({"a": [-2, -1, 1, 2], "b": ["w", "x", "y", "z"]}), + ) as ctx: + in_values = "(-3.0, -1.0, +2.0, +4.0)" if match_float else "(-3, -1, +2, +4)" + res = ctx.execute( + f""" + SELECT *, -(3) as c, (+4) as d + FROM df WHERE a IN {in_values} + """ + ) + assert res.collect().to_dict(False) == { + "a": [-1, 2], + "b": ["x", "z"], + "c": [-3, -3], + "d": [4, 4], + }