From 13569340bce99e4a317ec4d71e5c46d69dfa733d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 24 Jul 2024 22:30:05 +0800 Subject: [PATCH] ExprBuilder for Physical Aggregate Expr (#11617) * aggregate expr builder Signed-off-by: jayzhan211 * replace parts of test Signed-off-by: jayzhan211 * continue Signed-off-by: jayzhan211 * cleanup all Signed-off-by: jayzhan211 * clipp Signed-off-by: jayzhan211 * add sort Signed-off-by: jayzhan211 * rm field Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 * fix import path Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/src/lib.rs | 5 + .../aggregate_statistics.rs | 20 +- .../combine_partial_final_agg.rs | 41 +-- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 23 +- .../physical-expr-common/src/aggregate/mod.rs | 286 +++++++++++++----- .../physical-plan/src/aggregates/mod.rs | 134 +++----- datafusion/physical-plan/src/windows/mod.rs | 39 +-- datafusion/proto/src/physical_plan/mod.rs | 11 +- .../tests/cases/roundtrip_physical_plan.rs | 177 ++++------- 9 files changed, 369 insertions(+), 367 deletions(-) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 9ab6ed527d82..d9ab9e1c07dd 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -545,6 +545,11 @@ pub mod optimizer { pub use datafusion_optimizer::*; } +/// re-export of [`datafusion_physical_expr`] crate +pub mod physical_expr_common { + pub use datafusion_physical_expr_common::*; +} + /// re-export of [`datafusion_physical_expr`] crate pub mod physical_expr { pub use datafusion_physical_expr::*; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index e7580d3e33ef..5f08e4512b3a 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -326,7 +326,7 @@ pub(crate) mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic @@ -419,19 +419,11 @@ pub(crate) mod tests { // Return appropriate expr depending if COUNT is for col or table (*) pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { - create_aggregate_expr( - &count_udaf(), - &[self.column()], - &[], - &[], - &[], - schema, - self.column_name(), - false, - false, - false, - ) - .unwrap() + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .name(self.column_name()) + .build() + .unwrap() } /// what argument would this aggregate need in the plan? diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index ddb7d36fb595..6f3274820c8c 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -177,7 +177,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; - use datafusion_physical_plan::udaf::create_aggregate_expr; + use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { @@ -278,19 +278,11 @@ mod tests { name: &str, schema: &Schema, ) -> Arc { - create_aggregate_expr( - &count_udaf(), - &[expr], - &[], - &[], - &[], - schema, - name, - false, - false, - false, - ) - .unwrap() + AggregateExprBuilder::new(count_udaf(), vec![expr]) + .schema(Arc::new(schema.clone())) + .name(name) + .build() + .unwrap() } #[test] @@ -368,19 +360,14 @@ mod tests { #[test] fn aggregations_with_group_combined() -> Result<()> { let schema = schema(); - - let aggr_expr = vec![create_aggregate_expr( - &sum_udaf(), - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "Sum(b)", - false, - false, - false, - )?]; + let aggr_expr = + vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("Sum(b)") + .build() + .unwrap(), + ]; let groups: Vec<(Arc, String)> = vec![(col("c", &schema)?, "c".to_string())]; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 736560da97db..6f286c9aeba1 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -35,7 +35,7 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor} use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_plan::udaf::create_aggregate_expr; +use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; @@ -103,19 +103,14 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str .with_sort_information(vec![sort_keys]), ); - let aggregate_expr = vec![create_aggregate_expr( - &sum_udaf(), - &[col("d", &schema).unwrap()], - &[], - &[], - &[], - &schema, - "sum1", - false, - false, - false, - ) - .unwrap()]; + let aggregate_expr = + vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .name("sum1") + .build() + .unwrap(), + ]; let expr = group_by_columns .iter() .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 8c5f9f9e5a7e..b58a5a6faf24 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -22,8 +22,8 @@ pub mod stats; pub mod tdigest; pub mod utils; -use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{not_impl_err, DFSchema, Result}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; use datafusion_expr::ReversedUDAF; @@ -33,7 +33,7 @@ use datafusion_expr::{ use std::fmt::Debug; use std::{any::Any, sync::Arc}; -use self::utils::{down_cast_any_ref, ordering_fields}; +use self::utils::down_cast_any_ref; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; @@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity; /// `is_reversed` is used to indicate whether the aggregation is running in reverse order, /// it could be used to hint Accumulator to accumulate in the reversed order, /// you can just set to false if you are not reversing expression +/// +/// You can also create expression by [`AggregateExprBuilder`] #[allow(clippy::too_many_arguments)] pub fn create_aggregate_expr( fun: &AggregateUDF, @@ -66,45 +68,23 @@ pub fn create_aggregate_expr( name: impl Into, ignore_nulls: bool, is_distinct: bool, - is_reversed: bool, ) -> Result> { - debug_assert_eq!(sort_exprs.len(), ordering_req.len()); - - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(schema)) - .collect::>>()?; - - check_arg_count( - fun.name(), - &input_exprs_types, - &fun.signature().type_signature, - )?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - let name = name.into(); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - logical_args: input_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name, - schema: schema.clone(), - dfschema: DFSchema::empty(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - is_distinct, - input_type: input_exprs_types[0].clone(), - is_reversed, - })) + let mut builder = + AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec()); + builder = builder.sort_exprs(sort_exprs.to_vec()); + builder = builder.order_by(ordering_req.to_vec()); + builder = builder.logical_exprs(input_exprs.to_vec()); + builder = builder.schema(Arc::new(schema.clone())); + builder = builder.name(name); + + if ignore_nulls { + builder = builder.ignore_nulls(); + } + if is_distinct { + builder = builder.distinct(); + } + + builder.build() } #[allow(clippy::too_many_arguments)] @@ -121,44 +101,196 @@ pub fn create_aggregate_expr_with_dfschema( is_distinct: bool, is_reversed: bool, ) -> Result> { - debug_assert_eq!(sort_exprs.len(), ordering_req.len()); - + let mut builder = + AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec()); + builder = builder.sort_exprs(sort_exprs.to_vec()); + builder = builder.order_by(ordering_req.to_vec()); + builder = builder.logical_exprs(input_exprs.to_vec()); + builder = builder.dfschema(dfschema.clone()); let schema: Schema = dfschema.into(); + builder = builder.schema(Arc::new(schema)); + builder = builder.name(name); + + if ignore_nulls { + builder = builder.ignore_nulls(); + } + if is_distinct { + builder = builder.distinct(); + } + if is_reversed { + builder = builder.reversed(); + } + + builder.build() +} + +/// Builder for physical [`AggregateExpr`] +/// +/// `AggregateExpr` contains the information necessary to call +/// an aggregate expression. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilder { + fun: Arc, + /// Physical expressions of the aggregate function + args: Vec>, + /// Logical expressions of the aggregate function, it will be deprecated in + logical_args: Vec, + name: String, + /// Arrow Schema for the aggregate function + schema: SchemaRef, + /// Datafusion Schema for the aggregate function + dfschema: DFSchema, + /// The logical order by expressions, it will be deprecated in + sort_exprs: Vec, + /// The physical order by expressions + ordering_req: LexOrdering, + /// Whether to ignore null values + ignore_nulls: bool, + /// Whether is distinct aggregate function + is_distinct: bool, + /// Whether the expression is reversed + is_reversed: bool, +} + +impl AggregateExprBuilder { + pub fn new(fun: Arc, args: Vec>) -> Self { + Self { + fun, + args, + logical_args: vec![], + name: String::new(), + schema: Arc::new(Schema::empty()), + dfschema: DFSchema::empty(), + sort_exprs: vec![], + ordering_req: vec![], + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + } + } + + pub fn build(self) -> Result> { + let Self { + fun, + args, + logical_args, + name, + schema, + dfschema, + sort_exprs, + ordering_req, + ignore_nulls, + is_distinct, + is_reversed, + } = self; + if args.is_empty() { + return internal_err!("args should not be empty"); + } + + let mut ordering_fields = vec![]; + + debug_assert_eq!(sort_exprs.len(), ordering_req.len()); + if !ordering_req.is_empty() { + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; + + ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); + } + + let input_exprs_types = args + .iter() + .map(|arg| arg.data_type(&schema)) + .collect::>>()?; + + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(&schema)) - .collect::>>()?; - - check_arg_count( - fun.name(), - &input_exprs_types, - &fun.signature().type_signature, - )?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - logical_args: input_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name: name.into(), - schema: schema.clone(), - dfschema: dfschema.clone(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - is_distinct, - input_type: input_exprs_types[0].clone(), - is_reversed, - })) + let data_type = fun.return_type(&input_exprs_types)?; + + Ok(Arc::new(AggregateFunctionExpr { + fun: Arc::unwrap_or_clone(fun), + args, + logical_args, + data_type, + name, + schema: Arc::unwrap_or_clone(schema), + dfschema, + sort_exprs, + ordering_req, + ignore_nulls, + ordering_fields, + is_distinct, + input_type: input_exprs_types[0].clone(), + is_reversed, + })) + } + + pub fn name(mut self, name: impl Into) -> Self { + self.name = name.into(); + self + } + + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = schema; + self + } + + pub fn dfschema(mut self, dfschema: DFSchema) -> Self { + self.dfschema = dfschema; + self + } + + pub fn order_by(mut self, order_by: LexOrdering) -> Self { + self.ordering_req = order_by; + self + } + + pub fn reversed(mut self) -> Self { + self.is_reversed = true; + self + } + + pub fn with_reversed(mut self, is_reversed: bool) -> Self { + self.is_reversed = is_reversed; + self + } + + pub fn distinct(mut self) -> Self { + self.is_distinct = true; + self + } + + pub fn with_distinct(mut self, is_distinct: bool) -> Self { + self.is_distinct = is_distinct; + self + } + + pub fn ignore_nulls(mut self) -> Self { + self.ignore_nulls = true; + self + } + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } + + /// This method will be deprecated in + pub fn sort_exprs(mut self, sort_exprs: Vec) -> Self { + self.sort_exprs = sort_exprs; + self + } + + /// This method will be deprecated in + pub fn logical_exprs(mut self, logical_args: Vec) -> Self { + self.logical_args = logical_args; + self + } } /// An aggregate expression that: diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e7cd5cb2725b..d1152038eb2a 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1211,7 +1211,7 @@ mod tests { use crate::common::collect; use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr, create_aggregate_expr_with_dfschema, + create_aggregate_expr_with_dfschema, AggregateExprBuilder, }; use datafusion_physical_expr_common::expressions::Literal; use futures::{FutureExt, Stream}; @@ -1351,18 +1351,11 @@ mod tests { ], }; - let aggregates = vec![create_aggregate_expr( - &count_udaf(), - &[lit(1i8)], - &[datafusion_expr::lit(1i8)], - &[], - &[], - &input_schema, - "COUNT(1)", - false, - false, - false, - )?]; + let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .name("COUNT(1)") + .logical_exprs(vec![datafusion_expr::lit(1i8)]) + .build()?]; let task_ctx = if spill { new_spill_ctx(4, 1000) @@ -1501,18 +1494,13 @@ mod tests { groups: vec![vec![false]], }; - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &input_schema)?], - &[datafusion_expr::col("b")], - &[], - &[], - &input_schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .name("AVG(b)") + .build()?, + ]; let task_ctx = if spill { // set to an appropriate value to trigger spill @@ -1803,21 +1791,11 @@ mod tests { } // Median(a) - fn test_median_agg_expr(schema: &Schema) -> Result> { - let args = vec![col("a", schema)?]; - let fun = median_udaf(); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( - &fun, - &args, - &[], - &[], - &[], - schema, - "MEDIAN(a)", - false, - false, - false, - ) + fn test_median_agg_expr(schema: SchemaRef) -> Result> { + AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) + .schema(schema) + .name("MEDIAN(a)") + .build() } #[tokio::test] @@ -1840,21 +1818,16 @@ mod tests { // something that allocates within the aggregator let aggregates_v0: Vec> = - vec![test_median_agg_expr(&input_schema)?]; + vec![test_median_agg_expr(Arc::clone(&input_schema))?]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &input_schema)?], - &[datafusion_expr::col("b")], - &[], - &[], - &input_schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates_v2: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .name("AVG(b)") + .build()?, + ]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), @@ -1908,18 +1881,13 @@ mod tests { let groups = PhysicalGroupBy::default(); - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("a", &schema)?], - &[datafusion_expr::col("a")], - &[], - &[], - &schema, - "AVG(a)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(a)") + .build()?, + ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1953,18 +1921,13 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &schema)?], - &[datafusion_expr::col("b")], - &[], - &[], - &schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(b)") + .build()?, + ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -2388,18 +2351,11 @@ mod tests { ], ); - let aggregates: Vec> = vec![create_aggregate_expr( - count_udaf().as_ref(), - &[lit(1)], - &[datafusion_expr::lit(1)], - &[], - &[], - schema.as_ref(), - "1", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&schema)) + .name("1") + .build()?]; let input_batches = (0..4) .map(|_| { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 959796489c19..ffe558e21583 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -26,16 +26,16 @@ use crate::{ cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, PhysicalSortExpr, RowNumber, }, - udaf, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, Column, DataFusionError, Result, ScalarValue}; -use datafusion_expr::Expr; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::{col, Expr, SortExpr}; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, SortExpr, WindowFrame, - WindowFunctionDefinition, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -44,6 +44,7 @@ use datafusion_physical_expr::{ AggregateExpr, ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use itertools::Itertools; mod bounded_window_agg_exec; @@ -95,7 +96,7 @@ pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, args: &[Arc], - logical_args: &[Expr], + _logical_args: &[Expr], partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -129,7 +130,6 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - // TODO: Ordering not supported for Window UDFs yet // Convert `Vec` into `Vec` let sort_exprs = order_by .iter() @@ -137,28 +137,20 @@ pub fn create_window_expr( let field_name = expr.to_string(); let field_name = field_name.split('@').next().unwrap_or(&field_name); Expr::Sort(SortExpr { - expr: Box::new(Expr::Column(Column::new( - None::, - field_name, - ))), + expr: Box::new(col(field_name)), asc: !options.descending, nulls_first: options.nulls_first, }) }) .collect::>(); - let aggregate = udaf::create_aggregate_expr( - fun.as_ref(), - args, - logical_args, - &sort_exprs, - order_by, - input_schema, - name, - ignore_nulls, - false, - false, - )?; + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .name(name) + .order_by(order_by.to_vec()) + .sort_exprs(sort_exprs) + .with_ignore_nulls(ignore_nulls) + .build()?; window_expr_from_aggregate_expr( partition_by, order_by, @@ -166,6 +158,7 @@ pub fn create_window_expr( aggregate, ) } + // TODO: Ordering not supported for Window UDFs yet WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name)?, partition_by, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 8c9e5bbd0e95..5c4d41f0eca6 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -18,6 +18,7 @@ use std::fmt::Debug; use std::sync::Arc; +use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; use prost::bytes::BufMut; use prost::Message; @@ -58,7 +59,7 @@ use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, + AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF}; @@ -501,13 +502,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { None => registry.udaf(udaf_name)? }; - // TODO: 'logical_exprs' is not supported for UDAF yet. - // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. - let logical_exprs = &[]; + // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. // TODO: `order by` is not supported for UDAF yet - let sort_exprs = &[]; - let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, agg_node.distinct, false) + AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 31ed0837d2f5..3ddc122e3de2 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -24,6 +24,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; +use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; use prost::Message; use datafusion::arrow::array::ArrayRef; @@ -64,7 +65,6 @@ use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; -use datafusion::physical_plan::udaf::create_aggregate_expr; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, @@ -86,7 +86,7 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::nth_value::nth_value_udaf; -use datafusion_functions_aggregate::string_agg::StringAgg; +use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -291,18 +291,13 @@ fn roundtrip_window() -> Result<()> { )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( - create_aggregate_expr( - &avg_udaf(), - &[cast(col("b", &schema)?, &schema, DataType::Float64)?], - &[], - &[], - &[], - &schema, - "avg(b)", - false, - false, - false, - )?, + AggregateExprBuilder::new( + avg_udaf(), + vec![cast(col("b", &schema)?, &schema, DataType::Float64)?], + ) + .schema(Arc::clone(&schema)) + .name("avg(b)") + .build()?, &[], &[], Arc::new(WindowFrame::new(None)), @@ -315,18 +310,10 @@ fn roundtrip_window() -> Result<()> { ); let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; - let sum_expr = create_aggregate_expr( - &sum_udaf(), - &args, - &[], - &[], - &[], - &schema, - "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", - false, - false, - false, - )?; + let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) + .schema(Arc::clone(&schema)) + .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") + .build()?; let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( sum_expr, @@ -357,49 +344,28 @@ fn rountrip_aggregate() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; + let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(b)") + .build()?; + let nth_expr = + AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)]) + .schema(Arc::clone(&schema)) + .name("NTH_VALUE(b, 1)") + .build()?; + let str_agg_expr = + AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)]) + .schema(Arc::clone(&schema)) + .name("NTH_VALUE(b, 1)") + .build()?; + let test_cases: Vec>> = vec![ // AVG - vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "AVG(b)", - false, - false, - false, - )?], + vec![avg_expr], // NTH_VALUE - vec![create_aggregate_expr( - &nth_value_udaf(), - &[col("b", &schema)?, lit(1u64)], - &[], - &[], - &[], - &schema, - "NTH_VALUE(b, 1)", - false, - false, - false, - )?], + vec![nth_expr], // STRING_AGG - vec![create_aggregate_expr( - &AggregateUDF::new_from_impl(StringAgg::new()), - &[ - cast(col("b", &schema)?, &schema, DataType::Utf8)?, - lit(ScalarValue::Utf8(Some(",".to_string()))), - ], - &[], - &[], - &[], - &schema, - "STRING_AGG(name, ',')", - false, - false, - false, - )?], + vec![str_agg_expr], ]; for aggregates in test_cases { @@ -426,18 +392,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![create_aggregate_expr( - &avg_udaf(), - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "AVG(b)", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("AVG(b)") + .build()?, + ]; let agg = AggregateExec::try_new( AggregateMode::Final, @@ -498,18 +459,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![create_aggregate_expr( - &udaf, - &[col("b", &schema)?], - &[], - &[], - &[], - &schema, - "example_agg", - false, - false, - false, - )?]; + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .name("example_agg") + .build()?, + ]; roundtrip_test_with_context( Arc::new(AggregateExec::try_new( @@ -994,21 +950,16 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { DataType::Int64, )); - let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string())); - let aggr_args: [Arc; 1] = - [Arc::new(Literal::new(ScalarValue::from(42)))]; - let aggr_expr = create_aggregate_expr( - &udaf, - &aggr_args, - &[], - &[], - &[], - &schema, - "aggregate_udf", - false, - false, - false, - )?; + let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( + "result".to_string(), + ))); + let aggr_args: Vec> = + vec![Arc::new(Literal::new(ScalarValue::from(42)))]; + + let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) + .schema(Arc::clone(&schema)) + .name("aggregate_udf") + .build()?; let filter = Arc::new(FilterExec::try_new( Arc::new(BinaryExpr::new( @@ -1030,18 +981,12 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![col("author", &schema)?], )?); - let aggr_expr = create_aggregate_expr( - &udaf, - &aggr_args, - &[], - &[], - &[], - &schema, - "aggregate_udf", - true, - true, - false, - )?; + let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone()) + .schema(Arc::clone(&schema)) + .name("aggregate_udf") + .distinct() + .ignore_nulls() + .build()?; let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final,