Skip to content

Commit

Permalink
Introduce return_type_from_args for ScalarFunction. (#14094)
Browse files Browse the repository at this point in the history
* switch func

Signed-off-by: Jay Zhan <[email protected]>

* fix test

Signed-off-by: Jay Zhan <[email protected]>

* fix test

Signed-off-by: Jay Zhan <[email protected]>

* deprecate old

Signed-off-by: Jay Zhan <[email protected]>

* add try new

Signed-off-by: Jay Zhan <[email protected]>

* deprecate

Signed-off-by: Jay Zhan <[email protected]>

* rm deprecate

Signed-off-by: Jay Zhan <[email protected]>

* reaplce deprecated func

Signed-off-by: Jay Zhan <[email protected]>

* cleanup

Signed-off-by: Jay Zhan <[email protected]>

* combine type and nullable

Signed-off-by: Jay Zhan <[email protected]>

* fix slowdown

Signed-off-by: Jay Zhan <[email protected]>

* clippy

Signed-off-by: Jay Zhan <[email protected]>

* fix take

Signed-off-by: Jay Zhan <[email protected]>

* fmt

Signed-off-by: Jay Zhan <[email protected]>

* rm duplicated test

Signed-off-by: Jay Zhan <[email protected]>

* refactor: remove unused documentation sections from scalar functions

* upd doc

Signed-off-by: Jay Zhan <[email protected]>

* use scalar value

Signed-off-by: Jay Zhan <[email protected]>

* fix test

Signed-off-by: Jay Zhan <[email protected]>

* fix test

Signed-off-by: Jay Zhan <[email protected]>

* use try_as_str

Signed-off-by: Jay Zhan <[email protected]>

* refactor: improve error handling for constant string arguments in UDFs

* refactor: enhance error messages for constant string requirements in UDFs

* refactor: streamline argument validation in return_type_from_args for UDFs

* rename and doc

Signed-off-by: Jay Zhan <[email protected]>

* refactor: add documentation for nullability of scalar arguments in ReturnTypeArgs

* rm test

Signed-off-by: Jay Zhan <[email protected]>

* refactor: remove unused import of Int32Array in utils tests

---------

Signed-off-by: Jay Zhan <[email protected]>
  • Loading branch information
jayzhan211 authored Jan 20, 2025
1 parent acf66d6 commit d3f1c9a
Show file tree
Hide file tree
Showing 21 changed files with 475 additions and 326 deletions.
16 changes: 8 additions & 8 deletions datafusion/core/tests/fuzz_cases/equivalence/ordering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use crate::fuzz_cases::equivalence::utils::{
is_table_same_after_sort, TestScalarUDF,
};
use arrow_schema::SortOptions;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::ScalarFunctionExpr;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand Down Expand Up @@ -103,14 +104,13 @@ fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;

let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?);
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Operator::Plus,
Expand Down
29 changes: 14 additions & 15 deletions datafusion/core/tests/fuzz_cases/equivalence/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use crate::fuzz_cases::equivalence::utils::{
is_table_same_after_sort, TestScalarUDF,
};
use arrow_schema::SortOptions;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::equivalence::ProjectionMapping;
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand All @@ -42,14 +43,13 @@ fn project_orderings_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
// Floor(a)
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?);
// a + b
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Expand Down Expand Up @@ -120,14 +120,13 @@ fn ordering_satisfy_after_projection_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;
// Floor(a)
let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?) as PhysicalExprRef;
// a + b
let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Expand Down
17 changes: 9 additions & 8 deletions datafusion/core/tests/fuzz_cases/equivalence/properties.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use crate::fuzz_cases::equivalence::utils::{
create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort,
TestScalarUDF,
};
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::{Operator, ScalarUDF};
use datafusion_physical_expr::expressions::{col, BinaryExpr};
use datafusion_physical_expr::{PhysicalExprRef, ScalarFunctionExpr};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
use itertools::Itertools;
Expand All @@ -40,14 +41,14 @@ fn test_find_longest_permutation_random() -> Result<()> {
let table_data_with_properties =
generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?;

let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new());
let floor_a = datafusion_physical_expr::udf::create_physical_expr(
&test_fun,
&[col("a", &test_schema)?],
let test_fun = Arc::new(ScalarUDF::new_from_impl(TestScalarUDF::new()));
let col_a = col("a", &test_schema)?;
let floor_a = Arc::new(ScalarFunctionExpr::try_new(
Arc::clone(&test_fun),
vec![col_a],
&test_schema,
&[],
&DFSchema::empty(),
)?;
)?) as PhysicalExprRef;

let a_plus_b = Arc::new(BinaryExpr::new(
col("a", &test_schema)?,
Operator::Plus,
Expand Down
50 changes: 27 additions & 23 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,12 @@ use datafusion_common::cast::{as_float64_array, as_int32_array};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err,
not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, HashMap, Result,
ScalarValue,
not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue,
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable,
LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature,
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder,
OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_functions_nested::range::range_udf;
Expand Down Expand Up @@ -819,32 +818,36 @@ impl ScalarUDFImpl for TakeUDF {
///
/// 1. If the third argument is '0', return the type of the first argument
/// 2. If the third argument is '1', return the type of the second argument
fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
_arg_data_types: &[DataType],
) -> Result<DataType> {
if arg_exprs.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
if args.arg_types.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len());
}

let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
arg_exprs.get(2)
{
if *idx == 0 || *idx == 1 {
*idx as usize
let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) {
// This is for test only, safe to unwrap
let take_idx = take_idx
.unwrap()
.try_as_str()
.unwrap()
.unwrap()
.parse::<usize>()
.unwrap();

if take_idx == 0 || take_idx == 1 {
take_idx
} else {
return plan_err!("The third argument must be 0 or 1, got: {idx}");
return plan_err!("The third argument must be 0 or 1, got: {take_idx}");
}
} else {
return plan_err!(
"The third argument must be a literal of type int64, but got {:?}",
arg_exprs.get(2)
args.scalar_arguments.get(2)
);
};

arg_exprs.get(take_idx).unwrap().get_type(schema)
Ok(ReturnInfo::new_nullable(
args.arg_types[take_idx].to_owned(),
))
}

// The actual implementation
Expand All @@ -854,7 +857,8 @@ impl ScalarUDFImpl for TakeUDF {
_number_rows: usize,
) -> Result<ColumnarValue> {
let take_idx = match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0,
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1,
_ => unreachable!(),
};
match &args[take_idx] {
Expand All @@ -874,9 +878,9 @@ async fn verify_udf_return_type() -> Result<()> {
// take(smallint_col, double_col, 1) as take1
// FROM alltypes_plain;
let exprs = vec![
take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
take.call(vec![col("smallint_col"), col("double_col"), lit("0")])
.alias("take0"),
take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
take.call(vec![col("smallint_col"), col("double_col"), lit("1")])
.alias("take1"),
];

Expand Down
76 changes: 48 additions & 28 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
};
use crate::udf::ReturnTypeArgs;
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
Expand Down Expand Up @@ -145,32 +146,9 @@ impl ExprSchemable for Expr {
}
}
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
match err {
DataFusionError::Plan(msg) => msg,
err => err.to_string(),
},
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
})?;

// Perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
Expr::ScalarFunction(_func) => {
let (return_type, _) = self.data_type_and_nullable(schema)?;
Ok(return_type)
}
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema, window_function)
Expand Down Expand Up @@ -303,8 +281,9 @@ impl ExprSchemable for Expr {
}
}
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarFunction(ScalarFunction { func, args }) => {
Ok(func.is_nullable(args, input_schema))
Expr::ScalarFunction(_func) => {
let (_, nullable) = self.data_type_and_nullable(input_schema)?;
Ok(nullable)
}
Expr::AggregateFunction(AggregateFunction { func, .. }) => {
Ok(func.is_nullable())
Expand Down Expand Up @@ -415,6 +394,47 @@ impl ExprSchemable for Expr {
Expr::WindowFunction(window_function) => {
self.data_type_and_nullable_with_window_function(schema, window_function)
}
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let (arg_types, nullables): (Vec<DataType>, Vec<bool>) = args
.iter()
.map(|e| e.data_type_and_nullable(schema))
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();
// Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
let new_data_types = data_types_with_scalar_udf(&arg_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
match err {
DataFusionError::Plan(msg) => msg,
err => err.to_string(),
},
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_types,
)
)
})?;

let arguments = args
.iter()
.map(|e| match e {
Expr::Literal(sv) => Some(sv),
_ => None,
})
.collect::<Vec<_>>();
let args = ReturnTypeArgs {
arg_types: &new_data_types,
scalar_arguments: &arguments,
nullables: &nullables,
};

let (return_type, nullable) =
func.return_type_from_args(args)?.into_parts();
Ok((return_type, nullable))
}
_ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
}
}
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf::{
scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl,
};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

Expand Down
Loading

0 comments on commit d3f1c9a

Please sign in to comment.