diff --git a/crates/core/src/delta_datafusion/expr.rs b/crates/core/src/delta_datafusion/expr.rs index b633cae141..2c127f010d 100644 --- a/crates/core/src/delta_datafusion/expr.rs +++ b/crates/core/src/delta_datafusion/expr.rs @@ -502,6 +502,10 @@ impl fmt::Display for ScalarValueFormat<'_> { ScalarValue::UInt16(e) => format_option!(f, e)?, ScalarValue::UInt32(e) => format_option!(f, e)?, ScalarValue::UInt64(e) => format_option!(f, e)?, + ScalarValue::Decimal128(e, precision, scale) => match e { + Some(e) => write!(f, "'{e}'::decimal({precision}, {scale})",)?, + None => write!(f, "NULL")?, + }, ScalarValue::Date32(e) => match e { Some(e) => write!( f, @@ -657,6 +661,11 @@ mod test { DataType::Primitive(PrimitiveType::Binary), true, ), + StructField::new( + "_decimal".to_string(), + DataType::Primitive(PrimitiveType::Decimal(2, 2)), + true, + ), StructField::new( "_struct".to_string(), DataType::Struct(Box::new(StructType::new(vec![ @@ -887,6 +896,18 @@ mod test { ) )), }, + ParseTest { + expr: col("_decimal").eq(lit(ScalarValue::Decimal128(Some(1),2,2))), + expected: "_decimal = '1'::decimal(2, 2)".to_string(), + override_expected_expr: Some(col("_decimal").eq( + Expr::Cast( + Cast { + expr: Box::from(lit("1")), + data_type: arrow_schema::DataType::Decimal128(2, 2) + } + ) + )), + }, ]; let session: SessionContext = DeltaSessionContext::default().into(); @@ -908,11 +929,6 @@ mod test { } let unsupported_types = vec![ - /* TODO: Determine proper way to display decimal values in an sql expression*/ - simple!( - col("money").gt(lit(ScalarValue::Decimal128(Some(100), 12, 2))), - "money > 0.1".to_string() - ), simple!( col("_timestamp").gt(lit(ScalarValue::TimestampMillisecond(Some(100), None))), "".to_string() diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 1f81e81142..b90eecae88 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -1,6 +1,7 @@ import datetime import os import pathlib +from decimal import Decimal import pyarrow as pa import pyarrow.parquet as pq @@ -1149,3 +1150,44 @@ def test_merge_when_wrong_but_castable_type_passed_while_merge( tmp_path / dt.get_add_actions().column(0)[0].as_py() ).schema assert table_schema.field("price").type == sample_table["price"].type + + +def test_merge_on_decimal_3033(tmp_path): + data = { + "timestamp": [datetime.datetime(2024, 3, 20, 12, 30, 0)], + "altitude": [Decimal("150.5")], + } + + table = pa.Table.from_pydict(data) + + schema = pa.schema( + [ + ("timestamp", pa.timestamp("us")), + ("altitude", pa.decimal128(6, 1)), + ] + ) + + dt = DeltaTable.create(tmp_path, schema=schema) + + write_deltalake(dt, table, mode="append") + + dt.merge( + source=table, + predicate="target.timestamp = source.timestamp", + source_alias="source", + target_alias="target", + ).when_matched_update_all().when_not_matched_insert_all().execute() + + dt.merge( + source=table, + predicate="target.timestamp = source.timestamp AND target.altitude = source.altitude", + source_alias="source", + target_alias="target", + ).when_matched_update_all().when_not_matched_insert_all().execute() + + string_predicate = dt.history(1)[0]["operationParameters"]["predicate"] + + assert ( + string_predicate + == "timestamp BETWEEN arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND arrow_cast('2024-03-20T12:30:00.000000', 'Timestamp(Microsecond, None)') AND altitude BETWEEN '1505'::decimal(4, 1) AND '1505'::decimal(4, 1)" + )