Skip to content
46 changes: 23 additions & 23 deletions datafusion-examples/examples/to_char.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+------------+",
"| t.values |",
"+------------+",
"| 2020-09-01 |",
"| 2020-09-02 |",
"| 2020-09-03 |",
"| 2020-09-04 |",
"+------------+",
"+-----------------------------------+",
"| arrow_cast(t.values,Utf8(\"Utf8\")) |",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These differences are due to the fact that arrow_cast is just a normal function now rather than a special case in the parser. Thus the naming reflects normal function naming

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting I got stuck implementing the simpliy function because I thought it should convert arrow_cast(t.values,Utf8(\"Utf8\")) to t.values and other similar cases as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -- this is pretty tricky. arrow_cast was quite special in the parser, so now that it is handled like a normal function it has the same (somewhat strange) function effect of column naming

"+-----------------------------------+",
"| 2020-09-01 |",
"| 2020-09-02 |",
"| 2020-09-03 |",
"| 2020-09-04 |",
"+-----------------------------------+",
],
&result
);
Expand All @@ -146,11 +146,11 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+-----------------------------------------------------------------+",
"| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |",
"+-----------------------------------------------------------------+",
"| 03-08-2023 14:38:50 |",
"+-----------------------------------------------------------------+",
"+-------------------------------------------------------------------------------------------------------------+",
"| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |",
"+-------------------------------------------------------------------------------------------------------------+",
"| 03-08-2023 14:38:50 |",
"+-------------------------------------------------------------------------------------------------------------+",
],
&result
);
Expand All @@ -165,11 +165,11 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+---------------------------------------+",
"| to_char(Int64(123456),Utf8(\"pretty\")) |",
"+---------------------------------------+",
"| 1 days 10 hours 17 mins 36 secs |",
"+---------------------------------------+",
"+----------------------------------------------------------------------------+",
"| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |",
"+----------------------------------------------------------------------------+",
"| 1 days 10 hours 17 mins 36 secs |",
"+----------------------------------------------------------------------------+",
],
&result
);
Expand All @@ -184,11 +184,11 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+----------------------------------------+",
"| to_char(Int64(123456),Utf8(\"iso8601\")) |",
"+----------------------------------------+",
"| PT123456S |",
"+----------------------------------------+",
"+-----------------------------------------------------------------------------+",
"| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |",
"+-----------------------------------------------------------------------------+",
"| PT123456S |",
"+-----------------------------------------------------------------------------+",
],
&result
);
Expand Down
27 changes: 19 additions & 8 deletions datafusion/core/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

//! Tests for the DataFusion SQL query planner that require functions from the
//! datafusion-functions crate.

use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -42,20 +45,26 @@ fn init() {
let _ = env_logger::try_init();
}

#[test]
fn select_arrow_cast() {
let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large";
let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\
\n EmptyRelation";
quick_test(sql, expected);
}
#[test]
fn timestamp_nano_ts_none_predicates() -> Result<()> {
let sql = "SELECT col_int32
FROM test
WHERE col_ts_nano_none < (now() - interval '1 hour')";
let plan = test_sql(sql)?;
// a scan should have the now()... predicate folded to a single
// constant and compared to the column without a cast so it can be
// pushed down / pruned
let expected =
"Projection: test.col_int32\
\n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\
\n TableScan: test projection=[col_int32, col_ts_nano_none]";
assert_eq!(expected, format!("{plan:?}"));
quick_test(sql, expected);
Ok(())
}

Expand All @@ -74,19 +83,21 @@ fn timestamp_nano_ts_utc_predicates() {
assert_eq!(expected, format!("{plan:?}"));
}

fn quick_test(sql: &str, expected_plan: &str) {
let plan = test_sql(sql).unwrap();
assert_eq!(expected_plan, format!("{:?}", plan));
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
let ast: Vec<Statement> = Parser::parse_sql(&dialect, sql).unwrap();
let statement = &ast[0];

// create a logical query plan
let now_udf = datetime::functions()
.iter()
.find(|f| f.name() == "now")
.unwrap()
.to_owned();
let context_provider = MyContextProvider::default().with_udf(now_udf);
let context_provider = MyContextProvider::default()
.with_udf(datetime::now())
.with_udf(datafusion_functions::core::arrow_cast());
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ async fn test_udaf_shadows_builtin_fn() {

// compute with builtin `sum` aggregator
let expected = [
"+-------------+",
"| SUM(t.time) |",
"+-------------+",
"| 19000 |",
"+-------------+",
"+---------------------------------------+",
"| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |",
"+---------------------------------------+",
"| 19000 |",
"+---------------------------------------+",
];
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,125 @@
// specific language governing permissions and limitations
// under the License.

//! Implementation of the `arrow_cast` function that allows
//! casting to arbitrary arrow types (rather than SQL types)
//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`

use std::any::Any;
use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc};

use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion_common::{
plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue,
internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result,
ScalarValue,
};

use datafusion_common::plan_err;
use datafusion_expr::{Expr, ExprSchemable};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility};

pub const ARROW_CAST_NAME: &str = "arrow_cast";

/// Create an [`Expr`] that evaluates the `arrow_cast` function
/// Implements casting to arbitrary arrow types (rather than SQL types)
///
/// Note that the `arrow_cast` function is somewhat special in that its
/// return depends only on the *value* of its second argument (not its type)
///
/// This function is not a [`BuiltinScalarFunction`] because the
/// return type of [`BuiltinScalarFunction`] depends only on the
/// *types* of the arguments. However, the type of `arrow_type` depends on
/// the *value* of its second argument.
/// It is implemented by calling the same underlying arrow `cast` kernel as
/// normal SQL casts.
///
/// Use the `cast` function to cast to SQL type (which is then mapped
/// to the corresponding arrow type). For example to cast to `int`
/// (which is then mapped to the arrow type `Int32`)
/// For example to cast to `int` using SQL (which is then mapped to the arrow
/// type `Int32`)
///
/// ```sql
/// select cast(column_x as int) ...
/// ```
///
/// Use the `arrow_cast` functiont to cast to a specfic arrow type
/// You can use the `arrow_cast` functiont to cast to a specific arrow type
///
/// For example
/// ```sql
/// select arrow_cast(column_x, 'Float64')
/// ```
/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction
pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> {
#[derive(Debug)]
pub(super) struct ArrowCastFunc {
signature: Signature,
}

impl ArrowCastFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for ArrowCastFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"arrow_cast"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
// should be using return_type_from_exprs and not calling the default
// implementation
internal_err!("arrow_cast should return type from exprs")
}

fn return_type_from_exprs(
&self,
args: &[Expr],
_schema: &dyn ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
data_type_from_args(args)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
internal_err!("arrow_cast should have been simplified to cast")
}

fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// convert this into a real cast
let target_type = data_type_from_args(&args)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simplify logic mirrors the previous behavior in that arrow_cast is replaced with a normal cast

// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
// the argument's data type is already the correct type
arg
} else {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: target_type,
})
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
}
}

/// Returns the requested type from the arguments
fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
if args.len() != 2 {
return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len());
}
let arg1 = args.pop().unwrap();
let arg0 = args.pop().unwrap();

// arg1 must be a string
let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 {
v
} else {
let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else {
return plan_err!(
"arrow_cast requires its second argument to be a constant string, got {arg1}"
"arrow_cast requires its second argument to be a constant string, got {:?}",
&args[1]
);
};

// do the actual lookup to the appropriate data type
let data_type = parse_data_type(&data_type_string)?;

arg0.cast_to(&data_type, schema)
parse_data_type(val)
}

/// Parses `str` into a `DataType`.
Expand All @@ -80,22 +142,8 @@ pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr>
/// impl, and maintains the invariant that
/// `parse_data_type(data_type.to_string()) == data_type`
///
/// Example:
/// ```
/// # use datafusion_sql::parse_data_type;
/// # use arrow_schema::DataType;
/// let display_value = "Int32";
///
/// // "Int32" is the Display value of `DataType`
/// assert_eq!(display_value, &format!("{}", DataType::Int32));
///
/// // parse_data_type coverts "Int32" back to `DataType`:
/// let data_type = parse_data_type(display_value).unwrap();
/// assert_eq!(data_type, DataType::Int32);
/// ```
///
/// Remove if added to arrow: <https://github.com/apache/arrow-rs/issues/3821>
pub fn parse_data_type(val: &str) -> Result<DataType> {
fn parse_data_type(val: &str) -> Result<DataType> {
Parser::new(val).parse()
}

Expand Down Expand Up @@ -647,8 +695,6 @@ impl Display for Token {

#[cfg(test)]
mod test {
use arrow_schema::{IntervalUnit, TimeUnit};

use super::*;

#[test]
Expand Down Expand Up @@ -844,7 +890,6 @@ mod test {
assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'"));
}
}
println!(" Ok");
}
}
}
3 changes: 3 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! "core" DataFusion functions

mod arrow_cast;
mod arrowtypeof;
mod getfield;
mod nullif;
Expand All @@ -25,6 +26,7 @@ mod nvl2;
mod r#struct;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast);
make_udf_function!(nullif::NullIfFunc, NULLIF, nullif);
make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);
Expand All @@ -35,6 +37,7 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);
// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."),
(arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."),
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"),
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
Expand Down
Loading