From de67c44f44d93e171a5bdf8ddf8b87dd1f96931c Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Wed, 9 Oct 2024 14:46:23 +0100 Subject: [PATCH] Added support for `ScalarUDFImpl::invoke_with_return_type` where the invoke is passed the return type created for the udf instance --- datafusion/expr/src/udf.rs | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 1a5d50477b1c8..bd393a0290544 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -203,9 +203,6 @@ impl ScalarUDF { self.inner.simplify(args, info) } - /// Invoke the function on `args`, returning the appropriate result. - /// - /// See [`ScalarUDFImpl::invoke`] for more details. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke(&self, args: &[ColumnarValue]) -> Result { #[allow(deprecated)] @@ -227,6 +224,19 @@ impl ScalarUDF { self.inner.invoke_batch(args, number_rows) } + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_batch`] for more details. + pub fn invoke_with_return_type( + &self, + args: &[ColumnarValue], + number_rows: usize, + return_type: &DataType, + ) -> Result { + self.inner + .invoke_batch_with_return_type(args, number_rows, return_type) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_no_args`] for more details. @@ -356,7 +366,7 @@ where /// } /// } /// } -/// +/// /// static DOCUMENTATION: OnceLock = OnceLock::new(); /// /// fn get_doc() -> &'static Documentation { @@ -537,6 +547,17 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } + /// This function will be called with the evaluated children as in `invoke` however, the value + /// returned previously from `ScalarUDFImpl::return_type` for this expr will be passed in. + fn invoke_batch_with_return_type( + &self, + args: &[ColumnarValue], + number_rows: usize, + _return_type: &DataType, + ) -> Result { + self.invoke_batch(args, number_rows) + } + /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]