-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate arrow_cast
to a UDF
#9610
Merged
alamb
merged 13 commits into
apache:main
from
alamb:feat/migrate_arrow_cast_to_udf_fixed
Mar 18, 2024
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
b9aa40b
feat: arrow_cast function as UDF
brayanjuls 0c63b47
fix: cargo.lock in datafusion-cli
brayanjuls 040f5c2
fix: unwrap arg1 on match arm
brayanjuls 33cc854
fix: unwrap on matching arms using some
brayanjuls 01d1f6b
Merge remote-tracking branch 'apache/main' into feat/migrate_arrow_ca…
alamb 769ff55
Rewrite to use simplify API
alamb 56c337b
Update error messages
alamb 6025b0a
Fix up tests
alamb 9cb7ff1
Update cargo.lock
alamb 182e1da
fix test
alamb 5b8fc25
fix
alamb 1af869f
Merge remote-tracking branch 'apache/main' into feat/migrate_arrow_ca…
alamb 0c7b7be
Fix merge errors, return error
alamb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,63 +15,125 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
//! Implementation of the `arrow_cast` function that allows | ||
//! casting to arbitrary arrow types (rather than SQL types) | ||
//! [`ArrowCastFunc`]: Implementation of the `arrow_cast` | ||
|
||
use std::any::Any; | ||
use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; | ||
|
||
use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; | ||
use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; | ||
use datafusion_common::{ | ||
plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, | ||
internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, | ||
ScalarValue, | ||
}; | ||
|
||
use datafusion_common::plan_err; | ||
use datafusion_expr::{Expr, ExprSchemable}; | ||
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; | ||
use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; | ||
|
||
pub const ARROW_CAST_NAME: &str = "arrow_cast"; | ||
|
||
/// Create an [`Expr`] that evaluates the `arrow_cast` function | ||
/// Implements casting to arbitrary arrow types (rather than SQL types) | ||
/// | ||
/// Note that the `arrow_cast` function is somewhat special in that its | ||
/// return depends only on the *value* of its second argument (not its type) | ||
/// | ||
/// This function is not a [`BuiltinScalarFunction`] because the | ||
/// return type of [`BuiltinScalarFunction`] depends only on the | ||
/// *types* of the arguments. However, the type of `arrow_type` depends on | ||
/// the *value* of its second argument. | ||
/// It is implemented by calling the same underlying arrow `cast` kernel as | ||
/// normal SQL casts. | ||
/// | ||
/// Use the `cast` function to cast to SQL type (which is then mapped | ||
/// to the corresponding arrow type). For example to cast to `int` | ||
/// (which is then mapped to the arrow type `Int32`) | ||
/// For example to cast to `int` using SQL (which is then mapped to the arrow | ||
/// type `Int32`) | ||
/// | ||
/// ```sql | ||
/// select cast(column_x as int) ... | ||
/// ``` | ||
/// | ||
/// Use the `arrow_cast` functiont to cast to a specfic arrow type | ||
/// You can use the `arrow_cast` functiont to cast to a specific arrow type | ||
/// | ||
/// For example | ||
/// ```sql | ||
/// select arrow_cast(column_x, 'Float64') | ||
/// ``` | ||
/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction | ||
pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> { | ||
#[derive(Debug)] | ||
pub(super) struct ArrowCastFunc { | ||
signature: Signature, | ||
} | ||
|
||
impl ArrowCastFunc { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::any(2, Volatility::Immutable), | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for ArrowCastFunc { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"arrow_cast" | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
// should be using return_type_from_exprs and not calling the default | ||
// implementation | ||
internal_err!("arrow_cast should return type from exprs") | ||
} | ||
|
||
fn return_type_from_exprs( | ||
&self, | ||
args: &[Expr], | ||
_schema: &dyn ExprSchema, | ||
_arg_types: &[DataType], | ||
) -> Result<DataType> { | ||
data_type_from_args(args) | ||
} | ||
|
||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
internal_err!("arrow_cast should have been simplified to cast") | ||
} | ||
|
||
fn simplify( | ||
&self, | ||
mut args: Vec<Expr>, | ||
info: &dyn SimplifyInfo, | ||
) -> Result<ExprSimplifyResult> { | ||
// convert this into a real cast | ||
let target_type = data_type_from_args(&args)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This simplify logic mirrors the previous behavior in that arrow_cast is replaced with a normal cast |
||
// remove second (type) argument | ||
args.pop().unwrap(); | ||
let arg = args.pop().unwrap(); | ||
|
||
let source_type = info.get_data_type(&arg)?; | ||
let new_expr = if source_type == target_type { | ||
// the argument's data type is already the correct type | ||
arg | ||
} else { | ||
// Use an actual cast to get the correct type | ||
Expr::Cast(datafusion_expr::Cast { | ||
expr: Box::new(arg), | ||
data_type: target_type, | ||
}) | ||
}; | ||
// return the newly written argument to DataFusion | ||
Ok(ExprSimplifyResult::Simplified(new_expr)) | ||
} | ||
} | ||
|
||
/// Returns the requested type from the arguments | ||
fn data_type_from_args(args: &[Expr]) -> Result<DataType> { | ||
if args.len() != 2 { | ||
return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); | ||
} | ||
let arg1 = args.pop().unwrap(); | ||
let arg0 = args.pop().unwrap(); | ||
|
||
// arg1 must be a string | ||
let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { | ||
v | ||
} else { | ||
let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { | ||
return plan_err!( | ||
"arrow_cast requires its second argument to be a constant string, got {arg1}" | ||
"arrow_cast requires its second argument to be a constant string, got {:?}", | ||
&args[1] | ||
); | ||
}; | ||
|
||
// do the actual lookup to the appropriate data type | ||
let data_type = parse_data_type(&data_type_string)?; | ||
|
||
arg0.cast_to(&data_type, schema) | ||
parse_data_type(val) | ||
} | ||
|
||
/// Parses `str` into a `DataType`. | ||
|
@@ -80,22 +142,8 @@ pub fn create_arrow_cast(mut args: Vec<Expr>, schema: &DFSchema) -> Result<Expr> | |
/// impl, and maintains the invariant that | ||
/// `parse_data_type(data_type.to_string()) == data_type` | ||
/// | ||
/// Example: | ||
/// ``` | ||
/// # use datafusion_sql::parse_data_type; | ||
/// # use arrow_schema::DataType; | ||
/// let display_value = "Int32"; | ||
/// | ||
/// // "Int32" is the Display value of `DataType` | ||
/// assert_eq!(display_value, &format!("{}", DataType::Int32)); | ||
/// | ||
/// // parse_data_type coverts "Int32" back to `DataType`: | ||
/// let data_type = parse_data_type(display_value).unwrap(); | ||
/// assert_eq!(data_type, DataType::Int32); | ||
/// ``` | ||
/// | ||
/// Remove if added to arrow: <https://github.com/apache/arrow-rs/issues/3821> | ||
pub fn parse_data_type(val: &str) -> Result<DataType> { | ||
fn parse_data_type(val: &str) -> Result<DataType> { | ||
Parser::new(val).parse() | ||
} | ||
|
||
|
@@ -647,8 +695,6 @@ impl Display for Token { | |
|
||
#[cfg(test)] | ||
mod test { | ||
use arrow_schema::{IntervalUnit, TimeUnit}; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
|
@@ -844,7 +890,6 @@ mod test { | |
assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); | ||
} | ||
} | ||
println!(" Ok"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These differences are due to the fact that
arrow_cast
is just a normal function now rather than a special case in the parser. Thus the naming reflects normal function namingThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting I got stuck implementing the simpliy function because I thought it should convert
arrow_cast(t.values,Utf8(\"Utf8\"))
tot.values
and other similar cases as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah -- this is pretty tricky.
arrow_cast
was quite special in the parser, so now that it is handled like a normal function it has the same (somewhat strange) function effect of column naming