Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for ScalarUDFImpl::invoke_with_return_type where the invoke is passed the return type created for the udf instance #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
98 changes: 57 additions & 41 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,7 @@ impl ScalarUDF {
self.inner.simplify(args, info)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke(args)
Expand All @@ -216,20 +213,27 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for details.
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
#[allow(deprecated)]
Expand Down Expand Up @@ -324,26 +328,37 @@ where
}
}

/// Trait for implementing [`ScalarUDF`].
pub struct ScalarFunctionArgs<'a> {
// The evaluated arguments to the function
pub args: &'a [ColumnarValue],
// The number of rows in record batch being evaluated
pub number_rows: usize,
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
// when creating the physical expression from the logical expression
pub return_type: &'a DataType,
}

/// Trait for implementing user defined scalar functions.
///
/// This trait exposes the full API for implementing user defined functions and
/// can be used to implement any function.
///
/// See [`advanced_udf.rs`] for a full example with complete implementation and
/// [`ScalarUDF`] for other available options.
///
///
/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
///
/// # Basic Example
/// ```
/// # use std::any::Any;
/// # use std::sync::OnceLock;
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility};
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
///
/// /// This struct for a simple UDF that adds one to an int32
/// #[derive(Debug)]
/// struct AddOne {
/// signature: Signature,
Expand All @@ -356,7 +371,7 @@ where
/// }
/// }
/// }
///
///
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
///
/// fn get_doc() -> &'static Documentation {
Expand All @@ -383,7 +398,9 @@ where
/// Ok(DataType::Int32)
/// }
/// // The actual implementation would add one to the argument
/// fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { unimplemented!() }
/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
/// unimplemented!()
/// }
/// fn documentation(&self) -> Option<&Documentation> {
/// Some(get_doc())
/// }
Expand Down Expand Up @@ -479,24 +496,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {

/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// If the function does not take any arguments, please use [invoke_no_args]
/// instead and return [not_impl_err] for this function.
///
///
/// # Performance
///
/// For the best performance, the implementations of `invoke` should handle
/// the common case when one or more of their arguments are constant values
/// (aka [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke but called",
Expand All @@ -507,17 +509,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Invoke the function with `args` and the number of rows,
/// returning the appropriate result.
///
/// The function will be invoked with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// # Performance
/// Note: See notes on [`Self::invoke_with_args`]
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
/// See <https://github.com/apache/datafusion/issues/13515> for more details.
fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand All @@ -537,9 +534,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
}
}

/// Invoke the function returning the appropriate result.
///
/// # Performance
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.invoke_batch(args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
///
/// Note: This method is deprecated and will be removed in future releases.
/// User defined functions should implement [`Self::invoke_with_args`] instead.
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke_no_args but called",
Expand Down Expand Up @@ -767,6 +782,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/benches/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_8192", |b| {
b.iter(|| {
for _ in 0..iterations {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 8192).unwrap());
}
})
Expand All @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_128", |b| {
b.iter(|| {
for _ in 0..iterations_128 {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 128).unwrap());
}
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ mod test {
#[tokio::test]
async fn test_version_udf() {
let version_udf = ScalarUDF::from(VersionFunc::new());
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let version = version_udf.invoke_batch(&[], 1).unwrap();

if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {
Expand Down
9 changes: 7 additions & 2 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ mod tests {
use arrow::datatypes::{DataType, TimeUnit};
use chrono::NaiveDateTime;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::{adjust_to_local_time, ToLocalTimeFunc};

Expand Down Expand Up @@ -558,7 +558,11 @@ mod tests {

fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
let res = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
.invoke_with_args(ScalarFunctionArgs {
args: &[ColumnarValue::Scalar(input)],
number_rows: 1,
return_type: &expected.data_type(),
})
.unwrap();
match res {
ColumnarValue::Scalar(res) => {
Expand Down Expand Up @@ -617,6 +621,7 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>();
let batch_size = input.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, Some(_))));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down Expand Up @@ -1051,7 +1051,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, None)));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/datetime/to_unixtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
.cast_to(&DataType::Int64, None),
#[allow(deprecated)] // TODO: migrate to invoke_with_args
DataType::Utf8 => ToTimestampSecondsFunc::new()
.invoke_batch(args, batch_size)?
.cast_to(&DataType::Int64, None),
Expand Down
20 changes: 10 additions & 10 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ mod tests {
]))), // num
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let _ = LogFunc::new().invoke_batch(&args, 4);
}

Expand All @@ -286,7 +286,7 @@ mod tests {
let args = [
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new().invoke_batch(&args, 1);
result.expect_err("expected error");
}
Expand All @@ -296,7 +296,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -320,7 +320,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -345,7 +345,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -370,7 +370,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -396,7 +396,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -455,7 +455,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -485,7 +485,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand All @@ -232,7 +232,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand Down
Loading
Loading