Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate arrow_cast to a UDF #9610

Merged
merged 13 commits into from
Mar 18, 2024
46 changes: 23 additions & 23 deletions datafusion-examples/examples/to_char.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+------------+",
"| t.values |",
"+------------+",
"| 2020-09-01 |",
"| 2020-09-02 |",
"| 2020-09-03 |",
"| 2020-09-04 |",
"+------------+",
"+-----------------------------------+",
"| arrow_cast(t.values,Utf8(\"Utf8\")) |",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These differences are due to the fact that arrow_cast is just a normal function now rather than a special case in the parser. Thus the naming reflects normal function naming

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting I got stuck implementing the simpliy function because I thought it should convert arrow_cast(t.values,Utf8(\"Utf8\")) to t.values and other similar cases as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -- this is pretty tricky. arrow_cast was quite special in the parser, so now that it is handled like a normal function it has the same (somewhat strange) function effect of column naming

"+-----------------------------------+",
"| 2020-09-01 |",
"| 2020-09-02 |",
"| 2020-09-03 |",
"| 2020-09-04 |",
"+-----------------------------------+",
],
&result
);
Expand All @@ -146,11 +146,11 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+-----------------------------------------------------------------+",
"| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |",
"+-----------------------------------------------------------------+",
"| 03-08-2023 14:38:50 |",
"+-----------------------------------------------------------------+",
"+-------------------------------------------------------------------------------------------------------------+",
"| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |",
"+-------------------------------------------------------------------------------------------------------------+",
"| 03-08-2023 14:38:50 |",
"+-------------------------------------------------------------------------------------------------------------+",
],
&result
);
Expand All @@ -165,11 +165,11 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+---------------------------------------+",
"| to_char(Int64(123456),Utf8(\"pretty\")) |",
"+---------------------------------------+",
"| 1 days 10 hours 17 mins 36 secs |",
"+---------------------------------------+",
"+----------------------------------------------------------------------------+",
"| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |",
"+----------------------------------------------------------------------------+",
"| 1 days 10 hours 17 mins 36 secs |",
"+----------------------------------------------------------------------------+",
],
&result
);
Expand All @@ -184,11 +184,11 @@ async fn main() -> Result<()> {

assert_batches_eq!(
&[
"+----------------------------------------+",
"| to_char(Int64(123456),Utf8(\"iso8601\")) |",
"+----------------------------------------+",
"| PT123456S |",
"+----------------------------------------+",
"+-----------------------------------------------------------------------------+",
"| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |",
"+-----------------------------------------------------------------------------+",
"| PT123456S |",
"+-----------------------------------------------------------------------------+",
],
&result
);
Expand Down
27 changes: 19 additions & 8 deletions datafusion/core/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

//! Tests for the DataFusion SQL query planner that require functions from the
//! datafusion-functions crate.

use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -42,20 +45,26 @@ fn init() {
let _ = env_logger::try_init();
}

#[test]
fn select_arrow_cast() {
let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large";
let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\
\n EmptyRelation";
quick_test(sql, expected);
}
#[test]
fn timestamp_nano_ts_none_predicates() -> Result<()> {
let sql = "SELECT col_int32
FROM test
WHERE col_ts_nano_none < (now() - interval '1 hour')";
let plan = test_sql(sql)?;
// a scan should have the now()... predicate folded to a single
// constant and compared to the column without a cast so it can be
// pushed down / pruned
let expected =
"Projection: test.col_int32\
\n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\
\n TableScan: test projection=[col_int32, col_ts_nano_none]";
assert_eq!(expected, format!("{plan:?}"));
quick_test(sql, expected);
Ok(())
}

Expand All @@ -74,19 +83,21 @@ fn timestamp_nano_ts_utc_predicates() {
assert_eq!(expected, format!("{plan:?}"));
}

fn quick_test(sql: &str, expected_plan: &str) {
let plan = test_sql(sql).unwrap();
assert_eq!(expected_plan, format!("{:?}", plan));
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
let ast: Vec<Statement> = Parser::parse_sql(&dialect, sql).unwrap();
let statement = &ast[0];

// create a logical query plan
let now_udf = datetime::functions()
.iter()
.find(|f| f.name() == "now")
.unwrap()
.to_owned();
let context_provider = MyContextProvider::default().with_udf(now_udf);
let context_provider = MyContextProvider::default()
.with_udf(datetime::now())
.with_udf(datafusion_functions::core::arrow_cast());
let sql_to_rel = SqlToRel::new(&context_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,11 @@ async fn test_udaf_shadows_builtin_fn() {

// compute with builtin `sum` aggregator
let expected = [
"+-------------+",
"| SUM(t.time) |",
"+-------------+",
"| 19000 |",
"+-------------+",
"+---------------------------------------+",
"| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |",
"+---------------------------------------+",
"| 19000 |",
"+---------------------------------------+",
];
assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,63 +15,125 @@
// specific language governing permissions and limitations
// under the License.

//! Implementation of the `arrow_cast` function that allows
//! casting to arbitrary arrow types (rather than SQL types)
//! [`ArrowCastFunc`]: Implementation of the `arrow_cast`

use std::any::Any;
use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc};

use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit};
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit};
use datafusion_common::{
plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue,
internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result,
ScalarValue,
};

use datafusion_common::plan_err;
use datafusion_expr::{Expr, ExprSchemable};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility};

pub const ARROW_CAST_NAME: &str = "arrow_cast";

/// Create an [`Expr`] that evaluates the `arrow_cast` function
/// Implements casting to arbitrary arrow types (rather than SQL types)
///
/// Note that the `arrow_cast` function is somewhat special in that its
/// return depends only on the *value* of its second argument (not its type)
///
/// This function is not a [`BuiltinScalarFunction`] because the
/// return type of [`BuiltinScalarFunction`] depends only on the
/// *types* of the arguments. However, the type of `arrow_type` depends on
/// the *value* of its second argument.
/// It is implemented by calling the same underlying arrow `cast` kernel as
/// normal SQL casts.
///
/// Use the `cast` function to cast to SQL type (which is then mapped
/// to the corresponding arrow type). For example to cast to `int`
/// (which is then mapped to the arrow type `Int32`)
/// For example to cast to `int` using SQL (which is then mapped to the arrow
/// type `Int32`)
///
/// ```sql
/// select cast(column_x as int) ...
/// ```
///
/// Use the `arrow_cast` functiont to cast to a specfic arrow type
/// You can use the `arrow_cast` functiont to cast to a specific arrow type
///
/// For example
/// ```sql
/// select arrow_cast(column_x, 'Float64')
/// ```
/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction
pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> {
#[derive(Debug)]
pub(super) struct ArrowCastFunc {
signature: Signature,
}

impl ArrowCastFunc {
pub fn new() -> Self {
Self {
signature: Signature::any(2, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for ArrowCastFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"arrow_cast"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
// should be using return_type_from_exprs and not calling the default
// implementation
internal_err!("arrow_cast should return type from exprs")
}

fn return_type_from_exprs(
&self,
args: &[Expr],
_schema: &dyn ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
data_type_from_args(args)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
internal_err!("arrow_cast should have been simplified to cast")
}

fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// convert this into a real cast
let target_type = data_type_from_args(&args)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simplify logic mirrors the previous behavior in that arrow_cast is replaced with a normal cast

// remove second (type) argument
args.pop().unwrap();
let arg = args.pop().unwrap();

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == target_type {
// the argument's data type is already the correct type
arg
} else {
// Use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: target_type,
})
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
}
}

/// Returns the requested type from the arguments
fn data_type_from_args(args: &[Expr]) -> Result<DataType> {
if args.len() != 2 {
return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len());
}
let arg1 = args.pop().unwrap();
let arg0 = args.pop().unwrap();

// arg1 must be a string
let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 {
v
} else {
let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else {
return plan_err!(
"arrow_cast requires its second argument to be a constant string, got {arg1}"
"arrow_cast requires its second argument to be a constant string, got {:?}",
&args[1]
);
};

// do the actual lookup to the appropriate data type
let data_type = parse_data_type(&data_type_string)?;

arg0.cast_to(&data_type, schema)
parse_data_type(val)
}

/// Parses `str` into a `DataType`.
Expand All @@ -80,22 +142,8 @@ pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr>
/// impl, and maintains the invariant that
/// `parse_data_type(data_type.to_string()) == data_type`
///
/// Example:
/// ```
/// # use datafusion_sql::parse_data_type;
/// # use arrow_schema::DataType;
/// let display_value = "Int32";
///
/// // "Int32" is the Display value of `DataType`
/// assert_eq!(display_value, &format!("{}", DataType::Int32));
///
/// // parse_data_type coverts "Int32" back to `DataType`:
/// let data_type = parse_data_type(display_value).unwrap();
/// assert_eq!(data_type, DataType::Int32);
/// ```
///
/// Remove if added to arrow: <https://github.com/apache/arrow-rs/issues/3821>
pub fn parse_data_type(val: &str) -> Result<DataType> {
fn parse_data_type(val: &str) -> Result<DataType> {
Parser::new(val).parse()
}

Expand Down Expand Up @@ -647,8 +695,6 @@ impl Display for Token {

#[cfg(test)]
mod test {
use arrow_schema::{IntervalUnit, TimeUnit};

use super::*;

#[test]
Expand Down Expand Up @@ -844,7 +890,6 @@ mod test {
assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'"));
}
}
println!(" Ok");
}
}
}
3 changes: 3 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! "core" DataFusion functions

mod arrow_cast;
mod arrowtypeof;
mod getfield;
mod nullif;
Expand All @@ -25,6 +26,7 @@ mod nvl2;
mod r#struct;

// create UDFs
make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast);
make_udf_function!(nullif::NullIfFunc, NULLIF, nullif);
make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);
Expand All @@ -35,6 +37,7 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);
// Export the functions out of this package, both as expr_fn as well as a list of functions
export_functions!(
(nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."),
(arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."),
(nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"),
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
Expand Down
Loading
Loading