diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index e99f69fbcd55..ef616d72cc1c 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -125,14 +125,14 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------+", - "| t.values |", - "+------------+", - "| 2020-09-01 |", - "| 2020-09-02 |", - "| 2020-09-03 |", - "| 2020-09-04 |", - "+------------+", + "+-----------------------------------+", + "| arrow_cast(t.values,Utf8(\"Utf8\")) |", + "+-----------------------------------+", + "| 2020-09-01 |", + "| 2020-09-02 |", + "| 2020-09-03 |", + "| 2020-09-04 |", + "+-----------------------------------+", ], &result ); @@ -146,11 +146,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+-----------------------------------------------------------------+", - "| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", - "+-----------------------------------------------------------------+", - "| 03-08-2023 14:38:50 |", - "+-----------------------------------------------------------------+", + "+-------------------------------------------------------------------------------------------------------------+", + "| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", + "+-------------------------------------------------------------------------------------------------------------+", + "| 03-08-2023 14:38:50 |", + "+-------------------------------------------------------------------------------------------------------------+", ], &result ); @@ -165,11 +165,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------+", - "| to_char(Int64(123456),Utf8(\"pretty\")) |", - "+---------------------------------------+", - "| 1 days 10 hours 17 mins 36 secs |", - "+---------------------------------------+", + "+----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |", + "+----------------------------------------------------------------------------+", + "| 1 days 10 hours 17 mins 36 secs |", + "+----------------------------------------------------------------------------+", ], &result ); @@ -184,11 +184,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------+", - "| to_char(Int64(123456),Utf8(\"iso8601\")) |", - "+----------------------------------------+", - "| PT123456S |", - "+----------------------------------------+", + "+-----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |", + "+-----------------------------------------------------------------------------+", + "| PT123456S |", + "+-----------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index f9696955769e..60010bdddfb8 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +//! Tests for the DataFusion SQL query planner that require functions from the +//! datafusion-functions crate. + use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -42,12 +45,18 @@ fn init() { let _ = env_logger::try_init(); } +#[test] +fn select_arrow_cast() { + let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; + let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ + \n EmptyRelation"; + quick_test(sql, expected); +} #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { let sql = "SELECT col_int32 FROM test WHERE col_ts_nano_none < (now() - interval '1 hour')"; - let plan = test_sql(sql)?; // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned @@ -55,7 +64,7 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { "Projection: test.col_int32\ \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - assert_eq!(expected, format!("{plan:?}")); + quick_test(sql, expected); Ok(()) } @@ -74,6 +83,11 @@ fn timestamp_nano_ts_utc_predicates() { assert_eq!(expected, format!("{plan:?}")); } +fn quick_test(sql: &str, expected_plan: &str) { + let plan = test_sql(sql).unwrap(); + assert_eq!(expected_plan, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... @@ -81,12 +95,9 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan - let now_udf = datetime::functions() - .iter() - .find(|f| f.name() == "now") - .unwrap() - .to_owned(); - let context_provider = MyContextProvider::default().with_udf(now_udf); + let context_provider = MyContextProvider::default() + .with_udf(datetime::now()) + .with_udf(datafusion_functions::core::arrow_cast()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 3f40c55a3ed7..a58a8cf51681 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -184,11 +184,11 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let expected = [ - "+-------------+", - "| SUM(t.time) |", - "+-------------+", - "| 19000 |", - "+-------------+", + "+---------------------------------------+", + "| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |", + "+---------------------------------------+", + "| 19000 |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs similarity index 90% rename from datafusion/sql/src/expr/arrow_cast.rs rename to datafusion/functions/src/core/arrow_cast.rs index a75cdf9e3c6b..b6c1b5eb9a38 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -15,63 +15,125 @@ // specific language governing permissions and limitations // under the License. -//! Implementation of the `arrow_cast` function that allows -//! casting to arbitrary arrow types (rather than SQL types) +//! [`ArrowCastFunc`]: Implementation of the `arrow_cast` +use std::any::Any; use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{ - plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, + ScalarValue, }; -use datafusion_common::plan_err; -use datafusion_expr::{Expr, ExprSchemable}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -pub const ARROW_CAST_NAME: &str = "arrow_cast"; - -/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// Implements casting to arbitrary arrow types (rather than SQL types) +/// +/// Note that the `arrow_cast` function is somewhat special in that its +/// return depends only on the *value* of its second argument (not its type) /// -/// This function is not a [`BuiltinScalarFunction`] because the -/// return type of [`BuiltinScalarFunction`] depends only on the -/// *types* of the arguments. However, the type of `arrow_type` depends on -/// the *value* of its second argument. +/// It is implemented by calling the same underlying arrow `cast` kernel as +/// normal SQL casts. /// -/// Use the `cast` function to cast to SQL type (which is then mapped -/// to the corresponding arrow type). For example to cast to `int` -/// (which is then mapped to the arrow type `Int32`) +/// For example to cast to `int` using SQL (which is then mapped to the arrow +/// type `Int32`) /// /// ```sql /// select cast(column_x as int) ... /// ``` /// -/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// You can use the `arrow_cast` functiont to cast to a specific arrow type /// /// For example /// ```sql /// select arrow_cast(column_x, 'Float64') /// ``` -/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction -pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { +#[derive(Debug)] +pub(super) struct ArrowCastFunc { + signature: Signature, +} + +impl ArrowCastFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ArrowCastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrow_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default + // implementation + internal_err!("arrow_cast should return type from exprs") + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + _arg_types: &[DataType], + ) -> Result { + data_type_from_args(args) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + internal_err!("arrow_cast should have been simplified to cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + // convert this into a real cast + let target_type = data_type_from_args(&args)?; + // remove second (type) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + arg + } else { + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + data_type: target_type, + }) + }; + // return the newly written argument to DataFusion + Ok(ExprSimplifyResult::Simplified(new_expr)) + } +} + +/// Returns the requested type from the arguments +fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let arg1 = args.pop().unwrap(); - let arg0 = args.pop().unwrap(); - - // arg1 must be a string - let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { - v - } else { + let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { return plan_err!( - "arrow_cast requires its second argument to be a constant string, got {arg1}" + "arrow_cast requires its second argument to be a constant string, got {:?}", + &args[1] ); }; - - // do the actual lookup to the appropriate data type - let data_type = parse_data_type(&data_type_string)?; - - arg0.cast_to(&data_type, schema) + parse_data_type(val) } /// Parses `str` into a `DataType`. @@ -80,22 +142,8 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// -/// Example: -/// ``` -/// # use datafusion_sql::parse_data_type; -/// # use arrow_schema::DataType; -/// let display_value = "Int32"; -/// -/// // "Int32" is the Display value of `DataType` -/// assert_eq!(display_value, &format!("{}", DataType::Int32)); -/// -/// // parse_data_type coverts "Int32" back to `DataType`: -/// let data_type = parse_data_type(display_value).unwrap(); -/// assert_eq!(data_type, DataType::Int32); -/// ``` -/// /// Remove if added to arrow: -pub fn parse_data_type(val: &str) -> Result { +fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } @@ -647,8 +695,6 @@ impl Display for Token { #[cfg(test)] mod test { - use arrow_schema::{IntervalUnit, TimeUnit}; - use super::*; #[test] @@ -844,7 +890,6 @@ mod test { assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); } } - println!(" Ok"); } } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 73cc4d18bf9f..5a0bd2c77f63 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -17,6 +17,7 @@ //! "core" DataFusion functions +mod arrow_cast; mod arrowtypeof; mod getfield; mod nullif; @@ -25,6 +26,7 @@ mod nvl2; mod r#struct; // create UDFs +make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); make_udf_function!(nvl::NVLFunc, NVL, nvl); make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); @@ -35,6 +37,7 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), + (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."), (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ffc951a6fa66..582404b29749 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -34,8 +34,6 @@ use sqlparser::ast::{ use std::str::FromStr; use strum::IntoEnumIterator; -use super::arrow_cast::ARROW_CAST_NAME; - /// Suggest a valid function based on an invalid input function name pub fn suggest_valid_function( input_function_name: &str, @@ -249,12 +247,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { null_treatment, ))); }; - - // Special case arrow_cast (as its type is dependent on its argument value) - if name == ARROW_CAST_NAME { - let args = self.function_args_to_expr(args, schema, planner_context)?; - return super::arrow_cast::create_arrow_cast(args, schema); - } } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a6f1c78c7250..5e9c0623a265 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod arrow_cast; mod binary_op; mod function; mod grouping_set; diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index e8e07eebe22d..12d6a4669634 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -42,5 +42,4 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; -pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index c9c2bdd694b5..b6077353e5dd 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2566,15 +2566,6 @@ fn approx_median_window() { quick_test(sql, expected); } -#[test] -fn select_arrow_cast() { - let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; - let expected = "\ - Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ - \n EmptyRelation"; - quick_test(sql, expected); -} - #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -2670,6 +2661,11 @@ fn logical_plan_with_dialect_and_options( vec![DataType::Int32, DataType::Int32], DataType::Int32, )) + .with_udf(make_udf( + "arrow_cast", + vec![DataType::Int64, DataType::Utf8], + DataType::Float64, + )) .with_udf(make_udf( "date_trunc", vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)], diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 8b3bd7eac95d..3e8694f3b2c2 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -92,10 +92,11 @@ SELECT arrow_cast('1', 'Int16') 1 # Basic error test -query error Error during planning: arrow_cast needs 2 arguments, 1 provided +query error DataFusion error: Error during planning: No function matches the given name and argument types 'arrow_cast\(Utf8\)'. You might need to add explicit type casts. SELECT arrow_cast('1') -query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) + +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown @@ -315,7 +316,7 @@ select arrow_cast(interval '30 minutes', 'Duration(Second)'); ---- 0 days 0 hours 30 mins 0 secs -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) +query error DataFusion error: This feature is not implemented: Unsupported CAST from Utf8 to Duration\(Second\) select arrow_cast('30 minutes', 'Duration(Second)'); @@ -336,7 +337,7 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( ---- 2000-01-01T00:00:00+08:00 -statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))');