From 6b00b9ae14c30e6f0b174ba0500e5be956f5e1cb Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 13:57:24 +0800 Subject: [PATCH 01/28] switch func Signed-off-by: Jay Zhan --- datafusion/expr/src/expr_schema.rs | 17 ++++-- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udf.rs | 28 ++++++++++ datafusion/functions/src/core/arrow_cast.rs | 28 +++++++--- datafusion/functions/src/core/getfield.rs | 54 +++++++++++++++++-- datafusion/functions/src/core/named_struct.rs | 34 ++++-------- .../functions/src/datetime/date_part.rs | 20 +++---- .../functions/src/datetime/from_unixtime.rs | 27 ++++------ datafusion/physical-expr/src/planner.rs | 39 ++++++++++---- 9 files changed, 173 insertions(+), 76 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 25073ca7eaaa..90c8c9df9cda 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -24,12 +24,12 @@ use crate::type_coercion::binary::get_result_type; use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; +use crate::udf::ReturnTypeArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, - Result, TableReference, + not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, ScalarValue, TableReference }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; @@ -168,9 +168,20 @@ impl ExprSchemable for Expr { ) })?; + + let arguments = args.iter().map(|e| match e { + Expr::Literal(ScalarValue::Utf8(s)) => s.clone().unwrap_or_default(), + _ => "".to_string(), + }).collect::>(); + let args = ReturnTypeArgs { + arg_types: &new_data_types, + arguments: &arguments, + }; + + // Perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type - Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) + Ok(func.return_type_from_args(args)?) } Expr::WindowFunction(window_function) => self .data_type_and_nullable_with_window_function(schema, window_function) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index a57fd80c48e1..1c9059f607d2 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -93,7 +93,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; -pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, ReturnTypeArgs}; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 51c42b5c4c30..b1f26f885ffa 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -192,6 +192,13 @@ impl ScalarUDF { self.inner.return_type_from_exprs(args, schema, arg_types) } + pub fn return_type_from_args( + &self, + args: ReturnTypeArgs, + ) -> Result { + self.inner.return_type_from_args(args) + } + /// Do the function rewrite /// /// See [`ScalarUDFImpl::simplify`] for more details. @@ -342,6 +349,12 @@ pub struct ScalarFunctionArgs<'a> { pub return_type: &'a DataType, } + +pub struct ReturnTypeArgs<'a> { + pub arg_types: &'a [DataType], + pub arguments: &'a [String], +} + /// Trait for implementing user defined scalar functions. /// /// This trait exposes the full API for implementing user defined functions and @@ -490,6 +503,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { self.return_type(arg_types) } + fn return_type_from_args( + &self, + args: ReturnTypeArgs, + ) -> Result { + self.return_type(args.arg_types) + } + fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true } @@ -739,6 +759,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } + /// ScalarUDF that adds an alias to the underlying function. It is better to /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. #[derive(Debug)] @@ -796,6 +817,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_exprs(args, schema, arg_types) } + fn return_type_from_args( + &self, + args: ReturnTypeArgs, + ) -> Result { + self.inner.return_type_from_args(args) + } + fn invoke_batch( &self, args: &[ColumnarValue], diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 3853737d7b5b..58497360ddb0 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -28,8 +28,7 @@ use std::sync::OnceLock; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs, ScalarUDFImpl, Signature, Volatility }; /// Implements casting to arbitrary arrow types (rather than SQL types) @@ -95,13 +94,28 @@ impl ScalarUDFImpl for ArrowCastFunc { args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) } - fn return_type_from_exprs( + fn return_type_from_args( &self, - args: &[Expr], - _schema: &dyn ExprSchema, - _arg_types: &[DataType], + args: ReturnTypeArgs, ) -> Result { - data_type_from_args(args) + if args.arguments.len() != 2 { + return plan_err!("{} needs 2 arguments, {} provided", self.name(), args.arguments.len()); + } + + let val = &args.arguments[1]; + if val.is_empty() { + return plan_err!( + "{} requires its second argument to be a constant string", + self.name() + ); + }; + + val.parse().map_err(|e| match e { + // If the data type cannot be parsed, return a Plan error to signal an + // error in the input rather than a more general ArrowError + arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), + e => arrow_datafusion_err!(e), + }) } fn invoke_batch( diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 5c8e1e803e0f..eba28001d133 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -21,10 +21,10 @@ use arrow::array::{ use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue }; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -103,7 +103,55 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - todo!() + internal_err!("return_type_from_args should be called instead") + } + + fn return_type_from_args( + &self, + args: ReturnTypeArgs, + ) -> Result { + if args.arguments.len() != 2 { + return exec_err!( + "get_field function requires 2 arguments, got {}", + args.arguments.len() + ); + } + + let name = &args.arguments[1]; + if name.is_empty() { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + + let data_type = &args.arg_types[0]; + match (data_type, name) { + (DataType::Map(fields, _), _) => { + match fields.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second column is the "value" column both here and in + // execution. + let value_field = fields.get(1).expect("fields should have exactly two members"); + Ok(value_field.data_type().clone()) + }, + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + } + } + (DataType::Struct(fields), s) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) + } + } + (DataType::Null, _) => Ok(DataType::Null), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), + } } fn return_type_from_exprs( diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 556cad1de1ac..326d7d11e66e 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -19,7 +19,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -117,44 +117,32 @@ impl ScalarUDFImpl for NamedStructFunc { fn return_type(&self, _arg_types: &[DataType]) -> Result { internal_err!( - "named_struct: return_type called instead of return_type_from_exprs" + "named_struct: return_type called instead of return_type_from_args" ) } - fn return_type_from_exprs( + fn return_type_from_args( &self, - args: &[Expr], - schema: &dyn datafusion_common::ExprSchema, - _arg_types: &[DataType], + args: ReturnTypeArgs, ) -> Result { // do not accept 0 arguments. - if args.is_empty() { + if args.arguments.is_empty() { return exec_err!( "named_struct requires at least one pair of arguments, got 0 instead" ); } - if args.len() % 2 != 0 { + if args.arguments.len() % 2 != 0 { return exec_err!( "named_struct requires an even number of arguments, got {} instead", - args.len() + args.arguments.len() ); } - let return_fields = args - .chunks_exact(2) - .enumerate() - .map(|(i, chunk)| { - let name = &chunk[0]; - let value = &chunk[1]; - - if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { - Ok(Field::new(name, value.get_type(schema)?, true)) - } else { - exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) - } - }) - .collect::>>()?; + let return_fields = args.arguments.iter().step_by(2).zip(args.arg_types.iter().skip(1).step_by(2)).map(|(name, data_type)| { + Ok(Field::new(name, data_type.clone(), true)) + }).collect::>>()?; + Ok(DataType::Struct(Fields::from(return_fields))) } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 0f01b6a21b0a..834fb5934895 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -41,8 +41,7 @@ use datafusion_common::{ ExprSchema, Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, TypeSignature, - Volatility, + ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility }; use datafusion_expr_common::signature::TypeSignatureClass; use datafusion_macros::user_doc; @@ -136,20 +135,17 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_exprs should be called instead") + internal_err!("return_type_from_args should be called instead") } - fn return_type_from_exprs( + fn return_type_from_args( &self, - args: &[Expr], - _schema: &dyn ExprSchema, - _arg_types: &[DataType], + args: ReturnTypeArgs, ) -> Result { - match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(part))) if is_epoch(part) => { - Ok(DataType::Float64) - } - _ => Ok(DataType::Int32), + if is_epoch(&args.arguments[0]) { + Ok(DataType::Float64) + } else { + Ok(DataType::Int32) } } diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 425da7ddac29..051005e042ca 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -24,7 +24,7 @@ use arrow::datatypes::TimeUnit::Second; use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, Volatility }; use datafusion_macros::user_doc; @@ -81,29 +81,20 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_exprs( + + fn return_type_from_args( &self, - args: &[Expr], - _schema: &dyn ExprSchema, - arg_types: &[DataType], + args: ReturnTypeArgs, ) -> Result { - match arg_types.len() { - 1 => Ok(Timestamp(Second, None)), - 2 => match &args[1] { - Expr::Literal(ScalarValue::Utf8(Some(tz))) => Ok(Timestamp(Second, Some(Arc::from(tz.to_string())))), - _ => exec_err!( - "Second argument for `from_unixtime` must be non-null utf8, received {:?}", - arg_types[1]), - }, - _ => exec_err!( - "from_unixtime function requires 1 or 2 arguments, got {}", - arg_types.len() - ), + if args.arguments.len() == 1 { + Ok(Timestamp(Second, None)) + } else { + Ok(Timestamp(Second, Some(Arc::from(args.arguments[1].to_string())))) } } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_exprs instead") + internal_err!("call return_type_from_args instead") } fn invoke_batch( diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 906ca9fd1093..a9c24344f096 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use crate::scalar_function; +use crate::{scalar_function, ScalarFunctionExpr}; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, @@ -29,10 +29,11 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, + binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, ReturnTypeArgs, TryCast }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -109,6 +110,8 @@ pub fn create_physical_expr( execution_props: &ExecutionProps, ) -> Result> { let input_schema: &Schema = &input_dfschema.into(); + // println!("input_dfschema: {:?}", input_dfschema); + // println!("input_schema: {:?}", input_schema); match e { Expr::Alias(Alias { expr, .. }) => { @@ -302,13 +305,31 @@ pub fn create_physical_expr( let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - scalar_function::create_physical_expr( - Arc::clone(func).as_ref(), - &physical_args, - input_schema, - args, - input_dfschema, - ) + let args_types = args.iter().map(|e| e.get_type(input_dfschema)).collect::>>()?; + let arguments = args.iter().map(|e| match e { + Expr::Literal(ScalarValue::Utf8(s)) => s.clone().unwrap_or_default(), + _ => "".to_string(), + }).collect::>(); + + let return_type = + func.return_type_from_args(ReturnTypeArgs { + arg_types: &args_types, + arguments: &arguments, + })?; + let nullable = func.is_nullable(args, input_dfschema); + + // verify that input data types is consistent with function's `TypeSignature` + data_types_with_scalar_udf(&args_types, func)?; + + Ok(Arc::new( + ScalarFunctionExpr::new( + func.name(), + Arc::clone(func), + physical_args, + return_type, + ) + .with_nullable(nullable), + )) } Expr::Between(Between { expr, From b079be3086722569475b7f03d2090c11c2ce9075 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 15:58:57 +0800 Subject: [PATCH 02/28] fix test Signed-off-by: Jay Zhan --- datafusion/expr/src/udf.rs | 2 +- datafusion/functions/src/core/named_struct.rs | 24 +++++++++++++++---- .../sqllogictest/test_files/arrow_typeof.slt | 2 +- datafusion/sqllogictest/test_files/struct.slt | 8 +++---- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b1f26f885ffa..5cd4dac0cb34 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -349,7 +349,7 @@ pub struct ScalarFunctionArgs<'a> { pub return_type: &'a DataType, } - +#[derive(Debug)] pub struct ReturnTypeArgs<'a> { pub arg_types: &'a [DataType], pub arguments: &'a [String], diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 326d7d11e66e..5e87c93d488b 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -19,7 +19,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -48,7 +48,8 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { let name_column = &chunk[0]; let name = match name_column { ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, - _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) + // TODO: Implement Display for ColumnarValue + _ => return exec_err!("named_struct even arguments must be string literals at position {}", i * 2) }; Ok((name, chunk[1].clone())) @@ -139,10 +140,25 @@ impl ScalarUDFImpl for NamedStructFunc { ); } - let return_fields = args.arguments.iter().step_by(2).zip(args.arg_types.iter().skip(1).step_by(2)).map(|(name, data_type)| { - Ok(Field::new(name, data_type.clone(), true)) + println!("args: {:?}", args); + + // let return_fields = args.arg_types.iter().step_by(2).zip(args.arguments.iter().skip(1).step_by(2)).map(|(data_type, name)| { + // Ok(Field::new(name, data_type.clone(), true)) + // }).collect::>>()?; + + let names = args.arguments.iter().step_by(2).collect::>(); + let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + + println!("names: {:?}", names); + println!("types: {:?}", types); + + + let return_fields = names.into_iter().zip(types.into_iter()).map(|(name, data_type)| { + Ok(Field::new(name, data_type.to_owned(), true)) }).collect::>>()?; + println!("return_fields: {:?}", return_fields); + Ok(DataType::Struct(Fields::from(return_fields))) } diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 77b10b41ccb3..fc93d0270f1a 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index b05e86e5ea91..79982f32678e 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -151,19 +151,19 @@ query error DataFusion error: Execution error: named_struct requires an even num select named_struct('a', 1, 'b'); # error on even argument not a string literal #1 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(1\) instead at position 0 +query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 0 select named_struct(1, 'a'); # error on even argument not a string literal #2 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(0\) instead at position 2 +query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 2 select named_struct('corret', 1, 0, 'wrong'); # error on even argument not a string literal #3 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.a instead at position 0 +query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 0 select named_struct(values.a, 'a') from values; # error on even argument not a string literal #4 -query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.c instead at position 0 +query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 0 select named_struct(values.c, 'c') from values; # named_struct with mixed scalar and array values #1 From 8c9ee8c785040dbbb54a1a7d4dfa31e0d992d8cb Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 16:13:09 +0800 Subject: [PATCH 03/28] fix test Signed-off-by: Jay Zhan --- datafusion/functions/src/core/arrow_cast.rs | 4 +--- datafusion/functions/src/core/named_struct.rs | 12 ------------ datafusion/functions/src/datetime/date_part.rs | 4 ++-- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 58497360ddb0..c0518c6fe897 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -85,9 +85,7 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - // should be using return_type_from_exprs and not calling the default - // implementation - internal_err!("arrow_cast should return type from exprs") + internal_err!("return_type_from_args should be called instead") } fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 5e87c93d488b..7f03b00931a0 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -140,25 +140,13 @@ impl ScalarUDFImpl for NamedStructFunc { ); } - println!("args: {:?}", args); - - // let return_fields = args.arg_types.iter().step_by(2).zip(args.arguments.iter().skip(1).step_by(2)).map(|(data_type, name)| { - // Ok(Field::new(name, data_type.clone(), true)) - // }).collect::>>()?; - let names = args.arguments.iter().step_by(2).collect::>(); let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); - println!("names: {:?}", names); - println!("types: {:?}", types); - - let return_fields = names.into_iter().zip(types.into_iter()).map(|(name, data_type)| { Ok(Field::new(name, data_type.to_owned(), true)) }).collect::>>()?; - println!("return_fields: {:?}", return_fields); - Ok(DataType::Struct(Fields::from(return_fields))) } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 834fb5934895..27d3171a1a6a 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -38,10 +38,10 @@ use datafusion_common::{ }, exec_err, internal_err, types::logical_string, - ExprSchema, Result, ScalarValue, + Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility + ColumnarValue, Documentation, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility }; use datafusion_expr_common::signature::TypeSignatureClass; use datafusion_macros::user_doc; From 6df7476b97a2dbe49280aacc6566399dd5146834 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 16:17:24 +0800 Subject: [PATCH 04/28] deprecate old Signed-off-by: Jay Zhan --- datafusion/expr/src/expr_schema.rs | 18 +++-- datafusion/expr/src/lib.rs | 4 +- datafusion/expr/src/udf.rs | 16 +---- datafusion/functions/src/core/arrow_cast.rs | 14 ++-- datafusion/functions/src/core/getfield.rs | 66 ++----------------- datafusion/functions/src/core/named_struct.rs | 27 ++++---- .../functions/src/datetime/date_part.rs | 8 +-- .../functions/src/datetime/from_unixtime.rs | 14 ++-- datafusion/physical-expr/src/planner.rs | 30 +++++---- .../physical-expr/src/scalar_function.rs | 1 + 10 files changed, 74 insertions(+), 124 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 90c8c9df9cda..c0e05b458f70 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -29,7 +29,8 @@ use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, ScalarValue, TableReference + not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, + Result, ScalarValue, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; @@ -168,17 +169,20 @@ impl ExprSchemable for Expr { ) })?; - - let arguments = args.iter().map(|e| match e { - Expr::Literal(ScalarValue::Utf8(s)) => s.clone().unwrap_or_default(), - _ => "".to_string(), - }).collect::>(); + let arguments = args + .iter() + .map(|e| match e { + Expr::Literal(ScalarValue::Utf8(s)) => { + s.clone().unwrap_or_default() + } + _ => "".to_string(), + }) + .collect::>(); let args = ReturnTypeArgs { arg_types: &new_data_types, arguments: &arguments, }; - // Perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type Ok(func.return_type_from_args(args)?) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1c9059f607d2..e4d3bd6fb6a1 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -93,7 +93,9 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; -pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, ReturnTypeArgs}; +pub use udf::{ + scalar_doc_sections, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, +}; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 5cd4dac0cb34..9a0041f9ae7c 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -192,10 +192,7 @@ impl ScalarUDF { self.inner.return_type_from_exprs(args, schema, arg_types) } - pub fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.inner.return_type_from_args(args) } @@ -503,10 +500,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { self.return_type(arg_types) } - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.return_type(args.arg_types) } @@ -759,7 +753,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } - /// ScalarUDF that adds an alias to the underlying function. It is better to /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. #[derive(Debug)] @@ -817,10 +810,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_exprs(args, schema, arg_types) } - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.inner.return_type_from_args(args) } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index c0518c6fe897..163046629620 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -28,7 +28,8 @@ use std::sync::OnceLock; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs, ScalarUDFImpl, Signature, Volatility + ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs, ScalarUDFImpl, + Signature, Volatility, }; /// Implements casting to arbitrary arrow types (rather than SQL types) @@ -92,12 +93,13 @@ impl ScalarUDFImpl for ArrowCastFunc { args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) } - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arguments.len() != 2 { - return plan_err!("{} needs 2 arguments, {} provided", self.name(), args.arguments.len()); + return plan_err!( + "{} needs 2 arguments, {} provided", + self.name(), + args.arguments.len() + ); } let val = &args.arguments[1]; diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index eba28001d133..cfb64d36b8aa 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -21,10 +21,13 @@ use arrow::array::{ use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue + exec_err, internal_err, plan_datafusion_err, plan_err, ExprSchema, Result, + ScalarValue, }; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -106,10 +109,7 @@ impl ScalarUDFImpl for GetFieldFunc { internal_err!("return_type_from_args should be called instead") } - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arguments.len() != 2 { return exec_err!( "get_field function requires 2 arguments, got {}", @@ -154,60 +154,6 @@ impl ScalarUDFImpl for GetFieldFunc { } } - fn return_type_from_exprs( - &self, - args: &[Expr], - schema: &dyn ExprSchema, - _arg_types: &[DataType], - ) -> Result { - if args.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.len() - ); - } - - let name = match &args[1] { - Expr::Literal(name) => name, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - let data_type = args[0].get_type(schema)?; - match (data_type, name) { - (DataType::Map(fields, _), _) => { - match fields.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - // Arrow's MapArray is essentially a ListArray of structs with two columns. They are - // often named "key", and "value", but we don't require any specific naming here; - // instead, we assume that the second column is the "value" column both here and in - // execution. - let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(value_field.data_type().clone()) - }, - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - } - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - plan_err!( - "Struct based indexed access requires a non empty string" - ) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) - } - } - (DataType::Struct(_), _) => plan_err!( - "Only UTF8 strings are valid as an indexed field in a struct" - ), - (DataType::Null, _) => Ok(DataType::Null), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), - } - } - fn invoke_batch( &self, args: &[ColumnarValue], diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 7f03b00931a0..eac9ee933547 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -44,12 +44,16 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { .chunks_exact(2) .enumerate() .map(|(i, chunk)| { - let name_column = &chunk[0]; let name = match name_column { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => { + name_scalar + } // TODO: Implement Display for ColumnarValue - _ => return exec_err!("named_struct even arguments must be string literals at position {}", i * 2) + _ => return exec_err!( + "named_struct even arguments must be string literals at position {}", + i * 2 + ), }; Ok((name, chunk[1].clone())) @@ -117,15 +121,10 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!( - "named_struct: return_type called instead of return_type_from_args" - ) + internal_err!("named_struct: return_type called instead of return_type_from_args") } - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // do not accept 0 arguments. if args.arguments.is_empty() { return exec_err!( @@ -143,9 +142,11 @@ impl ScalarUDFImpl for NamedStructFunc { let names = args.arguments.iter().step_by(2).collect::>(); let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); - let return_fields = names.into_iter().zip(types.into_iter()).map(|(name, data_type)| { - Ok(Field::new(name, data_type.to_owned(), true)) - }).collect::>>()?; + let return_fields = names + .into_iter() + .zip(types.into_iter()) + .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) + .collect::>>()?; Ok(DataType::Struct(Fields::from(return_fields))) } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 27d3171a1a6a..00fda51e9d14 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -41,7 +41,8 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility + ColumnarValue, Documentation, ReturnTypeArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_expr_common::signature::TypeSignatureClass; use datafusion_macros::user_doc; @@ -138,10 +139,7 @@ impl ScalarUDFImpl for DatePartFunc { internal_err!("return_type_from_args should be called instead") } - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if is_epoch(&args.arguments[0]) { Ok(DataType::Float64) } else { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 051005e042ca..4c6061d806fb 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -24,7 +24,8 @@ use arrow::datatypes::TimeUnit::Second; use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, Volatility + ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -81,15 +82,14 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arguments.len() == 1 { Ok(Timestamp(Second, None)) } else { - Ok(Timestamp(Second, Some(Arc::from(args.arguments[1].to_string())))) + Ok(Timestamp( + Second, + Some(Arc::from(args.arguments[1].to_string())), + )) } } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index a9c24344f096..5856faf379e1 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,11 +17,11 @@ use std::sync::Arc; -use crate::{scalar_function, ScalarFunctionExpr}; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, }; +use crate::{scalar_function, ScalarFunctionExpr}; use arrow::datatypes::Schema; use datafusion_common::{ @@ -33,7 +33,8 @@ use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, ReturnTypeArgs, TryCast + binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, + ReturnTypeArgs, TryCast, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -305,17 +306,22 @@ pub fn create_physical_expr( let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - let args_types = args.iter().map(|e| e.get_type(input_dfschema)).collect::>>()?; - let arguments = args.iter().map(|e| match e { - Expr::Literal(ScalarValue::Utf8(s)) => s.clone().unwrap_or_default(), - _ => "".to_string(), - }).collect::>(); + let args_types = args + .iter() + .map(|e| e.get_type(input_dfschema)) + .collect::>>()?; + let arguments = args + .iter() + .map(|e| match e { + Expr::Literal(ScalarValue::Utf8(s)) => s.clone().unwrap_or_default(), + _ => "".to_string(), + }) + .collect::>(); - let return_type = - func.return_type_from_args(ReturnTypeArgs { - arg_types: &args_types, - arguments: &arguments, - })?; + let return_type = func.return_type_from_args(ReturnTypeArgs { + arg_types: &args_types, + arguments: &arguments, + })?; let nullable = func.is_nullable(args, input_dfschema); // verify that input data types is consistent with function's `TypeSignature` diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 0ae4115de67a..866513533392 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -218,6 +218,7 @@ impl PhysicalExpr for ScalarFunctionExpr { } /// Create a physical expression for the UDF. +#[deprecated(since = "44.0.0", note = "use ScalarFunctionExpr::new() instead")] pub fn create_physical_expr( fun: &ScalarUDF, input_phy_exprs: &[Arc], From fe7f6a5289c631ea4dc891bebc12493e7fc03eee Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 16:35:23 +0800 Subject: [PATCH 05/28] add try new Signed-off-by: Jay Zhan --- datafusion/expr/src/udf.rs | 8 +++ datafusion/functions/src/core/named_struct.rs | 6 ++- datafusion/physical-expr/src/planner.rs | 41 +++------------ .../physical-expr/src/scalar_function.rs | 52 ++++++++++++++++++- 4 files changed, 70 insertions(+), 37 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9a0041f9ae7c..c5849a905f71 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -217,6 +217,10 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } + pub fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { + self.inner.is_nullable_from_args_nullable(args_nullables) + } + pub fn invoke_batch( &self, args: &[ColumnarValue], @@ -508,6 +512,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { true } + fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool { + true + } + /// Invoke the function on `args`, returning the appropriate result /// /// Note: This method is deprecated and will be removed in future releases. diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index eac9ee933547..bf7d2cab997b 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -50,10 +50,12 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { name_scalar } // TODO: Implement Display for ColumnarValue - _ => return exec_err!( + _ => { + return exec_err!( "named_struct even arguments must be string literals at position {}", i * 2 - ), + ) + } }; Ok((name, chunk[1].clone())) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 5856faf379e1..0ca0f99afad9 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,11 +17,11 @@ use std::sync::Arc; +use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, PhysicalExpr, }; -use crate::{scalar_function, ScalarFunctionExpr}; use arrow::datatypes::Schema; use datafusion_common::{ @@ -29,12 +29,10 @@ use datafusion_common::{ }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; -use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, lit, Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, - ReturnTypeArgs, TryCast, + binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -306,36 +304,11 @@ pub fn create_physical_expr( let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - let args_types = args - .iter() - .map(|e| e.get_type(input_dfschema)) - .collect::>>()?; - let arguments = args - .iter() - .map(|e| match e { - Expr::Literal(ScalarValue::Utf8(s)) => s.clone().unwrap_or_default(), - _ => "".to_string(), - }) - .collect::>(); - - let return_type = func.return_type_from_args(ReturnTypeArgs { - arg_types: &args_types, - arguments: &arguments, - })?; - let nullable = func.is_nullable(args, input_dfschema); - - // verify that input data types is consistent with function's `TypeSignature` - data_types_with_scalar_udf(&args_types, func)?; - - Ok(Arc::new( - ScalarFunctionExpr::new( - func.name(), - Arc::clone(func), - physical_args, - return_type, - ) - .with_nullable(nullable), - )) + Ok(Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(func), + physical_args, + input_schema, + )?)) } Expr::Between(Between { expr, diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 866513533392..021dd5103bd0 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,6 +34,7 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::Hash; use std::sync::Arc; +use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; @@ -43,7 +44,9 @@ use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; -use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF}; +use datafusion_expr::{ + expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, +}; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -83,6 +86,53 @@ impl ScalarFunctionExpr { } } + pub fn try_new( + fun: Arc, + args: Vec>, + schema: &Schema, + ) -> Result { + let name = fun.name().to_string(); + let arg_types = args + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types_with_scalar_udf(&arg_types, &fun)?; + + let arg_nullables = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; + let arguments = args + .iter() + .map(|e| { + if let Some(literal) = e.as_any().downcast_ref::() { + if let ScalarValue::Utf8(s) = literal.value() { + s.clone().unwrap_or_default() + } else { + "".to_string() + } + } else { + "".to_string() + } + }) + .collect::>(); + let ret_args = ReturnTypeArgs { + arg_types: &arg_types, + arguments: &arguments, + }; + let return_type = fun.return_type_from_args(ret_args)?; + let nullable = fun.is_nullable_from_args_nullable(&arg_nullables); + Ok(Self { + fun, + name, + args, + return_type, + nullable, + }) + } + /// Get the scalar function implementation pub fn fun(&self) -> &ScalarUDF { &self.fun From 4da4c71332c8d02163daf2d56d76841de699280a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 16:35:32 +0800 Subject: [PATCH 06/28] deprecate Signed-off-by: Jay Zhan --- datafusion/physical-expr/src/scalar_function.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 021dd5103bd0..ad74c5f79655 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -71,6 +71,7 @@ impl Debug for ScalarFunctionExpr { impl ScalarFunctionExpr { /// Create a new Scalar function + #[deprecated(since = "44.0.0", note = "Use `try_new` instead")] pub fn new( name: &str, fun: Arc, @@ -86,6 +87,7 @@ impl ScalarFunctionExpr { } } + /// Create a new Scalar function pub fn try_new( fun: Arc, args: Vec>, From de4b484398cae439e20b06b4f0c1d040c2c9a3f5 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 16:42:26 +0800 Subject: [PATCH 07/28] rm deprecate Signed-off-by: Jay Zhan --- datafusion/expr/src/udf.rs | 2 ++ datafusion/functions/src/core/getfield.rs | 7 ++---- .../functions/src/datetime/from_unixtime.rs | 5 ++--- .../physical-expr/src/scalar_function.rs | 22 +++++++++---------- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index c5849a905f71..296bd8a99f0e 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -512,6 +512,8 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { true } + /// `is_nullable` from pre-computed nullable flags. + /// It has less dependencies on the input arguments. fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool { true } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index cfb64d36b8aa..312163d0a03a 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -21,13 +21,10 @@ use arrow::array::{ use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, ExprSchema, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; -use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs, -}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 4c6061d806fb..f83b35c4d8c6 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -21,11 +21,10 @@ use std::sync::Arc; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; -use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, ReturnTypeArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index ad74c5f79655..8bc3fec59ed0 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -71,7 +71,6 @@ impl Debug for ScalarFunctionExpr { impl ScalarFunctionExpr { /// Create a new Scalar function - #[deprecated(since = "44.0.0", note = "Use `try_new` instead")] pub fn new( name: &str, fun: Arc, @@ -106,20 +105,19 @@ impl ScalarFunctionExpr { .iter() .map(|e| e.nullable(schema)) .collect::>>()?; - let arguments = args + + let arguments: Vec = args .iter() .map(|e| { - if let Some(literal) = e.as_any().downcast_ref::() { - if let ScalarValue::Utf8(s) = literal.value() { - s.clone().unwrap_or_default() - } else { - "".to_string() - } - } else { - "".to_string() - } + e.as_any() + .downcast_ref::() + .map(|literal| match literal.value() { + ScalarValue::Utf8(Some(s)) => s.clone(), + _ => String::new(), + }) + .unwrap_or_else(String::new) }) - .collect::>(); + .collect(); let ret_args = ReturnTypeArgs { arg_types: &arg_types, arguments: &arguments, From 02a64cecb2d84d5aa4b7193409772023cdb6161f Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 17:00:09 +0800 Subject: [PATCH 08/28] reaplce deprecated func Signed-off-by: Jay Zhan --- .../tests/fuzz_cases/equivalence/ordering.rs | 16 +++--- .../fuzz_cases/equivalence/projection.rs | 29 +++++----- .../fuzz_cases/equivalence/properties.rs | 17 +++--- .../physical-expr/src/equivalence/ordering.rs | 57 +++++++++---------- .../src/equivalence/projection.rs | 18 +++--- datafusion/physical-expr/src/lib.rs | 1 + 6 files changed, 68 insertions(+), 70 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs index ecf267185bae..cd9897d43baa 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -21,9 +21,10 @@ use crate::fuzz_cases::equivalence::utils::{ is_table_same_after_sort, TestScalarUDF, }; use arrow_schema::SortOptions; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -103,14 +104,13 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?); let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs index f71df50fce2f..78fbda16c0a0 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -20,10 +20,11 @@ use crate::fuzz_cases::equivalence::utils::{ is_table_same_after_sort, TestScalarUDF, }; use arrow_schema::SortOptions; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -42,14 +43,13 @@ fn project_orderings_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?); // a + b let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, @@ -120,14 +120,13 @@ fn ordering_satisfy_after_projection_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; // a + b let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs index fc21c620a711..593e1c6c2dca 100644 --- a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -19,9 +19,10 @@ use crate::fuzz_cases::equivalence::utils::{ create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, TestScalarUDF, }; -use datafusion_common::{DFSchema, Result}; +use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use itertools::Itertools; @@ -40,14 +41,14 @@ fn test_find_longest_permutation_random() -> Result<()> { let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = datafusion_physical_expr::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + let col_a = col("a", &test_schema)?; + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![col_a], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; + let a_plus_b = Arc::new(BinaryExpr::new( col("a", &test_schema)?, Operator::Plus, diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index ae502d4d5f67..4a3598354fc7 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -268,11 +268,14 @@ mod tests { }; use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; - use crate::{AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalSortExpr}; + use crate::{ + AcrossPartitions, ConstExpr, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, + ScalarFunctionExpr, + }; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SortOptions; - use datafusion_common::{DFSchema, Result}; + use datafusion_common::Result; use datafusion_expr::{Operator, ScalarUDF}; use datafusion_physical_expr_common::sort_expr::LexOrdering; @@ -321,28 +324,24 @@ mod tests { let col_d = &col("d", &test_schema)?; let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = &crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + + let floor_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![Arc::clone(col_a)], &test_schema, - &[], - &DFSchema::empty(), - )?; - let floor_f = &crate::udf::create_physical_expr( - &test_fun, - &[col("f", &test_schema)?], + )?) as PhysicalExprRef; + let floor_f = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![Arc::clone(col_f)], &test_schema, - &[], - &DFSchema::empty(), - )?; - let exp_a = &crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], + )?) as PhysicalExprRef; + let exp_a = Arc::new(ScalarFunctionExpr::try_new( + Arc::clone(&test_fun), + vec![Arc::clone(col_a)], &test_schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; + let a_plus_b = Arc::new(BinaryExpr::new( Arc::clone(col_a), Operator::Plus, @@ -386,7 +385,7 @@ mod tests { // constants vec![col_e], // requirement [floor(a) ASC], - vec![(floor_a, options)], + vec![(&floor_a, options)], // expected: requirement is satisfied. true, ), @@ -404,7 +403,7 @@ mod tests { // constants vec![col_e], // requirement [floor(f) ASC], (Please note that a=f) - vec![(floor_f, options)], + vec![(&floor_f, options)], // expected: requirement is satisfied. true, ), @@ -443,7 +442,7 @@ mod tests { // constants vec![col_e], // requirement [floor(a) ASC, a+b ASC], - vec![(floor_a, options), (&a_plus_b, options)], + vec![(&floor_a, options), (&a_plus_b, options)], // expected: requirement is satisfied. false, ), @@ -464,7 +463,7 @@ mod tests { // constants vec![col_e], // requirement [exp(a) ASC, a+b ASC], - vec![(exp_a, options), (&a_plus_b, options)], + vec![(&exp_a, options), (&a_plus_b, options)], // expected: requirement is not satisfied. // TODO: If we know that exp function is 1-to-1 function. // we could have deduced that above requirement is satisfied. @@ -484,7 +483,7 @@ mod tests { // constants vec![col_e], // requirement [a ASC, d ASC, floor(a) ASC], - vec![(col_a, options), (col_d, options), (floor_a, options)], + vec![(col_a, options), (col_d, options), (&floor_a, options)], // expected: requirement is satisfied. true, ), @@ -502,7 +501,7 @@ mod tests { // constants vec![col_e], // requirement [a ASC, floor(a) ASC, a + b ASC], - vec![(col_a, options), (floor_a, options), (&a_plus_b, options)], + vec![(col_a, options), (&floor_a, options), (&a_plus_b, options)], // expected: requirement is not satisfied. false, ), @@ -523,7 +522,7 @@ mod tests { vec![ (col_a, options), (col_c, options), - (floor_a, options), + (&floor_a, options), (&a_plus_b, options), ], // expected: requirement is not satisfied. @@ -550,7 +549,7 @@ mod tests { (col_a, options), (col_b, options), (col_c, options), - (floor_a, options), + (&floor_a, options), ], // expected: requirement is satisfied. true, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 681484fd6bff..d1e7625525ae 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -143,12 +143,11 @@ mod tests { }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; - use crate::udf::create_physical_expr; use crate::utils::tests::TestScalarUDF; + use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; - use datafusion_common::DFSchema; use datafusion_expr::{Operator, ScalarUDF}; #[test] @@ -667,14 +666,13 @@ mod tests { Arc::clone(col_b), )) as Arc; - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let round_c = &create_physical_expr( - &test_fun, - &[Arc::clone(col_c)], + let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new())); + + let round_c = Arc::new(ScalarFunctionExpr::try_new( + test_fun, + vec![Arc::clone(col_c)], &schema, - &[], - &DFSchema::empty(), - )?; + )?) as PhysicalExprRef; let option_asc = SortOptions { descending: false, @@ -685,7 +683,7 @@ mod tests { (col_b, "b_new".to_string()), (col_a, "a_new".to_string()), (col_c, "c_new".to_string()), - (round_c, "round_c_res".to_string()), + (&round_c, "round_c_res".to_string()), ]; let proj_exprs = proj_exprs .into_iter() diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 4c55f4ddba93..11d6f54a7cc3 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -32,6 +32,7 @@ mod physical_expr; pub mod planner; mod scalar_function; pub mod udf { + #[allow(deprecated)] pub use crate::scalar_function::create_physical_expr; } pub mod utils; From f26ce70fcced5007343e1dc3179762d77301a3ff Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 12 Jan 2025 17:22:43 +0800 Subject: [PATCH 09/28] cleanup Signed-off-by: Jay Zhan --- datafusion/expr/src/udf.rs | 35 +++++++++++++++++++ datafusion/functions/src/core/arrow_cast.rs | 13 +++---- datafusion/functions/src/core/coalesce.rs | 8 ++--- datafusion/functions/src/datetime/now.rs | 4 +-- datafusion/physical-expr/src/planner.rs | 2 -- .../physical-expr/src/scalar_function.rs | 2 +- 6 files changed, 49 insertions(+), 15 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 296bd8a99f0e..9da8947b61a8 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -182,6 +182,7 @@ impl ScalarUDF { /// /// /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. + #[allow(deprecated)] pub fn return_type_from_exprs( &self, args: &[Expr], @@ -213,6 +214,7 @@ impl ScalarUDF { self.inner.invoke(args) } + #[allow(deprecated)] pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { self.inner.is_nullable(args, schema) } @@ -352,7 +354,9 @@ pub struct ScalarFunctionArgs<'a> { #[derive(Debug)] pub struct ReturnTypeArgs<'a> { + /// The data types of the arguments to the function pub arg_types: &'a [DataType], + /// The Utf8 arguments to the function, if the expression is not Utf8, it will be empty string pub arguments: &'a [String], } @@ -495,6 +499,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function must consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). + #[deprecated(since = "45.0.0", note = "Use `return_type_from_args` instead")] fn return_type_from_exprs( &self, _args: &[Expr], @@ -504,10 +509,39 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { self.return_type(arg_types) } + /// What [`DataType`] will be returned by this function, given the + /// arguments? + /// + /// Note most UDFs should implement [`Self::return_type`] and not this + /// function. The output type for most functions only depends on the types + /// of their inputs (e.g. `sqrt(f32)` is always `f32`). + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// This method can be overridden for functions that return different + /// *types* based on the *values* of their arguments. + /// + /// For example, the following two function calls get the same argument + /// types (something and a `Utf8` string) but return different types based + /// on the value of the second argument: + /// + /// * `arrow_cast(x, 'Int16')` --> `Int16` + /// * `arrow_cast(x, 'Float32')` --> `Float32` + /// + /// # Notes: + /// + /// This function must consistently return the same type for the same + /// logical input even if the input is simplified (e.g. it must return the same + /// value for `('foo' | 'bar')` as it does for ('foobar'). fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.return_type(args.arg_types) } + #[deprecated( + since = "45.0.0", + note = "Use `is_nullable_from_args_nullable` instead" + )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true } @@ -811,6 +845,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { &self.aliases } + #[allow(deprecated)] fn return_type_from_exprs( &self, args: &[Expr], diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 163046629620..f9d4b2841e99 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -18,9 +18,10 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` use arrow::datatypes::DataType; +use datafusion_common::DataFusionError; use datafusion_common::{ - arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, - ExprSchema, Result, ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, Result, + ScalarValue, }; use std::any::Any; use std::sync::OnceLock; @@ -28,8 +29,8 @@ use std::sync::OnceLock; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ExprSchemable, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, + Volatility, }; /// Implements casting to arbitrary arrow types (rather than SQL types) @@ -89,8 +90,8 @@ impl ScalarUDFImpl for ArrowCastFunc { internal_err!("return_type_from_args should be called instead") } - fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { - args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) + fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { + args_nullables.iter().any(|&nullable| nullable) } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 4f9e83fbf0d9..600a9d504961 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -19,10 +19,10 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, ExprSchema, Result}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; use std::any::Any; @@ -69,8 +69,8 @@ impl ScalarUDFImpl for CoalesceFunc { } // If any the arguments in coalesce is non-null, the result is non-null - fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { - args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true)) + fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { + args_nullables.iter().all(|&nullable| nullable) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 67cd49b7fd84..bb950178c111 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; use std::any::Any; -use datafusion_common::{internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, @@ -106,7 +106,7 @@ impl ScalarUDFImpl for NowFunc { &self.aliases } - fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { + fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool { false } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 0ca0f99afad9..e05de362bf14 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -109,8 +109,6 @@ pub fn create_physical_expr( execution_props: &ExecutionProps, ) -> Result> { let input_schema: &Schema = &input_dfschema.into(); - // println!("input_dfschema: {:?}", input_dfschema); - // println!("input_schema: {:?}", input_schema); match e { Expr::Alias(Alias { expr, .. }) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 8bc3fec59ed0..8a9e9af58bb5 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -268,7 +268,7 @@ impl PhysicalExpr for ScalarFunctionExpr { } /// Create a physical expression for the UDF. -#[deprecated(since = "44.0.0", note = "use ScalarFunctionExpr::new() instead")] +#[deprecated(since = "45.0.0", note = "use ScalarFunctionExpr::new() instead")] pub fn create_physical_expr( fun: &ScalarUDF, input_phy_exprs: &[Arc], From b967034b2d8a273531ed9ab97d46a605fc0346ca Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 13 Jan 2025 22:17:55 +0800 Subject: [PATCH 10/28] combine type and nullable Signed-off-by: Jay Zhan --- datafusion/expr/src/expr_schema.rs | 98 +++++++++++-------- datafusion/expr/src/lib.rs | 3 +- datafusion/expr/src/udf.rs | 71 ++++++++++---- datafusion/functions/src/core/arrow_cast.rs | 22 ++--- datafusion/functions/src/core/coalesce.rs | 23 +++-- datafusion/functions/src/core/getfield.rs | 10 +- datafusion/functions/src/core/named_struct.rs | 8 +- .../functions/src/datetime/date_part.rs | 8 +- .../functions/src/datetime/from_unixtime.rs | 11 ++- datafusion/functions/src/datetime/now.rs | 16 +-- .../physical-expr/src/scalar_function.rs | 6 +- 11 files changed, 168 insertions(+), 108 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index c0e05b458f70..cf42f03e010a 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -25,7 +25,9 @@ use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; use crate::udf::ReturnTypeArgs; -use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; +use crate::{ + utils, LogicalPlan, Projection, ScalarUDF, Subquery, WindowFunctionDefinition, +}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ @@ -146,46 +148,9 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(ScalarFunction { func, args }) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) - .map_err(|err| { - plan_datafusion_err!( - "{} {}", - match err { - DataFusionError::Plan(msg) => msg, - err => err.to_string(), - }, - utils::generate_signature_error_msg( - func.name(), - func.signature().clone(), - &arg_data_types, - ) - ) - })?; - - let arguments = args - .iter() - .map(|e| match e { - Expr::Literal(ScalarValue::Utf8(s)) => { - s.clone().unwrap_or_default() - } - _ => "".to_string(), - }) - .collect::>(); - let args = ReturnTypeArgs { - arg_types: &new_data_types, - arguments: &arguments, - }; - - // Perform additional function arguments validation (due to limited - // expressiveness of `TypeSignature`), then infer return type - Ok(func.return_type_from_args(args)?) + Expr::ScalarFunction(_func) => { + let (return_type, _) = self.data_type_and_nullable(schema)?; + Ok(return_type) } Expr::WindowFunction(window_function) => self .data_type_and_nullable_with_window_function(schema, window_function) @@ -318,8 +283,9 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), - Expr::ScalarFunction(ScalarFunction { func, args }) => { - Ok(func.is_nullable(args, input_schema)) + Expr::ScalarFunction(_func) => { + let (_, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(nullable) } Expr::AggregateFunction(AggregateFunction { func, .. }) => { Ok(func.is_nullable()) @@ -430,6 +396,52 @@ impl ExprSchemable for Expr { Expr::WindowFunction(window_function) => { self.data_type_and_nullable_with_window_function(schema, window_function) } + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let arg_data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_data_types, + ) + ) + })?; + + let arguments = args + .iter() + .map(|e| match e { + Expr::Literal(ScalarValue::Utf8(s)) => { + s.clone().unwrap_or_default() + } + _ => "".to_string(), + }) + .collect::>(); + let nullables = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; + let args = ReturnTypeArgs { + arg_types: &new_data_types, + arguments: &arguments, + nullables: &nullables, + }; + + let (return_type, nullable) = + func.return_type_from_args(args)?.into_parts(); + Ok((return_type, nullable)) + } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e4d3bd6fb6a1..017415da8f23 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -94,7 +94,8 @@ pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9da8947b61a8..a5cf8a7834e0 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -20,11 +20,15 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::type_coercion::functions::data_types_with_scalar_udf; use crate::{ - ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, + utils, ColumnarValue, Documentation, Expr, ExprSchemable, + ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, ExprSchema, Result}; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, DataFusionError, ExprSchema, Result, ScalarValue, +}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; @@ -193,7 +197,7 @@ impl ScalarUDF { self.inner.return_type_from_exprs(args, schema, arg_types) } - pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.inner.return_type_from_args(args) } @@ -219,10 +223,6 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } - pub fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { - self.inner.is_nullable_from_args_nullable(args_nullables) - } - pub fn invoke_batch( &self, args: &[ColumnarValue], @@ -358,6 +358,48 @@ pub struct ReturnTypeArgs<'a> { pub arg_types: &'a [DataType], /// The Utf8 arguments to the function, if the expression is not Utf8, it will be empty string pub arguments: &'a [String], + pub nullables: &'a [bool], +} + +#[derive(Debug)] +pub struct ReturnInfo { + return_type: DataType, + nullable: bool, +} + +impl ReturnInfo { + pub fn new(return_type: DataType, nullable: bool) -> Self { + Self { + return_type, + nullable, + } + } + + pub fn new_nullable(return_type: DataType) -> Self { + Self { + return_type, + nullable: true, + } + } + + pub fn new_non_nullable(return_type: DataType) -> Self { + Self { + return_type, + nullable: false, + } + } + + pub fn return_type(&self) -> &DataType { + &self.return_type + } + + pub fn nullable(&self) -> bool { + self.nullable + } + + pub fn into_parts(self) -> (DataType, bool) { + (self.return_type, self.nullable) + } } /// Trait for implementing user defined scalar functions. @@ -534,24 +576,19 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function must consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.return_type(args.arg_types) + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + let return_type = self.return_type(args.arg_types)?; + Ok(ReturnInfo::new_nullable(return_type)) } #[deprecated( since = "45.0.0", - note = "Use `is_nullable_from_args_nullable` instead" + note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true } - /// `is_nullable` from pre-computed nullable flags. - /// It has less dependencies on the input arguments. - fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool { - true - } - /// Invoke the function on `args`, returning the appropriate result /// /// Note: This method is deprecated and will be removed in future releases. @@ -855,7 +892,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type_from_exprs(args, schema, arg_types) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.inner.return_type_from_args(args) } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index f9d4b2841e99..4b0f668526b5 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -18,6 +18,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` use arrow::datatypes::DataType; +use arrow::error::ArrowError; use datafusion_common::DataFusionError; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, Result, @@ -29,8 +30,8 @@ use std::sync::OnceLock; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, + Signature, Volatility, }; /// Implements casting to arbitrary arrow types (rather than SQL types) @@ -90,11 +91,9 @@ impl ScalarUDFImpl for ArrowCastFunc { internal_err!("return_type_from_args should be called instead") } - fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { - args_nullables.iter().any(|&nullable| nullable) - } + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + let nullable = args.nullables.iter().any(|&nullable| nullable); - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arguments.len() != 2 { return plan_err!( "{} needs 2 arguments, {} provided", @@ -111,12 +110,13 @@ impl ScalarUDFImpl for ArrowCastFunc { ); }; - val.parse().map_err(|e| match e { + match val.parse::() { + Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), // If the data type cannot be parsed, return a Plan error to signal an // error in the input rather than a more general ArrowError - arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), - e => arrow_datafusion_err!(e), - }) + Err(ArrowError::ParseError(e)) => Err(plan_datafusion_err!("{e}")), + Err(e) => Err(arrow_datafusion_err!(e)), + } } fn invoke_batch( @@ -201,7 +201,7 @@ fn data_type_from_args(args: &[Expr]) -> Result { val.parse().map_err(|e| match e { // If the data type cannot be parsed, return a Plan error to signal an // error in the input rather than a more general ArrowError - arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), + ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), e => arrow_datafusion_err!(e), }) } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 600a9d504961..73b88af6f25f 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -19,10 +19,10 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; -use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use itertools::Itertools; use std::any::Any; @@ -60,17 +60,20 @@ impl ScalarUDFImpl for CoalesceFunc { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type_from_args should be called instead") + } + + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // If any the arguments in coalesce is non-null, the result is non-null + let nullable = args.nullables.iter().all(|&nullable| nullable); + let return_type = args + .arg_types .iter() .find_or_first(|d| !d.is_null()) .unwrap() - .clone()) - } - - // If any the arguments in coalesce is non-null, the result is non-null - fn is_nullable_from_args_nullable(&self, args_nullables: &[bool]) -> bool { - args_nullables.iter().all(|&nullable| nullable) + .clone(); + Ok(ReturnInfo::new(return_type, nullable)) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 312163d0a03a..50259354b452 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -24,7 +24,7 @@ use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnTypeArgs}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -106,7 +106,7 @@ impl ScalarUDFImpl for GetFieldFunc { internal_err!("return_type_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arguments.len() != 2 { return exec_err!( "get_field function requires 2 arguments, got {}", @@ -131,7 +131,7 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(value_field.data_type().clone()) + Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) }, _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -143,10 +143,10 @@ impl ScalarUDFImpl for GetFieldFunc { ) } else { let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) } } - (DataType::Null, _) => Ok(DataType::Null), + (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bf7d2cab997b..47b5abdf43ad 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -19,7 +19,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; -use datafusion_expr::{ColumnarValue, Documentation, ReturnTypeArgs}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -126,7 +126,7 @@ impl ScalarUDFImpl for NamedStructFunc { internal_err!("named_struct: return_type called instead of return_type_from_args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // do not accept 0 arguments. if args.arguments.is_empty() { return exec_err!( @@ -150,7 +150,9 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(DataType::Struct(Fields::from(return_fields))) + Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( + return_fields, + )))) } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 00fda51e9d14..d720a0e32ecf 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -41,7 +41,7 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnTypeArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_expr_common::signature::TypeSignatureClass; @@ -139,11 +139,11 @@ impl ScalarUDFImpl for DatePartFunc { internal_err!("return_type_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if is_epoch(&args.arguments[0]) { - Ok(DataType::Float64) + Ok(ReturnInfo::new_nullable(DataType::Float64)) } else { - Ok(DataType::Int32) + Ok(ReturnInfo::new_nullable(DataType::Int32)) } } diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index f83b35c4d8c6..65237b47d326 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -24,7 +24,8 @@ use arrow::datatypes::TimeUnit::Second; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnTypeArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -81,14 +82,14 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arguments.len() == 1 { - Ok(Timestamp(Second, None)) + Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) } else { - Ok(Timestamp( + Ok(ReturnInfo::new_nullable(Timestamp( Second, Some(Arc::from(args.arguments[1].to_string())), - )) + ))) } } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index bb950178c111..76e875737637 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -23,7 +23,8 @@ use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_macros::user_doc; @@ -76,8 +77,15 @@ impl ScalarUDFImpl for NowFunc { &self.signature } + fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { + Ok(ReturnInfo::new_non_nullable(Timestamp( + Nanosecond, + Some("+00:00".into()), + ))) + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Nanosecond, Some("+00:00".into()))) + internal_err!("return_type_from_args should be called instead") } fn invoke_batch( @@ -106,10 +114,6 @@ impl ScalarUDFImpl for NowFunc { &self.aliases } - fn is_nullable_from_args_nullable(&self, _args_nullables: &[bool]) -> bool { - false - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 8a9e9af58bb5..f2269f7fb336 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -101,7 +101,7 @@ impl ScalarFunctionExpr { // verify that input data types is consistent with function's `TypeSignature` data_types_with_scalar_udf(&arg_types, &fun)?; - let arg_nullables = args + let nullables = args .iter() .map(|e| e.nullable(schema)) .collect::>>()?; @@ -121,9 +121,9 @@ impl ScalarFunctionExpr { let ret_args = ReturnTypeArgs { arg_types: &arg_types, arguments: &arguments, + nullables: &nullables, }; - let return_type = fun.return_type_from_args(ret_args)?; - let nullable = fun.is_nullable_from_args_nullable(&arg_nullables); + let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); Ok(Self { fun, name, From 50cac9e0d9fae20bee422e2228058ba8537c9983 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 13 Jan 2025 22:34:48 +0800 Subject: [PATCH 11/28] fix slowdown Signed-off-by: Jay Zhan --- datafusion/expr/src/expr_schema.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index cf42f03e010a..a3b31578e941 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -397,13 +397,9 @@ impl ExprSchemable for Expr { self.data_type_and_nullable_with_window_function(schema, window_function) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let arg_data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - + let (arg_types, nullables) : (Vec, Vec) = args.iter().map(|e| e.data_type_and_nullable(schema)).collect::>>()?.into_iter().unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) + let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { plan_datafusion_err!( "{} {}", @@ -414,7 +410,7 @@ impl ExprSchemable for Expr { utils::generate_signature_error_msg( func.name(), func.signature().clone(), - &arg_data_types, + &arg_types, ) ) })?; @@ -428,10 +424,6 @@ impl ExprSchemable for Expr { _ => "".to_string(), }) .collect::>(); - let nullables = args - .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; let args = ReturnTypeArgs { arg_types: &new_data_types, arguments: &arguments, From 79092319c4f33960fee587725530b0738f739790 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 13 Jan 2025 22:39:51 +0800 Subject: [PATCH 12/28] clippy Signed-off-by: Jay Zhan --- datafusion/expr/src/expr_schema.rs | 11 +++++++---- datafusion/expr/src/udf.rs | 8 ++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a3b31578e941..163d2539416f 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -25,9 +25,7 @@ use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; use crate::udf::ReturnTypeArgs; -use crate::{ - utils, LogicalPlan, Projection, ScalarUDF, Subquery, WindowFunctionDefinition, -}; +use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ @@ -397,7 +395,12 @@ impl ExprSchemable for Expr { self.data_type_and_nullable_with_window_function(schema, window_function) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, nullables) : (Vec, Vec) = args.iter().map(|e| e.data_type_and_nullable(schema)).collect::>>()?.into_iter().unzip(); + let (arg_types, nullables): (Vec, Vec) = args + .iter() + .map(|e| e.data_type_and_nullable(schema)) + .collect::>>()? + .into_iter() + .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index a5cf8a7834e0..b32ec5654ebb 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -20,15 +20,11 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; -use crate::type_coercion::functions::data_types_with_scalar_udf; use crate::{ - utils, ColumnarValue, Documentation, Expr, ExprSchemable, - ScalarFunctionImplementation, Signature, + ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{ - not_impl_err, plan_datafusion_err, DataFusionError, ExprSchema, Result, ScalarValue, -}; +use datafusion_common::{not_impl_err, ExprSchema, Result}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; From 9a9565956e47524e1adf654859f1704fa13c0542 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 Jan 2025 08:38:25 +0800 Subject: [PATCH 13/28] fix take Signed-off-by: Jay Zhan --- .../user_defined_scalar_functions.rs | 41 +++++++++---------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 30b3c6e2bbeb..91d02afd5f86 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -34,14 +34,12 @@ use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, HashMap, Result, + not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable, - LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -819,32 +817,30 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_exprs( + fn return_type_from_args( &self, - arg_exprs: &[Expr], - schema: &dyn ExprSchema, - _arg_data_types: &[DataType], - ) -> Result { - if arg_exprs.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); + args: ReturnTypeArgs, + ) -> Result { + if args.arg_types.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); } - let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) = - arg_exprs.get(2) - { - if *idx == 0 || *idx == 1 { - *idx as usize + let take_idx = if let Some(take_idx) = args.arguments.get(2) { + let take_idx = take_idx.parse::().unwrap(); + + if take_idx == 0 || take_idx == 1 { + take_idx } else { - return plan_err!("The third argument must be 0 or 1, got: {idx}"); + return plan_err!("The third argument must be 0 or 1, got: {take_idx}"); } } else { return plan_err!( "The third argument must be a literal of type int64, but got {:?}", - arg_exprs.get(2) + args.arguments.get(2) ); }; - arg_exprs.get(take_idx).unwrap().get_type(schema) + Ok(ReturnInfo::new_nullable(args.arg_types[take_idx].to_owned())) } // The actual implementation @@ -854,7 +850,8 @@ impl ScalarUDFImpl for TakeUDF { _number_rows: usize, ) -> Result { let take_idx = match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0, + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1, _ => unreachable!(), }; match &args[take_idx] { @@ -874,9 +871,9 @@ async fn verify_udf_return_type() -> Result<()> { // take(smallint_col, double_col, 1) as take1 // FROM alltypes_plain; let exprs = vec![ - take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)]) + take.call(vec![col("smallint_col"), col("double_col"), lit("0")]) .alias("take0"), - take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)]) + take.call(vec![col("smallint_col"), col("double_col"), lit("1")]) .alias("take1"), ]; From 9320f349f0eb685e10e07a29f1d0d3298d3023e7 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 Jan 2025 08:38:39 +0800 Subject: [PATCH 14/28] fmt Signed-off-by: Jay Zhan --- .../user_defined_scalar_functions.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 91d02afd5f86..f76e07b992c8 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -34,12 +34,13 @@ use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, - not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, - ScalarValue, + not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, + OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -817,10 +818,7 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_args( - &self, - args: ReturnTypeArgs, - ) -> Result { + fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { if args.arg_types.len() != 3 { return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); } @@ -840,7 +838,9 @@ impl ScalarUDFImpl for TakeUDF { ); }; - Ok(ReturnInfo::new_nullable(args.arg_types[take_idx].to_owned())) + Ok(ReturnInfo::new_nullable( + args.arg_types[take_idx].to_owned(), + )) } // The actual implementation From 03bd527c7a4d3383e826f1eb76e58ab55c52f88c Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 Jan 2025 20:09:47 +0800 Subject: [PATCH 15/28] rm duplicated test Signed-off-by: Jay Zhan --- datafusion/functions/src/core/coalesce.rs | 36 ----------------------- 1 file changed, 36 deletions(-) diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 73b88af6f25f..abd200741834 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -177,39 +177,3 @@ fn get_coalesce_doc() -> &'static Documentation { .build() }) } - -#[cfg(test)] -mod test { - use arrow::datatypes::DataType; - - use datafusion_expr::ScalarUDFImpl; - - use crate::core; - - #[test] - fn test_coalesce_return_types() { - let coalesce = core::coalesce::CoalesceFunc::new(); - let return_type = coalesce - .return_type(&[DataType::Date32, DataType::Date32]) - .unwrap(); - assert_eq!(return_type, DataType::Date32); - } - - #[test] - fn test_coalesce_return_types_with_nulls_first() { - let coalesce = core::coalesce::CoalesceFunc::new(); - let return_type = coalesce - .return_type(&[DataType::Null, DataType::Date32]) - .unwrap(); - assert_eq!(return_type, DataType::Date32); - } - - #[test] - fn test_coalesce_return_types_with_nulls_last() { - let coalesce = core::coalesce::CoalesceFunc::new(); - let return_type = coalesce - .return_type(&[DataType::Int64, DataType::Null]) - .unwrap(); - assert_eq!(return_type, DataType::Int64); - } -} From 26e6346d81b63100735fb2dd6c7146b74716d0cc Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 18 Jan 2025 16:09:59 +0800 Subject: [PATCH 16/28] refactor: remove unused documentation sections from scalar functions --- datafusion/functions/src/core/coalesce.rs | 28 +------------------ datafusion/functions/src/core/getfield.rs | 3 +- datafusion/functions/src/core/named_struct.rs | 3 +- 3 files changed, 3 insertions(+), 31 deletions(-) diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index ce7de3a45855..602fe0fd9585 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -21,8 +21,7 @@ use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnInfo, ReturnTypeArgs}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use itertools::Itertools; @@ -169,28 +168,3 @@ impl ScalarUDFImpl for CoalesceFunc { self.doc() } } - -static DOCUMENTATION: OnceLock = OnceLock::new(); - -fn get_coalesce_doc() -> &'static Documentation { - DOCUMENTATION.get_or_init(|| { - Documentation::builder( - DOC_SECTION_CONDITIONAL, - "Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values.", - "coalesce(expression1[, ..., expression_n])") - .with_sql_example(r#"```sql -> select coalesce(null, null, 'datafusion'); -+----------------------------------------+ -| coalesce(NULL,NULL,Utf8("datafusion")) | -+----------------------------------------+ -| datafusion | -+----------------------------------------+ -```"#, - ) - .with_argument( - "expression1, expression_n", - "Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." - ) - .build() - }) -} diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index badc4f10a2cd..4105e7182b1e 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -23,8 +23,7 @@ use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnInfo, ReturnTypeArgs}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 608d88d962d1..b1e07c1fbb99 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -18,8 +18,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue}; -use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable, ReturnInfo, ReturnTypeArgs}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; From 3f2ae5cbc7fc461956580b113cd523629004b6a3 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 18 Jan 2025 16:21:15 +0800 Subject: [PATCH 17/28] upd doc Signed-off-by: Jay Zhan --- datafusion/expr/src/udf.rs | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index c1fbf992861f..2be96e5820da 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -505,38 +505,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_exprs`], + /// If you provide an implementation for [`Self::return_type_from_args`], /// DataFusion will not call `return_type` (this function). In this case it /// is recommended to return [`DataFusionError::Internal`]. /// /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; - /// What [`DataType`] will be returned by this function, given the - /// arguments? - /// - /// Note most UDFs should implement [`Self::return_type`] and not this - /// function. The output type for most functions only depends on the types - /// of their inputs (e.g. `sqrt(f32)` is always `f32`). - /// - /// By default, this function calls [`Self::return_type`] with the - /// types of each argument. - /// - /// This method can be overridden for functions that return different - /// *types* based on the *values* of their arguments. - /// - /// For example, the following two function calls get the same argument - /// types (something and a `Utf8` string) but return different types based - /// on the value of the second argument: - /// - /// * `arrow_cast(x, 'Int16')` --> `Int16` - /// * `arrow_cast(x, 'Float32')` --> `Float32` - /// - /// # Notes: - /// - /// This function must consistently return the same type for the same - /// logical input even if the input is simplified (e.g. it must return the same - /// value for `('foo' | 'bar')` as it does for ('foobar'). #[deprecated(since = "45.0.0", note = "Use `return_type_from_args` instead")] fn return_type_from_exprs( &self, From 5ad7b5c63f9e9cf20adc43e8c5834c5061b7da87 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 11:15:10 +0800 Subject: [PATCH 18/28] use scalar value Signed-off-by: Jay Zhan --- datafusion/common/src/utils/mod.rs | 12 ++++- datafusion/expr/src/expr_schema.rs | 8 ++-- datafusion/expr/src/udf.rs | 4 +- datafusion/functions/src/core/arrow_cast.rs | 46 +++++++++++-------- datafusion/functions/src/core/getfield.rs | 38 +++++++-------- datafusion/functions/src/core/named_struct.rs | 15 +++++- .../functions/src/datetime/date_part.rs | 16 +++++-- .../functions/src/datetime/from_unixtime.rs | 18 ++++++-- .../physical-expr/src/scalar_function.rs | 10 ++-- 9 files changed, 102 insertions(+), 65 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 29d33fec14ab..bcee49cc0e08 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -889,7 +889,8 @@ pub fn get_available_parallelism() -> usize { mod tests { use super::*; use crate::ScalarValue::Null; - use arrow::array::Float64Array; + use arrow::{array::Float64Array, util::pretty::pretty_format_columns}; + use arrow_array::Int32Array; use sqlparser::tokenizer::Span; #[test] @@ -1201,4 +1202,13 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } + + #[test] + fn test132() { + let a = Arc::new(Int32Array::from(vec![3; 200])) as ArrayRef; + println!( + "display {}", + pretty_format_columns("ColumnarValue(ArrayRef)", &[a]).unwrap() + ) + } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 163d2539416f..d9598b467e90 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -30,7 +30,7 @@ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, - Result, ScalarValue, TableReference, + Result, TableReference, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; @@ -421,10 +421,8 @@ impl ExprSchemable for Expr { let arguments = args .iter() .map(|e| match e { - Expr::Literal(ScalarValue::Utf8(s)) => { - s.clone().unwrap_or_default() - } - _ => "".to_string(), + Expr::Literal(sv) => Some(sv), + _ => None, }) .collect::>(); let args = ReturnTypeArgs { diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 2be96e5820da..f6397409205d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -24,7 +24,7 @@ use crate::{ ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, ExprSchema, Result}; +use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; @@ -353,7 +353,7 @@ pub struct ReturnTypeArgs<'a> { /// The data types of the arguments to the function pub arg_types: &'a [DataType], /// The Utf8 arguments to the function, if the expression is not Utf8, it will be empty string - pub arguments: &'a [String], + pub arguments: &'a [Option<&'a ScalarValue>], pub nullables: &'a [bool], } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 275535e310b2..dea683c3e046 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -19,9 +19,9 @@ use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::DataFusionError; +use datafusion_common::{exec_datafusion_err, DataFusionError}; use datafusion_common::{ - arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, Result, + arrow_datafusion_err, internal_err, exec_err, Result, ScalarValue, }; use std::any::Any; @@ -119,27 +119,33 @@ impl ScalarUDFImpl for ArrowCastFunc { let nullable = args.nullables.iter().any(|&nullable| nullable); if args.arguments.len() != 2 { - return plan_err!( + return exec_err!( "{} needs 2 arguments, {} provided", self.name(), args.arguments.len() ); } - let val = &args.arguments[1]; - if val.is_empty() { - return plan_err!( - "{} requires its second argument to be a constant string", - self.name() - ); - }; - - match val.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), - // If the data type cannot be parsed, return a Plan error to signal an - // error in the input rather than a more general ArrowError - Err(ArrowError::ParseError(e)) => Err(plan_datafusion_err!("{e}")), - Err(e) => Err(arrow_datafusion_err!(e)), + match args.arguments[1].as_ref() { + Some(ScalarValue::Utf8(Some(casted_type))) if !casted_type.is_empty() => { + match casted_type.parse::() { + Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), + Err(e) => Err(arrow_datafusion_err!(e)), + } + } + Some(ScalarValue::Utf8(Some(_))) => { + exec_err!( + "{} requires its second argument to be a non-empty constant string", + self.name() + ) + } + _ => { + exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ) + } } } @@ -185,10 +191,10 @@ impl ScalarUDFImpl for ArrowCastFunc { /// Returns the requested type from the arguments fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { - return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); + return exec_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { - return plan_err!( + return exec_err!( "arrow_cast requires its second argument to be a constant string, got {:?}", &args[1] ); @@ -197,7 +203,7 @@ fn data_type_from_args(args: &[Expr]) -> Result { val.parse().map_err(|e| match e { // If the data type cannot be parsed, return a Plan error to signal an // error in the input rather than a more general ArrowError - ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), + ArrowError::ParseError(e) => exec_datafusion_err!("{e}"), e => arrow_datafusion_err!(e), }) } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 4105e7182b1e..ae900076bdcb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -21,7 +21,7 @@ use arrow::array::{ use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, + exec_err, internal_err, plan_datafusion_err, Result, ScalarValue, }; use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -154,15 +154,7 @@ impl ScalarUDFImpl for GetFieldFunc { ); } - let name = &args.arguments[1]; - if name.is_empty() { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - - let data_type = &args.arg_types[0]; - match (data_type, name) { + match (&args.arg_types[0], args.arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -173,21 +165,25 @@ impl ScalarUDFImpl for GetFieldFunc { let value_field = fields.get(1).expect("fields should have exactly two members"); Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) }, - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } } - (DataType::Struct(fields), s) => { - if s.is_empty() { - plan_err!( - "Struct based indexed access requires a non empty string" - ) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) - } + (DataType::Struct(fields), Some(ScalarValue::Utf8(Some(s)))) if !s.is_empty() => { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + } + (DataType::Struct(_), Some(ScalarValue::Utf8(Some(_)))) => { + exec_err!( + "Struct based indexed access requires a non-empty string" + ) + } + (DataType::Struct(_), _) => { + exec_err!( + "Struct based indexed access requires a constant string" + ) } (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), + (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index b1e07c1fbb99..bc600f0e20ed 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -173,7 +173,20 @@ impl ScalarUDFImpl for NamedStructFunc { ); } - let names = args.arguments.iter().step_by(2).collect::>(); + let names = args + .arguments + .iter() + .step_by(2) + .map(|x| match x { + Some(ScalarValue::Utf8(Some(name))) if !name.is_empty() => Ok(name), + Some(ScalarValue::Utf8(Some(_))) => { + exec_err!("{} requires field name as non-empty string", self.name()) + } + _ => { + exec_err!("{} requires field name as constant string", self.name()) + } + }) + .collect::>>()?; let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); let return_fields = names diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index d720a0e32ecf..ba79867ad082 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,10 +140,18 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if is_epoch(&args.arguments[0]) { - Ok(ReturnInfo::new_nullable(DataType::Float64)) - } else { - Ok(ReturnInfo::new_nullable(DataType::Int32)) + match args.arguments[0].as_ref() { + Some(ScalarValue::Utf8(Some(part))) if !part.is_empty() => { + if is_epoch(part) { + Ok(ReturnInfo::new_nullable(DataType::Float64)) + } else { + Ok(ReturnInfo::new_nullable(DataType::Int32)) + } + } + Some(ScalarValue::Utf8(Some(_))) => { + exec_err!("{} requires non-empty string", self.name()) + } + _ => exec_err!("{} requires constant string", self.name()), } } diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 65237b47d326..48bfa0c43da6 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -86,10 +86,20 @@ impl ScalarUDFImpl for FromUnixtimeFunc { if args.arguments.len() == 1 { Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) } else { - Ok(ReturnInfo::new_nullable(Timestamp( - Second, - Some(Arc::from(args.arguments[1].to_string())), - ))) + match &args.arguments[1] { + Some(ScalarValue::Utf8(Some(v))) if !v.is_empty() => { + Ok(ReturnInfo::new_nullable(Timestamp( + Second, + Some(Arc::from(v.to_string())), + ))) + } + Some(ScalarValue::Utf8(Some(_))) => { + exec_err!("{} requires its second argument to be a non-empty constant string", self.name()) + } + _ => { + exec_err!("{} requires its second argument to be a constant string", self.name()) + } + } } } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index f2269f7fb336..c78ef2ed7dcb 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -106,18 +106,14 @@ impl ScalarFunctionExpr { .map(|e| e.nullable(schema)) .collect::>>()?; - let arguments: Vec = args + let arguments = args .iter() .map(|e| { e.as_any() .downcast_ref::() - .map(|literal| match literal.value() { - ScalarValue::Utf8(Some(s)) => s.clone(), - _ => String::new(), - }) - .unwrap_or_else(String::new) + .map(|literal| literal.value()) }) - .collect(); + .collect::>(); let ret_args = ReturnTypeArgs { arg_types: &arg_types, arguments: &arguments, From 0545181891d7a0f66c5016a0c8ddadd42f6f8bda Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 11:19:00 +0800 Subject: [PATCH 19/28] fix test Signed-off-by: Jay Zhan --- datafusion/functions/src/core/named_struct.rs | 7 ++++--- datafusion/sqllogictest/test_files/arrow_typeof.slt | 2 +- datafusion/sqllogictest/test_files/struct.slt | 8 ++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bc600f0e20ed..70a75994217d 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -176,14 +176,15 @@ impl ScalarUDFImpl for NamedStructFunc { let names = args .arguments .iter() + .enumerate() .step_by(2) - .map(|x| match x { + .map(|(i, x)| match x { Some(ScalarValue::Utf8(Some(name))) if !name.is_empty() => Ok(name), Some(ScalarValue::Utf8(Some(_))) => { - exec_err!("{} requires field name as non-empty string", self.name()) + exec_err!("{} requires {i}-th (0-indexed) field name as non-empty string", self.name()) } _ => { - exec_err!("{} requires field name as constant string", self.name()) + exec_err!("{} requires {i}-th (0-indexed) field name as constant string", self.name()) } }) .collect::>>()?; diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index fc93d0270f1a..f75b4eeb7656 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string +query error DataFusion error: Execution error: arrow_cast requires its second argument to be a constant string SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 79982f32678e..9f767e99ab93 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -151,19 +151,19 @@ query error DataFusion error: Execution error: named_struct requires an even num select named_struct('a', 1, 'b'); # error on even argument not a string literal #1 -query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 0 +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as constant string select named_struct(1, 'a'); # error on even argument not a string literal #2 -query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 2 +query error DataFusion error: Execution error: named_struct requires 2\-th \(0\-indexed\) field name as constant string select named_struct('corret', 1, 0, 'wrong'); # error on even argument not a string literal #3 -query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 0 +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as constant string select named_struct(values.a, 'a') from values; # error on even argument not a string literal #4 -query error DataFusion error: Execution error: named_struct even arguments must be string literals at position 0 +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as constant string select named_struct(values.c, 'c') from values; # named_struct with mixed scalar and array values #1 From 30142670fb936f88f87b195d2be27af56af3bce0 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 11:43:54 +0800 Subject: [PATCH 20/28] fix test Signed-off-by: Jay Zhan --- .../user_defined/user_defined_scalar_functions.rs | 5 ++++- datafusion/functions/src/core/arrow_cast.rs | 5 ++--- datafusion/functions/src/core/named_struct.rs | 10 ++++++++-- datafusion/functions/src/datetime/from_unixtime.rs | 5 ++++- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index f76e07b992c8..525067878083 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -824,7 +824,10 @@ impl ScalarUDFImpl for TakeUDF { } let take_idx = if let Some(take_idx) = args.arguments.get(2) { - let take_idx = take_idx.parse::().unwrap(); + // This is for test only, safe to unwrap + let take_idx = take_idx.unwrap() + .try_as_str().unwrap().unwrap() + .parse::().unwrap(); if take_idx == 0 || take_idx == 1 { take_idx diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index dea683c3e046..63551a8a72ac 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -19,11 +19,10 @@ use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_common::{exec_datafusion_err, DataFusionError}; use datafusion_common::{ - arrow_datafusion_err, internal_err, exec_err, Result, - ScalarValue, + arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; +use datafusion_common::{exec_datafusion_err, DataFusionError}; use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 70a75994217d..6465ce321441 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -181,10 +181,16 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(i, x)| match x { Some(ScalarValue::Utf8(Some(name))) if !name.is_empty() => Ok(name), Some(ScalarValue::Utf8(Some(_))) => { - exec_err!("{} requires {i}-th (0-indexed) field name as non-empty string", self.name()) + exec_err!( + "{} requires {i}-th (0-indexed) field name as non-empty string", + self.name() + ) } _ => { - exec_err!("{} requires {i}-th (0-indexed) field name as constant string", self.name()) + exec_err!( + "{} requires {i}-th (0-indexed) field name as constant string", + self.name() + ) } }) .collect::>>()?; diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 48bfa0c43da6..f34f76e70d31 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -97,7 +97,10 @@ impl ScalarUDFImpl for FromUnixtimeFunc { exec_err!("{} requires its second argument to be a non-empty constant string", self.name()) } _ => { - exec_err!("{} requires its second argument to be a constant string", self.name()) + exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ) } } } From 84636981d47b60181576a48110e89a5ee02c67d0 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 12:09:08 +0800 Subject: [PATCH 21/28] use try_as_str Signed-off-by: Jay Zhan --- .../user_defined_scalar_functions.rs | 10 +++-- datafusion/functions/src/core/arrow_cast.rs | 29 +++++++-------- datafusion/functions/src/core/getfield.rs | 18 +++++---- datafusion/functions/src/core/named_struct.rs | 37 +++++++++++++------ .../functions/src/datetime/date_part.rs | 20 +++++----- .../functions/src/datetime/from_unixtime.rs | 25 ++++++------- 6 files changed, 78 insertions(+), 61 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 525067878083..a18ba82a4483 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -825,9 +825,13 @@ impl ScalarUDFImpl for TakeUDF { let take_idx = if let Some(take_idx) = args.arguments.get(2) { // This is for test only, safe to unwrap - let take_idx = take_idx.unwrap() - .try_as_str().unwrap().unwrap() - .parse::().unwrap(); + let take_idx = take_idx + .unwrap() + .try_as_str() + .unwrap() + .unwrap() + .parse::() + .unwrap(); if take_idx == 0 || take_idx == 1 { take_idx diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 63551a8a72ac..7a061846bcf8 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -125,27 +125,26 @@ impl ScalarUDFImpl for ArrowCastFunc { ); } - match args.arguments[1].as_ref() { - Some(ScalarValue::Utf8(Some(casted_type))) if !casted_type.is_empty() => { + args.arguments[1].map_or_else(|| exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ), |sv| sv.try_as_str().flatten().map_or_else(|| exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ), |casted_type| { + if casted_type.is_empty() { + exec_err!( + "{} requires its second argument to be a non-empty constant string", + self.name() + ) + } else { match casted_type.parse::() { Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), } } - Some(ScalarValue::Utf8(Some(_))) => { - exec_err!( - "{} requires its second argument to be a non-empty constant string", - self.name() - ) - } - _ => { - exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ) - } - } + })) } fn invoke_batch( diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index ae900076bdcb..618d9775e9f6 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -168,14 +168,16 @@ impl ScalarUDFImpl for GetFieldFunc { _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } } - (DataType::Struct(fields), Some(ScalarValue::Utf8(Some(s)))) if !s.is_empty() => { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) - } - (DataType::Struct(_), Some(ScalarValue::Utf8(Some(_)))) => { - exec_err!( - "Struct based indexed access requires a non-empty string" - ) + (DataType::Struct(fields), Some(sv)) => { + sv.try_as_str().flatten().map_or_else(|| exec_err!("Field name must be a constant string"), + |field_name| { + if field_name.is_empty() { + exec_err!("Field name must be a non-empty string") + } else { + let field = fields.iter().find(|f| f.name() == field_name); + field.ok_or(plan_datafusion_err!("Field {field_name} not found in struct")).map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + } + }) } (DataType::Struct(_), _) => { exec_err!( diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 6465ce321441..f9b187ab346f 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -178,20 +178,35 @@ impl ScalarUDFImpl for NamedStructFunc { .iter() .enumerate() .step_by(2) - .map(|(i, x)| match x { - Some(ScalarValue::Utf8(Some(name))) if !name.is_empty() => Ok(name), - Some(ScalarValue::Utf8(Some(_))) => { - exec_err!( + .map(|(i, sv)| { + sv.map_or_else( + || { + exec_err!( + "{} requires {i}-th (0-indexed) field name as constant string", + self.name() + ) + }, + |sv| { + sv.try_as_str().flatten().map_or_else( + || { + exec_err!( + "{} requires {i}-th (0-indexed) field name as constant string", + self.name() + ) + }, + |name| { + if name.is_empty() { + exec_err!( "{} requires {i}-th (0-indexed) field name as non-empty string", self.name() ) - } - _ => { - exec_err!( - "{} requires {i}-th (0-indexed) field name as constant string", - self.name() - ) - } + } else { + Ok(name) + } + }, + ) + }, + ) }) .collect::>>()?; let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index ba79867ad082..ceed7775b79b 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,19 +140,17 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - match args.arguments[0].as_ref() { - Some(ScalarValue::Utf8(Some(part))) if !part.is_empty() => { - if is_epoch(part) { - Ok(ReturnInfo::new_nullable(DataType::Float64)) - } else { - Ok(ReturnInfo::new_nullable(DataType::Int32)) - } - } - Some(ScalarValue::Utf8(Some(_))) => { + args.arguments[0].map_or_else(|| exec_err!("{} requires constant string", self.name()), + |sv| sv.try_as_str().flatten().map_or_else(|| exec_err!("{} requires constant string", self.name()), + |part| { + if part.is_empty() { exec_err!("{} requires non-empty string", self.name()) + } else if is_epoch(part) { + Ok(ReturnInfo::new_nullable(DataType::Float64)) + } else { + Ok(ReturnInfo::new_nullable(DataType::Int32)) } - _ => exec_err!("{} requires constant string", self.name()), - } + })) } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index f34f76e70d31..d9d9ef0e9077 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -86,23 +86,22 @@ impl ScalarUDFImpl for FromUnixtimeFunc { if args.arguments.len() == 1 { Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) } else { - match &args.arguments[1] { - Some(ScalarValue::Utf8(Some(v))) if !v.is_empty() => { - Ok(ReturnInfo::new_nullable(Timestamp( - Second, - Some(Arc::from(v.to_string())), - ))) - } - Some(ScalarValue::Utf8(Some(_))) => { - exec_err!("{} requires its second argument to be a non-empty constant string", self.name()) - } - _ => { + args.arguments[1].map_or_else(|| exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ), |sv| sv.try_as_str().flatten().map_or_else(|| exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ), |tz| { + if tz.is_empty() { exec_err!( - "{} requires its second argument to be a constant string", + "{} requires its second argument to be a non-empty constant string", self.name() ) + } else { + Ok(ReturnInfo::new_nullable(Timestamp(Second, Some(Arc::from(tz.to_string()))))) } - } + })) } } From 40dfc6cd6fbf95f1b4acb4a6b30695d2454ee14e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 12:30:19 +0800 Subject: [PATCH 22/28] refactor: improve error handling for constant string arguments in UDFs --- datafusion/functions/src/core/arrow_cast.rs | 46 ++++++++++++------- .../functions/src/datetime/date_part.rs | 28 ++++++----- .../functions/src/datetime/from_unixtime.rs | 37 ++++++++------- 3 files changed, 68 insertions(+), 43 deletions(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 7a061846bcf8..66b48722a387 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -125,26 +125,40 @@ impl ScalarUDFImpl for ArrowCastFunc { ); } - args.arguments[1].map_or_else(|| exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ), |sv| sv.try_as_str().flatten().map_or_else(|| exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ), |casted_type| { - if casted_type.is_empty() { + args.arguments[1].map_or_else( + || { exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ) + }, + |sv| { + sv.try_as_str().flatten().map_or_else( + || { + exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ) + }, + |casted_type| { + if casted_type.is_empty() { + exec_err!( "{} requires its second argument to be a non-empty constant string", self.name() ) - } else { - match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), - Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), - Err(e) => Err(arrow_datafusion_err!(e)), - } - } - })) + } else { + match casted_type.parse::() { + Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Err(ArrowError::ParseError(e)) => { + Err(exec_datafusion_err!("{e}")) + } + Err(e) => Err(arrow_datafusion_err!(e)), + } + } + }, + ) + }, + ) } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index ceed7775b79b..d116d46701f7 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,17 +140,23 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - args.arguments[0].map_or_else(|| exec_err!("{} requires constant string", self.name()), - |sv| sv.try_as_str().flatten().map_or_else(|| exec_err!("{} requires constant string", self.name()), - |part| { - if part.is_empty() { - exec_err!("{} requires non-empty string", self.name()) - } else if is_epoch(part) { - Ok(ReturnInfo::new_nullable(DataType::Float64)) - } else { - Ok(ReturnInfo::new_nullable(DataType::Int32)) - } - })) + args.arguments[0].map_or_else( + || exec_err!("{} requires constant string", self.name()), + |sv| { + sv.try_as_str().flatten().map_or_else( + || exec_err!("{} requires constant string", self.name()), + |part| { + if part.is_empty() { + exec_err!("{} requires non-empty string", self.name()) + } else if is_epoch(part) { + Ok(ReturnInfo::new_nullable(DataType::Float64)) + } else { + Ok(ReturnInfo::new_nullable(DataType::Int32)) + } + }, + ) + }, + ) } fn invoke_batch( diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d9d9ef0e9077..d0555dec66f7 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -86,22 +86,27 @@ impl ScalarUDFImpl for FromUnixtimeFunc { if args.arguments.len() == 1 { Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) } else { - args.arguments[1].map_or_else(|| exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ), |sv| sv.try_as_str().flatten().map_or_else(|| exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ), |tz| { - if tz.is_empty() { - exec_err!( - "{} requires its second argument to be a non-empty constant string", - self.name() - ) - } else { - Ok(ReturnInfo::new_nullable(Timestamp(Second, Some(Arc::from(tz.to_string()))))) - } - })) + args.arguments[1] + .and_then(|sv| { + sv.try_as_str() + .flatten() + .filter(|s| !s.is_empty()) + .map(|tz| { + ReturnInfo::new_nullable(Timestamp( + Second, + Some(Arc::from(tz.to_string())), + )) + }) + }) + .map_or_else( + || { + exec_err!( + "{} requires its second argument to be a constant string", + self.name() + ) + }, + Ok, + ) } } From c321ff8769b120c76f0693dcec7ce60556d96c55 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 12:46:04 +0800 Subject: [PATCH 23/28] refactor: enhance error messages for constant string requirements in UDFs --- datafusion/functions/src/core/getfield.rs | 24 +++++-------- datafusion/functions/src/core/named_struct.rs | 35 +++++-------------- .../functions/src/datetime/date_part.rs | 30 ++++++++-------- datafusion/sqllogictest/test_files/struct.slt | 8 ++--- 4 files changed, 36 insertions(+), 61 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 618d9775e9f6..76aaf56efd8b 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -168,22 +168,16 @@ impl ScalarUDFImpl for GetFieldFunc { _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } } - (DataType::Struct(fields), Some(sv)) => { - sv.try_as_str().flatten().map_or_else(|| exec_err!("Field name must be a constant string"), - |field_name| { - if field_name.is_empty() { - exec_err!("Field name must be a non-empty string") - } else { - let field = fields.iter().find(|f| f.name() == field_name); - field.ok_or(plan_datafusion_err!("Field {field_name} not found in struct")).map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) - } + (DataType::Struct(fields),sv) => { + sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || exec_err!("Field name must be a non-empty string"), + |field_name| { + fields.iter().find(|f| f.name() == field_name) + .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) + .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) }) - } - (DataType::Struct(_), _) => { - exec_err!( - "Struct based indexed access requires a constant string" - ) - } + }, (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index f9b187ab346f..b055799484af 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -178,36 +178,17 @@ impl ScalarUDFImpl for NamedStructFunc { .iter() .enumerate() .step_by(2) - .map(|(i, sv)| { - sv.map_or_else( - || { + .map(|(i, sv)| + sv.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || exec_err!( - "{} requires {i}-th (0-indexed) field name as constant string", + "{} requires {i}-th (0-indexed) field name as non-empty constant string", self.name() + ), + Ok ) - }, - |sv| { - sv.try_as_str().flatten().map_or_else( - || { - exec_err!( - "{} requires {i}-th (0-indexed) field name as constant string", - self.name() - ) - }, - |name| { - if name.is_empty() { - exec_err!( - "{} requires {i}-th (0-indexed) field name as non-empty string", - self.name() - ) - } else { - Ok(name) - } - }, - ) - }, - ) - }) + ) .collect::>>()?; let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index d116d46701f7..0b86e3b221cb 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,23 +140,23 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - args.arguments[0].map_or_else( - || exec_err!("{} requires constant string", self.name()), - |sv| { - sv.try_as_str().flatten().map_or_else( - || exec_err!("{} requires constant string", self.name()), - |part| { - if part.is_empty() { - exec_err!("{} requires non-empty string", self.name()) - } else if is_epoch(part) { - Ok(ReturnInfo::new_nullable(DataType::Float64)) + args.arguments[0] + .and_then(|sv| { + sv.try_as_str() + .flatten() + .filter(|s| !s.is_empty()) + .map(|part| { + if is_epoch(part) { + ReturnInfo::new_nullable(DataType::Float64) } else { - Ok(ReturnInfo::new_nullable(DataType::Int32)) + ReturnInfo::new_nullable(DataType::Int32) } - }, - ) - }, - ) + }) + }) + .map_or_else( + || exec_err!("{} requires non-empty constant string", self.name()), + Ok, + ) } fn invoke_batch( diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 9f767e99ab93..d671798b7d0f 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -151,19 +151,19 @@ query error DataFusion error: Execution error: named_struct requires an even num select named_struct('a', 1, 'b'); # error on even argument not a string literal #1 -query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as constant string +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as non\-empty constant string select named_struct(1, 'a'); # error on even argument not a string literal #2 -query error DataFusion error: Execution error: named_struct requires 2\-th \(0\-indexed\) field name as constant string +query error DataFusion error: Execution error: named_struct requires 2\-th \(0\-indexed\) field name as non\-empty constant string select named_struct('corret', 1, 0, 'wrong'); # error on even argument not a string literal #3 -query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as constant string +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as non\-empty constant string select named_struct(values.a, 'a') from values; # error on even argument not a string literal #4 -query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as constant string +query error DataFusion error: Execution error: named_struct requires 0\-th \(0\-indexed\) field name as non\-empty constant string select named_struct(values.c, 'c') from values; # named_struct with mixed scalar and array values #1 From 8ea6cefbe6d259c4b4b16d0bc06f84a124ac7aaa Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 19 Jan 2025 14:05:18 +0800 Subject: [PATCH 24/28] refactor: streamline argument validation in return_type_from_args for UDFs --- datafusion/functions/src/core/arrow_cast.rs | 60 ++++++------------- datafusion/functions/src/core/getfield.rs | 8 +-- .../functions/src/datetime/date_part.rs | 3 + .../functions/src/datetime/from_unixtime.rs | 3 + .../sqllogictest/test_files/arrow_typeof.slt | 2 +- 5 files changed, 27 insertions(+), 49 deletions(-) diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 66b48722a387..a25e94916b60 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -117,48 +117,24 @@ impl ScalarUDFImpl for ArrowCastFunc { fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { let nullable = args.nullables.iter().any(|&nullable| nullable); - if args.arguments.len() != 2 { - return exec_err!( - "{} needs 2 arguments, {} provided", - self.name(), - args.arguments.len() - ); - } - - args.arguments[1].map_or_else( - || { - exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ) - }, - |sv| { - sv.try_as_str().flatten().map_or_else( - || { - exec_err!( - "{} requires its second argument to be a constant string", - self.name() - ) - }, - |casted_type| { - if casted_type.is_empty() { - exec_err!( - "{} requires its second argument to be a non-empty constant string", - self.name() - ) - } else { - match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), - Err(ArrowError::ParseError(e)) => { - Err(exec_datafusion_err!("{e}")) - } - Err(e) => Err(arrow_datafusion_err!(e)), - } - } - }, - ) - }, - ) + // Length check handled in the signature + debug_assert_eq!(args.arguments.len(), 2); + + args.arguments[1] + .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) + .map_or_else( + || { + exec_err!( + "{} requires its second argument to be a non-empty constant string", + self.name() + ) + }, + |casted_type| match casted_type.parse::() { + Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), + Err(e) => Err(arrow_datafusion_err!(e)), + }, + ) } fn invoke_batch( diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 76aaf56efd8b..3d83cdfdcbe9 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -147,12 +147,8 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arguments.len() != 2 { - return exec_err!( - "get_field function requires 2 arguments, got {}", - args.arguments.len() - ); - } + // Length check handled in the signature + debug_assert_eq!(args.arguments.len(), 2); match (&args.arg_types[0], args.arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 0b86e3b221cb..94fed982b650 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,6 +140,9 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // Length check handled in the signature + debug_assert_eq!(args.arguments.len(), 2); + args.arguments[0] .and_then(|sv| { sv.try_as_str() diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d0555dec66f7..90d88b00ea66 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -83,6 +83,9 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + // Length check handled in the signature + debug_assert!(matches!(args.arguments.len(), 1 | 2)); + if args.arguments.len() == 1 { Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) } else { diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index f75b4eeb7656..654218531f1d 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error SELECT arrow_cast('1') -query error DataFusion error: Execution error: arrow_cast requires its second argument to be a constant string +query error DataFusion error: Execution error: arrow_cast requires its second argument to be a non\-empty constant string SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown From 486a3b63017f9a8d7fb14663365dc21872df7b1e Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 20 Jan 2025 08:49:45 +0800 Subject: [PATCH 25/28] rename and doc Signed-off-by: Jay Zhan --- .../user_defined_scalar_functions.rs | 4 ++-- datafusion/expr/src/expr_schema.rs | 2 +- datafusion/expr/src/udf.rs | 19 +++++++++++++++++-- datafusion/functions/src/core/arrow_cast.rs | 4 ++-- datafusion/functions/src/core/getfield.rs | 4 ++-- datafusion/functions/src/core/named_struct.rs | 8 ++++---- .../functions/src/datetime/date_part.rs | 4 ++-- .../functions/src/datetime/from_unixtime.rs | 6 +++--- .../physical-expr/src/scalar_function.rs | 2 +- 9 files changed, 34 insertions(+), 19 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a18ba82a4483..a228eb0286aa 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -823,7 +823,7 @@ impl ScalarUDFImpl for TakeUDF { return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); } - let take_idx = if let Some(take_idx) = args.arguments.get(2) { + let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { // This is for test only, safe to unwrap let take_idx = take_idx .unwrap() @@ -841,7 +841,7 @@ impl ScalarUDFImpl for TakeUDF { } else { return plan_err!( "The third argument must be a literal of type int64, but got {:?}", - args.arguments.get(2) + args.scalar_arguments.get(2) ); }; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d9598b467e90..08eb06160c09 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -427,7 +427,7 @@ impl ExprSchemable for Expr { .collect::>(); let args = ReturnTypeArgs { arg_types: &new_data_types, - arguments: &arguments, + scalar_arguments: &arguments, nullables: &nullables, }; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index f6397409205d..237f567425b1 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -348,15 +348,30 @@ pub struct ScalarFunctionArgs<'a> { pub return_type: &'a DataType, } +/// Information about arguments passed to the function +/// +/// This structure contains metadata about how the function was called +/// such as the type of the arguments, any scalar arguments and if the +/// arguments can (ever) be null +/// +/// See [`ScalarUDFImpl::return_type_from_args`] for more information #[derive(Debug)] pub struct ReturnTypeArgs<'a> { /// The data types of the arguments to the function pub arg_types: &'a [DataType], - /// The Utf8 arguments to the function, if the expression is not Utf8, it will be empty string - pub arguments: &'a [Option<&'a ScalarValue>], + /// Is argument `i` to the function a scalar (constant) + /// + /// If argument `i` is not a scalar, it will be None + /// + /// For example, if a function is called like `my_function(column_a, 5)` + /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` + pub scalar_arguments: &'a [Option<&'a ScalarValue>], pub nullables: &'a [bool], } +/// Return metadata for this function. +/// +/// See [`ScalarUDFImpl::return_type_from_args`] for more information #[derive(Debug)] pub struct ReturnInfo { return_type: DataType, diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index a25e94916b60..b0fba57460f8 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -118,9 +118,9 @@ impl ScalarUDFImpl for ArrowCastFunc { let nullable = args.nullables.iter().any(|&nullable| nullable); // Length check handled in the signature - debug_assert_eq!(args.arguments.len(), 2); + debug_assert_eq!(args.scalar_arguments.len(), 2); - args.arguments[1] + args.scalar_arguments[1] .and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty())) .map_or_else( || { diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3d83cdfdcbe9..7c72d4594583 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -148,9 +148,9 @@ impl ScalarUDFImpl for GetFieldFunc { fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // Length check handled in the signature - debug_assert_eq!(args.arguments.len(), 2); + debug_assert_eq!(args.scalar_arguments.len(), 2); - match (&args.arg_types[0], args.arguments[1].as_ref()) { + match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index b055799484af..70c9a425790c 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -160,21 +160,21 @@ impl ScalarUDFImpl for NamedStructFunc { fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // do not accept 0 arguments. - if args.arguments.is_empty() { + if args.scalar_arguments.is_empty() { return exec_err!( "named_struct requires at least one pair of arguments, got 0 instead" ); } - if args.arguments.len() % 2 != 0 { + if args.scalar_arguments.len() % 2 != 0 { return exec_err!( "named_struct requires an even number of arguments, got {} instead", - args.arguments.len() + args.scalar_arguments.len() ); } let names = args - .arguments + .scalar_arguments .iter() .enumerate() .step_by(2) diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 94fed982b650..bec378e137c0 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -141,9 +141,9 @@ impl ScalarUDFImpl for DatePartFunc { fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // Length check handled in the signature - debug_assert_eq!(args.arguments.len(), 2); + debug_assert_eq!(args.scalar_arguments.len(), 2); - args.arguments[0] + args.scalar_arguments[0] .and_then(|sv| { sv.try_as_str() .flatten() diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 90d88b00ea66..534b7a4fa638 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -84,12 +84,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { // Length check handled in the signature - debug_assert!(matches!(args.arguments.len(), 1 | 2)); + debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); - if args.arguments.len() == 1 { + if args.scalar_arguments.len() == 1 { Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) } else { - args.arguments[1] + args.scalar_arguments[1] .and_then(|sv| { sv.try_as_str() .flatten() diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index c78ef2ed7dcb..936adbc098d6 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -116,7 +116,7 @@ impl ScalarFunctionExpr { .collect::>(); let ret_args = ReturnTypeArgs { arg_types: &arg_types, - arguments: &arguments, + scalar_arguments: &arguments, nullables: &nullables, }; let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); From 78b81731c796fc5b8f7399ebc58a8a812e546dba Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 20 Jan 2025 08:50:24 +0800 Subject: [PATCH 26/28] refactor: add documentation for nullability of scalar arguments in ReturnTypeArgs --- datafusion/expr/src/udf.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 237f567425b1..bb5a405a9352 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -366,6 +366,7 @@ pub struct ReturnTypeArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], + /// Can argument `i` (ever) null? pub nullables: &'a [bool], } From a72f11652f27a73ae4f58f87884e125ee1294ff5 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 20 Jan 2025 08:53:37 +0800 Subject: [PATCH 27/28] rm test Signed-off-by: Jay Zhan --- datafusion/common/src/utils/mod.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index bcee49cc0e08..87943121c7fc 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -1202,13 +1202,4 @@ mod tests { assert_eq!(expected, transposed); Ok(()) } - - #[test] - fn test132() { - let a = Arc::new(Int32Array::from(vec![3; 200])) as ArrayRef; - println!( - "display {}", - pretty_format_columns("ColumnarValue(ArrayRef)", &[a]).unwrap() - ) - } } From 61abb93518b16add69f1aecfd46e99d974459909 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 20 Jan 2025 09:19:42 +0800 Subject: [PATCH 28/28] refactor: remove unused import of Int32Array in utils tests --- datafusion/common/src/utils/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 87943121c7fc..29d33fec14ab 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -889,8 +889,7 @@ pub fn get_available_parallelism() -> usize { mod tests { use super::*; use crate::ScalarValue::Null; - use arrow::{array::Float64Array, util::pretty::pretty_format_columns}; - use arrow_array::Int32Array; + use arrow::array::Float64Array; use sqlparser::tokenizer::Span; #[test]