From 85719df5f2fff12a1752c47086a4a9cf158c28da Mon Sep 17 00:00:00 2001 From: junxiangMu <63799833+guojidan@users.noreply.github.com> Date: Wed, 10 Jan 2024 01:27:19 +0800 Subject: [PATCH] Implement trait based API for define AggregateUDF (#8733) * Implement trait based API for define AggregateUDF * implement Inner * fix test case && doc * fix annotation --- datafusion-examples/README.md | 1 + datafusion-examples/examples/advanced_udaf.rs | 228 ++++++++++++++++++ .../user_defined/user_defined_aggregates.rs | 47 ++-- .../user_defined_scalar_functions.rs | 4 +- datafusion/expr/src/expr_fn.rs | 109 ++++++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 224 ++++++++++++++--- datafusion/expr/src/udwf.rs | 2 +- .../optimizer/src/analyzer/type_coercion.rs | 24 +- .../optimizer/src/common_subexpr_eliminate.rs | 23 +- .../tests/cases/roundtrip_physical_plan.rs | 21 +- docs/source/library-user-guide/adding-udfs.md | 7 +- 12 files changed, 581 insertions(+), 111 deletions(-) create mode 100644 datafusion-examples/examples/advanced_udaf.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index aae451add9e7..eecb63d3be65 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -62,6 +62,7 @@ cargo run --example csv_sql - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) +- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs new file mode 100644 index 000000000000..8d5314bfbea5 --- /dev/null +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, Float32Array}, + record_batch::RecordBatch, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{cast::as_float64_array, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; + +/// This example shows how to use the full AggregateUDFImpl API to implement a user +/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements +/// a function `accumulator` that returns the `Accumulator` instance. +/// +/// To do so, we must implement the `AggregateUDFImpl` trait. +#[derive(Debug, Clone)] +struct GeoMeanUdf { + signature: Signature, +} + +impl GeoMeanUdf { + /// Create a new instance of the GeoMeanUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for GeoMeanUdf { + /// We implement as_any so that we can downcast the AggregateUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "geo_mean" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the accumulator factory; DataFusion uses it to create new accumulators. + fn accumulator(&self, _arg: &DataType) -> Result> { + Ok(Box::new(GeometricMean::new())) + } + + /// This is the description of the state. accumulator's state() must match the types here. + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(vec![DataType::Float64, DataType::UInt32]) + } +} + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Merge the output of `Self::state()` from other instances of this accumulator + // into this accumulator's state + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = + (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the AggregateUDF + let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; + + // get a DataFrame from the context + // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. + let df = ctx.table("t").await?; + + // perform the aggregation + let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; + + // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. + + // execute the query + let results = df.collect().await?; + + // downcast the array to the expected type + let result = as_float64_array(results[0].column(0))?; + + // verify that the calculation is correct + assert!((result.value(0) - 8.0).abs() < f64::EPSILON); + println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index fb0ecd02c6b0..5882718acefd 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -36,8 +36,7 @@ use datafusion::{ assert_batches_eq, error::Result, logical_expr::{ - AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, TypeSignature, Volatility, + AccumulatorFactoryFunction, AggregateUDF, Signature, TypeSignature, Volatility, }, physical_plan::Accumulator, prelude::SessionContext, @@ -46,7 +45,7 @@ use datafusion::{ use datafusion_common::{ assert_contains, cast::as_primitive_array, exec_err, DataFusionError, }; -use datafusion_expr::create_udaf; +use datafusion_expr::{create_udaf, SimpleAggregateUDF}; use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup @@ -141,7 +140,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { name: \"time_sum\""); + assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\", signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]), volatility: Immutable }, fun: \"\" } }(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); } /// Basic query for with a udaf returning a structure @@ -408,26 +407,27 @@ impl TimeSum { fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let input_type = vec![timestamp_type.clone()]; // Returns the same type as its input - let return_type = Arc::new(timestamp_type.clone()); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); + let return_type = timestamp_type.clone(); - let state_type = Arc::new(vec![timestamp_type.clone()]); - let state_type: StateTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&state_type))); + let state_type = vec![timestamp_type.clone()]; let volatility = Volatility::Immutable; - let signature = Signature::exact(vec![timestamp_type], volatility); - let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); - let time_sum = - AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type); + let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( + name, + input_type, + return_type, + volatility, + accumulator, + state_type, + )); // register the selector as "time_sum" ctx.register_udaf(time_sum) @@ -510,11 +510,8 @@ impl FirstSelector { } fn register(ctx: &mut SessionContext) { - let return_type = Arc::new(Self::output_datatype()); - let state_type = Arc::new(Self::state_datatypes()); - - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + let return_type = Self::output_datatype(); + let state_type = Self::state_datatypes(); // Possible input signatures let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; @@ -526,13 +523,13 @@ impl FirstSelector { let name = "first"; - let first = AggregateUDF::new( + let first = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( name, - &Signature::one_of(signatures, volatility), - &return_type, - &accumulator, - &state_type, - ); + Signature::one_of(signatures, volatility), + return_type, + accumulator, + state_type, + )); // register the selector as "first" ctx.register_udaf(first) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 985b0bd5bc76..4f39f2374ea9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -291,8 +291,8 @@ async fn udaf_as_window_func() -> Result<()> { context.register_udaf(my_acc); let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table"; - let expected = r#"Projection: my_table.a, AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - WindowAggr: windowExpr=[[AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + let expected = r#"Projection: my_table.a, AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; let dataframe = context.sql(sql).await.unwrap(); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0491750d18a9..cc8322272aa8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -25,10 +25,10 @@ use crate::function::PartitionEvaluatorFactory; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, - BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, - ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, + BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, + ScalarUDF, Signature, Volatility, }; -use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; @@ -1047,15 +1047,102 @@ pub fn create_udaf( accumulator: AccumulatorFactoryFunction, state_type: Arc>, ) -> AggregateUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); - AggregateUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + AggregateUDF::from(SimpleAggregateUDF::new( name, - &Signature::exact(input_type, volatility), - &return_type, - &accumulator, - &state_type, - ) + input_type, + return_type, + volatility, + accumulator, + state_type, + )) +} + +/// Implements [`AggregateUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleAggregateUDF { + name: String, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, +} + +impl Debug for SimpleAggregateUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl SimpleAggregateUDF { + /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: Vec, + return_type: DataType, + volatility: Volatility, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } + + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } +} + +impl AggregateUDFImpl for SimpleAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(self.state_type.clone()) + } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 077681d21725..0d431f10c432 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::AggregateUDF; +pub use udaf::{AggregateUDF, AggregateUDFImpl}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index cfbca4ab1337..4983f6247d24 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -23,6 +23,7 @@ use crate::{ }; use arrow::datatypes::DataType; use datafusion_common::Result; +use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -42,36 +43,29 @@ use std::sync::Arc; /// /// For more information, please see [the examples]. /// +/// 1. For simple (less performant) use cases, use [`create_udaf`] and [`simple_udaf.rs`]. +/// +/// 2. For advanced use cases, use [`AggregateUDFImpl`] and [`advanced_udaf.rs`]. +/// +/// # API Note +/// This is a separate struct from `AggregateUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process /// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function /// [`Accumulator`]: crate::Accumulator -#[derive(Clone)] -pub struct AggregateUDF { - /// name - name: String, - /// Signature (input arguments) - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// actual implementation - accumulator: AccumulatorFactoryFunction, - /// the accumulator's state's description as a function of the return type - state_type: StateTypeFunction, -} +/// [`create_udaf`]: crate::expr_fn::create_udaf +/// [`simple_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } +#[derive(Debug, Clone)] +pub struct AggregateUDF { + inner: Arc, } impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -79,13 +73,17 @@ impl Eq for AggregateUDF {} impl std::hash::Hash for AggregateUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } impl AggregateUDF { /// Create a new AggregateUDF + /// + /// See [`AggregateUDFImpl`] for a more convenient way to create a + /// `AggregateUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -93,15 +91,32 @@ impl AggregateUDF { accumulator: &AccumulatorFactoryFunction, state_type: &StateTypeFunction, ) -> Self { - Self { + Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), accumulator: accumulator.clone(), state_type: state_type.clone(), + }) + } + + /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`AggregateUDF::from`) + pub fn new_from_impl(fun: F) -> AggregateUDF + where + F: AggregateUDFImpl + 'static, + { + Self { + inner: Arc::new(fun), } } + /// Return the underlying [`AggregateUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + /// creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to @@ -117,33 +132,176 @@ impl AggregateUDF { } /// Returns this function's name + /// + /// See [`AggregateUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } /// Returns this function's signature (what input types are accepted) + /// + /// See [`AggregateUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } /// Return the type of the function given its input types + /// + /// See [`AggregateUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) } /// Return an accumualator the given aggregate, given /// its return datatype. pub fn accumulator(&self, return_type: &DataType) -> Result> { - (self.accumulator)(return_type) + self.inner.accumulator(return_type) } /// Return the type of the intermediate state used by this aggregator, given /// its return datatype. Supports multi-phase aggregations pub fn state_type(&self, return_type: &DataType) -> Result> { - // old API returns an Arc for some reason, try and unwrap it here + self.inner.state_type(return_type) + } +} + +impl From for AggregateUDF +where + F: AggregateUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`AggregateUDF`]. +/// +/// This trait exposes the full API for implementing user defined aggregate functions and +/// can be used to implement any function. +/// +/// See [`advanced_udaf.rs`] for a full example with complete implementation and +/// [`AggregateUDF`] for other available options. +/// +/// +/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// #[derive(Debug, Clone)] +/// struct GeoMeanUdf { +/// signature: Signature +/// }; +/// +/// impl GeoMeanUdf { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf +/// impl AggregateUDFImpl for GeoMeanUdf { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "geo_mean" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Float64)) { +/// return plan_err!("add_one only accepts Float64 arguments"); +/// } +/// Ok(DataType::Float64) +/// } +/// // This is the accumulator factory; DataFusion uses it to create new accumulators. +/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } +/// fn state_type(&self, _return_type: &DataType) -> Result> { +/// Ok(vec![DataType::Float64, DataType::UInt32]) +/// } +/// } +/// +/// // Create a new AggregateUDF from the implementation +/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); +/// +/// // Call the function `geo_mean(col)` +/// let expr = geometric_mean.call(vec![col("a")]); +/// ``` +pub trait AggregateUDFImpl: Debug + Send + Sync { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// Return a new [`Accumulator`] that aggregates values for a specific + /// group during query execution. + fn accumulator(&self, arg: &DataType) -> Result>; + + /// Return the type used to serialize the [`Accumulator`]'s intermediate state. + /// See [`Accumulator::state()`] for more details + fn state_type(&self, return_type: &DataType) -> Result>; +} + +/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers +/// of the older API +pub struct AggregateUDFLegacyWrapper { + /// name + name: String, + /// Signature (input arguments) + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// actual implementation + accumulator: AccumulatorFactoryFunction, + /// the accumulator's state's description as a function of the return type + state_type: StateTypeFunction, +} + +impl Debug for AggregateUDFLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl AggregateUDFImpl for AggregateUDFLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg) + } + + fn state_type(&self, return_type: &DataType) -> Result> { let res = (self.state_type)(return_type)?; - Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) + Ok(res.as_ref().clone()) } } diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 239a5e24cbf2..9b8f94f4b020 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -36,7 +36,7 @@ use std::{ /// /// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`]. /// -/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udf.rs`]. +/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udwf.rs`]. /// /// # API Note /// This is a separate struct from `WindowUDFImpl` to maintain backwards diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 6f1da5f4e6d9..3821279fed0f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -751,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, - Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, + SimpleAggregateUDF, Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, + Expr, LogicalPlan, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -903,19 +903,17 @@ mod test { #[test] fn agg_udaf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Float64))); - let state_type: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); + let return_type = DataType::Float64; + let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::::default())); - let my_avg = AggregateUDF::new( + let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "MY_AVG", - &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), - &return_type, - &accumulator, - &state_type, - ); + Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), + return_type, + accumulator, + state_type, + )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61a..000329d0d078 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -780,8 +780,8 @@ mod test { avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, }; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, - Signature, StateTypeFunction, Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, + SimpleAggregateUDF, Volatility, }; use crate::optimizer::OptimizerContext; @@ -901,21 +901,18 @@ mod test { fn aggregate() -> Result<()> { let table_scan = test_table_scan()?; - let return_type: ReturnTypeFunction = Arc::new(|inputs| { - assert_eq!(inputs, &[DataType::UInt32]); - Ok(Arc::new(DataType::UInt32)) - }); + let return_type = DataType::UInt32; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); - let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); + let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - Arc::new(AggregateUDF::new( + Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "my_agg", - &Signature::exact(vec![DataType::UInt32], Volatility::Stable), - &return_type, - &accumulator, - &state_type, - )), + Signature::exact(vec![DataType::UInt32], Volatility::Stable), + return_type.clone(), + accumulator.clone(), + state_type.clone(), + ))), vec![inner], false, None, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 27ac5d122f83..dd5fd73c69fd 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -73,8 +73,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; @@ -374,18 +374,17 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } } - let rt_func: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int64))); + let return_type = DataType::Int64; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); - let st_func: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64]))); + let state_type = vec![DataType::Int64]; - let udaf = AggregateUDF::new( + let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "example", - &Signature::exact(vec![DataType::Int64], Volatility::Immutable), - &rt_func, - &accumulator, - &st_func, - ); + Signature::exact(vec![DataType::Int64], Volatility::Immutable), + return_type, + accumulator, + state_type, + )); let ctx = SessionContext::new(); ctx.register_udaf(udaf.clone()); diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 1f687f978f30..64dc25411deb 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -398,7 +398,8 @@ impl Accumulator for GeometricMean { ### registering an Aggregate UDF -To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. +To register a Aggreate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udaf}; @@ -421,6 +422,10 @@ let geometric_mean = create_udaf( ); ``` +[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html +[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html +[`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs + The `create_udaf` has six arguments to check: - The first argument is the name of the function. This is the name that will be used in SQL queries.