From c3591a258eaca7492cb8f94910ac441de8b49938 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 26 Feb 2025 13:19:22 -0600 Subject: [PATCH] fix: sql round without precision (#3863) --- src/daft-sql/src/modules/numeric.rs | 8 ++++++-- tests/sql/test_exprs.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/daft-sql/src/modules/numeric.rs b/src/daft-sql/src/modules/numeric.rs index ab8b41a529..2c4c6677e2 100644 --- a/src/daft-sql/src/modules/numeric.rs +++ b/src/daft-sql/src/modules/numeric.rs @@ -175,8 +175,11 @@ fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult Ok(sign(args[0].clone())) } SQLNumericExpr::Round => { - ensure!(args.len() == 2, "round takes exactly two arguments"); - let precision = match args[1].as_ref().as_literal() { + ensure!( + args.len() == 2 || args.len() == 1, + "round takes one or two arguments" + ); + let precision = match args.get(1).and_then(|arg| arg.as_literal()) { Some(LiteralValue::Int8(i)) => *i as i32, Some(LiteralValue::UInt8(u)) => *u as i32, Some(LiteralValue::Int16(i)) => *i as i32, @@ -185,6 +188,7 @@ fn to_expr(expr: &SQLNumericExpr, args: &[ExprRef]) -> SQLPlannerResult Some(LiteralValue::UInt32(u)) => *u as i32, Some(LiteralValue::Int64(i)) => *i as i32, Some(LiteralValue::UInt64(u)) => *u as i32, + None => 0, _ => invalid_operation_err!("round precision must be an integer"), }; Ok(round(args[0].clone(), Some(precision))) diff --git a/tests/sql/test_exprs.py b/tests/sql/test_exprs.py index 56eb6d6994..93c4b67493 100644 --- a/tests/sql/test_exprs.py +++ b/tests/sql/test_exprs.py @@ -248,3 +248,23 @@ def test_coalesce(): ).to_pydict() assert actual == expected + + +@pytest.mark.parametrize( + "precision, value, expected", + [ + (None, 3.14159, 3), + (None, 3, 3), + (1, 3.14159, 3.1), + (2, 3.14159, 3.14), + ], +) +def test_round(precision, value, expected): + if precision is None: + query = f"select round({value})" + else: + query = f"select round({value}, {precision})" + actual = daft.sql(query).to_pydict() + expected = {"literal": [expected]} + + assert actual == expected