diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 6ebf88a0b671b..d530b9abe030a 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -40,6 +40,7 @@ use std::sync::Arc; /// the power of the second argument `a^b`. /// /// To do so, we must implement the `ScalarUDFImpl` trait. +#[derive(Debug, Clone)] struct PowUdf { signature: Signature, aliases: Vec, diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index ebf4d3143c122..5617d217eb9f8 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1948,6 +1948,7 @@ mod test { ); // UDF + #[derive(Debug)] struct TestScalarUDF { signature: Signature, } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index eed41d97ccba1..5439754b8b663 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -32,6 +32,7 @@ use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; +use std::fmt::Debug; use std::ops::Not; use std::sync::Arc; @@ -983,6 +984,16 @@ pub struct SimpleScalarUDF { fun: ScalarFunctionImplementation, } +impl Debug for SimpleScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + impl SimpleScalarUDF { /// Create a new `SimpleScalarUDF` from a name, input types, return type and /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 2ec80a4a9ea1c..6c9b6a6363e08 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -1,4 +1,5 @@ // Licensed to the Apache Software Foundation (ASF) under one +// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -35,48 +36,26 @@ use std::sync::Arc; /// functions you supply such name, type signature, return type, and actual /// implementation. /// -/// /// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`]. /// /// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`]. /// +/// # API Note +/// +/// This is a separate struct from `ScalarUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [`create_udf`]: crate::expr_fn::create_udf /// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs /// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct ScalarUDF { - /// The name of the function - name: String, - /// The signature (the types of arguments that are supported) - signature: Signature, - /// Function that returns the return type given the argument types - return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - fun: ScalarFunctionImplementation, - /// Optional aliases for the function. This list should NOT include the value of `name` as well - aliases: Vec, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } + inner: Arc, } impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -84,8 +63,8 @@ impl Eq for ScalarUDF {} impl std::hash::Hash for ScalarUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -101,13 +80,12 @@ impl ScalarUDF { return_type: &ReturnTypeFunction, fun: &ScalarFunctionImplementation, ) -> Self { - Self { + Self::new_from_impl(ScalarUdfLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), fun: fun.clone(), - aliases: vec![], - } + }) } /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object @@ -115,37 +93,24 @@ impl ScalarUDF { /// Note this is the same as using the `From` impl (`ScalarUDF::from`) pub fn new_from_impl(fun: F) -> ScalarUDF where - F: ScalarUDFImpl + Send + Sync + 'static, + F: ScalarUDFImpl + 'static, { - // TODO change the internal implementation to use the trait object - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let func: ScalarFunctionImplementation = - Arc::new(move |args| captured_self.invoke(args)); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - fun: func, - aliases: arc_fun.aliases().to_vec(), + inner: Arc::new(fun), } } - /// Adds additional names that can be used to invoke this function, in addition to `name` - pub fn with_aliases( - mut self, - aliases: impl IntoIterator, - ) -> Self { - self.aliases - .extend(aliases.into_iter().map(|s| s.to_string())); - self + /// Return the underlying [`ScalarUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + + /// Adds additional names that can be used to invoke this function, in + /// addition to `name` + /// + /// If you implement [`ScalarUDFImpl`] directly you can return aliases directly. + pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { + Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -159,31 +124,47 @@ impl ScalarUDF { )) } - /// Returns this function's name + /// Returns this function's name. + /// + /// See [`ScalarUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } - /// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details + /// Returns the aliases for this function. + /// + /// See [`ScalarUDF::with_aliases`] for more details pub fn aliases(&self) -> &[String] { - &self.aliases + self.inner.aliases() } - /// Returns this function's [`Signature`] (what input types are accepted) + /// Returns this function's [`Signature`] (what input types are accepted). + /// + /// See [`ScalarUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } - /// The datatype this function returns given the input argument input types + /// The datatype this function returns given the input argument input types. + /// + /// See [`ScalarUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) + } + + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke`] for more details. + pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) } - /// Return an [`Arc`] to the function implementation + /// Returns a `ScalarFunctionImplementation` that can invoke the function pub fn fun(&self) -> ScalarFunctionImplementation { - self.fun.clone() + // TODO: use ScalarUDF directly in `ScalarFunctionExpr` and remove this + // method + let captured = self.inner.clone(); + Arc::new(move |args| captured.invoke(args)) } } @@ -246,7 +227,7 @@ where /// // Call the function `add_one(col)` /// let expr = add_one.call(vec![col("a")]); /// ``` -pub trait ScalarUDFImpl { +pub trait ScalarUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -292,3 +273,105 @@ pub trait ScalarUDFImpl { &[] } } + +/// ScalarUDF that adds an alias to the underlying function. It is better to +/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. +#[derive(Debug)] +struct AliasedScalarUDFImpl { + inner: ScalarUDF, + aliases: Vec, +} + +impl AliasedScalarUDFImpl { + pub fn new( + inner: ScalarUDF, + new_aliases: impl IntoIterator, + ) -> Self { + let mut aliases = inner.aliases().to_vec(); + aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); + + Self { inner, aliases } + } +} + +impl ScalarUDFImpl for AliasedScalarUDFImpl { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + self.inner.name() + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + self.inner.invoke(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers of the older API +/// (see https://github.com/apache/arrow-datafusion/pull/8578) +struct ScalarUdfLegacyWrapper { + /// The name of the function + name: String, + /// The signature (the types of arguments that are supported) + signature: Signature, + /// Function that returns the return type given the argument types + return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUdfLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl ScalarUDFImpl for ScalarUdfLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + (self.fun)(args) + } + + fn aliases(&self) -> &[String] { + &[] + } +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4d54dad996703..6f1da5f4e6d9b 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -811,6 +811,7 @@ mod test { static TEST_SIGNATURE: OnceLock = OnceLock::new(); + #[derive(Debug, Clone, Default)] struct TestScalarUDF {} impl ScalarUDFImpl for TestScalarUDF { fn as_any(&self) -> &dyn Any {