diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 83563099cad6..f8ad3ac6c4ad 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -103,15 +103,28 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { // the actual ops here are largely unimportant as they are just a sample // of ops that could occur on a dataframe df = df - .with_column(&c_name, cast(c.clone(), DataType::Utf8)) + .with_column( + &c_name, + cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Utf8, true)), + ), + ) .unwrap() .with_column( &c_name, when( - cast(c.clone(), DataType::Int32).gt(lit(135)), cast( - cast(c.clone(), DataType::Int32) - lit(i + 3), - DataType::Utf8, + c.clone(), + Arc::new(Field::new(&c_name, DataType::Int32, true)), + ) + .gt(lit(135)), + cast( + cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Int32, true)), + ) - lit(i + 3), + Arc::new(Field::new(&c_name, DataType::Utf8, true)), ), ) .otherwise(c.clone()) @@ -122,15 +135,25 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { &c_name, when( c.clone().is_not_null().and( - cast(c.clone(), DataType::Int32) - .between(lit(120), lit(130)), + cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Int32, true)), + ) + .between(lit(120), lit(130)), ), Literal(ScalarValue::Utf8(None), None), ) .otherwise( when( c.clone().is_not_null().and(regexp_like( - cast(c.clone(), DataType::Utf8View), + cast( + c.clone(), + Arc::new(Field::new( + &c_name, + DataType::Utf8View, + true, + )), + ), lit("[0-9]*"), None, )), @@ -146,10 +169,16 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { &c_name, when( c.clone().is_not_null().and( - cast(c.clone(), DataType::Int32) - .between(lit(90), lit(100)), + cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Int32, true)), + ) + .between(lit(90), lit(100)), + ), + cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Utf8View, true)), ), - cast(c.clone(), DataType::Utf8View), ) .otherwise(Literal(ScalarValue::Date32(None), None)) .unwrap(), @@ -159,10 +188,22 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { &c_name, when( c.clone().is_not_null().and( - cast(c.clone(), DataType::Int32).rem(lit(10)).gt(lit(7)), + cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Int32, true)), + ) + .rem(lit(10)) + .gt(lit(7)), ), regexp_replace( - cast(c.clone(), DataType::Utf8View), + cast( + c.clone(), + Arc::new(Field::new( + &c_name, + DataType::Utf8View, + true, + )), + ), lit("1"), lit("a"), None, @@ -179,11 +220,21 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame { &c_name, try_cast( to_timestamp(vec![c.clone(), lit("%Y-%m-%d %H:%M:%S")]), - DataType::Timestamp(Nanosecond, Some("UTC".into())), + Arc::new(Field::new( + &c_name, + DataType::Timestamp(Nanosecond, Some("UTC".into())), + true, + )), ), ) .unwrap() - .with_column(&c_name, try_cast(c.clone(), DataType::Date32)) + .with_column( + &c_name, + try_cast( + c.clone(), + Arc::new(Field::new(&c_name, DataType::Date32, true)), + ), + ) .unwrap() } diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index c80c0b4bf54b..ca01a0657988 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -183,7 +183,7 @@ impl TableProvider for CustomProvider { Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64, Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i, - Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() { + Expr::Cast(Cast { expr, field: _ }) => match expr.deref() { Expr::Literal(lit_value, _) => match lit_value { ScalarValue::Int8(Some(v)) => *v as i64, ScalarValue::Int16(Some(v)) => *v as i64, diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index d95eb38c19e1..7ae93820035f 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -443,7 +443,10 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { // the arg2 parameter is a complex expr, but it can be evaluated to the literal value let alias_expr = Expr::Alias(Alias::new( - cast(lit(0.5), DataType::Float32), + cast( + lit(0.5), + Arc::new(Field::new("arg_2", DataType::Float32, true)), + ), None::<&str>, "arg_2".to_string(), )); @@ -463,7 +466,10 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { ); let alias_expr = Expr::Alias(Alias::new( - cast(lit(0.1), DataType::Float32), + cast( + lit(0.1), + Arc::new(Field::new("arg_2", DataType::Float32, true)), + ), None::<&str>, "arg_2".to_string(), )); @@ -1129,7 +1135,7 @@ async fn test_fn_substr() -> Result<()> { #[tokio::test] async fn test_cast() -> Result<()> { - let expr = cast(col("b"), DataType::Float64); + let expr = cast(col("b"), Arc::new(Field::new("b", DataType::Float64, true))); let batches = get_batches(expr).await?; assert_snapshot!( diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 979ada2bc6bb..f60aae57cfa9 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2071,7 +2071,13 @@ async fn cast_expr_test() -> Result<()> { .await? .select_columns(&["c2", "c3"])? .limit(0, Some(1))? - .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + .with_column( + "sum", + cast( + col("c2") + col("c3"), + Arc::new(Field::new("sum", DataType::Int64, false)), + ), + )?; let df_results = df.clone().collect().await?; df.clone().show().await?; @@ -2173,7 +2179,13 @@ async fn cache_test() -> Result<()> { .await? .select_columns(&["c2", "c3"])? .limit(0, Some(1))? - .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?; + .with_column( + "sum", + cast( + col("c2") + col("c3"), + Arc::new(Field::new("sum", DataType::Int64, false)), + ), + )?; let cached_df = df.clone().cache().await?; @@ -2671,8 +2683,13 @@ async fn write_table_with_order() -> Result<()> { .unwrap(); // Ensure the column type matches the target table - write_df = - write_df.with_column("tablecol1", cast(col("tablecol1"), DataType::Utf8View))?; + write_df = write_df.with_column( + "tablecol1", + cast( + col("tablecol1"), + Arc::new(Field::new("tablecol1", DataType::Utf8View, false)), + ), + )?; let sql_str = "create external table data(tablecol1 varchar) stored as parquet location '" @@ -4687,7 +4704,10 @@ async fn consecutive_projection_same_schema() -> Result<()> { let df = df .with_column( "t", - cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32), + cast( + Expr::Literal(ScalarValue::Null, None), + Arc::new(Field::new("t", DataType::Int32, true)), + ), ) .unwrap(); df.clone().show().await.unwrap(); diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 572a7e2b335c..1503dcf54e67 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -192,7 +192,10 @@ fn make_udf_add(volatility: Volatility) -> Arc { } fn cast_to_int64_expr(expr: Expr) -> Expr { - Expr::Cast(Cast::new(expr.into(), DataType::Int64)) + Expr::Cast(Cast::new( + expr.into(), + Arc::new(Field::new("cast_to_i64", DataType::Int64, true)), + )) } fn to_timestamp_expr(arg: impl Into) -> Expr { @@ -747,8 +750,14 @@ fn test_simplify_concat() -> Result<()> { #[test] fn test_simplify_cycles() { // cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX - let expr = cast(now(), DataType::Int64) - .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX)); + let expr = cast( + now(), + Arc::new(Field::new("cast_to_i64", DataType::Int64, true)), + ) + .lt(cast( + to_timestamp(vec![lit(0)]), + Arc::new(Field::new("cast_to_i64", DataType::Int64, true)), + ) + lit(i64::MAX)); let expected = lit(true); test_simplify_with_cycle_count(expr, expected, 3); } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index fb1371da6ceb..ebd53f855801 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -713,8 +713,12 @@ impl ScalarUDFImpl for CastToI64UDF { } else { // need to use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: DataType::Int64, + expr: Box::new(arg.clone()), + field: Arc::new(Field::new( + "cast_to_i64", + DataType::Int64, + info.nullable(&arg)?, + )), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 51d50d780f10..ec3ab9686c12 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -868,9 +868,13 @@ mod tests { .with_scale(0) .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); - let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast( + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", Decimal128(11, 2), true)), + ) + .gt(cast( lit(ScalarValue::Decimal128(Some(500), 5, 2)), - Decimal128(11, 2), + Arc::new(Field::new("c1", Decimal128(11, 2), true)), )); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -1023,7 +1027,10 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), Decimal128(28, 3)); + let left = cast( + col("c1"), + Arc::new(Field::new("c1", Decimal128(28, 3), true)), + ); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -1101,7 +1108,10 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), Decimal128(28, 3)); + let left = cast( + col("c1"), + Arc::new(Field::new("c1", Decimal128(28, 3), true)), + ); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 6077b3c1e5bb..3ce1261b1780 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -794,14 +794,14 @@ pub enum GetFieldAccess { pub struct Cast { /// The expression being cast pub expr: Box, - /// The `DataType` the expression will yield - pub data_type: DataType, + /// Field describing the result of the cast, including metadata + pub field: FieldRef, } impl Cast { /// Create a new Cast expression - pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + pub fn new(expr: Box, field: FieldRef) -> Self { + Self { expr, field } } } @@ -810,14 +810,14 @@ impl Cast { pub struct TryCast { /// The expression being cast pub expr: Box, - /// The `DataType` the expression will yield - pub data_type: DataType, + /// Field describing the result of the cast, including metadata + pub field: FieldRef, } impl TryCast { /// Create a new TryCast expression - pub fn new(expr: Box, data_type: DataType) -> Self { - Self { expr, data_type } + pub fn new(expr: Box, field: FieldRef) -> Self { + Self { expr, field } } } @@ -2252,23 +2252,26 @@ impl NormalizeEq for Expr { ( Expr::Cast(Cast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::Cast(Cast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), ) | ( Expr::TryCast(TryCast { expr: self_expr, - data_type: self_data_type, + field: self_field, }), Expr::TryCast(TryCast { expr: other_expr, - data_type: other_data_type, + field: other_field, }), - ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ) => { + self_field.data_type() == other_field.data_type() + && self_expr.normalize_eq(other_expr) + } ( Expr::ScalarFunction(ScalarFunction { func: self_func, @@ -2584,15 +2587,9 @@ impl HashNode for Expr { when_then_expr: _when_then_expr, else_expr: _else_expr, }) => {} - Expr::Cast(Cast { - expr: _expr, - data_type, - }) - | Expr::TryCast(TryCast { - expr: _expr, - data_type, - }) => { - data_type.hash(state); + Expr::Cast(Cast { expr: _expr, field }) + | Expr::TryCast(TryCast { expr: _expr, field }) => { + field.data_type().hash(state); } Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { func.hash(state); @@ -3283,11 +3280,11 @@ impl Display for Expr { } write!(f, "END") } - Expr::Cast(Cast { expr, data_type }) => { - write!(f, "CAST({expr} AS {data_type})") + Expr::Cast(Cast { expr, field }) => { + write!(f, "CAST({expr} AS {})", field.data_type()) } - Expr::TryCast(TryCast { expr, data_type }) => { - write!(f, "TRY_CAST({expr} AS {data_type})") + Expr::TryCast(TryCast { expr, field }) => { + write!(f, "TRY_CAST({expr} AS {})", field.data_type()) } Expr::Not(expr) => write!(f, "NOT {expr}"), Expr::Negative(expr) => write!(f, "(- {expr})"), @@ -3673,7 +3670,7 @@ mod test { fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)), - data_type: DataType::Utf8, + field: Arc::new(Field::new("cast", DataType::Utf8, false)), }); let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, format!("{expr}")); @@ -3700,7 +3697,10 @@ mod test { fn test_collect_expr() -> Result<()> { // single column { - let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); + let expr = &Expr::Cast(Cast::new( + Box::new(col("a")), + Arc::new(Field::new("cast", DataType::Float64, false)), + )); let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c777c4978f99..88fde4080e53 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -326,14 +326,75 @@ pub fn rollup(exprs: Vec) -> Expr { Expr::GroupingSet(GroupingSet::Rollup(exprs)) } +/// Types that can be used to describe the result of a cast expression. +pub trait IntoCastField { + fn into_cast_field(self, expr: &Expr) -> FieldRef; +} + +impl IntoCastField for FieldRef { + fn into_cast_field(self, _expr: &Expr) -> FieldRef { + self + } +} + +impl IntoCastField for &FieldRef { + fn into_cast_field(self, _expr: &Expr) -> FieldRef { + Arc::clone(self) + } +} + +impl IntoCastField for Field { + fn into_cast_field(self, _expr: &Expr) -> FieldRef { + Arc::new(self) + } +} + +impl IntoCastField for &Field { + fn into_cast_field(self, _expr: &Expr) -> FieldRef { + Arc::new(self.clone()) + } +} + +impl IntoCastField for DataType { + fn into_cast_field(self, expr: &Expr) -> FieldRef { + let nullable = infer_cast_nullability(expr); + Arc::new(Field::new("", self, nullable)) + } +} + +fn infer_cast_nullability(expr: &Expr) -> bool { + match expr { + Expr::Literal(value, _) => value.is_null(), + Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => { + field.is_nullable() + } + _ => true, + } +} + /// Create a cast expression -pub fn cast(expr: Expr, data_type: DataType) -> Expr { - Expr::Cast(Cast::new(Box::new(expr), data_type)) +pub fn cast(expr: Expr, field: F) -> Expr +where + F: IntoCastField, +{ + let field = field.into_cast_field(&expr); + Expr::Cast(Cast::new(Box::new(expr), field)) } /// Create a try cast expression -pub fn try_cast(expr: Expr, data_type: DataType) -> Expr { - Expr::TryCast(TryCast::new(Box::new(expr), data_type)) +pub fn try_cast(expr: Expr, field: F) -> Expr +where + F: IntoCastField, +{ + let field = field.into_cast_field(&expr); + if field.is_nullable() { + Expr::TryCast(TryCast::new(Box::new(expr), field)) + } else { + Expr::TryCast(TryCast::new( + Box::new(expr), + Arc::new(field.as_ref().clone().with_nullable(true)), + )) + } } /// Create is null expression diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index d9fb9f7219c6..4c5cec349263 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -510,7 +510,10 @@ mod test { // cast data types test_rewrite( col("a"), - Expr::Cast(Cast::new(Box::new(col("a")), DataType::Int32)), + Expr::Cast(Cast::new( + Box::new(col("a")), + Arc::new(Field::new("a", DataType::Int32, false)), + )), ); // change literal type from i32 to i64 diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 6db95555502d..0349bf9315bd 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection( if let Some(found) = found { return Ok(Transformed::yes(match normalized_expr { - Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { + Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast { expr: Box::new(found), - data_type, + field, }), - Expr::TryCast(TryCast { expr: _, data_type }) => Expr::TryCast(TryCast { + Expr::TryCast(TryCast { expr: _, field }) => Expr::TryCast(TryCast { expr: Box::new(found), - data_type, + field, }), _ => found, })); @@ -268,13 +268,25 @@ mod test { let cases = vec![ TestCase { desc: "Cast is preserved by rewrite_sort_cols_by_aggs", - input: sort(cast(col("c2"), DataType::Int64)), - expected: sort(cast(col("c2").alias("c2"), DataType::Int64)), + input: sort(cast( + col("c2"), + Arc::new(Field::new("cast", DataType::Int64, false)), + )), + expected: sort(cast( + col("c2").alias("c2"), + Arc::new(Field::new("cast", DataType::Int64, false)), + )), }, TestCase { desc: "TryCast is preserved by rewrite_sort_cols_by_aggs", - input: sort(try_cast(col("c2"), DataType::Int64)), - expected: sort(try_cast(col("c2").alias("c2"), DataType::Int64)), + input: sort(try_cast( + col("c2"), + Arc::new(Field::new("try_cast", DataType::Int64, false)), + )), + expected: sort(try_cast( + col("c2").alias("c2"), + Arc::new(Field::new("try_cast", DataType::Int64, false)), + )), }, ]; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 8c557a5630f0..9fe04146ec03 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -127,8 +127,9 @@ impl ExprSchemable for Expr { .as_ref() .map_or(Ok(DataType::Null), |e| e.get_type(schema)) } - Expr::Cast(Cast { data_type, .. }) - | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), + Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. }) => { + Ok(field.data_type().clone()) + } Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list @@ -579,10 +580,13 @@ impl ExprSchemable for Expr { func.return_field_from_args(args) } // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - Expr::Cast(Cast { expr, data_type }) => expr - .to_field(schema) - .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())) - .map(Arc::new), + Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr, field }) => { + let (_, input_field) = expr.to_field(schema)?; + let mut combined_metadata = FieldMetadata::from(input_field.metadata()); + combined_metadata.extend(FieldMetadata::from(field.metadata())); + let field = combined_metadata.add_to_field(field.as_ref().clone()); + Ok(Arc::new(field)) + } Expr::Placeholder(Placeholder { id: _, field: Some(field), @@ -592,7 +596,6 @@ impl ExprSchemable for Expr { | Expr::Not(_) | Expr::Between(_) | Expr::Case(_) - | Expr::TryCast(_) | Expr::InList(_) | Expr::InSubquery(_) | Expr::Wildcard { .. } @@ -632,7 +635,14 @@ impl ExprSchemable for Expr { Expr::ScalarSubquery(subquery) => { Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?)) } - _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))), + _ => { + let field = Arc::new(Field::new( + "", + cast_to_type.clone(), + self.nullable(schema)?, + )); + Ok(Expr::Cast(Cast::new(Box::new(self), field))) + } } } else { plan_err!("Cannot automatically convert {this_type} to {cast_to_type}") @@ -772,6 +782,7 @@ mod tests { use std::collections::HashMap; use super::*; + use crate::expr::{Cast, TryCast}; use crate::{col, lit, out_ref_col_with_metadata}; use datafusion_common::{internal_err, DFSchema, ScalarValue}; @@ -962,6 +973,63 @@ mod tests { ); } + #[test] + fn test_cast_metadata_overrides() { + let source_meta = FieldMetadata::from(HashMap::from([ + ("source".to_string(), "value".to_string()), + ("shared".to_string(), "source".to_string()), + ])); + let cast_meta = FieldMetadata::from(HashMap::from([ + ("shared".to_string(), "cast".to_string()), + ("cast".to_string(), "value".to_string()), + ])); + + let schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_metadata(source_meta.clone()); + + let cast_field = Arc::new( + Field::new("ignored", DataType::Utf8, true) + .with_metadata(cast_meta.to_hashmap()), + ); + + let expr = col("foo"); + let cast_expr = Expr::Cast(Cast::new(Box::new(expr), Arc::clone(&cast_field))); + + let mut expected = source_meta.clone(); + expected.extend(cast_meta.clone()); + assert_eq!(expected, cast_expr.metadata(&schema).unwrap()); + } + + #[test] + fn test_try_cast_metadata_overrides() { + let source_meta = FieldMetadata::from(HashMap::from([ + ("source".to_string(), "value".to_string()), + ("shared".to_string(), "source".to_string()), + ])); + let cast_meta = FieldMetadata::from(HashMap::from([ + ("shared".to_string(), "cast".to_string()), + ("cast".to_string(), "value".to_string()), + ])); + + let schema = MockExprSchema::new() + .with_data_type(DataType::Int32) + .with_metadata(source_meta.clone()); + + let cast_field = Arc::new( + Field::new("ignored", DataType::Utf8, true) + .with_metadata(cast_meta.to_hashmap()), + ); + + let expr = col("foo"); + let cast_expr = + Expr::TryCast(TryCast::new(Box::new(expr), Arc::clone(&cast_field))); + + let mut expected = source_meta.clone(); + expected.extend(cast_meta.clone()); + assert_eq!(expected, cast_expr.metadata(&schema).unwrap()); + } + #[derive(Debug)] struct MockExprSchema { field: Field, diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f8060..139417377721 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -220,12 +220,12 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), - Expr::Cast(Cast { expr, data_type }) => expr + Expr::Cast(Cast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => expr + .update_data(|be| Expr::Cast(Cast::new(be, field))), + Expr::TryCast(TryCast { expr, field }) => expr .map_elements(f)? - .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new(be, field))), Expr::ScalarFunction(ScalarFunction { func, args }) => { args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b91db4527b3a..e82dd7e97c84 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1686,11 +1686,17 @@ mod tests { fn test_collect_expr() -> Result<()> { let mut accum: HashSet = HashSet::new(); expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::Cast(Cast::new( + Box::new(col("a")), + Arc::new(Field::new("cast", DataType::Float64, false)), + )), &mut accum, )?; expr_to_columns( - &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)), + &Expr::Cast(Cast::new( + Box::new(col("a")), + Arc::new(Field::new("cast", DataType::Float64, false)), + )), &mut accum, )?; assert_eq!(1, accum.len()); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 94a41ba4bb25..a0ad2e8e23bb 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -26,6 +26,7 @@ use datafusion_common::{ exec_datafusion_err, utils::take_function_args, DataFusionError, }; use std::any::Any; +use std::sync::Arc; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ @@ -160,8 +161,12 @@ impl ScalarUDFImpl for ArrowCastFunc { } else { // Use an actual cast to get the correct type Expr::Cast(datafusion_expr::Cast { - expr: Box::new(arg), - data_type: target_type, + expr: Box::new(arg.clone()), + field: Arc::new(Field::new( + "arrow_cast", + target_type, + info.nullable(&arg.clone())?, + )), }) }; // return the newly written argument to DataFusion diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index d75eb9141c05..c7d2f5176469 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -19,8 +19,8 @@ use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; use arrow::compute::kernels::regexp; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::types::logical_string; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, plan_err, DataFusionError, Result, @@ -184,13 +184,27 @@ impl ScalarUDFImpl for RegexpLikeFunc { Ok(ExprSimplifyResult::Simplified(binary_expr( if string_type != coerced_string_type { - cast(string, coerced_string_type) + cast( + string.clone(), + Arc::new(Field::new( + "", + coerced_string_type, + info.nullable(&string)?, + )), + ) } else { string }, op, if regexp_type != coerced_regexp_type { - cast(regexp, coerced_regexp_type) + cast( + regexp.clone(), + Arc::new(Field::new( + "", + coerced_regexp_type, + info.nullable(®exp)?, + )), + ) } else { regexp }, diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index c4159cba86f3..71ba29a022e6 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::sync::Arc; use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::type_coercion::binary::{ binary_to_string_coercion, string_coercion, @@ -158,13 +158,27 @@ impl ScalarUDFImpl for StartsWithFunc { let expr = if expr_data_type == coercion_data_type { args[0].clone() } else { - cast(args[0].clone(), coercion_data_type.clone()) + cast( + args[0].clone(), + Arc::new(Field::new( + "", + coercion_data_type.clone(), + info.nullable(&args[0])?, + )), + ) }; let pattern = if pattern_data_type == coercion_data_type { like_expr } else { - cast(like_expr, coercion_data_type) + cast( + like_expr.clone(), + Arc::new(Field::new( + "", + coercion_data_type, + info.nullable(&like_expr)?, + )), + ) }; return Ok(ExprSimplifyResult::Simplified(Expr::Like(Like { diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index fa7ff1b8b19d..38530fdb8dae 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ @@ -214,7 +214,10 @@ fn grouping_function_on_id( .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::Int32)); + return Ok(cast( + grouping_id_column, + Arc::new(Field::new("", DataType::Int32, true)), + )); } args.iter() @@ -240,7 +243,7 @@ fn grouping_function_on_id( bit_exprs .into_iter() .reduce(bitwise_or) - .map(|expr| cast(expr, DataType::Int32)) + .map(|expr| cast(expr, Arc::new(Field::new("", DataType::Int32, true)))) }) .ok_or_else(|| { internal_datafusion_err!("Grouping sets should contains at least one element") diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3d5dee3a7255..37c41c0a86c6 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1722,7 +1722,10 @@ mod test { let empty = empty_with_type(DataType::Int32); let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( avg_udaf(), - vec![cast(col("a"), DataType::Float64)], + vec![cast( + col("a"), + Arc::new(Field::new("a", DataType::Float64, true)), + )], false, None, vec![], @@ -1761,8 +1764,10 @@ mod test { #[test] fn binary_op_date32_op_interval() -> Result<()> { // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...") - let expr = cast(lit("1998-03-18"), DataType::Date32) - + lit(ScalarValue::new_interval_dt(123, 456)); + let expr = cast( + lit("1998-03-18"), + Arc::new(Field::new("date", DataType::Date32, true)), + ) + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); @@ -1811,8 +1816,10 @@ mod test { let expr = col("a").between( lit("2002-05-08"), // (cast('2002-05-08' as date) + interval '1 months') - cast(lit("2002-05-08"), DataType::Date32) - + lit(ScalarValue::new_interval_ym(0, 1)), + cast( + lit("2002-05-08"), + Arc::new(Field::new("date", DataType::Date32, true)), + ) + lit(ScalarValue::new_interval_ym(0, 1)), ); let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); @@ -1830,8 +1837,10 @@ mod test { fn between_infer_cheap_type() -> Result<()> { let expr = col("a").between( // (cast('2002-05-08' as date) + interval '1 months') - cast(lit("2002-05-08"), DataType::Date32) - + lit(ScalarValue::new_interval_ym(0, 1)), + cast( + lit("2002-05-08"), + Arc::new(Field::new("date", DataType::Date32, true)), + ) + lit(ScalarValue::new_interval_ym(0, 1)), lit("2002-12-08"), ); let empty = empty_with_type(Utf8); @@ -2108,9 +2117,16 @@ mod test { fn binary_op_date32_eq_ts() -> Result<()> { let expr = cast( lit("1998-03-18"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + Arc::new(Field::new( + "date", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )), ) - .eq(cast(lit("1998-03-18"), DataType::Date32)); + .eq(cast( + lit("1998-03-18"), + Arc::new(Field::new("date", DataType::Date32, true)), + )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); @@ -2458,7 +2474,10 @@ mod test { let fields = Field::new("key_value", DataType::Struct(struct_fields), false); let may_type_custom = DataType::Map(Arc::new(fields), false); - let expr = col("a").eq(cast(col("a"), may_type_custom)); + let expr = col("a").eq(cast( + col("a"), + Arc::new(Field::new("a", may_type_custom, true)), + )); let empty = empty_with_type(map_type_entries); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); @@ -2479,7 +2498,11 @@ mod test { Operator::Plus, Box::new(cast( lit("2000-01-01T00:00:00"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + Arc::new(Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )), )), )); let empty = empty(); @@ -2499,12 +2522,20 @@ mod test { let expr = Expr::BinaryExpr(BinaryExpr::new( Box::new(cast( lit("1998-03-18"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + Arc::new(Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )), )), Operator::Minus, Box::new(cast( lit("1998-03-18"), - DataType::Timestamp(TimeUnit::Nanosecond, None), + Arc::new(Field::new( + "timestamp", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + )), )), )); let empty = empty(); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 251006849459..4b3612669c46 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -1205,9 +1205,12 @@ mod test { let plan = table_scan(Some("table"), &schema, None) .unwrap() .filter( - cast(col("a"), DataType::Int64) + cast(col("a"), Arc::new(Field::new("a", DataType::Int64, false))) .lt(lit(1_i64)) - .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))), + .and( + cast(col("a"), Arc::new(Field::new("a", DataType::Int64, false))) + .not_eq(lit(1_i64)), + ), ) .unwrap() .build() diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 45877642f276..b58193d02686 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -289,8 +289,8 @@ fn extract_non_nullable_columns( false, ) } - Expr::Cast(Cast { expr, data_type: _ }) - | Expr::TryCast(TryCast { expr, data_type: _ }) => extract_non_nullable_columns( + Expr::Cast(Cast { expr, field: _ }) + | Expr::TryCast(TryCast { expr, field: _ }) => extract_non_nullable_columns( expr, non_nullable_cols, left_schema, @@ -308,6 +308,7 @@ mod tests { use crate::test::*; use crate::OptimizerContext; use arrow::datatypes::DataType; + use arrow::datatypes::Field; use datafusion_expr::{ binary_expr, cast, col, lit, logical_plan::builder::LogicalPlanBuilder, @@ -449,9 +450,11 @@ mod tests { None, )? .filter(binary_expr( - cast(col("t1.b"), DataType::Int64).gt(lit(10u32)), + cast(col("t1.b"), Arc::new(Field::new("", DataType::Int64, true))) + .gt(lit(10u32)), And, - try_cast(col("t2.c"), DataType::Int64).lt(lit(20u32)), + try_cast(col("t2.c"), Arc::new(Field::new("", DataType::Int64, true))) + .lt(lit(20u32)), ))? .build()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5db71417bc8f..3fc880850ec9 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1425,7 +1425,10 @@ mod tests { fn test_try_cast() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![try_cast(col("a"), DataType::Float64)])? + .project(vec![try_cast( + col("a"), + Arc::new(Field::new("a", DataType::Float64, true)), + )])? .build()?; assert_optimized_plan_equal!( @@ -1996,7 +1999,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .project(vec![Expr::Cast(Cast::new( Box::new(col("c")), - DataType::Float64, + Arc::new(Field::new("c", DataType::Float64, true)), ))])? .build()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 204ce14e37d8..81a33ac18f00 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2154,7 +2154,10 @@ fn simplify_right_is_one_case( Ok(result_type) => { // Only cast if the types differ if left_type != result_type { - Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type)))) + Ok(Transformed::yes(Expr::Cast(Cast::new( + left.clone(), + Arc::new(Field::new("", result_type, info.nullable(&left)?)), + )))) } else { Ok(Transformed::yes(*left)) } diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 4faf9389cfac..224d9ed66050 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -467,7 +467,10 @@ mod tests { #[test] fn cast_expr() -> Result<()> { let table_scan = test_table_scan(); - let proj = vec![Expr::Cast(Cast::new(Box::new(lit("0")), DataType::Int32))]; + let proj = vec![Expr::Cast(Cast::new( + Box::new(lit("0")), + Arc::new(Field::new("c1", DataType::Int32, true)), + ))]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj)? .build()?; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs index e811ce731310..0564b23fa262 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_predicates.rs @@ -249,8 +249,10 @@ fn extract_column_from_expr(expr: &Expr) -> Option { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_expr::{cast, col, lit}; #[test] @@ -262,7 +264,8 @@ mod tests { let predicates = vec![ col("a").lt(lit(5i32)), - cast(col("a"), DataType::Utf8).lt(lit("abc")), + cast(col("a"), Arc::new(Field::new("a", DataType::Utf8, true))) + .lt(lit("abc")), col("a").lt(lit(6i32)), ]; @@ -295,7 +298,7 @@ mod tests { #[test] fn test_extract_column_ignores_cast() { // Test that extract_column_from_expr does not extract columns from cast expressions - let cast_expr = cast(col("a"), DataType::Utf8); + let cast_expr = cast(col("a"), Arc::new(Field::new("a", DataType::Utf8, true))); assert_eq!(extract_column_from_expr(&cast_expr), None); // Test that it still extracts from direct column references diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 5286cbd7bdf6..b0940b451bc3 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -250,7 +250,8 @@ mod tests { fn test_not_unwrap_cast_comparison() { let schema = expr_test_schema(); // cast(INT32(c1), INT64) > INT64(c2) - let c1_gt_c2 = cast(col("c1"), DataType::Int64).gt(col("c2")); + let c1_gt_c2 = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .gt(col("c2")); assert_eq!(optimize_test(c1_gt_c2.clone(), &schema), c1_gt_c2); // INT32(c1) < INT32(16), the type is same @@ -258,25 +259,32 @@ mod tests { assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64)); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .lt(lit(99999999999i64)); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // cast(c1, UTF8) < '123', only eq/not_eq should be optimized - let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123")); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .lt(lit("123")); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not // be casted - let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("0123")); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .lt(lit("0123")); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // cast(c1, UTF8) = 'not a number', should not be able to cast to column type - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("not a number")); + let expr_input = + cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .eq(lit("not a number")); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will // not be optimized to integer comparison - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("99999999999")); + let expr_input = + cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .eq(lit("99999999999")); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); } @@ -285,40 +293,55 @@ mod tests { let schema = expr_test_schema(); // cast(c1, INT64) < INT64(16) -> INT32(c1) < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .lt(lit(16i64)); let expected = col("c1").lt(lit(16i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); - let expr_lt = try_cast(col("c1"), DataType::Int64).lt(lit(16i64)); + let expr_lt = + try_cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .lt(lit(16i64)); let expected = col("c1").lt(lit(16i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); // cast(c2, INT32) = INT32(16) => INT64(c2) = INT64(16) - let c2_eq_lit = cast(col("c2"), DataType::Int32).eq(lit(16i32)); + let c2_eq_lit = + cast(col("c2"), Arc::new(Field::new("c2", DataType::Int32, true))) + .eq(lit(16i32)); let expected = col("c2").eq(lit(16i64)); assert_eq!(optimize_test(c2_eq_lit, &schema), expected); // cast(c1, INT64) < INT64(NULL) => NULL - let c1_lt_lit_null = cast(col("c1"), DataType::Int64).lt(null_i64()); + let c1_lt_lit_null = + cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .lt(null_i64()); let expected = null_bool(); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // cast(INT8(NULL), INT32) < INT32(12) => INT8(NULL) < INT8(12) => BOOL(NULL) - let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32)); + let lit_lt_lit = + cast(null_i8(), Arc::new(Field::new("c1", DataType::Int32, true))) + .lt(lit(12i32)); let expected = null_bool(); assert_eq!(optimize_test(lit_lt_lit, &schema), expected); // cast(c1, UTF8) = '123' => c1 = 123 - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123")); + let expr_input = + cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .eq(lit("123")); let expected = col("c1").eq(lit(123i32)); assert_eq!(optimize_test(expr_input, &schema), expected); // cast(c1, UTF8) != '123' => c1 != 123 - let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123")); + let expr_input = + cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .not_eq(lit("123")); let expected = col("c1").not_eq(lit(123i32)); assert_eq!(optimize_test(expr_input, &schema), expected); // cast(c1, UTF8) = NULL => NULL - let expr_input = cast(col("c1"), DataType::Utf8).eq(lit(ScalarValue::Utf8(None))); + let expr_input = + cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, true))) + .eq(lit(ScalarValue::Utf8(None))); let expected = null_bool(); assert_eq!(optimize_test(expr_input, &schema), expected); } @@ -327,17 +350,25 @@ mod tests { fn test_unwrap_cast_comparison_unsigned() { // "cast(c6, UINT64) = 0u64 => c6 = 0u32 let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64)); + let expr_input = cast( + col("c6"), + Arc::new(Field::new("c6", DataType::UInt64, true)), + ) + .eq(lit(0u64)); let expected = col("c6").eq(lit(0u32)); assert_eq!(optimize_test(expr_input, &schema), expected); // cast(c6, UTF8) = "123" => c6 = 123 - let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123")); + let expr_input = + cast(col("c6"), Arc::new(Field::new("c6", DataType::Utf8, true))) + .eq(lit("123")); let expected = col("c6").eq(lit(123u32)); assert_eq!(optimize_test(expr_input, &schema), expected); // cast(c6, UTF8) != "123" => c6 != 123 - let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123")); + let expr_input = + cast(col("c6"), Arc::new(Field::new("c6", DataType::Utf8, true))) + .not_eq(lit("123")); let expected = col("c6").not_eq(lit(123u32)); assert_eq!(optimize_test(expr_input, &schema), expected); } @@ -351,18 +382,29 @@ mod tests { ); // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') - let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); + let expr_input = cast( + col("str1"), + Arc::new(Field::new("str1", dict.data_type(), true)), + ) + .eq(lit(dict.clone())); let expected = col("str1").eq(lit("value")); assert_eq!(optimize_test(expr_input, &schema), expected); // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') - let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); + let expr_input = cast( + col("tag"), + Arc::new(Field::new("tag", DataType::Utf8, true)), + ) + .eq(lit("value")); let expected = col("tag").eq(lit(dict.clone())); assert_eq!(optimize_test(expr_input, &schema), expected); // Verify reversed argument order // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 - let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); + let expr_input = lit(dict.clone()).eq(cast( + col("str1"), + Arc::new(Field::new("str1", dict.data_type(), true)), + )); let expected = col("str1").eq(lit("value")); assert_eq!(optimize_test(expr_input, &schema), expected); } @@ -375,7 +417,11 @@ mod tests { Box::new(DataType::Int32), Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), ); - let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict)); + let expr_input = cast( + col("largestr"), + Arc::new(Field::new("largestr", dict.data_type(), true)), + ) + .eq(lit(dict)); let expected = col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); assert_eq!(optimize_test(expr_input, &schema), expected); @@ -386,28 +432,39 @@ mod tests { let schema = expr_test_schema(); // integer to decimal: value is out of the bounds of the decimal // cast(c3, INT64) = INT64(100000000000000000) - let expr_eq = cast(col("c3"), DataType::Int64).eq(lit(100000000000000000i64)); + let expr_eq = cast(col("c3"), Arc::new(Field::new("c3", DataType::Int64, true))) + .eq(lit(100000000000000000i64)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // cast(c4, INT64) = INT64(1000) will overflow the i128 - let expr_eq = cast(col("c4"), DataType::Int64).eq(lit(1000i64)); + let expr_eq = cast(col("c4"), Arc::new(Field::new("c4", DataType::Int64, true))) + .eq(lit(1000i64)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // decimal to decimal: value will lose the scale when convert to the target data type // c3 = DECIMAL(12340,20,4) - let expr_eq = - cast(col("c3"), DataType::Decimal128(20, 4)).eq(lit_decimal(12340, 20, 4)); + let expr_eq = cast( + col("c3"), + Arc::new(Field::new("c3", DataType::Decimal128(20, 4), true)), + ) + .eq(lit_decimal(12340, 20, 4)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // decimal to integer // c1 = DECIMAL(123, 10, 1): value will lose the scale when convert to the target data type - let expr_eq = - cast(col("c1"), DataType::Decimal128(10, 1)).eq(lit_decimal(123, 10, 1)); + let expr_eq = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Decimal128(10, 1), true)), + ) + .eq(lit_decimal(123, 10, 1)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); // c1 = DECIMAL(1230, 10, 2): value will lose the scale when convert to the target data type - let expr_eq = - cast(col("c1"), DataType::Decimal128(10, 2)).eq(lit_decimal(1230, 10, 2)); + let expr_eq = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Decimal128(10, 2), true)), + ) + .eq(lit_decimal(1230, 10, 2)); assert_eq!(optimize_test(expr_eq.clone(), &schema), expr_eq); } @@ -416,32 +473,45 @@ mod tests { let schema = expr_test_schema(); // integer to decimal // c3 < INT64(16) -> c3 < (CAST(INT64(16) AS DECIMAL(18,2)); - let expr_lt = try_cast(col("c3"), DataType::Int64).lt(lit(16i64)); + let expr_lt = + try_cast(col("c3"), Arc::new(Field::new("c3", DataType::Int64, true))) + .lt(lit(16i64)); let expected = col("c3").lt(lit_decimal(1600, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < INT64(NULL) - let c1_lt_lit_null = cast(col("c3"), DataType::Int64).lt(null_i64()); + let c1_lt_lit_null = + cast(col("c3"), Arc::new(Field::new("c3", DataType::Int64, true))) + .lt(null_i64()); let expected = null_bool(); assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected); // decimal to decimal // c3 < Decimal(123,10,0) -> c3 < CAST(DECIMAL(123,10,0) AS DECIMAL(18,2)) -> c3 < DECIMAL(12300,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 0)).lt(lit_decimal(123, 10, 0)); + let expr_lt = cast( + col("c3"), + Arc::new(Field::new("c3", DataType::Decimal128(10, 0), true)), + ) + .lt(lit_decimal(123, 10, 0)); let expected = col("c3").lt(lit_decimal(12300, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // c3 < Decimal(1230,10,3) -> c3 < CAST(DECIMAL(1230,10,3) AS DECIMAL(18,2)) -> c3 < DECIMAL(123,18,2) - let expr_lt = - cast(col("c3"), DataType::Decimal128(10, 3)).lt(lit_decimal(1230, 10, 3)); + let expr_lt = cast( + col("c3"), + Arc::new(Field::new("c3", DataType::Decimal128(10, 3), true)), + ) + .lt(lit_decimal(1230, 10, 3)); let expected = col("c3").lt(lit_decimal(123, 18, 2)); assert_eq!(optimize_test(expr_lt, &schema), expected); // decimal to integer // c1 < Decimal(12300, 10, 2) -> c1 < CAST(DECIMAL(12300,10,2) AS INT32) -> c1 < INT32(123) - let expr_lt = - cast(col("c1"), DataType::Decimal128(10, 2)).lt(lit_decimal(12300, 10, 2)); + let expr_lt = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Decimal128(10, 2), true)), + ) + .lt(lit_decimal(12300, 10, 2)); let expected = col("c1").lt(lit(123i32)); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -451,22 +521,29 @@ mod tests { let schema = expr_test_schema(); // internal left type is not supported // FLOAT32(C5) in ... - let expr_lt = - cast(col("c5"), DataType::Int64).in_list(vec![lit(12i64), lit(12i64)], false); + let expr_lt = cast(col("c5"), Arc::new(Field::new("c5", DataType::Int64, true))) + .in_list(vec![lit(12i64), lit(12i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // cast(INT32(C1), Float32) in (FLOAT32(1.23), Float32(12), Float32(12)) - let expr_lt = cast(col("c1"), DataType::Float32) - .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); + let expr_lt = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Float32, true)), + ) + .in_list(vec![lit(12.0f32), lit(12.0f32), lit(1.23f32)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // INT32(C1) in (INT64(99999999999), INT64(12)) - let expr_lt = cast(col("c1"), DataType::Int64) + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) .in_list(vec![lit(12i32), lit(99999999999i64)], false); assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt); // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3)) - let expr_lt = cast(col("c3"), DataType::Decimal128(12, 3)).in_list( + let expr_lt = cast( + col("c3"), + Arc::new(Field::new("c3", DataType::Decimal128(12, 3), true)), + ) + .in_list( vec![ lit_decimal(12, 12, 3), lit_decimal(12, 12, 3), @@ -482,10 +559,11 @@ mod tests { let schema = expr_test_schema(); // INT32(C1) IN (INT32(12),INT64(23),INT64(34),INT64(56),INT64(78)) -> // INT32(C1) IN (INT32(12),INT32(23),INT32(34),INT32(56),INT32(78)) - let expr_lt = cast(col("c1"), DataType::Int64).in_list( - vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)], - false, - ); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .in_list( + vec![lit(12i64), lit(23i64), lit(34i64), lit(56i64), lit(78i64)], + false, + ); let expected = col("c1").in_list( vec![lit(12i32), lit(23i32), lit(34i32), lit(56i32), lit(78i32)], false, @@ -493,10 +571,11 @@ mod tests { assert_eq!(optimize_test(expr_lt, &schema), expected); // INT32(C2) IN (INT64(NULL),INT64(24),INT64(34),INT64(56),INT64(78)) -> // INT32(C2) IN (INT32(NULL),INT32(24),INT32(34),INT32(56),INT32(78)) - let expr_lt = cast(col("c2"), DataType::Int32).in_list( - vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)], - false, - ); + let expr_lt = cast(col("c2"), Arc::new(Field::new("c2", DataType::Int32, true))) + .in_list( + vec![null_i32(), lit(24i32), lit(34i64), lit(56i64), lit(78i64)], + false, + ); let expected = col("c2").in_list( vec![null_i64(), lit(24i64), lit(34i64), lit(56i64), lit(78i64)], false, @@ -506,7 +585,11 @@ mod tests { // decimal test case // c3 is decimal(18,2) - let expr_lt = cast(col("c3"), DataType::Decimal128(19, 3)).in_list( + let expr_lt = cast( + col("c3"), + Arc::new(Field::new("c3", DataType::Decimal128(19, 3), true)), + ) + .in_list( vec![ lit_decimal(12000, 19, 3), lit_decimal(24000, 19, 3), @@ -529,7 +612,11 @@ mod tests { // cast(INT32(12), INT64) IN (.....) => // INT64(12) IN (INT64(12),INT64(13),INT64(14),INT64(15),INT64(16)) // => true - let expr_lt = cast(lit(12i32), DataType::Int64).in_list( + let expr_lt = cast( + lit(12i32), + Arc::new(Field::new("c1", DataType::Int64, true)), + ) + .in_list( vec![lit(12i64), lit(13i64), lit(14i64), lit(15i64), lit(16i64)], false, ); @@ -542,7 +629,9 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) -> c1 < cast(INT32(16)) // the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16) - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).alias("x"); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .lt(lit(16i64)) + .alias("x"); let expected = col("c1").lt(lit(16i32)).alias("x"); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -552,11 +641,12 @@ mod tests { let schema = expr_test_schema(); // c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32) // the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32 - let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(16i64)).or(cast( - col("c1"), - DataType::Int64, - ) - .gt(lit(32i64))); + let expr_lt = cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .lt(lit(16i64)) + .or( + cast(col("c1"), Arc::new(Field::new("c1", DataType::Int64, true))) + .gt(lit(32i64)), + ); let expected = col("c1").lt(lit(16i32)).or(col("c1").gt(lit(32i32))); assert_eq!(optimize_test(expr_lt, &schema), expected); } @@ -567,12 +657,19 @@ mod tests { // but the type of c6 is uint32 // the rewriter will not throw error and just return the original expr let schema = expr_test_schema(); - let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64)); + let expr_input = cast( + col("c6"), + Arc::new(Field::new("c6", DataType::Float64, true)), + ) + .eq(lit(0f64)); assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input); // inlist for unsupported data type let expr_input = in_list( - cast(col("c6"), DataType::Float64), + cast( + col("c6"), + Arc::new(Field::new("c6", DataType::Float64, true)), + ), // need more literals to avoid rewriting to binary expr vec![lit(0f64), lit(1f64), lit(2f64), lit(3f64), lit(4f64)], false, @@ -585,8 +682,11 @@ mod tests { fn test_unwrap_cast_with_timestamp_nanos() { let schema = expr_test_schema(); // cast(ts_nano as Timestamp(Nanosecond, UTC)) < 1666612093000000000::Timestamp(Nanosecond, Utc)) - let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type()) - .lt(lit_timestamp_nano_utc(1666612093000000000)); + let expr_lt = try_cast( + col("ts_nano_none"), + Arc::new(Field::new("ts_nano_none", timestamp_nano_utc_type(), false)), + ) + .lt(lit_timestamp_nano_utc(1666612093000000000)); let expected = col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000)); assert_eq!(optimize_test(expr_lt, &schema), expected); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd5..1c4407e86ceb 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -287,15 +287,15 @@ pub fn create_physical_expr( }; Ok(expressions::case(expr, when_then_expr, else_expr)?) } - Expr::Cast(Cast { expr, data_type }) => expressions::cast( + Expr::Cast(Cast { expr, field }) => expressions::cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + field.data_type().clone(), ), - Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( + Expr::TryCast(TryCast { expr, field }) => expressions::try_cast( create_physical_expr(expr, input_dfschema, execution_props)?, input_schema, - data_type.clone(), + field.data_type().clone(), ), Expr::Not(expr) => { expressions::not(create_physical_expr(expr, input_dfschema, execution_props)?) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f9400d14a59c..f0263ed03e67 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -70,17 +70,13 @@ message LogicalExtensionNode { repeated LogicalPlanNode inputs = 2; } -message ProjectionColumns { - repeated string columns = 1; -} +message ProjectionColumns { repeated string columns = 1; } message LogicalExprNodeCollection { repeated LogicalExprNode logical_expr_nodes = 1; } -message SortExprNodeCollection { - repeated SortExprNode sort_expr_nodes = 1; -} +message SortExprNodeCollection { repeated SortExprNode sort_expr_nodes = 1; } message ListingTableScanNode { reserved 1; // was string table_name @@ -125,9 +121,7 @@ message CustomTableScanNode { message ProjectionNode { LogicalPlanNode input = 1; repeated LogicalExprNode expr = 2; - oneof optional_alias { - string alias = 3; - } + oneof optional_alias { string alias = 3; } } message SelectionNode { @@ -155,9 +149,7 @@ message HashRepartition { uint64 partition_count = 2; } -message EmptyRelationNode { - bool produce_one_row = 1; -} +message EmptyRelationNode { bool produce_one_row = 1; } message CreateExternalTableNode { reserved 1; // was string name @@ -213,8 +205,9 @@ message CreateViewNode { string definition = 4; } -// a node containing data for defining values list. unlike in SQL where it's two dimensional, here -// the list is flattened, and with the field n_cols it can be parsed and partitioned into rows +// a node containing data for defining values list. unlike in SQL where it's two +// dimensional, here the list is flattened, and with the field n_cols it can be +// parsed and partitioned into rows message ValuesNode { uint64 n_cols = 1; repeated LogicalExprNode values_list = 2; @@ -252,9 +245,7 @@ message JoinNode { LogicalExprNode filter = 8; } -message DistinctNode { - LogicalPlanNode input = 1; -} +message DistinctNode { LogicalPlanNode input = 1; } message DistinctOnNode { repeated LogicalExprNode on_expr = 1; @@ -270,8 +261,8 @@ message CopyToNode { repeated string partition_by = 7; } -message DmlNode{ - enum Type { +message DmlNode { + enum Type { UPDATE = 0; DELETE = 1; CTAS = 2; @@ -319,9 +310,7 @@ message RecursionUnnestOption { uint32 depth = 3; } -message UnionNode { - repeated LogicalPlanNode inputs = 1; -} +message UnionNode { repeated LogicalPlanNode inputs = 1; } message CrossJoinNode { LogicalPlanNode left = 1; @@ -336,9 +325,7 @@ message LimitNode { int64 fetch = 3; } -message SelectionExecNode { - LogicalExprNode expr = 1; -} +message SelectionExecNode { LogicalExprNode expr = 1; } message SubqueryAliasNode { reserved 2; // Was string alias @@ -360,7 +347,6 @@ message LogicalExprNode { // binary expressions BinaryExprNode binary_expr = 4; - // null checks IsNull is_null_expr = 6; IsNotNull is_not_null_expr = 7; @@ -405,13 +391,10 @@ message LogicalExprNode { PlaceholderNode placeholder = 34; Unnest unnest = 35; - } } -message Wildcard { - TableReference qualifier = 1; -} +message Wildcard { TableReference qualifier = 1; } message PlaceholderNode { string id = 1; @@ -422,29 +405,17 @@ message PlaceholderNode { map metadata = 4; } -message LogicalExprList { - repeated LogicalExprNode expr = 1; -} +message LogicalExprList { repeated LogicalExprNode expr = 1; } -message GroupingSetNode { - repeated LogicalExprList expr = 1; -} +message GroupingSetNode { repeated LogicalExprList expr = 1; } -message CubeNode { - repeated LogicalExprNode expr = 1; -} +message CubeNode { repeated LogicalExprNode expr = 1; } -message RollupNode { - repeated LogicalExprNode expr = 1; -} +message RollupNode { repeated LogicalExprNode expr = 1; } -message NamedStructField { - datafusion_common.ScalarValue name = 1; -} +message NamedStructField { datafusion_common.ScalarValue name = 1; } -message ListIndex { - LogicalExprNode key = 1; -} +message ListIndex { LogicalExprNode key = 1; } message ListRange { LogicalExprNode start = 1; @@ -452,41 +423,23 @@ message ListRange { LogicalExprNode stride = 3; } -message IsNull { - LogicalExprNode expr = 1; -} +message IsNull { LogicalExprNode expr = 1; } -message IsNotNull { - LogicalExprNode expr = 1; -} +message IsNotNull { LogicalExprNode expr = 1; } -message IsTrue { - LogicalExprNode expr = 1; -} +message IsTrue { LogicalExprNode expr = 1; } -message IsFalse { - LogicalExprNode expr = 1; -} +message IsFalse { LogicalExprNode expr = 1; } -message IsUnknown { - LogicalExprNode expr = 1; -} +message IsUnknown { LogicalExprNode expr = 1; } -message IsNotTrue { - LogicalExprNode expr = 1; -} +message IsNotTrue { LogicalExprNode expr = 1; } -message IsNotFalse { - LogicalExprNode expr = 1; -} +message IsNotFalse { LogicalExprNode expr = 1; } -message IsNotUnknown { - LogicalExprNode expr = 1; -} +message IsNotUnknown { LogicalExprNode expr = 1; } -message Not { - LogicalExprNode expr = 1; -} +message Not { LogicalExprNode expr = 1; } message AliasNode { LogicalExprNode expr = 1; @@ -503,13 +456,9 @@ message BinaryExprNode { string op = 3; } -message NegativeNode { - LogicalExprNode expr = 1; -} +message NegativeNode { LogicalExprNode expr = 1; } -message Unnest { - repeated LogicalExprNode exprs = 1; -} +message Unnest { repeated LogicalExprNode exprs = 1; } message InListNode { LogicalExprNode expr = 1; @@ -517,7 +466,6 @@ message InListNode { bool negated = 3; } - message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; @@ -592,12 +540,12 @@ message WhenThen { message CastNode { LogicalExprNode expr = 1; - datafusion_common.ArrowType arrow_type = 2; + datafusion_common.Field field = 2; } message TryCastNode { LogicalExprNode expr = 1; - datafusion_common.ArrowType arrow_type = 2; + datafusion_common.Field field = 2; } message SortExprNode { @@ -615,11 +563,12 @@ enum WindowFrameUnits { message WindowFrame { WindowFrameUnits window_frame_units = 1; WindowFrameBound start_bound = 2; - // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/tokio-rs/prost/issues/430 and https://github.com/tokio-rs/prost/pull/455) - // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) - oneof end_bound { - WindowFrameBound bound = 3; - } + // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see + // https://github.com/tokio-rs/prost/issues/430 and + // https://github.com/tokio-rs/prost/pull/455) this syntax is ugly but is + // binary compatible with the "optional" keyword (see + // https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) + oneof end_bound { WindowFrameBound bound = 3; } } enum WindowFrameBoundType { @@ -642,26 +591,18 @@ enum NullTreatment { // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// -message FixedSizeBinary{ - int32 length = 1; -} +message FixedSizeBinary { int32 length = 1; } -enum DateUnit{ +enum DateUnit { Day = 0; DateMillisecond = 1; } -message AnalyzedLogicalPlanType { - string analyzer_name = 1; -} +message AnalyzedLogicalPlanType { string analyzer_name = 1; } -message OptimizedLogicalPlanType { - string optimizer_name = 1; -} +message OptimizedLogicalPlanType { string optimizer_name = 1; } -message OptimizedPhysicalPlanType { - string optimizer_name = 1; -} +message OptimizedPhysicalPlanType { string optimizer_name = 1; } message PlanType { oneof plan_type_enum { @@ -686,9 +627,7 @@ message StringifiedPlan { string plan = 2; } -message BareTableReference { - string table = 1; -} +message BareTableReference { string table = 1; } message PartialTableReference { string schema = 1; @@ -756,7 +695,6 @@ message PartitionColumn { datafusion_common.ArrowType arrow_type = 2; } - message FileSinkConfig { reserved 6; // writer_mode reserved 8; // was `overwrite` which has been superseded by `insert_op` @@ -884,9 +822,7 @@ message PhysicalScalarUdfNode { } message PhysicalAggregateExprNode { - oneof AggregateFunction { - string user_defined_aggr_function = 4; - } + oneof AggregateFunction { string user_defined_aggr_function = 4; } repeated PhysicalExprNode expr = 2; repeated PhysicalSortExprNode ordering_req = 5; bool distinct = 3; @@ -911,17 +847,11 @@ message PhysicalWindowExprNode { bool distinct = 12; } -message PhysicalIsNull { - PhysicalExprNode expr = 1; -} +message PhysicalIsNull { PhysicalExprNode expr = 1; } -message PhysicalIsNotNull { - PhysicalExprNode expr = 1; -} +message PhysicalIsNotNull { PhysicalExprNode expr = 1; } -message PhysicalNot { - PhysicalExprNode expr = 1; -} +message PhysicalNot { PhysicalExprNode expr = 1; } message PhysicalAliasNode { PhysicalExprNode expr = 1; @@ -980,9 +910,7 @@ message PhysicalCastNode { datafusion_common.ArrowType arrow_type = 2; } -message PhysicalNegativeNode { - PhysicalExprNode expr = 1; -} +message PhysicalNegativeNode { PhysicalExprNode expr = 1; } message PhysicalExtensionExprNode { bytes expr = 1; @@ -996,9 +924,7 @@ message FilterExecNode { repeated uint32 projection = 9; } -message FileGroup { - repeated PartitionedFile files = 1; -} +message FileGroup { repeated PartitionedFile files = 1; } message ScanLimit { // wrap into a message to make it optional @@ -1042,23 +968,15 @@ message CsvScanExecNode { bool has_header = 2; string delimiter = 3; string quote = 4; - oneof optional_escape { - string escape = 5; - } - oneof optional_comment { - string comment = 6; - } + oneof optional_escape { string escape = 5; } + oneof optional_comment { string comment = 6; } bool newlines_in_values = 7; bool truncate_rows = 8; } -message JsonScanExecNode { - FileScanExecConf base_conf = 1; -} +message JsonScanExecNode { FileScanExecConf base_conf = 1; } -message AvroScanExecNode { - FileScanExecConf base_conf = 1; -} +message AvroScanExecNode { FileScanExecConf base_conf = 1; } message MemoryScanExecNode { repeated bytes partitions = 1; @@ -1069,9 +987,7 @@ message MemoryScanExecNode { optional uint32 fetch = 6; } -message CooperativeExecNode { - PhysicalPlanNode input = 1; -} +message CooperativeExecNode { PhysicalPlanNode input = 1; } enum PartitionMode { COLLECT_LEFT = 0; @@ -1107,13 +1023,9 @@ message SymmetricHashJoinExecNode { repeated PhysicalSortExprNode right_sort_exprs = 10; } -message InterleaveExecNode { - repeated PhysicalPlanNode inputs = 1; -} +message InterleaveExecNode { repeated PhysicalPlanNode inputs = 1; } -message UnionExecNode { - repeated PhysicalPlanNode inputs = 1; -} +message UnionExecNode { repeated PhysicalPlanNode inputs = 1; } message ExplainExecNode { datafusion_common.Schema schema = 1; @@ -1138,22 +1050,16 @@ message PhysicalColumn { uint32 index = 2; } -message UnknownColumn { - string name = 1; -} +message UnknownColumn { string name = 1; } message JoinOn { PhysicalExprNode left = 1; PhysicalExprNode right = 2; } -message EmptyExecNode { - datafusion_common.Schema schema = 1; -} +message EmptyExecNode { datafusion_common.Schema schema = 1; } -message PlaceholderRowExecNode { - datafusion_common.Schema schema = 1; -} +message PlaceholderRowExecNode { datafusion_common.Schema schema = 1; } message ProjectionExecNode { PhysicalPlanNode input = 1; @@ -1169,9 +1075,7 @@ enum AggregateMode { SINGLE_PARTITIONED = 4; } -message PartiallySortedInputOrderMode { - repeated uint64 columns = 6; -} +message PartiallySortedInputOrderMode { repeated uint64 columns = 6; } message WindowAggExecNode { PhysicalPlanNode input = 1; @@ -1185,13 +1089,9 @@ message WindowAggExecNode { } } -message MaybeFilter { - PhysicalExprNode expr = 1; -} +message MaybeFilter { PhysicalExprNode expr = 1; } -message MaybePhysicalSortExprs { - repeated PhysicalSortExprNode sort_expr = 1; -} +message MaybePhysicalSortExprs { repeated PhysicalSortExprNode sort_expr = 1; } message AggLimit { // wrap into a message to make it optional @@ -1205,7 +1105,8 @@ message AggregateExecNode { PhysicalPlanNode input = 4; repeated string group_expr_name = 5; repeated string aggr_expr_name = 6; - // we need the input schema to the partial aggregate to pass to the final aggregate + // we need the input schema to the partial aggregate to pass to the final + // aggregate datafusion_common.Schema input_schema = 7; repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; @@ -1265,7 +1166,7 @@ message PhysicalHashRepartition { uint64 partition_count = 2; } -message RepartitionExecNode{ +message RepartitionExecNode { PhysicalPlanNode input = 1; // oneof partition_method { // uint64 round_robin = 2; @@ -1283,13 +1184,13 @@ message Partitioning { } } -message JoinFilter{ +message JoinFilter { PhysicalExprNode expression = 1; repeated ColumnIndex column_indices = 2; datafusion_common.Schema schema = 3; } -message ColumnIndex{ +message ColumnIndex { uint32 index = 1; datafusion_common.JoinSide side = 2; } @@ -1323,8 +1224,8 @@ message RecursiveQueryNode { } message CteWorkTableScanNode { - string name = 1; - datafusion_common.Schema schema = 2; + string name = 1; + datafusion_common.Schema schema = 2; } enum GenerateSeriesName { @@ -1332,45 +1233,43 @@ enum GenerateSeriesName { GS_RANGE = 1; } -message GenerateSeriesArgsContainsNull { - GenerateSeriesName name = 1; -} +message GenerateSeriesArgsContainsNull { GenerateSeriesName name = 1; } message GenerateSeriesArgsInt64 { - int64 start = 1; - int64 end = 2; - int64 step = 3; - bool include_end = 4; - GenerateSeriesName name = 5; + int64 start = 1; + int64 end = 2; + int64 step = 3; + bool include_end = 4; + GenerateSeriesName name = 5; } message GenerateSeriesArgsTimestamp { - int64 start = 1; - int64 end = 2; - datafusion_common.IntervalMonthDayNanoValue step = 3; - optional string tz = 4; - bool include_end = 5; - GenerateSeriesName name = 6; + int64 start = 1; + int64 end = 2; + datafusion_common.IntervalMonthDayNanoValue step = 3; + optional string tz = 4; + bool include_end = 5; + GenerateSeriesName name = 6; } message GenerateSeriesArgsDate { - int64 start = 1; - int64 end = 2; - datafusion_common.IntervalMonthDayNanoValue step = 3; - bool include_end = 4; - GenerateSeriesName name = 5; + int64 start = 1; + int64 end = 2; + datafusion_common.IntervalMonthDayNanoValue step = 3; + bool include_end = 4; + GenerateSeriesName name = 5; } message GenerateSeriesNode { - datafusion_common.Schema schema = 1; - uint32 target_batch_size = 2; - - oneof args { - GenerateSeriesArgsContainsNull contains_null = 3; - GenerateSeriesArgsInt64 int64_args = 4; - GenerateSeriesArgsTimestamp timestamp_args = 5; - GenerateSeriesArgsDate date_args = 6; - } + datafusion_common.Schema schema = 1; + uint32 target_batch_size = 2; + + oneof args { + GenerateSeriesArgsContainsNull contains_null = 3; + GenerateSeriesArgsInt64 int64_args = 4; + GenerateSeriesArgsTimestamp timestamp_args = 5; + GenerateSeriesArgsDate date_args = 6; + } } message SortMergeJoinExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4cf834d0601e..dfbde2f623da 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1831,15 +1831,15 @@ impl serde::Serialize for CastNode { if self.expr.is_some() { len += 1; } - if self.arrow_type.is_some() { + if self.field.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; } struct_ser.end() } @@ -1852,14 +1852,13 @@ impl<'de> serde::Deserialize<'de> for CastNode { { const FIELDS: &[&str] = &[ "expr", - "arrow_type", - "arrowType", + "field", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - ArrowType, + Field, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1882,7 +1881,7 @@ impl<'de> serde::Deserialize<'de> for CastNode { { match value { "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "field" => Ok(GeneratedField::Field), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1903,7 +1902,7 @@ impl<'de> serde::Deserialize<'de> for CastNode { V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut arrow_type__ = None; + let mut field__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -1912,17 +1911,17 @@ impl<'de> serde::Deserialize<'de> for CastNode { } expr__ = map_.next_value()?; } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); } - arrow_type__ = map_.next_value()?; + field__ = map_.next_value()?; } } } Ok(CastNode { expr: expr__, - arrow_type: arrow_type__, + field: field__, }) } } @@ -22046,15 +22045,15 @@ impl serde::Serialize for TryCastNode { if self.expr.is_some() { len += 1; } - if self.arrow_type.is_some() { + if self.field.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.TryCastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; } struct_ser.end() } @@ -22067,14 +22066,13 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { const FIELDS: &[&str] = &[ "expr", - "arrow_type", - "arrowType", + "field", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - ArrowType, + Field, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22097,7 +22095,7 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { { match value { "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "field" => Ok(GeneratedField::Field), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22118,7 +22116,7 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut arrow_type__ = None; + let mut field__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -22127,17 +22125,17 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { } expr__ = map_.next_value()?; } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); } - arrow_type__ = map_.next_value()?; + field__ = map_.next_value()?; } } } Ok(TryCastNode { expr: expr__, - arrow_type: arrow_type__, + field: field__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 12b417627411..900630cb3f2b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -327,8 +327,9 @@ pub struct CreateViewNode { #[prost(string, tag = "4")] pub definition: ::prost::alloc::string::String, } -/// a node containing data for defining values list. unlike in SQL where it's two dimensional, here -/// the list is flattened, and with the field n_cols it can be parsed and partitioned into rows +/// a node containing data for defining values list. unlike in SQL where it's two +/// dimensional, here the list is flattened, and with the field n_cols it can be +/// parsed and partitioned into rows #[derive(Clone, PartialEq, ::prost::Message)] pub struct ValuesNode { #[prost(uint64, tag = "1")] @@ -918,14 +919,14 @@ pub struct CastNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub field: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct TryCastNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub field: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { @@ -942,15 +943,21 @@ pub struct WindowFrame { pub window_frame_units: i32, #[prost(message, optional, tag = "2")] pub start_bound: ::core::option::Option, - /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see and ) - /// this syntax is ugly but is binary compatible with the "optional" keyword (see ) + /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see + /// and + /// ) this syntax is ugly but is + /// binary compatible with the "optional" keyword (see + /// ) #[prost(oneof = "window_frame::EndBound", tags = "3")] pub end_bound: ::core::option::Option, } /// Nested message and enum types in `WindowFrame`. pub mod window_frame { - /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see and ) - /// this syntax is ugly but is binary compatible with the "optional" keyword (see ) + /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see + /// and + /// ) this syntax is ugly but is + /// binary compatible with the "optional" keyword (see + /// ) #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum EndBound { #[prost(message, tag = "3")] @@ -1812,7 +1819,8 @@ pub struct AggregateExecNode { pub group_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(string, repeated, tag = "6")] pub aggr_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - /// we need the input schema to the partial aggregate to pass to the final aggregate + /// we need the input schema to the partial aggregate to pass to the final + /// aggregate #[prost(message, optional, tag = "7")] pub input_schema: ::core::option::Option, #[prost(message, repeated, tag = "8")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 598a77f5420e..0e2c107cc421 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -527,8 +527,11 @@ pub fn parse_expr( "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::Cast(Cast::new(expr, data_type))) + let field = cast + .field + .as_ref() + .ok_or_else(|| Error::required("field"))?; + Ok(Expr::Cast(Cast::new(expr, Arc::new(field.try_into()?)))) } ExprType::TryCast(cast) => { let expr = Box::new(parse_required_expr( @@ -537,8 +540,14 @@ pub fn parse_expr( "expr", codec, )?); - let data_type = cast.arrow_type.as_ref().required("arrow_type")?; - Ok(Expr::TryCast(TryCast::new(expr, data_type))) + let field = cast + .field + .as_ref() + .ok_or_else(|| Error::required("field"))?; + Ok(Expr::TryCast(TryCast::new( + expr, + Arc::new(field.try_into()?), + ))) } ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c..cb0a9c368ad9 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -522,19 +522,19 @@ pub fn serialize_expr( expr_type: Some(ExprType::Case(expr)), } } - Expr::Cast(Cast { expr, data_type }) => { + Expr::Cast(Cast { expr, field }) => { let expr = Box::new(protobuf::CastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + field: Some(field.as_ref().try_into()?), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::Cast(expr)), } } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let expr = Box::new(protobuf::TryCastNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - arrow_type: Some(data_type.try_into()?), + field: Some(field.as_ref().try_into()?), }); protobuf::LogicalExprNode { expr_type: Some(ExprType::TryCast(expr)), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index bfd693e6a0f8..51aca96c85c7 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2032,7 +2032,10 @@ fn roundtrip_null_literal() { #[test] fn roundtrip_cast() { - let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + let test_expr = Expr::Cast(Cast::new( + Box::new(lit(1.0_f32)), + Arc::new(Field::new("a", DataType::Boolean, true)), + )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2040,14 +2043,18 @@ fn roundtrip_cast() { #[test] fn roundtrip_try_cast() { - let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + let test_expr = Expr::TryCast(TryCast::new( + Box::new(lit(1.0_f32)), + Arc::new(Field::new("a", DataType::Boolean, true)), + )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); - let test_expr = - Expr::TryCast(TryCast::new(Box::new(lit("not a bool")), DataType::Boolean)); + let test_expr = Expr::TryCast(TryCast::new( + Box::new(lit("not a bool")), + Arc::new(Field::new("a", DataType::Boolean, true)), + )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index fa3454ce5644..0b5e32fc8488 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -2990,13 +2990,20 @@ mod tests { // test cast(c1 as int64) = 1 // test column on the left - let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + ) + .eq(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); + let expr = lit(ScalarValue::Int64(Some(1))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3005,15 +3012,20 @@ mod tests { "c1_null_count@1 != row_count@2 AND TRY_CAST(c1_max@0 AS Int64) > 1"; // test column on the left - let expr = - try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); + let expr = try_cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + ) + .gt(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = - lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); + let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3027,15 +3039,17 @@ mod tests { let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)"; // test column on the left - let expr = cast(col("c1"), DataType::Utf8) + let expr = cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, false))) .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) - .eq(cast(col("c1"), DataType::Utf8)); + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Utf8, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3049,13 +3063,20 @@ mod tests { let expected_expr = "true"; // test column on the left - let expr = cast(col("c1"), DataType::Int32).eq(lit(ScalarValue::Int32(Some(1)))); + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int32, false)), + ) + .eq(lit(ScalarValue::Int32(Some(1)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = lit(ScalarValue::Int32(Some(1))).eq(cast(col("c1"), DataType::Int32)); + let expr = lit(ScalarValue::Int32(Some(1))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int32, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3069,15 +3090,17 @@ mod tests { let expected_expr = "true"; // test column on the left - let expr = cast(col("c1"), DataType::Utf8) + let expr = cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, false))) .eq(lit(ScalarValue::Utf8(Some("1".to_string())))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = lit(ScalarValue::Utf8(Some("1".to_string()))) - .eq(cast(col("c1"), DataType::Utf8)); + let expr = lit(ScalarValue::Utf8(Some("1".to_string()))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Utf8, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3091,15 +3114,20 @@ mod tests { let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)"; // test column on the left - let expr = - cast(col("c1"), DataType::Date64).eq(lit(ScalarValue::Date64(Some(123)))); + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Date64, false)), + ) + .eq(lit(ScalarValue::Date64(Some(123)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = - lit(ScalarValue::Date64(Some(123))).eq(cast(col("c1"), DataType::Date64)); + let expr = lit(ScalarValue::Date64(Some(123))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Date64, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3116,7 +3144,11 @@ mod tests { // test column on the left let expr = cast( col("c1"), - DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + Arc::new(Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )), ) .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); let predicate_expr = @@ -3126,7 +3158,11 @@ mod tests { // test column on the right let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast( col("c1"), - DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + Arc::new(Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )), )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); @@ -3146,15 +3182,20 @@ mod tests { let expected_expr = "true"; // test column on the left - let expr = - cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Date32, false)), + ) + .eq(lit(ScalarValue::Date32(Some(123)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = - lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let expr = lit(ScalarValue::Date32(Some(123))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Date32, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3182,7 +3223,14 @@ mod tests { // Test with column cast to a dictionary with different key type let expr = cast( col("c1"), - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + Arc::new(Field::new( + "c1", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + false, + )), ) .eq(lit(ScalarValue::Utf8(Some("test".to_string())))); let predicate_expr = @@ -3204,8 +3252,11 @@ mod tests { let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)"; // Test with literal of a different type - let expr = - cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(123)))); + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + ) + .eq(lit(ScalarValue::Int64(Some(123)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3252,7 +3303,14 @@ mod tests { // Test with a cast to a different date type let expr = cast( col("c1"), - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Date64)), + Arc::new(Field::new( + "c1", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Date64), + ), + false, + )), ) .eq(lit(ScalarValue::Date64(Some(123)))); let predicate_expr = @@ -3268,15 +3326,20 @@ mod tests { let expected_expr = "true"; // test column on the left - let expr = - cast(col("c1"), DataType::Date32).eq(lit(ScalarValue::Date32(Some(123)))); + let expr = cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Date32, false)), + ) + .eq(lit(ScalarValue::Date32(Some(123)))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = - lit(ScalarValue::Date32(Some(123))).eq(cast(col("c1"), DataType::Date32)); + let expr = lit(ScalarValue::Date32(Some(123))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Date32, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3290,15 +3353,17 @@ mod tests { let expected_expr = "true"; // test column on the left - let expr = cast(col("c1"), DataType::Utf8) + let expr = cast(col("c1"), Arc::new(Field::new("c1", DataType::Utf8, false))) .eq(lit(ScalarValue::Utf8(Some("2024-01-01".to_string())))); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right - let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))) - .eq(cast(col("c1"), DataType::Utf8)); + let expr = lit(ScalarValue::Utf8(Some("2024-01-01".to_string()))).eq(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Utf8, false)), + )); let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -3311,7 +3376,10 @@ mod tests { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); // test cast(c1 as int64) in int64(1, 2, 3) let expr = Expr::InList(InList::new( - Box::new(cast(col("c1"), DataType::Int64)), + Box::new(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + )), vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(2))), @@ -3325,7 +3393,10 @@ mod tests { assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList(InList::new( - Box::new(cast(col("c1"), DataType::Int64)), + Box::new(cast( + col("c1"), + Arc::new(Field::new("c1", DataType::Int64, false)), + )), vec![ lit(ScalarValue::Int64(Some(1))), lit(ScalarValue::Int64(Some(2))), @@ -3368,8 +3439,11 @@ mod tests { prune_with_expr( // with cast column to other type - cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + cast( + col("s1"), + Arc::new(Field::new("s1", DataType::Decimal128(14, 3), false)), + ) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), &schema, &TestStatistics::new().with( "s1", @@ -3383,8 +3457,11 @@ mod tests { prune_with_expr( // with try cast column to other type - try_cast(col("s1"), DataType::Decimal128(14, 3)) - .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), + try_cast( + col("s1"), + Arc::new(Field::new("s1", DataType::Decimal128(14, 3), false)), + ) + .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))), &schema, &TestStatistics::new().with( "s1", @@ -3470,7 +3547,11 @@ mod tests { prune_with_expr( // filter with cast - cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))), + cast( + col("s2"), + Arc::new(Field::new("s2", DataType::Int64, false)), + ) + .gt(lit(ScalarValue::Int64(Some(5)))), &schema, &statistics, &[false, true, true, true], @@ -3696,7 +3777,8 @@ mod tests { prune_with_expr( // cast(i as utf8) <= 0 - cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + cast(col("i"), Arc::new(Field::new("i", DataType::Utf8, false))) + .lt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3704,7 +3786,8 @@ mod tests { prune_with_expr( // try_cast(i as utf8) <= 0 - try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")), + try_cast(col("i"), Arc::new(Field::new("i", DataType::Utf8, false))) + .lt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3712,7 +3795,11 @@ mod tests { prune_with_expr( // cast(-i as utf8) >= 0 - cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + cast( + Expr::Negative(Box::new(col("i"))), + Arc::new(Field::new("i", DataType::Utf8, false)), + ) + .gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3720,7 +3807,11 @@ mod tests { prune_with_expr( // try_cast(-i as utf8) >= 0 - try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")), + try_cast( + Expr::Negative(Box::new(col("i"))), + Arc::new(Field::new("i", DataType::Utf8, false)), + ) + .gt_eq(lit("0")), &schema, &statistics, expected_ret, @@ -3761,14 +3852,16 @@ mod tests { let expected_ret = &[true, false, false, true, false]; prune_with_expr( - cast(col("i"), DataType::Int64).eq(lit(0i64)), + cast(col("i"), Arc::new(Field::new("i", DataType::Int64, false))) + .eq(lit(0i64)), &schema, &statistics, expected_ret, ); prune_with_expr( - try_cast(col("i"), DataType::Int64).eq(lit(0i64)), + try_cast(col("i"), Arc::new(Field::new("i", DataType::Int64, false))) + .eq(lit(0i64)), &schema, &statistics, expected_ret, @@ -3791,7 +3884,7 @@ mod tests { let expected_ret = &[true, true, true, true, true]; prune_with_expr( - cast(col("i"), DataType::Utf8).eq(lit("0")), + cast(col("i"), Arc::new(Field::new("i", DataType::Utf8, false))).eq(lit("0")), &schema, &statistics, expected_ret, @@ -3946,7 +4039,10 @@ mod tests { prune_with_expr( // i > int64(0) - col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)), + col("i").gt(cast( + lit(ScalarValue::Int64(Some(0))), + Arc::new(Field::new("i", DataType::Int32, false)), + )), &schema, &statistics, expected_ret, @@ -3954,7 +4050,8 @@ mod tests { prune_with_expr( // cast(i as int64) > int64(0) - cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + cast(col("i"), Arc::new(Field::new("i", DataType::Int64, false))) + .gt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, expected_ret, @@ -3962,7 +4059,8 @@ mod tests { prune_with_expr( // try_cast(i as int64) > int64(0) - try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))), + try_cast(col("i"), Arc::new(Field::new("i", DataType::Int64, false))) + .gt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, expected_ret, @@ -3970,8 +4068,11 @@ mod tests { prune_with_expr( // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` - Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) - .lt(lit(ScalarValue::Int64(Some(0)))), + Expr::Negative(Box::new(cast( + col("i"), + Arc::new(Field::new("i", DataType::Int64, false)), + ))) + .lt(lit(ScalarValue::Int64(Some(0)))), &schema, &statistics, expected_ret, @@ -4544,7 +4645,10 @@ mod tests { assert_eq!(result_right.to_string(), right_input.to_string()); // cast op lit - let left_input = cast(col("a"), DataType::Decimal128(20, 3)); + let left_input = cast( + col("a"), + Arc::new(Field::new("a", DataType::Decimal128(20, 3), true)), + ); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3)); let right_input = logical2physical(&right_input, &schema); @@ -4559,7 +4663,8 @@ mod tests { assert_eq!(result_right.to_string(), right_input.to_string()); // try_cast op lit - let left_input = try_cast(col("a"), DataType::Int64); + let left_input = + try_cast(col("a"), Arc::new(Field::new("a", DataType::Int64, true))); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); @@ -4646,7 +4751,7 @@ mod tests { // this cast is not supported let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let df_schema = DFSchema::try_from(schema.clone()).unwrap(); - let left_input = cast(col("a"), DataType::Int64); + let left_input = cast(col("a"), Arc::new(Field::new("a", DataType::Int64, true))); let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); let right_input = logical2physical(&right_input, &schema); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 715a02db8b02..e8dd81b5a155 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, TimeUnit}; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_expr::planner::{ PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, }; @@ -288,9 +290,7 @@ impl SqlToRel<'_, S> { schema, planner_context, )?), - self.convert_data_type_to_field(&data_type)? - .data_type() - .clone(), + self.convert_data_type_to_field(&data_type)?, ))) } @@ -298,12 +298,17 @@ impl SqlToRel<'_, S> { data_type, value, uses_odbc_syntax: _, - }) => Ok(Expr::Cast(Cast::new( - Box::new(lit(value.into_string().unwrap())), - self.convert_data_type_to_field(&data_type)? - .data_type() - .clone(), - ))), + }) => { + let literal = value.clone().into_string().unwrap(); + let field_name = value.into_string().unwrap(); + let field = Arc::new( + self.convert_data_type_to_field(&data_type)? + .as_ref() + .clone() + .with_name(field_name), + ); + Ok(Expr::Cast(Cast::new(Box::new(lit(literal)), field))) + } SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new( self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, @@ -571,7 +576,11 @@ impl SqlToRel<'_, S> { SQLExpr::Value(ValueWithSpan { value: Value::SingleQuotedString(s), span: _, - }) => DataType::Timestamp(TimeUnit::Nanosecond, Some(s.into())), + }) => Arc::new(Field::new( + s.clone(), + DataType::Timestamp(TimeUnit::Nanosecond, Some(s.into())), + true, + )), _ => { return not_impl_err!( "Unsupported ast node in sqltorel: {time_zone:?}" @@ -999,18 +1008,17 @@ impl SqlToRel<'_, S> { { Expr::Cast(Cast::new( Box::new(expr), - DataType::Timestamp(TimeUnit::Second, tz.clone()), + Arc::new(Field::new( + "", + DataType::Timestamp(TimeUnit::Second, tz.clone()), + true, + )), )) } _ => expr, }; - // Currently drops metadata attached to the type - // https://github.com/apache/datafusion/issues/18060 - Ok(Expr::Cast(Cast::new( - Box::new(expr), - dt.data_type().clone(), - ))) + Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) } /// Extracts the root expression and access chain from a compound expression. diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 81381bf49fc5..eeae336a610e 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -499,11 +499,8 @@ impl SqlToRel<'_, S> { .iter() .zip(input_fields) .map(|(field, input_field)| { - cast( - col(input_field.name()), - field.data_type().clone(), - ) - .alias(field.name()) + cast(col(input_field.name()), Arc::clone(field)) + .alias(field.name()) }) .collect::>(); diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a7fe8efa153c..a4688ffa173e 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -189,8 +189,8 @@ impl Unparser<'_> { end_token: AttachedToken::empty(), }) } - Expr::Cast(Cast { expr, data_type }) => { - Ok(self.cast_to_sql(expr, data_type)?) + Expr::Cast(Cast { expr, field }) => { + Ok(self.cast_to_sql(expr, field.data_type())?) } Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), @@ -464,12 +464,12 @@ impl Unparser<'_> { ) }) } - Expr::TryCast(TryCast { expr, data_type }) => { + Expr::TryCast(TryCast { expr, field }) => { let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + data_type: self.arrow_dtype_to_ast_dtype(field.data_type())?, format: None, }) } @@ -1887,31 +1887,36 @@ mod tests { ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Date64, + field: Arc::new(Field::new("a", DataType::Date64, true)), }), r#"CAST(a AS DATETIME)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Timestamp( - TimeUnit::Nanosecond, - Some("+08:00".into()), - ), + field: Arc::new(Field::new( + "a", + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())), + true, + )), }), r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + field: Arc::new(Field::new( + "a", + DataType::Timestamp(TimeUnit::Second, None), + true, + )), }), r#"CAST(a AS TIMESTAMP)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::UInt32, + field: Arc::new(Field::new("a", DataType::UInt32, true)), }), r#"CAST(a AS INTEGER UNSIGNED)"#, ), @@ -2161,11 +2166,11 @@ mod tests { r#"NOT EXISTS (SELECT * FROM t WHERE (t.a = 1))"#, ), ( - try_cast(col("a"), DataType::Date64), + try_cast(col("a"), Arc::new(Field::new("a", DataType::Date64, false))), r#"TRY_CAST(a AS DATETIME)"#, ), ( - try_cast(col("a"), DataType::UInt32), + try_cast(col("a"), Arc::new(Field::new("a", DataType::UInt32, false))), r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( @@ -2229,7 +2234,7 @@ mod tests { ( Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Decimal128(10, -2), + field: Arc::new(Field::new("a", DataType::Decimal128(10, -2), true)), }), r#"CAST(a AS DECIMAL(12,0))"#, ), @@ -2369,7 +2374,7 @@ mod tests { let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Date64, + field: Arc::new(Field::new("a", DataType::Date64, true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2394,7 +2399,7 @@ mod tests { let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Float64, + field: Arc::new(Field::new("a", DataType::Float64, true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2630,7 +2635,7 @@ mod tests { ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::Binary, + field: Arc::new(Field::new("blah", DataType::Binary, true)), }), "'blah'", ), @@ -2640,7 +2645,7 @@ mod tests { ScalarValue::Utf8(Some("blah".to_string())), None, )), - data_type: DataType::BinaryView, + field: Arc::new(Field::new("blah", DataType::BinaryView, true)), }), "'blah'", ), @@ -2674,7 +2679,7 @@ mod tests { let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type, + field: Arc::new(Field::new("a", data_type, true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2760,7 +2765,7 @@ mod tests { let unparser = Unparser::new(&dialect); let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Int64, + field: Arc::new(Field::new("a", DataType::Int64, true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2788,7 +2793,7 @@ mod tests { let unparser = Unparser::new(&dialect); let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Int32, + field: Arc::new(Field::new("a", DataType::Int32, true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2827,7 +2832,7 @@ mod tests { let unparser = Unparser::new(dialect); let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: data_type.clone(), + field: Arc::new(Field::new("a", data_type.clone(), true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2883,7 +2888,7 @@ mod tests { let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type, + field: Arc::new(Field::new("a", data_type, true)), }); let ast = unparser.expr_to_sql(&expr)?; @@ -2929,7 +2934,14 @@ mod tests { ScalarValue::Utf8(Some("variation".to_string())), None, )), - data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), + field: Arc::new(Field::new( + "dict_col", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + false, + )), }), "'variation'", )]; @@ -2964,7 +2976,7 @@ mod tests { args: vec![ Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Float64, + field: Arc::new(Field::new("a", DataType::Float64, false)), }), Expr::Literal(ScalarValue::Int64(Some(2)), None), ], @@ -3148,7 +3160,10 @@ mod tests { assert_eq!(ast_dtype, ast::DataType::Char(None)); - let expr = cast(col("a"), DataType::Utf8View); + let expr = cast( + col("a"), + Arc::new(Field::new("a", DataType::Utf8View, false)), + ); let ast = unparser.expr_to_sql(&expr)?; let actual = format!("{ast}"); @@ -3214,7 +3229,11 @@ mod tests { let unparser = Unparser::new(dialect.as_ref()); let expr = Expr::Cast(Cast { expr: Box::new(col("a")), - data_type: DataType::Timestamp(TimeUnit::Nanosecond, None), + field: Arc::new(Field::new( + "a", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + )), }); let ast = unparser.expr_to_sql(&expr)?; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 5f76afb763cf..9bb9eacd0c95 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -2238,9 +2238,12 @@ fn test_unparse_subquery_alias_with_table_pushdown() -> Result<()> { .alias("customer")? .project(vec![ col("customer.c_custkey"), - cast(col("customer.c_custkey"), DataType::Int64) - .add(lit(1)) - .alias("custkey_plus"), + cast( + col("customer.c_custkey"), + Arc::new(Field::new("customer.c_custkey", DataType::Int64, false)), + ) + .add(lit(1)) + .alias("custkey_plus"), col("customer.c_name"), ])? .alias("customer")? diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs index 5e8d3d93065f..dcb1951b3227 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::logical_plan::consumer::types::from_substrait_type_without_names; use crate::logical_plan::consumer::SubstraitConsumer; +use datafusion::arrow::datatypes::Field; use datafusion::common::{substrait_err, DFSchema}; use datafusion::logical_expr::{Cast, Expr, TryCast}; use substrait::proto::expression as substrait_expression; @@ -39,9 +42,15 @@ pub async fn from_cast( ); let data_type = from_substrait_type_without_names(consumer, output_type)?; if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + Ok(Expr::TryCast(TryCast::new( + input_expr, + Arc::new(Field::new("", data_type, true)), + ))) } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) + Ok(Expr::Cast(Cast::new( + input_expr, + Arc::new(Field::new("", data_type, true)), + ))) } } None => substrait_err!("Cast expression without output type is not allowed"), diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index f7eedcb7a2b2..023b2fb3f6e1 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -270,10 +270,7 @@ pub(super) fn rename_expressions( .map(|(old_expr, new_field)| { // Check if type (i.e. nested struct field names) match, use Cast to rename if needed let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { - Expr::Cast(Cast::new( - Box::new(old_expr), - new_field.data_type().to_owned(), - )) + Expr::Cast(Cast::new(Box::new(old_expr), Arc::clone(new_field))) } else { old_expr }; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs index 71c2140bac8b..7761c1dda2ad 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs @@ -29,7 +29,7 @@ pub fn from_cast( cast: &Cast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let Cast { expr, data_type } = cast; + let Cast { expr, field } = cast; // since substrait Null must be typed, so if we see a cast(null, dt), we make it a typed null if let Expr::Literal(lit, _) = expr.as_ref() { // only the untyped(a null scalar value) null literal need this special handling @@ -40,7 +40,9 @@ pub fn from_cast( nullable: true, type_variation_reference: DEFAULT_TYPE_VARIATION_REF, literal_type: Some(LiteralType::Null(to_substrait_type( - producer, data_type, true, + producer, + field.data_type(), + true, )?)), }; return Ok(Expression { @@ -51,7 +53,7 @@ pub fn from_cast( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(producer, data_type, true)?), + r#type: Some(to_substrait_type(producer, field.data_type(), true)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ThrowException.into(), }, @@ -64,11 +66,11 @@ pub fn from_try_cast( cast: &TryCast, schema: &DFSchemaRef, ) -> datafusion::common::Result { - let TryCast { expr, data_type } = cast; + let TryCast { expr, field } = cast; Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(producer, data_type, true)?), + r#type: Some(to_substrait_type(producer, field.data_type(), true)?), input: Some(Box::new(producer.handle_expr(expr, schema)?)), failure_behavior: FailureBehavior::ReturnNull.into(), },