Skip to content

Commit

Permalink
coerce rules
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Jun 4, 2023
1 parent f8996c2 commit 6a4a04d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
52 changes: 35 additions & 17 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::{
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{DataFusionError, Result};
use std::collections::BTreeSet;

/// Performs type coercion for function arguments.
///
Expand Down Expand Up @@ -72,30 +73,30 @@ fn get_valid_types(
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual(allowed_types) => {
// special case when no args
if current_types.is_empty() {
return Ok(vec![current_types.to_vec()]);
}
if allowed_types.is_empty() {
return Err(DataFusionError::Plan(
"allowed types cannot be empty".to_string(),
));
}
let first_type = &current_types[0];
// all types must be the same
if current_types[1..].iter().any(|t| t != first_type) {
return Err(DataFusionError::Plan(format!(
"The function expected all arguments to be of type {:?} but received {:?}",
first_type, current_types
)));
}
// first type must be within allowed_types
if !allowed_types.iter().any(|t| t == first_type) {
// if there are any types that are not allowed, return error
if current_types.iter().any(|t| !allowed_types.contains(t)) {
return Err(DataFusionError::Plan(format!(
"The function expected all arguments to be of type {:?} but received {:?}",
allowed_types, current_types
)));
}
vec![current_types.iter().map(|_| first_type.clone()).collect()]
let types_set = current_types
.iter()
.cloned()
.collect::<BTreeSet<DataType>>();
// for each type in the type set, return a vector of the same length as current_types
types_set
.iter()
.map(|t| {
(0..current_types.len())
.map(|_| t.clone())
.collect::<Vec<DataType>>()
})
.collect::<Vec<_>>()
}
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
Expand Down Expand Up @@ -228,7 +229,7 @@ mod tests {
vec![DataType::UInt8, DataType::UInt16],
Some(vec![DataType::UInt8, DataType::UInt16]),
),
// 2 entries, can coerse values
// 2 entries, can coerce values
(
vec![DataType::UInt16, DataType::UInt16],
vec![DataType::UInt8, DataType::UInt16],
Expand Down Expand Up @@ -287,6 +288,23 @@ mod tests {
let is_error = get_valid_types(&signature, &[DataType::Int32]).is_err();
assert!(is_error);

// cast case, with i32 and i64
let signature =
TypeSignature::VariadicEqual(vec![DataType::Int64, DataType::Int32]);
let valid_types = get_valid_types(
&signature,
&[DataType::Int32, DataType::Int64, DataType::Int32],
)?;
assert_eq!(valid_types.len(), 2);
assert_eq!(
valid_types[1],
vec![DataType::Int64, DataType::Int64, DataType::Int64]
);
assert_eq!(
valid_types[0],
vec![DataType::Int32, DataType::Int32, DataType::Int32]
);

Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ mod tests {
Signature::variadic(vec![DataType::Float32], Volatility::Immutable),
vec![DataType::Float32, DataType::Float32],
)?,
// u32 -> f32
// for variadic equal, no casting is performed
case(
vec![DataType::Float32, DataType::UInt32],
Signature::variadic_equal(vec![DataType::Float32], Volatility::Immutable),
Expand Down

0 comments on commit 6a4a04d

Please sign in to comment.