Skip to content

Commit

Permalink
Add fail_on_error to be passed up to make_comet_scalar_udf
Browse files Browse the repository at this point in the history
  • Loading branch information
raulcd committed Oct 2, 2024
1 parent 6d0b46c commit 113ec6d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ use std::fmt::Debug;
use std::sync::Arc;

macro_rules! make_comet_scalar_udf {
($name:expr, $func:ident, $data_type:ident, $fail_on_error:ident) => {{
let scalar_func = CometScalarFunction::new(
$name.to_string(),
Signature::variadic_any(Volatility::Immutable),
$data_type.clone(),
Arc::new(move |args| $func(args, &$data_type)),
);
// TODO Check for overflow
Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func)))
}};
($name:expr, $func:ident, $data_type:ident) => {{
let scalar_func = CometScalarFunction::new(
$name.to_string(),
Expand All @@ -59,6 +69,7 @@ pub fn create_comet_physical_fun(
fun_name: &str,
data_type: DataType,
registry: &dyn FunctionRegistry,
_fail_on_error: &bool,
) -> Result<Arc<ScalarUDF>, DataFusionError> {
match fun_name {
"ceil" => {
Expand All @@ -72,7 +83,7 @@ pub fn create_comet_physical_fun(
make_comet_scalar_udf!("read_side_padding", func, without data_type)
}
"round" => {
make_comet_scalar_udf!("round", spark_round, data_type)
make_comet_scalar_udf!("round", spark_round, data_type, _fail_on_error)
}
"unscaled_value" => {
let func = Arc::new(spark_unscaled_value);
Expand Down
11 changes: 9 additions & 2 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,12 @@ impl PhysicalPlanner {
Ok(DataType::Decimal128(_p2, _s2)),
) => {
let data_type = return_type.map(to_arrow_datatype).unwrap();
let fail_on_error = false;
let fun_expr = create_comet_physical_fun(
"decimal_div",
data_type.clone(),
&self.session_ctx.state(),
&fail_on_error,
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
"decimal_div",
Expand Down Expand Up @@ -1872,6 +1874,7 @@ impl PhysicalPlanner {
.collect::<Result<Vec<_>, _>>()?;

let fun_name = &expr.func;
let fail_on_error = &expr.fail_on_error;
let input_expr_types = args
.iter()
.map(|x| x.data_type(input_schema.as_ref()))
Expand All @@ -1897,8 +1900,12 @@ impl PhysicalPlanner {
}
};

let fun_expr =
create_comet_physical_fun(fun_name, data_type.clone(), &self.session_ctx.state())?;
let fun_expr = create_comet_physical_fun(
fun_name,
data_type.clone(),
&self.session_ctx.state(),
fail_on_error,
)?;

let args = args
.into_iter()
Expand Down

0 comments on commit 113ec6d

Please sign in to comment.