Skip to content

Commit

Permalink
Implement ScalarUDF in terms of ScalarUDFImpl trait
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 1, 2024
1 parent d2b3d1c commit 1334667
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 75 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use std::sync::Arc;
/// the power of the second argument `a^b`.
///
/// To do so, we must implement the `ScalarUDFImpl` trait.
#[derive(Debug, Clone)]
struct PowUdf {
signature: Signature,
aliases: Vec<String>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,7 @@ mod test {
);

// UDF
#[derive(Debug)]
struct TestScalarUDF {
signature: Signature,
}
Expand Down
11 changes: 11 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF};
use arrow::datatypes::DataType;
use datafusion_common::{Column, Result};
use std::any::Any;
use std::fmt::Debug;
use std::ops::Not;
use std::sync::Arc;

Expand Down Expand Up @@ -983,6 +984,16 @@ pub struct SimpleScalarUDF {
fun: ScalarFunctionImplementation,
}

impl Debug for SimpleScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl SimpleScalarUDF {
/// Create a new `SimpleScalarUDF` from a name, input types, return type and
/// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
Expand Down
229 changes: 155 additions & 74 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,57 +35,35 @@ use std::sync::Arc;
/// functions you supply such name, type signature, return type, and actual
/// implementation.
///
///
/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`].
///
/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
///
/// # API Note
///
/// This is a separate struct from `ScalarUDFImpl` to maintain backwards
/// compatibility with the older API.
///
/// [`create_udf`]: crate::expr_fn::create_udf
/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct ScalarUDF {
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
/// be passed with the slice / vec of columnar values (either scalar or array)
/// with the exception of zero param function, where a singular element vec
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
fun: ScalarFunctionImplementation,
/// Optional aliases for the function. This list should NOT include the value of `name` as well
aliases: Vec<String>,
}

impl Debug for ScalarUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
inner: Arc<dyn ScalarUDFImpl>,
}

impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.signature == other.signature
self.name() == other.name() && self.signature() == other.signature()
}
}

impl Eq for ScalarUDF {}

impl std::hash::Hash for ScalarUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
self.name().hash(state);
self.signature().hash(state);
}
}

Expand All @@ -101,51 +79,37 @@ impl ScalarUDF {
return_type: &ReturnTypeFunction,
fun: &ScalarFunctionImplementation,
) -> Self {
Self {
Self::new_from_impl(ScalarUdfLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
fun: fun.clone(),
aliases: vec![],
}
})
}

/// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
///
/// Note this is the same as using the `From` impl (`ScalarUDF::from`)
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
where
F: ScalarUDFImpl + Send + Sync + 'static,
F: ScalarUDFImpl + 'static,
{
// TODO change the internal implementation to use the trait object
let arc_fun = Arc::new(fun);
let captured_self = arc_fun.clone();
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
let return_type = captured_self.return_type(arg_types)?;
Ok(Arc::new(return_type))
});

let captured_self = arc_fun.clone();
let func: ScalarFunctionImplementation =
Arc::new(move |args| captured_self.invoke(args));

Self {
name: arc_fun.name().to_string(),
signature: arc_fun.signature().clone(),
return_type: return_type.clone(),
fun: func,
aliases: arc_fun.aliases().to_vec(),
inner: Arc::new(fun),
}
}

/// Adds additional names that can be used to invoke this function, in addition to `name`
pub fn with_aliases(
mut self,
aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
self.aliases
.extend(aliases.into_iter().map(|s| s.to_string()));
self
/// Return the underlying [`ScalarUDFImpl`] trait object for this function
pub fn inner(&self) -> Arc<dyn ScalarUDFImpl> {
self.inner.clone()
}

/// Adds additional names that can be used to invoke this function, in
/// addition to `name`
///
/// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases))
}

/// Returns a [`Expr`] logical expression to call this UDF with specified
Expand All @@ -159,31 +123,46 @@ impl ScalarUDF {
))
}

/// Returns this function's name
/// Returns this function's name.
///
/// See [`ScalarUDFImpl::name`] for more details.
pub fn name(&self) -> &str {
&self.name
self.inner.name()
}

/// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
/// Returns the aliases for this function.
///
/// See [`ScalarUDF::with_aliases`] for more details
pub fn aliases(&self) -> &[String] {
&self.aliases
self.inner.aliases()
}

/// Returns this function's [`Signature`] (what input types are accepted)
/// Returns this function's [`Signature`] (what input types are accepted).
///
/// See [`ScalarUDFImpl::signature`] for more details.
pub fn signature(&self) -> &Signature {
&self.signature
self.inner.signature()
}

/// The datatype this function returns given the input argument input types
/// The datatype this function returns given the input argument input types.
///
/// See [`ScalarUDFImpl::return_type`] for more details.
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
self.inner.return_type(args)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
self.inner.invoke(args)
}

/// Return an [`Arc`] to the function implementation
/// Returns a `ScalarFunctionImplementation` that can invoke the function
/// during execution
pub fn fun(&self) -> ScalarFunctionImplementation {
self.fun.clone()
let captured = self.inner.clone();
Arc::new(move |args| captured.invoke(args))
}
}

Expand Down Expand Up @@ -246,7 +225,7 @@ where
/// // Call the function `add_one(col)`
/// let expr = add_one.call(vec![col("a")]);
/// ```
pub trait ScalarUDFImpl {
pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Returns this object as an [`Any`] trait object
fn as_any(&self) -> &dyn Any;

Expand Down Expand Up @@ -292,3 +271,105 @@ pub trait ScalarUDFImpl {
&[]
}
}

/// ScalarUDF that adds an alias to the underlying function. It is better to
/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
struct AliasedScalarUDFImpl {
inner: ScalarUDF,
aliases: Vec<String>,
}

impl AliasedScalarUDFImpl {
pub fn new(
inner: ScalarUDF,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));

Self { inner, aliases }
}
}

impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}

fn signature(&self) -> &Signature {
self.inner.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
self.inner.invoke(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers of the older API
/// (see https://github.com/apache/arrow-datafusion/pull/8578)
struct ScalarUdfLegacyWrapper {
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
/// be passed with the slice / vec of columnar values (either scalar or array)
/// with the exception of zero param function, where a singular element vec
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
fun: ScalarFunctionImplementation,
}

impl Debug for ScalarUdfLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl ScalarUDFImpl for ScalarUdfLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
(self.fun)(args)
}

fn aliases(&self) -> &[String] {
&[]
}
}
1 change: 1 addition & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ mod test {

static TEST_SIGNATURE: OnceLock<Signature> = OnceLock::new();

#[derive(Debug, Clone, Default)]
struct TestScalarUDF {}
impl ScalarUDFImpl for TestScalarUDF {
fn as_any(&self) -> &dyn Any {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn create_physical_expr(

Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
fun.fun().clone(),
fun.fun(),
input_phy_exprs.to_vec(),
fun.return_type(&input_exprs_types)?,
None,
Expand Down

0 comments on commit 1334667

Please sign in to comment.