From 30be8a064be5cac85fbcc298e9e5ad471d28f96a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 13:40:13 +0800 Subject: [PATCH 01/14] use logical type for signature Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 2 + datafusion/common/src/types/logical.rs | 6 ++ datafusion/common/src/types/native.rs | 46 +++++++++++-- datafusion/expr-common/src/signature.rs | 10 ++- .../expr/src/type_coercion/functions.rs | 64 ++++++++++++++----- datafusion/functions-aggregate/src/stddev.rs | 3 +- .../functions-aggregate/src/variance.rs | 5 +- datafusion/functions/src/string/bit_length.rs | 2 +- datafusion/functions/src/string/concat.rs | 2 +- datafusion/functions/src/string/concat_ws.rs | 2 +- datafusion/functions/src/string/lower.rs | 2 +- .../functions/src/string/octet_length.rs | 2 +- datafusion/functions/src/string/upper.rs | 2 +- .../functions/src/unicode/character_length.rs | 2 +- datafusion/functions/src/unicode/lpad.rs | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- 16 files changed, 118 insertions(+), 36 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 541d464d381f..f91b5faf3240 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1382,6 +1382,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "hashbrown 0.14.5", "hex", "itertools", @@ -1404,6 +1405,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs index bde393992a0c..a65392cae344 100644 --- a/datafusion/common/src/types/logical.rs +++ b/datafusion/common/src/types/logical.rs @@ -98,6 +98,12 @@ impl fmt::Debug for dyn LogicalType { } } +impl std::fmt::Display for dyn LogicalType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + impl PartialEq for dyn LogicalType { fn eq(&self, other: &Self) -> bool { self.signature().eq(&other.signature()) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index bfb546783ea2..8241e64a9aab 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -16,8 +16,8 @@ // under the License. use super::{ - LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, - TypeSignature, + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalTypeRef, + LogicalUnionFields, TypeSignature, }; use crate::error::{Result, _internal_err}; use arrow::compute::can_cast_types; @@ -25,6 +25,7 @@ use arrow_schema::{ DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; use std::sync::Arc; +use std::sync::{Arc, OnceLock}; /// Representation of a type that DataFusion can handle natively. It is a subset /// of the physical variants in Arrow's native [`DataType`]. @@ -348,6 +349,12 @@ impl LogicalType for NativeType { // mapping solutions to provide backwards compatibility while transitioning from // the purely physical system to a logical / physical system. +impl From<&DataType> for NativeType { + fn from(value: &DataType) -> Self { + value.clone().into() + } +} + impl From for NativeType { fn from(value: DataType) -> Self { use NativeType::*; @@ -392,8 +399,37 @@ impl From for NativeType { } } -impl From<&DataType> for NativeType { - fn from(value: &DataType) -> Self { - value.clone().into() +impl NativeType { + #[inline] + pub fn is_numeric(&self) -> bool { + use NativeType::*; + matches!( + self, + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float16 + | Float32 + | Float64 + ) + } + + /// This function is the NativeType version of `can_cast_types`. + /// It handles general coercion rules that are widely applicable. + /// Avoid adding specific coercion cases here. + /// Aim to keep this logic as SIMPLE as possible! + pub fn can_cast_to(&self, target_type: &Self) -> bool { + // In Postgres, most functions coerce numeric strings to numeric inputs, + // but they do not accept numeric inputs as strings. + if self.is_numeric() && target_type == &NativeType::String { + return false; + } + + true } } diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 24cb54f634b1..dae979cffc67 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -19,6 +19,7 @@ //! and return types of functions in DataFusion. use arrow::datatypes::DataType; +use datafusion_common::types::LogicalTypeRef; /// Constant that is used as a placeholder for any valid timezone. /// This is used where a function can accept a timestamp type with any @@ -109,7 +110,7 @@ pub enum TypeSignature { /// For example, `Coercible(vec![DataType::Float64])` accepts /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` /// since i32 and f32 can be casted to f64 - Coercible(Vec), + Coercible(Vec), /// Fixed number of arguments of arbitrary types /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` Any(usize), @@ -201,7 +202,10 @@ impl TypeSignature { TypeSignature::Numeric(num) => { vec![format!("Numeric({num})")] } - TypeSignature::Exact(types) | TypeSignature::Coercible(types) => { + TypeSignature::Coercible(types) => { + vec![Self::join_types(types, ", ")] + } + TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] } TypeSignature::Any(arg_count) => { @@ -322,7 +326,7 @@ impl Signature { } } /// Target coerce types in order - pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { + pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), volatility, diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 85f8e20ba4a5..2a0d15bd9c80 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,6 +23,7 @@ use arrow::{ }; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, + types::{logical_string, NativeType}, utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; @@ -401,6 +402,10 @@ fn get_valid_types( .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::String(number) => { + // TODO: we can switch to coercible after all the string functions support utf8view since it is choosen as the default string type. + // + // let data_types = get_valid_types(&TypeSignature::Coercible(vec![logical_string(); *number]), current_types)?.swap_remove(0); + if *number < 1 { return plan_err!( "The signature expected at least one argument but received {}", @@ -415,20 +420,38 @@ fn get_valid_types( ); } - fn coercion_rule( + let mut new_types = Vec::with_capacity(current_types.len()); + for data_type in current_types.iter() { + let logical_data_type: NativeType = data_type.into(); + + match logical_data_type { + NativeType::String => { + new_types.push(data_type.to_owned()); + } + NativeType::Null => { + new_types.push(DataType::Utf8); + } + _ => { + return plan_err!( + "The signature expected NativeType::String but received {data_type}" + ); + } + } + } + + let data_types = new_types; + + // Find the common string type for the given types + fn find_common_type( lhs_type: &DataType, rhs_type: &DataType, ) -> Result { match (lhs_type, rhs_type) { - (DataType::Null, DataType::Null) => Ok(DataType::Utf8), - (DataType::Null, data_type) | (data_type, DataType::Null) => { - coercion_rule(data_type, &DataType::Utf8) - } (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => { - coercion_rule(lhs, rhs) + find_common_type(lhs, rhs) } (DataType::Dictionary(_, v), other) - | (other, DataType::Dictionary(_, v)) => coercion_rule(v, other), + | (other, DataType::Dictionary(_, v)) => find_common_type(v, other), _ => { if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) { Ok(coerced_type) @@ -444,15 +467,13 @@ fn get_valid_types( } // Length checked above, safe to unwrap - let mut coerced_type = current_types.first().unwrap().to_owned(); - for t in current_types.iter().skip(1) { - coerced_type = coercion_rule(&coerced_type, t)?; + let mut coerced_type = data_types.first().unwrap().to_owned(); + for t in data_types.iter().skip(1) { + coerced_type = find_common_type(&coerced_type, t)?; } fn base_type_or_default_type(data_type: &DataType) -> DataType { - if data_type.is_null() { - DataType::Utf8 - } else if let DataType::Dictionary(_, v) = data_type { + if let DataType::Dictionary(_, v) = data_type { base_type_or_default_type(v) } else { data_type.to_owned() @@ -506,14 +527,25 @@ fn get_valid_types( ); } + let mut new_types = Vec::with_capacity(current_types.len()); for (data_type, target_type) in current_types.iter().zip(target_types.iter()) { - if !can_cast_types(data_type, target_type) { - return plan_err!("{data_type} is not coercible to {target_type}"); + let logical_data_type: NativeType = data_type.into(); + if logical_data_type == *target_type.native() { + new_types.push(data_type.to_owned()); + } else if logical_data_type.can_cast_to(target_type.native()) { + let casted_type = target_type.default_cast_for(data_type)?; + new_types.push(casted_type); + } else { + return plan_err!( + "The signature expected {:?} but received {:?}", + target_type.native(), + logical_data_type + ); } } - vec![target_types.to_owned()] + vec![new_types] } TypeSignature::Uniform(number, valid_types) => valid_types .iter() diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 95269ed8217c..dbd3dafc4053 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -25,6 +25,7 @@ use std::sync::{Arc, OnceLock}; use arrow::array::Float64Array; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::types::logical_float64; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; @@ -72,7 +73,7 @@ impl Stddev { pub fn new() -> Self { Self { signature: Signature::coercible( - vec![DataType::Float64], + vec![logical_float64()], Volatility::Immutable, ), alias: vec!["stddev_samp".to_string()], diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 810247a2884a..75245f25c7d4 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -29,7 +29,8 @@ use std::sync::OnceLock; use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ - downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, + downcast_value, not_impl_err, plan_err, types::logical_float64, DataFusionError, + Result, ScalarValue, }; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::{ @@ -83,7 +84,7 @@ impl VarianceSample { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], signature: Signature::coercible( - vec![DataType::Float64], + vec![logical_float64()], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 25b56341fcaa..d02c2b6a65f4 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -79,7 +79,7 @@ impl ScalarUDFImpl for BitLengthFunc { ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), )), - _ => unreachable!(), + _ => unreachable!("bit length"), }, } } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a4218c39e7b2..5d63b21f5db7 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -186,7 +186,7 @@ impl ScalarUDFImpl for ConcatFunc { } }; } - _ => unreachable!(), + _ => unreachable!("concat"), } } diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 8d966f495663..264436054ec6 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ConcatWsFunc { ColumnarValueRef::NonNullableArray(string_array) } } - _ => unreachable!(), + _ => unreachable!("concat ws"), }; let mut columns = Vec::with_capacity(args.len() - 1); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index b07189a832dc..9a87d22fc5ad 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -107,7 +107,7 @@ mod tests { let args = vec![ColumnarValue::Array(input)]; let result = match func.invoke(&args)? { ColumnarValue::Array(result) => result, - _ => unreachable!(), + _ => unreachable!("lower"), }; assert_eq!(&expected, &result); Ok(()) diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 2ac2bf70da23..89f71d457199 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -82,7 +82,7 @@ impl ScalarUDFImpl for OctetLengthFunc { ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar( ScalarValue::Int32(v.as_ref().map(|x| x.len() as i32)), )), - _ => unreachable!(), + _ => unreachable!("OctetLengthFunc"), }, } } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 042c26b2e3da..600fa1dd2b93 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -107,7 +107,7 @@ mod tests { let args = vec![ColumnarValue::Array(input)]; let result = match func.invoke(&args)? { ColumnarValue::Array(result) => result, - _ => unreachable!(), + _ => unreachable!("upper"), }; assert_eq!(&expected, &result); Ok(()) diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 7858a59664d3..eca8d3fd493d 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -128,7 +128,7 @@ fn character_length(args: &[ArrayRef]) -> Result { let string_array = args[0].as_string_view(); character_length_general::(string_array) } - _ => unreachable!(), + _ => unreachable!("CharacterLengthFunc"), } } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 767eda203c8f..a639bcedcd1f 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -162,7 +162,7 @@ pub fn lpad(args: &[ArrayRef]) -> Result { length_array, &args[2], ), - (_, _) => unreachable!(), + (_, _) => unreachable!("lpad"), } } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 145172f31fd7..5770da06639a 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1940,7 +1940,7 @@ select position('' in '') ---- 1 -query error DataFusion error: Error during planning: Error during planning: Int64 and Int64 are not coercible to a common string +query error DataFusion error: Error during planning: Error during planning: The signature expected NativeType::String but received Int64 select position(1 in 1) query I From 4b6d4336f52bda04e04eb20631e32ff7475d2966 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 15:11:23 +0800 Subject: [PATCH 02/14] fmt & clippy Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 2 -- datafusion/common/src/types/native.rs | 5 ++--- datafusion/expr/src/type_coercion/functions.rs | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f91b5faf3240..541d464d381f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1382,7 +1382,6 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-expr-common", "hashbrown 0.14.5", "hex", "itertools", @@ -1405,7 +1404,6 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-expr-common", "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 8241e64a9aab..c606306acd77 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -16,8 +16,8 @@ // under the License. use super::{ - LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalTypeRef, - LogicalUnionFields, TypeSignature, + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, + TypeSignature, }; use crate::error::{Result, _internal_err}; use arrow::compute::can_cast_types; @@ -25,7 +25,6 @@ use arrow_schema::{ DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; use std::sync::Arc; -use std::sync::{Arc, OnceLock}; /// Representation of a type that DataFusion can handle natively. It is a subset /// of the physical variants in Arrow's native [`DataType`]. diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 2a0d15bd9c80..f609159650ac 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,7 +23,7 @@ use arrow::{ }; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, - types::{logical_string, NativeType}, + types::NativeType, utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; From 5cb2aa972eebbfb67869fb14482a33d32241338a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 16:50:36 +0800 Subject: [PATCH 03/14] numeric Signed-off-by: jayzhan211 --- datafusion/common/src/types/native.rs | 1 + .../expr/src/type_coercion/functions.rs | 16 +++++++++++-- .../functions-aggregate/src/first_last.rs | 24 ++++--------------- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index c606306acd77..a9109415a58e 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -415,6 +415,7 @@ impl NativeType { | Float16 | Float32 | Float64 + | Decimal(_, _) ) } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index f609159650ac..684ab5252911 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -497,8 +497,20 @@ fn get_valid_types( ); } - let mut valid_type = current_types.first().unwrap().clone(); - for t in current_types.iter().skip(1) { + let mut new_types = Vec::with_capacity(current_types.len()); + for data_type in current_types.iter() { + let logical_data_type: NativeType = data_type.into(); + if logical_data_type.is_numeric() { + new_types.push(data_type.to_owned()); + } else { + return plan_err!( + "The signature expected NativeType::Numeric but received {data_type}" + ); + } + } + + let mut valid_type = new_types.first().unwrap().clone(); + for t in new_types.iter().skip(1) { if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { valid_type = coerced_type; } else { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0b05713499a9..d9ac539aac49 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -33,8 +33,8 @@ use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr, - ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature, + SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; @@ -79,15 +79,7 @@ impl Default for FirstValue { impl FirstValue { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // TODO: we can introduce more strict signature that only numeric of array types are allowed - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), - TypeSignature::Numeric(1), - TypeSignature::Uniform(1, vec![DataType::Utf8]), - ], - Volatility::Immutable, - ), + signature: Signature::any(1, Volatility::Immutable), requirement_satisfied: false, } } @@ -406,15 +398,7 @@ impl Default for LastValue { impl LastValue { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // TODO: we can introduce more strict signature that only numeric of array types are allowed - TypeSignature::ArraySignature(ArrayFunctionSignature::Array), - TypeSignature::Numeric(1), - TypeSignature::Uniform(1, vec![DataType::Utf8]), - ], - Volatility::Immutable, - ), + signature: Signature::any(1, Volatility::Immutable), requirement_satisfied: false, } } From c9ac3c5a13788df012ec249fb163be5c75f65dfa Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 20:18:18 +0800 Subject: [PATCH 04/14] fix numeric Signed-off-by: jayzhan211 --- datafusion/common/src/types/native.rs | 20 +++++-- .../expr/src/type_coercion/functions.rs | 56 +++++++++++-------- datafusion/functions/src/math/abs.rs | 2 +- datafusion/sqllogictest/test_files/expr.slt | 2 +- datafusion/sqllogictest/test_files/math.slt | 2 +- datafusion/sqllogictest/test_files/scalar.slt | 4 +- 6 files changed, 52 insertions(+), 34 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index a9109415a58e..537f147078f6 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -423,13 +423,21 @@ impl NativeType { /// It handles general coercion rules that are widely applicable. /// Avoid adding specific coercion cases here. /// Aim to keep this logic as SIMPLE as possible! - pub fn can_cast_to(&self, target_type: &Self) -> bool { - // In Postgres, most functions coerce numeric strings to numeric inputs, - // but they do not accept numeric inputs as strings. - if self.is_numeric() && target_type == &NativeType::String { - return false; + /// + /// Ensure there is a corresponding test for this function. + pub fn can_coerce_to(&self, target_type: &Self) -> bool { + if self.eq(target_type) { + return true; } - true + if self.is_numeric() && target_type.is_numeric() { + return true; + } + + if self.eq(&NativeType::Null) && target_type == &NativeType::String { + return true; + } + + false } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 684ab5252911..6140ff23c483 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,7 +23,7 @@ use arrow::{ }; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, - types::NativeType, + types::{LogicalType, NativeType}, utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; @@ -423,19 +423,16 @@ fn get_valid_types( let mut new_types = Vec::with_capacity(current_types.len()); for data_type in current_types.iter() { let logical_data_type: NativeType = data_type.into(); - - match logical_data_type { - NativeType::String => { - new_types.push(data_type.to_owned()); - } - NativeType::Null => { + if logical_data_type.can_coerce_to(&NativeType::String) { + if data_type.is_null() { new_types.push(DataType::Utf8); + } else { + new_types.push(data_type.to_owned()); } - _ => { - return plan_err!( - "The signature expected NativeType::String but received {data_type}" - ); - } + } else { + return plan_err!( + "The signature expected NativeType::String but received {data_type}" + ); } } @@ -497,20 +494,25 @@ fn get_valid_types( ); } - let mut new_types = Vec::with_capacity(current_types.len()); - for data_type in current_types.iter() { - let logical_data_type: NativeType = data_type.into(); - if logical_data_type.is_numeric() { - new_types.push(data_type.to_owned()); - } else { + // Find common numeric type amongs given types except string + let mut valid_type = current_types.first().unwrap().to_owned(); + for t in current_types.iter().skip(1) { + let logical_data_type: NativeType = t.into(); + // Skip string, assume it is numeric string, let arrow::cast handle the actual casting logic + if logical_data_type == NativeType::String { + continue; + } + + if logical_data_type == NativeType::Null { + continue; + } + + if !logical_data_type.is_numeric() { return plan_err!( - "The signature expected NativeType::Numeric but received {data_type}" + "The signature expected NativeType::Numeric but received {t}" ); } - } - let mut valid_type = new_types.first().unwrap().clone(); - for t in new_types.iter().skip(1) { if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { valid_type = coerced_type; } else { @@ -522,6 +524,14 @@ fn get_valid_types( } } + let logical_data_type: NativeType = valid_type.clone().into(); + // Fallback to default type if we don't know which type to coerced to + // f64 is choosen since most of the math function utilize Signature::numeric, + // and their default type is double precision + if matches!(logical_data_type, NativeType::String | NativeType::Null) { + valid_type = DataType::Float64; + } + vec![vec![valid_type; *number]] } TypeSignature::Coercible(target_types) => { @@ -545,7 +555,7 @@ fn get_valid_types( let logical_data_type: NativeType = data_type.into(); if logical_data_type == *target_type.native() { new_types.push(data_type.to_owned()); - } else if logical_data_type.can_cast_to(target_type.native()) { + } else if logical_data_type.can_coerce_to(target_type.native()) { let casted_type = target_type.default_cast_for(data_type)?; new_types.push(casted_type); } else { diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 5511a57d8566..798939162a63 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -117,7 +117,7 @@ impl Default for AbsFunc { impl AbsFunc { pub fn new() -> Self { Self { - signature: Signature::any(1, Volatility::Immutable), + signature: Signature::numeric(1, Volatility::Immutable), } } } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 182afff7a693..708abcde8eac 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -22,7 +22,7 @@ SELECT true, false, false = false, true = false true false true false # test_mathematical_expressions_with_null -query RRRRRRRRRRRRRRRRRR?RRRRRIIIRRRRRRBB +query RRRRRRRRRRRRRRRRRRRRRRRRIIIRRRRRRBB SELECT sqrt(NULL), cbrt(NULL), diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 1bc972a3e37d..7f431092ff6c 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -126,7 +126,7 @@ statement error SELECT abs(1, 2); # abs: unsupported argument type -query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Float64 type SELECT abs('foo'); diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 5770da06639a..fdd28a2db9f9 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -67,13 +67,13 @@ CREATE TABLE small_floats( ## abs # abs scalar function -query III rowsort +query III select abs(64), abs(0), abs(-64); ---- 64 0 64 # abs scalar nulls -query ? rowsort +query R select abs(null); ---- NULL From ec65ca1346082c10a6d6723f0a0a855b9d155a87 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 20:56:12 +0800 Subject: [PATCH 05/14] deprecate coercible Signed-off-by: jayzhan211 --- datafusion/common/src/types/native.rs | 22 -------- datafusion/expr-common/src/signature.rs | 3 +- .../expr/src/type_coercion/functions.rs | 55 +++---------------- datafusion/functions-aggregate/src/stddev.rs | 6 +- .../functions-aggregate/src/variance.rs | 8 +-- 5 files changed, 14 insertions(+), 80 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 537f147078f6..f359d9b11066 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -418,26 +418,4 @@ impl NativeType { | Decimal(_, _) ) } - - /// This function is the NativeType version of `can_cast_types`. - /// It handles general coercion rules that are widely applicable. - /// Avoid adding specific coercion cases here. - /// Aim to keep this logic as SIMPLE as possible! - /// - /// Ensure there is a corresponding test for this function. - pub fn can_coerce_to(&self, target_type: &Self) -> bool { - if self.eq(target_type) { - return true; - } - - if self.is_numeric() && target_type.is_numeric() { - return true; - } - - if self.eq(&NativeType::Null) && target_type == &NativeType::String { - return true; - } - - false - } } diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index dae979cffc67..95c63bd2a4ca 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -124,7 +124,7 @@ pub enum TypeSignature { /// Specifies Signatures for array functions ArraySignature(ArrayFunctionSignature), /// Fixed number of arguments of numeric types. - /// See to know which type is considered numeric + /// See [`NativeType::is_numeric`] to know which type is considered numeric Numeric(usize), /// Fixed number of arguments of all the same string types. /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. @@ -326,6 +326,7 @@ impl Signature { } } /// Target coerce types in order + #[deprecated(since = "42.0.0", note = "Use String, Numeric")] pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 6140ff23c483..e84ef721c8df 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,7 +23,7 @@ use arrow::{ }; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, - types::{LogicalType, NativeType}, + types::NativeType, utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; @@ -402,10 +402,6 @@ fn get_valid_types( .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::String(number) => { - // TODO: we can switch to coercible after all the string functions support utf8view since it is choosen as the default string type. - // - // let data_types = get_valid_types(&TypeSignature::Coercible(vec![logical_string(); *number]), current_types)?.swap_remove(0); - if *number < 1 { return plan_err!( "The signature expected at least one argument but received {}", @@ -423,12 +419,11 @@ fn get_valid_types( let mut new_types = Vec::with_capacity(current_types.len()); for data_type in current_types.iter() { let logical_data_type: NativeType = data_type.into(); - if logical_data_type.can_coerce_to(&NativeType::String) { - if data_type.is_null() { - new_types.push(DataType::Utf8); - } else { - new_types.push(data_type.to_owned()); - } + if logical_data_type == NativeType::String { + new_types.push(data_type.to_owned()); + } else if logical_data_type == NativeType::Null { + // TODO: Switch to Utf8View if all the string functions supports Utf8View + new_types.push(DataType::Utf8); } else { return plan_err!( "The signature expected NativeType::String but received {data_type}" @@ -526,7 +521,7 @@ fn get_valid_types( let logical_data_type: NativeType = valid_type.clone().into(); // Fallback to default type if we don't know which type to coerced to - // f64 is choosen since most of the math function utilize Signature::numeric, + // f64 is choosen since most of the math function utilize Signature::numeric, // and their default type is double precision if matches!(logical_data_type, NativeType::String | NativeType::Null) { valid_type = DataType::Float64; @@ -534,40 +529,8 @@ fn get_valid_types( vec![vec![valid_type; *number]] } - TypeSignature::Coercible(target_types) => { - if target_types.is_empty() { - return plan_err!( - "The signature expected at least one argument but received {}", - current_types.len() - ); - } - if target_types.len() != current_types.len() { - return plan_err!( - "The signature expected {} arguments but received {}", - target_types.len(), - current_types.len() - ); - } - - let mut new_types = Vec::with_capacity(current_types.len()); - for (data_type, target_type) in current_types.iter().zip(target_types.iter()) - { - let logical_data_type: NativeType = data_type.into(); - if logical_data_type == *target_type.native() { - new_types.push(data_type.to_owned()); - } else if logical_data_type.can_coerce_to(target_type.native()) { - let casted_type = target_type.default_cast_for(data_type)?; - new_types.push(casted_type); - } else { - return plan_err!( - "The signature expected {:?} but received {:?}", - target_type.native(), - logical_data_type - ); - } - } - - vec![new_types] + TypeSignature::Coercible(_target_types) => { + return plan_err!("Deprecated, use String, Numeric directly"); } TypeSignature::Uniform(number, valid_types) => valid_types .iter() diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index dbd3dafc4053..5c6c92c486a0 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -25,7 +25,6 @@ use std::sync::{Arc, OnceLock}; use arrow::array::Float64Array; use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::types::logical_float64; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; @@ -72,10 +71,7 @@ impl Stddev { /// Create a new STDDEV aggregate function pub fn new() -> Self { Self { - signature: Signature::coercible( - vec![logical_float64()], - Volatility::Immutable, - ), + signature: Signature::numeric(1, Volatility::Immutable), alias: vec!["stddev_samp".to_string()], } } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 75245f25c7d4..344e21b40662 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -29,8 +29,7 @@ use std::sync::OnceLock; use std::{fmt::Debug, sync::Arc}; use datafusion_common::{ - downcast_value, not_impl_err, plan_err, types::logical_float64, DataFusionError, - Result, ScalarValue, + downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; use datafusion_expr::{ @@ -83,10 +82,7 @@ impl VarianceSample { pub fn new() -> Self { Self { aliases: vec![String::from("var_sample"), String::from("var_samp")], - signature: Signature::coercible( - vec![logical_float64()], - Volatility::Immutable, - ), + signature: Signature::numeric(1, Volatility::Immutable), } } } From 1aba4cbbdc67f33d88df6af799a6c47b55ba25d7 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 21:16:52 +0800 Subject: [PATCH 06/14] introduce numeric and numeric string Signed-off-by: jayzhan211 --- datafusion/expr-common/src/signature.rs | 17 ++++ .../expr/src/type_coercion/functions.rs | 85 +++++++++++++------ datafusion/functions/src/math/abs.rs | 2 +- datafusion/sqllogictest/test_files/math.slt | 5 ++ 4 files changed, 81 insertions(+), 28 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 95c63bd2a4ca..59346eadb29b 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -126,6 +126,13 @@ pub enum TypeSignature { /// Fixed number of arguments of numeric types. /// See [`NativeType::is_numeric`] to know which type is considered numeric Numeric(usize), + /// Fixed number of arguments of numeric types. + /// See [`NativeType::is_numeric`] to know which type is considered numeric + /// This signature accepts numeric string + /// Example of functions In Postgres that support numeric string + /// 1. Mathematical Functions, like `abs` + /// 2. `to_timestamp` + NumericAndNumericString(usize), /// Fixed number of arguments of all the same string types. /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. /// Null is considerd as `Utf8` by default @@ -202,6 +209,9 @@ impl TypeSignature { TypeSignature::Numeric(num) => { vec![format!("Numeric({num})")] } + TypeSignature::NumericAndNumericString(num) => { + vec![format!("NumericAndNumericString({num})")] + } TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] } @@ -292,6 +302,13 @@ impl Signature { } } + pub fn numeric_and_numeric_string(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::NumericAndNumericString(arg_count), + volatility, + } + } + /// A specified number of numeric arguments pub fn string(arg_count: usize, volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index e84ef721c8df..7fbb7646e874 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -178,6 +178,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { type_signature, TypeSignature::UserDefined | TypeSignature::Numeric(_) + | TypeSignature::NumericAndNumericString(_) | TypeSignature::String(_) | TypeSignature::Coercible(_) | TypeSignature::Any(_) @@ -396,25 +397,29 @@ fn get_valid_types( } } + fn function_length_check(length: usize, expected_length: usize) -> Result<()> { + if length < 1 { + return plan_err!( + "The signature expected at least one argument but received {expected_length}" + ); + } + + if length != expected_length { + return plan_err!( + "The signature expected {length} arguments but received {expected_length}" + ); + } + + Ok(()) + } + let valid_types = match signature { TypeSignature::Variadic(valid_types) => valid_types .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::String(number) => { - if *number < 1 { - return plan_err!( - "The signature expected at least one argument but received {}", - current_types.len() - ); - } - if *number != current_types.len() { - return plan_err!( - "The signature expected {} arguments but received {}", - number, - current_types.len() - ); - } + function_length_check(current_types.len(), *number)?; let mut new_types = Vec::with_capacity(current_types.len()); for data_type in current_types.iter() { @@ -474,20 +479,8 @@ fn get_valid_types( vec![vec![base_type_or_default_type(&coerced_type); *number]] } - TypeSignature::Numeric(number) => { - if *number < 1 { - return plan_err!( - "The signature expected at least one argument but received {}", - current_types.len() - ); - } - if *number != current_types.len() { - return plan_err!( - "The signature expected {} arguments but received {}", - number, - current_types.len() - ); - } + TypeSignature::NumericAndNumericString(number) => { + function_length_check(current_types.len(), *number)?; // Find common numeric type amongs given types except string let mut valid_type = current_types.first().unwrap().to_owned(); @@ -529,6 +522,44 @@ fn get_valid_types( vec![vec![valid_type; *number]] } + TypeSignature::Numeric(number) => { + function_length_check(current_types.len(), *number)?; + + // Find common numeric type amongs given types except string + let mut valid_type = current_types.first().unwrap().to_owned(); + for t in current_types.iter().skip(1) { + let logical_data_type: NativeType = t.into(); + if logical_data_type == NativeType::Null { + continue; + } + + if !logical_data_type.is_numeric() { + return plan_err!( + "The signature expected NativeType::Numeric but received {t}" + ); + } + + if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { + valid_type = coerced_type; + } else { + return plan_err!( + "{} and {} are not coercible to a common numeric type", + valid_type, + t + ); + } + } + + let logical_data_type: NativeType = valid_type.clone().into(); + // Fallback to default type if we don't know which type to coerced to + // f64 is choosen since most of the math function utilize Signature::numeric, + // and their default type is double precision + if logical_data_type == NativeType::Null { + valid_type = DataType::Float64; + } + + vec![vec![valid_type; *number]] + } TypeSignature::Coercible(_target_types) => { return plan_err!("Deprecated, use String, Numeric directly"); } diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 798939162a63..363b147f2de9 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -117,7 +117,7 @@ impl Default for AbsFunc { impl AbsFunc { pub fn new() -> Self { Self { - signature: Signature::numeric(1, Volatility::Immutable), + signature: Signature::numeric_and_numeric_string(1, Volatility::Immutable), } } } diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 7f431092ff6c..760e99d80e44 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -129,6 +129,11 @@ SELECT abs(1, 2); query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Float64 type SELECT abs('foo'); +# abs: numeric string +query R +select abs('-1.2'); +---- +1.2 statement ok CREATE TABLE test_nullable_integer( From acd2ceb12f5d10099de1be534f122ee866563eb8 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 21:20:30 +0800 Subject: [PATCH 07/14] fix doc Signed-off-by: jayzhan211 --- datafusion/expr-common/src/signature.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 59346eadb29b..7ea1df7b36cf 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -125,6 +125,8 @@ pub enum TypeSignature { ArraySignature(ArrayFunctionSignature), /// Fixed number of arguments of numeric types. /// See [`NativeType::is_numeric`] to know which type is considered numeric + /// + /// [`NativeType::is_numeric`]: datafusion_common Numeric(usize), /// Fixed number of arguments of numeric types. /// See [`NativeType::is_numeric`] to know which type is considered numeric @@ -132,6 +134,8 @@ pub enum TypeSignature { /// Example of functions In Postgres that support numeric string /// 1. Mathematical Functions, like `abs` /// 2. `to_timestamp` + /// + /// [`NativeType::is_numeric`]: datafusion_common NumericAndNumericString(usize), /// Fixed number of arguments of all the same string types. /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. From 3855c6fc277bcfe7c4ebdd3dbd5881be208c6d35 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 21:27:34 +0800 Subject: [PATCH 08/14] cleanup Signed-off-by: jayzhan211 --- datafusion/expr/src/type_coercion/functions.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 7fbb7646e874..e878563ea480 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -436,8 +436,6 @@ fn get_valid_types( } } - let data_types = new_types; - // Find the common string type for the given types fn find_common_type( lhs_type: &DataType, @@ -464,8 +462,8 @@ fn get_valid_types( } // Length checked above, safe to unwrap - let mut coerced_type = data_types.first().unwrap().to_owned(); - for t in data_types.iter().skip(1) { + let mut coerced_type = new_types.first().unwrap().to_owned(); + for t in new_types.iter().skip(1) { coerced_type = find_common_type(&coerced_type, t)?; } From 804242cfd80dbe15fcdc87be4aa1985f68fe193e Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 22:35:41 +0800 Subject: [PATCH 09/14] add back coercible Signed-off-by: jayzhan211 --- datafusion/common/src/types/native.rs | 9 +++++ datafusion/expr-common/src/signature.rs | 1 - .../expr/src/type_coercion/functions.rs | 40 +++++++++++++++++-- datafusion/functions/src/string/repeat.rs | 16 +++----- datafusion/sqllogictest/test_files/expr.slt | 8 ++++ 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index f359d9b11066..650785e68f42 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -418,4 +418,13 @@ impl NativeType { | Decimal(_, _) ) } + + #[inline] + pub fn is_integer(&self) -> bool { + use NativeType::*; + matches!( + self, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 + ) + } } diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 7ea1df7b36cf..8088e56c3e28 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -347,7 +347,6 @@ impl Signature { } } /// Target coerce types in order - #[deprecated(since = "42.0.0", note = "Use String, Numeric")] pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { Self { type_signature: TypeSignature::Coercible(target_types), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index e878563ea480..8c32bc5c1598 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -23,7 +23,7 @@ use arrow::{ }; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, - types::NativeType, + types::{LogicalType, NativeType}, utils::{coerced_fixed_size_list_to_list, list_ndims}, Result, }; @@ -558,8 +558,42 @@ fn get_valid_types( vec![vec![valid_type; *number]] } - TypeSignature::Coercible(_target_types) => { - return plan_err!("Deprecated, use String, Numeric directly"); + TypeSignature::Coercible(target_types) => { + function_length_check(current_types.len(), target_types.len())?; + + // Aim to keep this logic as SIMPLE as possible! + // Make sure the corresponding test is covered + // If this function becomes COMPLEX, create another new signature! + fn can_cast_to(logical_type: &NativeType, target_type: &NativeType) -> bool { + if logical_type == target_type { + return true; + } + + if logical_type == &NativeType::Null { + return true; + } + + if target_type.is_integer() && logical_type.is_integer() { + return true; + } + + false + } + + let mut new_types = Vec::with_capacity(current_types.len()); + for (current_type, target_type) in + current_types.iter().zip(target_types.iter()) + { + let logical_type: NativeType = current_type.into(); + let target_logical_type = target_type.native(); + if can_cast_to(&logical_type, target_logical_type) { + let target_type = + target_logical_type.default_cast_for(current_type)?; + new_types.push(target_type); + } + } + + vec![new_types] } TypeSignature::Uniform(number, valid_types) => valid_types .iter() diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index aa69f9c6609a..249ce15d6dbe 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -25,11 +25,12 @@ use arrow::array::{ OffsetSizeTrait, StringViewArray, }; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Int64, LargeUtf8, Utf8, Utf8View}; +use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::cast::as_int64_array; +use datafusion_common::types::{logical_int64, logical_string}; use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; -use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] @@ -46,15 +47,8 @@ impl Default for RepeatFunc { impl RepeatFunc { pub fn new() -> Self { Self { - signature: Signature::one_of( - vec![ - // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. - // If that fails, it proceeds to `(Utf8, Int64)`. - TypeSignature::Exact(vec![Utf8View, Int64]), - TypeSignature::Exact(vec![Utf8, Int64]), - TypeSignature::Exact(vec![LargeUtf8, Int64]), - ], + signature: Signature::coercible( + vec![logical_string(), logical_int64()], Volatility::Immutable, ), } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 708abcde8eac..c653113fd438 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -550,6 +550,14 @@ SELECT repeat(NULL, 4) ---- NULL +query T +select repeat('-1.2', arrow_cast(3, 'Int32')); +---- +-1.2-1.2-1.2 + +query error DataFusion error: Error during planning: Error during planning: Coercion from \[Utf8, Float64\] to the signature +select repeat('-1.2', 3.2); + query T SELECT replace('abcdefabcdef', 'cd', 'XX') ---- From 793b0dae255e98b2ec2e2fd8f40f71cabf2885e0 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 4 Nov 2024 22:40:58 +0800 Subject: [PATCH 10/14] rename Signed-off-by: jayzhan211 --- datafusion/expr/src/type_coercion/functions.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 8c32bc5c1598..75dd16c39ed5 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -564,7 +564,7 @@ fn get_valid_types( // Aim to keep this logic as SIMPLE as possible! // Make sure the corresponding test is covered // If this function becomes COMPLEX, create another new signature! - fn can_cast_to(logical_type: &NativeType, target_type: &NativeType) -> bool { + fn can_coerce_to(logical_type: &NativeType, target_type: &NativeType) -> bool { if logical_type == target_type { return true; } @@ -586,7 +586,7 @@ fn get_valid_types( { let logical_type: NativeType = current_type.into(); let target_logical_type = target_type.native(); - if can_cast_to(&logical_type, target_logical_type) { + if can_coerce_to(&logical_type, target_logical_type) { let target_type = target_logical_type.default_cast_for(current_type)?; new_types.push(target_type); From 2cfa273d036805f7ddd63101dfc17c6b2c3fe2d0 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 5 Nov 2024 07:31:04 +0800 Subject: [PATCH 11/14] fmt Signed-off-by: jayzhan211 --- datafusion/expr/src/type_coercion/functions.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 75dd16c39ed5..245b889b3077 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -564,7 +564,10 @@ fn get_valid_types( // Aim to keep this logic as SIMPLE as possible! // Make sure the corresponding test is covered // If this function becomes COMPLEX, create another new signature! - fn can_coerce_to(logical_type: &NativeType, target_type: &NativeType) -> bool { + fn can_coerce_to( + logical_type: &NativeType, + target_type: &NativeType, + ) -> bool { if logical_type == target_type { return true; } From 6e97cb15cde6c4cdfea6e54a4d7c6b62ec51fa50 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 6 Nov 2024 08:28:08 +0800 Subject: [PATCH 12/14] rm numeric string signature Signed-off-by: jayzhan211 --- datafusion/expr-common/src/signature.rs | 19 -------- .../expr/src/type_coercion/functions.rs | 44 ------------------- datafusion/functions/src/math/abs.rs | 2 +- datafusion/sqllogictest/test_files/math.slt | 10 +++-- .../sqllogictest/test_files/timestamps.slt | 7 +++ 5 files changed, 14 insertions(+), 68 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 8088e56c3e28..e0aa6153da9d 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -128,15 +128,6 @@ pub enum TypeSignature { /// /// [`NativeType::is_numeric`]: datafusion_common Numeric(usize), - /// Fixed number of arguments of numeric types. - /// See [`NativeType::is_numeric`] to know which type is considered numeric - /// This signature accepts numeric string - /// Example of functions In Postgres that support numeric string - /// 1. Mathematical Functions, like `abs` - /// 2. `to_timestamp` - /// - /// [`NativeType::is_numeric`]: datafusion_common - NumericAndNumericString(usize), /// Fixed number of arguments of all the same string types. /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. /// Null is considerd as `Utf8` by default @@ -213,9 +204,6 @@ impl TypeSignature { TypeSignature::Numeric(num) => { vec![format!("Numeric({num})")] } - TypeSignature::NumericAndNumericString(num) => { - vec![format!("NumericAndNumericString({num})")] - } TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] } @@ -306,13 +294,6 @@ impl Signature { } } - pub fn numeric_and_numeric_string(arg_count: usize, volatility: Volatility) -> Self { - Self { - type_signature: TypeSignature::NumericAndNumericString(arg_count), - volatility, - } - } - /// A specified number of numeric arguments pub fn string(arg_count: usize, volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 245b889b3077..9b05f83e2f70 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -178,7 +178,6 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { type_signature, TypeSignature::UserDefined | TypeSignature::Numeric(_) - | TypeSignature::NumericAndNumericString(_) | TypeSignature::String(_) | TypeSignature::Coercible(_) | TypeSignature::Any(_) @@ -477,49 +476,6 @@ fn get_valid_types( vec![vec![base_type_or_default_type(&coerced_type); *number]] } - TypeSignature::NumericAndNumericString(number) => { - function_length_check(current_types.len(), *number)?; - - // Find common numeric type amongs given types except string - let mut valid_type = current_types.first().unwrap().to_owned(); - for t in current_types.iter().skip(1) { - let logical_data_type: NativeType = t.into(); - // Skip string, assume it is numeric string, let arrow::cast handle the actual casting logic - if logical_data_type == NativeType::String { - continue; - } - - if logical_data_type == NativeType::Null { - continue; - } - - if !logical_data_type.is_numeric() { - return plan_err!( - "The signature expected NativeType::Numeric but received {t}" - ); - } - - if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { - valid_type = coerced_type; - } else { - return plan_err!( - "{} and {} are not coercible to a common numeric type", - valid_type, - t - ); - } - } - - let logical_data_type: NativeType = valid_type.clone().into(); - // Fallback to default type if we don't know which type to coerced to - // f64 is choosen since most of the math function utilize Signature::numeric, - // and their default type is double precision - if matches!(logical_data_type, NativeType::String | NativeType::Null) { - valid_type = DataType::Float64; - } - - vec![vec![valid_type; *number]] - } TypeSignature::Numeric(number) => { function_length_check(current_types.len(), *number)?; diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 363b147f2de9..798939162a63 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -117,7 +117,7 @@ impl Default for AbsFunc { impl AbsFunc { pub fn new() -> Self { Self { - signature: Signature::numeric_and_numeric_string(1, Volatility::Immutable), + signature: Signature::numeric(1, Volatility::Immutable), } } } diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 760e99d80e44..e86d78a62353 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -126,14 +126,16 @@ statement error SELECT abs(1, 2); # abs: unsupported argument type -query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Float64 type +query error This feature is not implemented: Unsupported data type Utf8 for function abs SELECT abs('foo'); # abs: numeric string -query R +# TODO: In Postgres, '-1.2' is unknown type and interpreted to float8 so they don't fail on this query +query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs select abs('-1.2'); ----- -1.2 + +query error DataFusion error: This feature is not implemented: Unsupported data type Utf8 for function abs +select abs(arrow_cast('-1.2', 'Utf8')); statement ok CREATE TABLE test_nullable_integer( diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index a09a63a791fc..fbd022593ee8 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -3293,3 +3293,10 @@ drop view t_utc; statement ok drop view t_europe; + +# TODO: In Postgres, '-1' is unknown type and interpreted to float8 so they don't fail on this query +query error DataFusion error: Arrow error: Parser error: Error parsing timestamp from '\-1': timestamp must contain at least 10 characters +select to_timestamp('-1'); + +query error DataFusion error: Arrow error: Parser error: Error parsing timestamp from '\-1': timestamp must contain at least 10 characters +select to_timestamp(arrow_cast('-1', 'Utf8')); From 99f2cd3bf9a0d000284cac27aadb4ff41952d6b8 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 6 Nov 2024 08:30:53 +0800 Subject: [PATCH 13/14] typo Signed-off-by: jayzhan211 --- datafusion/expr/src/type_coercion/functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9b05f83e2f70..11051df4d97f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -506,7 +506,7 @@ fn get_valid_types( let logical_data_type: NativeType = valid_type.clone().into(); // Fallback to default type if we don't know which type to coerced to - // f64 is choosen since most of the math function utilize Signature::numeric, + // f64 is chosen since most of the math functions utilize Signature::numeric, // and their default type is double precision if logical_data_type == NativeType::Null { valid_type = DataType::Float64; From f96e5d8267f342025ba5173d09230062e787a810 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 6 Nov 2024 22:53:46 +0800 Subject: [PATCH 14/14] improve doc and err msg Signed-off-by: jayzhan211 --- datafusion/common/src/types/native.rs | 8 +++++++- datafusion/expr-common/src/signature.rs | 2 +- datafusion/expr/src/type_coercion/functions.rs | 4 ++-- datafusion/sqllogictest/test_files/scalar.slt | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs index 650785e68f42..7e326dc15bb2 100644 --- a/datafusion/common/src/types/native.rs +++ b/datafusion/common/src/types/native.rs @@ -24,7 +24,7 @@ use arrow::compute::can_cast_types; use arrow_schema::{ DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, }; -use std::sync::Arc; +use std::{fmt::Display, sync::Arc}; /// Representation of a type that DataFusion can handle natively. It is a subset /// of the physical variants in Arrow's native [`DataType`]. @@ -183,6 +183,12 @@ pub enum NativeType { Map(LogicalFieldRef), } +impl Display for NativeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NativeType::{self:?}") + } +} + impl LogicalType for NativeType { fn native(&self) -> &NativeType { self diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index e0aa6153da9d..6e78f31e6a3c 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -107,7 +107,7 @@ pub enum TypeSignature { /// Exact number of arguments of an exact type Exact(Vec), /// The number of arguments that can be coerced to in order - /// For example, `Coercible(vec![DataType::Float64])` accepts + /// For example, `Coercible(vec![logical_float64()])` accepts /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` /// since i32 and f32 can be casted to f64 Coercible(Vec), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 11051df4d97f..5a4d89a0b2ec 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -430,7 +430,7 @@ fn get_valid_types( new_types.push(DataType::Utf8); } else { return plan_err!( - "The signature expected NativeType::String but received {data_type}" + "The signature expected NativeType::String but received {logical_data_type}" ); } } @@ -489,7 +489,7 @@ fn get_valid_types( if !logical_data_type.is_numeric() { return plan_err!( - "The signature expected NativeType::Numeric but received {t}" + "The signature expected NativeType::Numeric but received {logical_data_type}" ); } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index fdd28a2db9f9..fe7d1a90c5bd 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1940,7 +1940,7 @@ select position('' in '') ---- 1 -query error DataFusion error: Error during planning: Error during planning: The signature expected NativeType::String but received Int64 +query error DataFusion error: Error during planning: Error during planning: The signature expected NativeType::String but received NativeType::Int64 select position(1 in 1) query I