Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarUDF in terms of ScalarUDFImpl trait #8713

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use std::sync::Arc;
/// the power of the second argument `a^b`.
///
/// To do so, we must implement the `ScalarUDFImpl` trait.
#[derive(Debug, Clone)]
struct PowUdf {
signature: Signature,
aliases: Vec<String>,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,7 @@ mod test {
);

// UDF
#[derive(Debug)]
Copy link
Contributor Author

@alamb alamb Jan 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example of the API change -- any new impl of ScalarUDFImpl must also derive Debug -- note that ScalarUDFImpl was introduced in #8578 and not yet released so this is not a breaking change for released versions

struct TestScalarUDF {
signature: Signature,
}
Expand Down
10 changes: 10 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,16 @@ pub struct SimpleScalarUDF {
fun: ScalarFunctionImplementation,
}

impl Debug for SimpleScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl SimpleScalarUDF {
/// Create a new `SimpleScalarUDF` from a name, input types, return type and
/// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
Expand Down
231 changes: 157 additions & 74 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,57 +35,35 @@ use std::sync::Arc;
/// functions you supply such name, type signature, return type, and actual
/// implementation.
///
///
/// 1. For simple (less performant) use cases, use [`create_udf`] and [`simple_udf.rs`].
///
/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
///
/// # API Note
Copy link
Contributor Author

@alamb alamb Jan 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I went through this PR, ScalarUDF is now basically a pass through wrapper to ScalarUDFImpl -- if we didn't want to maintain backwards compatibility we could probably simply remove the ScalarUDF struct and make it a trait, but I think that would be super disruptive to all exisiting users of DataFusion so I think we should avoid doing so unless absolutely necessary.

///
/// This is a separate struct from `ScalarUDFImpl` to maintain backwards
/// compatibility with the older API.
///
/// [`create_udf`]: crate::expr_fn::create_udf
/// [`simple_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
/// [`advanced_udf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct ScalarUDF {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API for ScalarUDF is not changed at all -- only its internal implementation

/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
/// be passed with the slice / vec of columnar values (either scalar or array)
/// with the exception of zero param function, where a singular element vec
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
fun: ScalarFunctionImplementation,
/// Optional aliases for the function. This list should NOT include the value of `name` as well
aliases: Vec<String>,
}

impl Debug for ScalarUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
inner: Arc<dyn ScalarUDFImpl>,
}

impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.signature == other.signature
self.name() == other.name() && self.signature() == other.signature()
}
}

impl Eq for ScalarUDF {}

impl std::hash::Hash for ScalarUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
self.name().hash(state);
self.signature().hash(state);
}
}

Expand All @@ -101,51 +79,37 @@ impl ScalarUDF {
return_type: &ReturnTypeFunction,
fun: &ScalarFunctionImplementation,
) -> Self {
Self {
Self::new_from_impl(ScalarUdfLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
fun: fun.clone(),
aliases: vec![],
}
})
}

/// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
///
/// Note this is the same as using the `From` impl (`ScalarUDF::from`)
pub fn new_from_impl<F>(fun: F) -> ScalarUDF
where
F: ScalarUDFImpl + Send + Sync + 'static,
F: ScalarUDFImpl + 'static,
{
// TODO change the internal implementation to use the trait object
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let arc_fun = Arc::new(fun);
let captured_self = arc_fun.clone();
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
let return_type = captured_self.return_type(arg_types)?;
Ok(Arc::new(return_type))
});

let captured_self = arc_fun.clone();
let func: ScalarFunctionImplementation =
Arc::new(move |args| captured_self.invoke(args));

Self {
name: arc_fun.name().to_string(),
signature: arc_fun.signature().clone(),
return_type: return_type.clone(),
fun: func,
aliases: arc_fun.aliases().to_vec(),
inner: Arc::new(fun),
}
}

/// Adds additional names that can be used to invoke this function, in addition to `name`
pub fn with_aliases(
mut self,
aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
self.aliases
.extend(aliases.into_iter().map(|s| s.to_string()));
self
/// Return the underlying [`ScalarUDFImpl`] trait object for this function
pub fn inner(&self) -> Arc<dyn ScalarUDFImpl> {
self.inner.clone()
}

/// Adds additional names that can be used to invoke this function, in
/// addition to `name`
///
/// If you implement [`ScalarUDFImpl`] directly you should return aliases directly.
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) -> Self {
Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat of a hack (to have a wrapper around the scalar UDF). It may make more sense to simply remove the call to with_aliases -- however, since it was released in datafusion 34.0.0 -- https://docs.rs/datafusion/34.0.0/datafusion/physical_plan/udf/struct.ScalarUDF.html -- that would be a breaking API change.

We could deprecate the API 🤔

}

/// Returns a [`Expr`] logical expression to call this UDF with specified
Expand All @@ -159,31 +123,46 @@ impl ScalarUDF {
))
}

/// Returns this function's name
/// Returns this function's name.
///
/// See [`ScalarUDFImpl::name`] for more details.
pub fn name(&self) -> &str {
&self.name
self.inner.name()
}

/// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
/// Returns the aliases for this function.
///
/// See [`ScalarUDF::with_aliases`] for more details
pub fn aliases(&self) -> &[String] {
&self.aliases
self.inner.aliases()
}

/// Returns this function's [`Signature`] (what input types are accepted)
/// Returns this function's [`Signature`] (what input types are accepted).
///
/// See [`ScalarUDFImpl::signature`] for more details.
pub fn signature(&self) -> &Signature {
&self.signature
self.inner.signature()
}

/// The datatype this function returns given the input argument input types
/// The datatype this function returns given the input argument input types.
///
/// See [`ScalarUDFImpl::return_type`] for more details.
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
self.inner.return_type(args)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
self.inner.invoke(args)
}

/// Return an [`Arc`] to the function implementation
/// Returns a `ScalarFunctionImplementation` that can invoke the function
/// during execution
pub fn fun(&self) -> ScalarFunctionImplementation {
self.fun.clone()
let captured = self.inner.clone();
Arc::new(move |args| captured.invoke(args))
}
}

Expand Down Expand Up @@ -213,6 +192,7 @@ where
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// #[derive(Debug)]
/// struct AddOne {
/// signature: Signature
/// };
Expand Down Expand Up @@ -246,7 +226,7 @@ where
/// // Call the function `add_one(col)`
/// let expr = add_one.call(vec![col("a")]);
/// ```
pub trait ScalarUDFImpl {
pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Returns this object as an [`Any`] trait object
fn as_any(&self) -> &dyn Any;

Expand Down Expand Up @@ -292,3 +272,106 @@ pub trait ScalarUDFImpl {
&[]
}
}

/// ScalarUDF that adds an alias to the underlying function. It is better to
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat boilerplate, but it is a pretty straightforward example of using Trait objects to extend functionality

/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
struct AliasedScalarUDFImpl {
inner: ScalarUDF,
aliases: Vec<String>,
}

impl AliasedScalarUDFImpl {
pub fn new(
inner: ScalarUDF,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));

Self { inner, aliases }
}
}

impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
self.inner.name()
}

fn signature(&self) -> &Signature {
self.inner.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
self.inner.invoke(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers
/// of the older API (see <https://github.com/apache/arrow-datafusion/pull/8578>
/// for more details)
struct ScalarUdfLegacyWrapper {
/// The name of the function
name: String,
/// The signature (the types of arguments that are supported)
signature: Signature,
/// Function that returns the return type given the argument types
return_type: ReturnTypeFunction,
/// actual implementation
///
/// The fn param is the wrapped function but be aware that the function will
/// be passed with the slice / vec of columnar values (either scalar or array)
/// with the exception of zero param function, where a singular element vec
/// will be passed. In that case the single element is a null array to indicate
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
fun: ScalarFunctionImplementation,
}

impl Debug for ScalarUdfLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl ScalarUDFImpl for ScalarUdfLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
(self.fun)(args)
}

fn aliases(&self) -> &[String] {
&[]
}
}
1 change: 1 addition & 0 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ mod test {

static TEST_SIGNATURE: OnceLock<Signature> = OnceLock::new();

#[derive(Debug, Clone, Default)]
struct TestScalarUDF {}
impl ScalarUDFImpl for TestScalarUDF {
fn as_any(&self) -> &dyn Any {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub fn create_physical_expr(

Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
fun.fun().clone(),
fun.fun(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a drive by cleanup I noticed while working on the code

input_phy_exprs.to_vec(),
fun.return_type(&input_exprs_types)?,
None,
Expand Down