From 0faaa266d8a57a3e81ed6900e771582643773366 Mon Sep 17 00:00:00 2001 From: Xin Li Date: Thu, 1 Aug 2024 18:31:46 +0800 Subject: [PATCH] improve AccumulatorArgs --- .../physical_plan/parquet/opener.rs | 2 +- .../physical_plan/parquet/statistics.rs | 1 + .../core/src/execution/session_state.rs | 2 +- .../aggregate_statistics.rs | 3 +- .../combine_partial_final_agg.rs | 5 +- .../src/physical_optimizer/limit_pushdown.rs | 2 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 3 +- .../src/expressions/column.rs | 4 +- datafusion/expr/src/expressions/mod.rs | 18 ++ datafusion/expr/src/function.rs | 6 +- datafusion/expr/src/lib.rs | 2 + .../src/physical_expr.rs | 6 +- datafusion/expr/src/tree_node.rs | 86 +++++++++- datafusion/expr/src/utils.rs | 141 ++++++++++++++++ .../src/approx_distinct.rs | 41 ++--- .../functions-aggregate/src/approx_median.rs | 2 +- .../src/approx_percentile_cont.rs | 43 +++-- .../src/approx_percentile_cont_with_weight.rs | 5 +- .../functions-aggregate/src/array_agg.rs | 13 +- datafusion/functions-aggregate/src/average.rs | 19 ++- datafusion/functions-aggregate/src/count.rs | 3 +- datafusion/functions-aggregate/src/median.rs | 2 +- .../functions-aggregate/src/nth_value.rs | 31 ++-- datafusion/functions-aggregate/src/stddev.rs | 8 +- .../functions-aggregate/src/string_agg.rs | 33 ++-- .../physical-expr-common/src/aggregate/mod.rs | 35 ++-- .../src/expressions/cast.rs | 4 +- .../src/expressions/literal.rs | 3 +- .../src/expressions/mod.rs | 1 - datafusion/physical-expr-common/src/lib.rs | 2 - .../physical-expr-common/src/sort_expr.rs | 3 +- .../physical-expr-common/src/tree_node.rs | 105 ------------ datafusion/physical-expr-common/src/utils.rs | 159 +----------------- datafusion/physical-expr/benches/case_when.rs | 4 +- datafusion/physical-expr/benches/is_null.rs | 4 +- .../src/equivalence/properties.rs | 6 +- .../physical-expr/src/expressions/binary.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-expr/src/lib.rs | 4 +- datafusion/physical-expr/src/physical_expr.rs | 4 +- .../physical-plan/src/aggregates/mod.rs | 16 +- datafusion/physical-plan/src/union.rs | 2 +- datafusion/physical-plan/src/windows/mod.rs | 7 +- datafusion/proto/src/physical_plan/mod.rs | 6 +- .../tests/cases/roundtrip_physical_plan.rs | 22 +-- datafusion/substrait/src/serializer.rs | 1 - 47 files changed, 434 insertions(+), 441 deletions(-) rename datafusion/{physical-expr-common => expr}/src/expressions/column.rs (98%) create mode 100644 datafusion/expr/src/expressions/mod.rs rename datafusion/{physical-expr-common => expr}/src/physical_expr.rs (98%) delete mode 100644 datafusion/physical-expr-common/src/tree_node.rs diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 4edc0ac525de..9e83d4dfe274 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -30,7 +30,7 @@ use crate::datasource::schema_adapter::SchemaAdapterFactory; use crate::physical_optimizer::pruning::PruningPredicate; use arrow_schema::{ArrowError, SchemaRef}; use datafusion_common::{exec_err, Result}; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use futures::{StreamExt, TryStreamExt}; use log::debug; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 11b8f5fc6c79..eec7c95fff94 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -517,6 +517,7 @@ macro_rules! make_data_page_stats_iterator { } } + #[allow(clippy::redundant_closure_call)] impl<'a, I> Iterator for $iterator_type<'a, I> where I: Iterator, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index ccad0240fddb..fd9b5c786859 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -48,6 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::planner::ExprPlanner; use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; use datafusion_expr::simplify::SimplifyInfo; @@ -61,7 +62,6 @@ use datafusion_optimizer::{ Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, }; use datafusion_physical_expr::create_physical_expr; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::ExecutionPlan; use datafusion_sql::parser::{DFParser, Statement}; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index a8332d1d55e4..590f9dc8fde1 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -324,6 +324,7 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; + use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; @@ -421,7 +422,7 @@ pub(crate) mod tests { // Return appropriate expr depending if COUNT is for col or table (*) pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![self.column()]) - .schema(Arc::new(schema.clone())) + .dfschema(schema.clone().to_dfschema().unwrap()) .name(self.column_name()) .build() .unwrap() diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 6f3274820c8c..ab547b86f582 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -174,6 +174,7 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; @@ -279,7 +280,7 @@ mod tests { schema: &Schema, ) -> Arc { AggregateExprBuilder::new(count_udaf(), vec![expr]) - .schema(Arc::new(schema.clone())) + .dfschema(schema.clone().to_dfschema().unwrap()) .name(name) .build() .unwrap() @@ -363,7 +364,7 @@ mod tests { let aggr_expr = vec![ AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("Sum(b)") .build() .unwrap(), diff --git a/datafusion/core/src/physical_optimizer/limit_pushdown.rs b/datafusion/core/src/physical_optimizer/limit_pushdown.rs index 4379a34a9426..ef641e40b78b 100644 --- a/datafusion/core/src/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/limit_pushdown.rs @@ -256,10 +256,10 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use datafusion_expr::expressions::column::col; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::Partitioning; - use datafusion_physical_expr_common::expressions::column::col; use datafusion_physical_expr_common::expressions::lit; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 6f286c9aeba1..31fa59af8c18 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -32,6 +32,7 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::ToDFSchema; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; @@ -106,7 +107,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let aggregate_expr = vec![ AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema().unwrap()) .name("sum1") .build() .unwrap(), diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/expr/src/expressions/column.rs similarity index 98% rename from datafusion/physical-expr-common/src/expressions/column.rs rename to datafusion/expr/src/expressions/column.rs index 5397599ea2dc..fa8e6188038a 100644 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ b/datafusion/expr/src/expressions/column.rs @@ -21,12 +21,12 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::ColumnarValue; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; @@ -89,7 +89,7 @@ impl PhysicalExpr for Column { /// Evaluate the expression fn evaluate(&self, batch: &RecordBatch) -> Result { self.bounds_check(batch.schema().as_ref())?; - Ok(ColumnarValue::Array(batch.column(self.index).clone())) + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/expr/src/expressions/mod.rs b/datafusion/expr/src/expressions/mod.rs new file mode 100644 index 000000000000..d102422081dc --- /dev/null +++ b/datafusion/expr/src/expressions/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod column; diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1f3f8ace4d17..1ed793228ded 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,6 +17,7 @@ //! Function module contains typing and signature for built-in and user defined functions. +use crate::physical_expr::PhysicalExpr; use crate::ColumnarValue; use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Field}; @@ -91,11 +92,8 @@ pub struct AccumulatorArgs<'a> { /// ``` pub is_distinct: bool, - /// The input types of the aggregate function. - pub input_types: &'a [DataType], - /// The logical expression of arguments the aggregate function takes. - pub input_exprs: &'a [Expr], + pub input_exprs: &'a [Arc], } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 0a5cf4653a22..3e02b0fdb3ed 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -46,10 +46,12 @@ pub mod expr; pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; +pub mod expressions; pub mod function; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod logical_plan; +pub mod physical_expr; pub mod planner; pub mod registry; pub mod simplify; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/expr/src/physical_expr.rs similarity index 98% rename from datafusion/physical-expr-common/src/physical_expr.rs rename to datafusion/expr/src/physical_expr.rs index e62606a42e6f..35bf5455df67 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/expr/src/physical_expr.rs @@ -23,15 +23,15 @@ use std::sync::Arc; use crate::expressions::column::Column; use crate::utils::scatter; +use crate::interval_arithmetic::Interval; +use crate::sort_properties::ExprProperties; +use crate::ColumnarValue; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{internal_err, not_impl_err, plan_err, Result}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::ColumnarValue; /// See [create_physical_expr](https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html) /// for examples of creating `PhysicalExpr` from `Expr` diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a97b9f010f79..813257122b65 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,15 +17,20 @@ //! Tree node implementation for logical expr +use std::fmt::{self, Display, Formatter}; +use std::sync::Arc; + use crate::expr::{ AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; +use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + ConcreteTreeNode, DynTreeNode, Transformed, TreeNode, TreeNodeIterator, + TreeNodeRecursion, }; use datafusion_common::{map_until_stop_and_collect, Result}; @@ -401,3 +406,82 @@ fn transform_vec Result>>( ) -> Result>> { ve.into_iter().map_until_stop_and_collect(f) } + +impl DynTreeNode for dyn PhysicalExpr { + fn arc_children(&self) -> Vec<&Arc> { + self.children() + } + + fn with_new_arc_children( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result> { + with_new_children_if_necessary(arc_self, new_children) + } +} + +/// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are +/// two ways to access child plans—directly from the plan and through child nodes—it's +/// recommended to perform mutable operations via [`Self::update_expr_from_children`]. +#[derive(Debug)] +pub struct ExprContext { + /// The physical expression associated with this context. + pub expr: Arc, + /// Custom data payload of the node. + pub data: T, + /// Child contexts of this node. + pub children: Vec, +} + +impl ExprContext { + pub fn new(expr: Arc, data: T, children: Vec) -> Self { + Self { + expr, + data, + children, + } + } + + pub fn update_expr_from_children(mut self) -> Result { + let children_expr = self.children.iter().map(|c| Arc::clone(&c.expr)).collect(); + self.expr = with_new_children_if_necessary(self.expr, children_expr)?; + Ok(self) + } +} + +impl ExprContext { + pub fn new_default(plan: Arc) -> Self { + let children = plan + .children() + .into_iter() + .cloned() + .map(Self::new_default) + .collect(); + Self::new(plan, Default::default(), children) + } +} + +impl Display for ExprContext { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "expr: {:?}", self.expr)?; + write!(f, "data:{}", self.data)?; + write!(f, "") + } +} + +impl ConcreteTreeNode for ExprContext { + fn children(&self) -> &[Self] { + &self.children + } + + fn take_children(mut self) -> (Self, Vec) { + let children = std::mem::take(&mut self.children); + (self, children) + } + + fn with_new_children(mut self, children: Vec) -> Result { + self.children = children; + self.update_expr_from_children() + } +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2ef1597abfd1..1d6919494587 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -23,12 +23,18 @@ use std::sync::Arc; use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; +use crate::physical_expr::PhysicalExpr; use crate::signature::{Signature, TypeSignature}; +use crate::sort_properties::ExprProperties; +use crate::tree_node::ExprContext; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, }; +use arrow::array::MutableArrayData; +use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow_array::{make_array, Array, ArrayRef, BooleanArray}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -1248,8 +1254,77 @@ impl AggregateOrderSensitivity { } } +/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` +/// are taken, when the mask evaluates `false` values null values are filled. +/// +/// # Arguments +/// * `mask` - Boolean values used to determine where to put the `truthy` values +/// * `truthy` - All values of this array are to scatter according to `mask` into final result. +pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { + let truthy = truthy.to_data(); + + // update the mask so that any null values become false + // (SlicesIterator doesn't respect nulls) + let mask = and_kleene(mask, &is_not_null(mask)?)?; + + let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); + + // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to + // fill with falsy values + + // keep track of how much is filled + let mut filled = 0; + // keep track of current position we have in truthy array + let mut true_pos = 0; + + SlicesIterator::new(&mask).for_each(|(start, end)| { + // the gap needs to be filled with nulls + if start > filled { + mutable.extend_nulls(start - filled); + } + // fill with truthy values + let len = end - start; + mutable.extend(0, true_pos, true_pos + len); + true_pos += len; + filled = end; + }); + // the remaining part is falsy + if filled < mask.len() { + mutable.extend_nulls(mask.len() - filled); + } + + let data = mutable.freeze(); + Ok(make_array(data)) +} + +/// Represents a [`PhysicalExpr`] node with associated properties (order and +/// range) in a context where properties are tracked. +pub type ExprPropertiesNode = ExprContext; + +impl ExprPropertiesNode { + /// Constructs a new `ExprPropertiesNode` with unknown properties for a + /// given physical expression. This node initializes with default properties + /// and recursively applies this to all child expressions. + pub fn new_unknown(expr: Arc) -> Self { + let children = expr + .children() + .into_iter() + .cloned() + .map(Self::new_unknown) + .collect(); + Self { + expr, + data: ExprProperties::new_unknown(), + children, + } + } +} + #[cfg(test)] mod tests { + use arrow_array::Int32Array; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, @@ -1696,4 +1771,70 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn scatter_int() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = + Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_int_end_with_false() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); + let mask = BooleanArray::from(vec![true, false, true, false, false, false]); + + // output should be same length as mask + let expected = + Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_with_null_mask() -> Result<()> { + let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); + let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] + .into_iter() + .collect(); + + // output should treat nulls as though they are false + let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_int32_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } + + #[test] + fn scatter_boolean() -> Result<()> { + let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); + let mask = BooleanArray::from(vec![true, true, false, false, true]); + + // the output array is expected to be the same length as the mask array + let expected = BooleanArray::from_iter(vec![ + Some(false), + Some(false), + None, + None, + Some(false), + ]); + let result = scatter(&mask, truthy.as_ref())?; + let result = as_boolean_array(&result)?; + + assert_eq!(&expected, result); + Ok(()) + } } diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 56ef32e7ebe0..bcd132ec4910 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -277,28 +277,29 @@ impl AggregateUDFImpl for ApproxDistinct { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let accumulator: Box = match &acc_args.input_types[0] { - // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL - // TODO support for boolean (trivial case) - // https://github.com/apache/datafusion/issues/1109 - DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), - DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), - DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), - DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), - other => { - return not_impl_err!( + let accumulator: Box = + match &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())? { + // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL + // TODO support for boolean (trivial case) + // https://github.com/apache/datafusion/issues/1109 + DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), + DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), + DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), + DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), + other => { + return not_impl_err!( "Support for 'approx_distinct' for data type {other} is not implemented" ) - } - }; + } + }; Ok(accumulator) } } diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index e12e3445a83e..f37c164799bd 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_types[0].clone(), + acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?, ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 844e48f0a44d..8fdd45e71cf3 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -30,21 +30,20 @@ use arrow::{ }; use arrow_schema::{Field, Schema}; +use datafusion_common::DataFusionError; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, - ScalarValue, + downcast_value, internal_err, not_impl_err, plan_err, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, - Volatility, + Accumulator, AggregateUDFImpl, ColumnarValue, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; -use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; make_udaf_expr_and_func!( ApproxPercentileCont, @@ -105,7 +104,7 @@ impl ApproxPercentileCont { None }; - let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] { + let accumulator: ApproxPercentileAccumulator = match &args.input_exprs[0].data_type(args.dfschema.as_arrow())? { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 @@ -134,24 +133,22 @@ impl ApproxPercentileCont { } } -fn get_lit_value(expr: &Expr) -> datafusion_common::Result { - let empty_schema = Arc::new(Schema::empty()); - let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); - let dfschema = DFSchema::empty(); - let expr = - limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, &dfschema)?; - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), +fn get_lit_value( + physical_expr: &Arc, +) -> datafusion_common::Result { + match physical_expr.evaluate(&RecordBatch::new_empty(Arc::new(Schema::empty())))? { ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + ColumnarValue::Array(_) => internal_err!( + "The expr {:?} can't be evaluated to scalar value", + physical_expr + ), } } -fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result { - let lit = get_lit_value(expr)?; +fn validate_input_percentile_expr( + physical_expr: &Arc, +) -> datafusion_common::Result { + let lit = get_lit_value(physical_expr)?; let percentile = match &lit { ScalarValue::Float32(Some(q)) => *q as f64, ScalarValue::Float64(Some(q)) => *q, @@ -170,8 +167,10 @@ fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result Ok(percentile) } -fn validate_input_max_size_expr(expr: &Expr) -> datafusion_common::Result { - let lit = get_lit_value(expr)?; +fn validate_input_max_size_expr( + physical_expr: &Arc, +) -> datafusion_common::Result { + let lit = get_lit_value(physical_expr)?; let max_size = match &lit { ScalarValue::UInt8(Some(q)) => *q as usize, ScalarValue::UInt16(Some(q)) => *q as usize, diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 0dbea1fb1ff7..0c62e996763a 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::sync::Arc; use arrow::{ array::ArrayRef, @@ -131,8 +132,8 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { let sub_args = AccumulatorArgs { input_exprs: &[ - acc_args.input_exprs[0].clone(), - acc_args.input_exprs[2].clone(), + Arc::clone(&acc_args.input_exprs[0]), + Arc::clone(&acc_args.input_exprs[2]), ], ..acc_args }; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index bb25de113525..abb344eaf693 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -117,16 +117,15 @@ impl AggregateUDFImpl for ArrayAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let data_type = + acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + if acc_args.is_distinct { - return Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &acc_args.input_types[0], - )?)); + return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); } if acc_args.sort_exprs.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new( - &acc_args.input_types[0], - )?)); + return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); } let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( @@ -140,7 +139,7 @@ impl AggregateUDFImpl for ArrayAgg { .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( - &acc_args.input_types[0], + &data_type, &ordering_dtypes, ordering_req, acc_args.is_reversed, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 228bce1979a3..f9d10426df0b 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -93,7 +93,10 @@ impl AggregateUDFImpl for Avg { } use DataType::*; // instantiate specialized accumulator based for the type - match (&acc_args.input_types[0], acc_args.data_type) { + let input_type = + acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; + + match (&input_type, acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -120,7 +123,7 @@ impl AggregateUDFImpl for Avg { })), _ => exec_err!( "AvgAccumulator for ({} --> {})", - &acc_args.input_types[0], + &input_type, acc_args.data_type ), } @@ -154,10 +157,12 @@ impl AggregateUDFImpl for Avg { ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - match (&args.input_types[0], args.data_type) { + let sum_data_type = &args.input_exprs[0].data_type(args.dfschema.as_arrow())?; + + match (sum_data_type, args.data_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], + sum_data_type, args.data_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) @@ -176,7 +181,7 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], + sum_data_type, args.data_type, avg_fn, ))) @@ -197,7 +202,7 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], + sum_data_type, args.data_type, avg_fn, ))) @@ -205,7 +210,7 @@ impl AggregateUDFImpl for Avg { _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - &args.input_types[0], + sum_data_type, args.data_type ), } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index e2d59003fca1..aacff28baeea 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -148,7 +148,8 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = &acc_args.input_types[0]; + let data_type = + &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index febf1fcd2fef..a0a1dbeb4d3c 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = &acc_args.input_types[0]; + let dt = &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 7c4b9a7f06c6..6362bdcc9287 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -30,10 +30,11 @@ use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValu use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Expr, ReversedUDAF, Signature, Volatility, + Accumulator, AggregateUDFImpl, ReversedUDAF, Signature, Volatility, }; use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; use datafusion_physical_expr_common::aggregate::utils::ordering_fields; +use datafusion_physical_expr_common::expressions::Literal; use datafusion_physical_expr_common::sort_expr::{ limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, PhysicalSortExpr, @@ -87,20 +88,26 @@ impl AggregateUDFImpl for NthValueAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let n = match acc_args.input_exprs[1] { - Expr::Literal(ScalarValue::Int64(Some(value))) => { - if acc_args.is_reversed { - Ok(-value) - } else { - Ok(value) + let Some(n) = acc_args.input_exprs[1] + .as_any() + .downcast_ref::() + .and_then(|lit| match lit.value() { + ScalarValue::Int64(Some(n)) => { + if acc_args.is_reversed { + Some(-n) + } else { + Some(*n) + } } - } - _ => not_impl_err!( + _ => None, + }) + else { + return not_impl_err!( "{} not supported for n: {}", self.name(), &acc_args.input_exprs[1] - ), - }?; + ); + }; let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( acc_args.sort_exprs, @@ -114,7 +121,7 @@ impl AggregateUDFImpl for NthValueAgg { NthValueAccumulator::try_new( n, - &acc_args.input_types[0], + &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?, &ordering_dtypes, ordering_req, ) diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 1d2257d90133..caa27d059729 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -274,9 +274,9 @@ mod tests { use arrow::{array::*, datatypes::*}; use datafusion_common::DFSchema; + use datafusion_expr::expressions::column::{col, Column}; use datafusion_expr::AggregateUDF; use datafusion_physical_expr_common::aggregate::utils::get_accum_scalar_values_as_arrays; - use datafusion_physical_expr_common::expressions::column::col; use super::*; @@ -334,8 +334,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_types: &[DataType::Float64], - input_exprs: &[datafusion_expr::col("a")], + input_exprs: &[Arc::new(Column::new("a", 0))], }; let args2 = AccumulatorArgs { @@ -346,8 +345,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_types: &[DataType::Float64], - input_exprs: &[datafusion_expr::col("a")], + input_exprs: &[Arc::new(Column::new("a", 0))], }; let mut accum1 = agg1.accumulator(args1)?; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 371cc8fb9739..9d16616e1c9a 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -24,8 +24,9 @@ use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, }; +use datafusion_physical_expr_common::expressions::Literal; use std::any::Any; make_udaf_expr_and_func!( @@ -82,21 +83,25 @@ impl AggregateUDFImpl for StringAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - match &acc_args.input_exprs[1] { - Expr::Literal(ScalarValue::Utf8(Some(delimiter))) - | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { - Ok(Box::new(StringAggAccumulator::new(delimiter))) - } - Expr::Literal(ScalarValue::Utf8(None)) - | Expr::Literal(ScalarValue::LargeUtf8(None)) - | Expr::Literal(ScalarValue::Null) => { - Ok(Box::new(StringAggAccumulator::new(""))) - } - _ => not_impl_err!( + let Some(delimiter) = acc_args.input_exprs[1] + .as_any() + .downcast_ref::() + .and_then(|lit| match lit.value() { + ScalarValue::Utf8(Some(s)) => Some(s.as_str()), + ScalarValue::LargeUtf8(Some(s)) => Some(s.as_str()), + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Null => Some(""), + _ => None, + }) + else { + return not_impl_err!( "StringAgg not supported for delimiter {}", &acc_args.input_exprs[1] - ), - } + ); + }; + + Ok(Box::new(StringAggAccumulator::new(delimiter))) } } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 0707301b2557..acf1d3c4aef6 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -18,9 +18,9 @@ use std::fmt::Debug; use std::{any::Any, sync::Arc}; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, ToDFSchema}; use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; @@ -30,9 +30,9 @@ use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, }; -use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; +use datafusion_expr::physical_expr::PhysicalExpr; use self::utils::down_cast_any_ref; @@ -76,7 +76,7 @@ pub fn create_aggregate_expr( builder = builder.sort_exprs(sort_exprs.to_vec()); builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); - builder = builder.schema(Arc::new(schema.clone())); + builder = builder.dfschema(Arc::new(schema.clone()).to_dfschema()?); builder = builder.name(name); if ignore_nulls { @@ -109,8 +109,6 @@ pub fn create_aggregate_expr_with_dfschema( builder = builder.order_by(ordering_req.to_vec()); builder = builder.logical_exprs(input_exprs.to_vec()); builder = builder.dfschema(dfschema.clone()); - let schema: Schema = dfschema.into(); - builder = builder.schema(Arc::new(schema)); builder = builder.name(name); if ignore_nulls { @@ -138,8 +136,6 @@ pub struct AggregateExprBuilder { /// Logical expressions of the aggregate function, it will be deprecated in logical_args: Vec, name: String, - /// Arrow Schema for the aggregate function - schema: SchemaRef, /// Datafusion Schema for the aggregate function dfschema: DFSchema, /// The logical order by expressions, it will be deprecated in @@ -161,7 +157,6 @@ impl AggregateExprBuilder { args, logical_args: vec![], name: String::new(), - schema: Arc::new(Schema::empty()), dfschema: DFSchema::empty(), sort_exprs: vec![], ordering_req: vec![], @@ -177,7 +172,6 @@ impl AggregateExprBuilder { args, logical_args, name, - schema, dfschema, sort_exprs, ordering_req, @@ -195,7 +189,7 @@ impl AggregateExprBuilder { if !ordering_req.is_empty() { let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(&schema)) + .map(|e| e.expr.data_type(dfschema.as_arrow())) .collect::>>()?; ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); @@ -203,7 +197,7 @@ impl AggregateExprBuilder { let input_exprs_types = args .iter() - .map(|arg| arg.data_type(&schema)) + .map(|arg| arg.data_type(dfschema.as_arrow())) .collect::>>()?; check_arg_count( @@ -236,11 +230,6 @@ impl AggregateExprBuilder { self } - pub fn schema(mut self, schema: SchemaRef) -> Self { - self.schema = schema; - self - } - pub fn dfschema(mut self, dfschema: DFSchema) -> Self { self.dfschema = dfschema; self @@ -524,8 +513,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; @@ -540,8 +528,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; @@ -611,8 +598,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; @@ -626,8 +612,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, + input_exprs: &self.args, name: &self.name, is_reversed: self.is_reversed, }; diff --git a/datafusion/physical-expr-common/src/expressions/cast.rs b/datafusion/physical-expr-common/src/expressions/cast.rs index dd6131ad65c3..2b0058bb338f 100644 --- a/datafusion/physical-expr-common/src/expressions/cast.rs +++ b/datafusion/physical-expr-common/src/expressions/cast.rs @@ -20,7 +20,7 @@ use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use datafusion_expr::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, DataType::*, Schema}; @@ -235,7 +235,7 @@ pub fn cast( mod tests { use super::*; - use crate::expressions::column::col; + use datafusion_expr::expressions::column::col; use arrow::{ array::{ diff --git a/datafusion/physical-expr-common/src/expressions/literal.rs b/datafusion/physical-expr-common/src/expressions/literal.rs index b3cff1ef69ba..1be46d13f5fb 100644 --- a/datafusion/physical-expr-common/src/expressions/literal.rs +++ b/datafusion/physical-expr-common/src/expressions/literal.rs @@ -21,14 +21,13 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; - use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::physical_expr::{down_cast_any_ref, PhysicalExpr}; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ColumnarValue, Expr}; diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index dd534cc07d20..b53bdc829440 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -16,7 +16,6 @@ // under the License. mod cast; -pub mod column; pub mod literal; pub use cast::{cast, cast_with_options, CastExpr}; diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f03eedd4cf65..9a694b3478c4 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -20,7 +20,5 @@ pub mod binary_map; pub mod binary_view_map; pub mod datum; pub mod expressions; -pub mod physical_expr; pub mod sort_expr; -pub mod tree_node; pub mod utils; diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 2b506b74216f..9f9beb3190e6 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -21,13 +21,12 @@ use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::PhysicalExpr; use crate::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; - use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use datafusion_common::{exec_err, DFSchema, Result}; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::{ColumnarValue, Expr}; /// Represents Sort operation for a column in a RecordBatch diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs deleted file mode 100644 index d9892ce55509..000000000000 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ /dev/null @@ -1,105 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module provides common traits for visiting or rewriting tree nodes easily. - -use std::fmt::{self, Display, Formatter}; -use std::sync::Arc; - -use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; - -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; -use datafusion_common::Result; - -impl DynTreeNode for dyn PhysicalExpr { - fn arc_children(&self) -> Vec<&Arc> { - self.children() - } - - fn with_new_arc_children( - &self, - arc_self: Arc, - new_children: Vec>, - ) -> Result> { - with_new_children_if_necessary(arc_self, new_children) - } -} - -/// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are -/// two ways to access child plans—directly from the plan and through child nodes—it's -/// recommended to perform mutable operations via [`Self::update_expr_from_children`]. -#[derive(Debug)] -pub struct ExprContext { - /// The physical expression associated with this context. - pub expr: Arc, - /// Custom data payload of the node. - pub data: T, - /// Child contexts of this node. - pub children: Vec, -} - -impl ExprContext { - pub fn new(expr: Arc, data: T, children: Vec) -> Self { - Self { - expr, - data, - children, - } - } - - pub fn update_expr_from_children(mut self) -> Result { - let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); - self.expr = with_new_children_if_necessary(self.expr, children_expr)?; - Ok(self) - } -} - -impl ExprContext { - pub fn new_default(plan: Arc) -> Self { - let children = plan - .children() - .into_iter() - .cloned() - .map(Self::new_default) - .collect(); - Self::new(plan, Default::default(), children) - } -} - -impl Display for ExprContext { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "expr: {:?}", self.expr)?; - write!(f, "data:{}", self.data)?; - write!(f, "") - } -} - -impl ConcreteTreeNode for ExprContext { - fn children(&self) -> &[Self] { - &self.children - } - - fn take_children(mut self) -> (Self, Vec) { - let children = std::mem::take(&mut self.children); - (self, children) - } - - fn with_new_children(mut self, children: Vec) -> Result { - self.children = children; - self.update_expr_from_children() - } -} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 0978a906a5dc..f4a1616c9131 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,86 +17,14 @@ use std::sync::Arc; -use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; -use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; - -use datafusion_common::{exec_err, DFSchema, Result}; -use datafusion_expr::expr::Alias; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::Expr; - -use crate::expressions::column::Column; use crate::expressions::literal::Literal; use crate::expressions::CastExpr; -use crate::physical_expr::PhysicalExpr; use crate::sort_expr::PhysicalSortExpr; -use crate::tree_node::ExprContext; - -/// Represents a [`PhysicalExpr`] node with associated properties (order and -/// range) in a context where properties are tracked. -pub type ExprPropertiesNode = ExprContext; - -impl ExprPropertiesNode { - /// Constructs a new `ExprPropertiesNode` with unknown properties for a - /// given physical expression. This node initializes with default properties - /// and recursively applies this to all child expressions. - pub fn new_unknown(expr: Arc) -> Self { - let children = expr - .children() - .into_iter() - .cloned() - .map(Self::new_unknown) - .collect(); - Self { - expr, - data: ExprProperties::new_unknown(), - children, - } - } -} - -/// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` -/// are taken, when the mask evaluates `false` values null values are filled. -/// -/// # Arguments -/// * `mask` - Boolean values used to determine where to put the `truthy` values -/// * `truthy` - All values of this array are to scatter according to `mask` into final result. -pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { - let truthy = truthy.to_data(); - - // update the mask so that any null values become false - // (SlicesIterator doesn't respect nulls) - let mask = and_kleene(mask, &is_not_null(mask)?)?; - - let mut mutable = MutableArrayData::new(vec![&truthy], true, mask.len()); - - // the SlicesIterator slices only the true values. So the gaps left by this iterator we need to - // fill with falsy values - - // keep track of how much is filled - let mut filled = 0; - // keep track of current position we have in truthy array - let mut true_pos = 0; - - SlicesIterator::new(&mask).for_each(|(start, end)| { - // the gap needs to be filled with nulls - if start > filled { - mutable.extend_nulls(start - filled); - } - // fill with truthy values - let len = end - start; - mutable.extend(0, true_pos, true_pos + len); - true_pos += len; - filled = end; - }); - // the remaining part is falsy - if filled < mask.len() { - mutable.extend_nulls(mask.len() - filled); - } - - let data = mutable.freeze(); - Ok(make_array(data)) -} +use datafusion_common::{exec_err, DFSchema, Result}; +use datafusion_expr::expr::Alias; +use datafusion_expr::expressions::column::Column; +use datafusion_expr::physical_expr::PhysicalExpr; +use datafusion_expr::Expr; /// Reverses the ORDER BY expression, which is useful during equivalent window /// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into @@ -136,80 +64,3 @@ pub fn limited_convert_logical_expr_to_physical_expr_with_dfschema( ), } } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::array::Int32Array; - - use datafusion_common::cast::{as_boolean_array, as_int32_array}; - - use super::*; - - #[test] - fn scatter_int() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = - Int32Array::from_iter(vec![Some(1), Some(10), None, None, Some(11)]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_int_end_with_false() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); - let mask = BooleanArray::from(vec![true, false, true, false, false, false]); - - // output should be same length as mask - let expected = - Int32Array::from_iter(vec![Some(1), None, Some(10), None, None, None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_with_null_mask() -> Result<()> { - let truthy = Arc::new(Int32Array::from(vec![1, 10, 11])); - let mask: BooleanArray = vec![Some(false), None, Some(true), Some(true), None] - .into_iter() - .collect(); - - // output should treat nulls as though they are false - let expected = Int32Array::from_iter(vec![None, None, Some(1), Some(10), None]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_int32_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } - - #[test] - fn scatter_boolean() -> Result<()> { - let truthy = Arc::new(BooleanArray::from(vec![false, false, false, true])); - let mask = BooleanArray::from(vec![true, true, false, false, true]); - - // the output array is expected to be the same length as the mask array - let expected = BooleanArray::from_iter(vec![ - Some(false), - Some(false), - None, - None, - Some(false), - ]); - let result = scatter(&mask, truthy.as_ref())?; - let result = as_boolean_array(&result)?; - - assert_eq!(&expected, result); - Ok(()) - } -} diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 862edd9c1fac..74c52a9294cd 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -21,11 +21,11 @@ use arrow_array::builder::{Int32Builder, StringBuilder}; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; +use datafusion_expr::expressions::column::Column; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; -use datafusion_physical_expr_common::expressions::column::Column; use datafusion_physical_expr_common::expressions::Literal; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; fn make_col(name: &str, index: usize) -> Arc { diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs index 3dad8e9b456a..3cc7c4f0c37b 100644 --- a/datafusion/physical-expr/benches/is_null.rs +++ b/datafusion/physical-expr/benches/is_null.rs @@ -20,9 +20,9 @@ use arrow::record_batch::RecordBatch; use arrow_array::builder::Int32Builder; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::expressions::column::Column; +use datafusion_expr::physical_expr::PhysicalExpr; use datafusion_physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; -use datafusion_physical_expr_common::expressions::column::Column; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a6e9fba28167..fcf3ac81c5c2 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -34,12 +34,12 @@ use crate::{ use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, JoinSide, JoinType, Result}; +use datafusion_expr::expressions::column::Column; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::physical_expr::with_new_schema; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_expr::utils::ExprPropertiesNode; use datafusion_physical_expr_common::expressions::CastExpr; -use datafusion_physical_expr_common::physical_expr::with_new_schema; -use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c34dcdfb7598..b7d2202b9f31 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -684,8 +684,8 @@ mod tests { use crate::expressions::{col, lit, try_cast, Literal}; use datafusion_common::plan_datafusion_err; + use datafusion_expr::expressions::column::Column; use datafusion_expr::type_coercion::binary::get_input_types; - use datafusion_physical_expr_common::expressions::column::Column; /// Performs a binary operation, applying any type coercion necessary fn binary_op( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b428d562bd1b..b433f6c5e3d2 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -32,7 +32,7 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_expr::expressions::column::Column; use datafusion_physical_expr_common::expressions::Literal; use itertools::Itertools; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7cbe4e796844..951bef4521e3 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,8 +47,8 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; +pub use datafusion_expr::expressions::column::{col, Column}; pub use datafusion_expr::utils::format_state_name; -pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; pub use in_list::{in_list, InListExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 2e78119eba46..0fa01fc689d6 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -55,7 +55,7 @@ pub use physical_expr::{ PhysicalExprRef, }; -pub use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub use datafusion_expr::physical_expr::PhysicalExpr; pub use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, @@ -69,5 +69,5 @@ pub use utils::split_conjunction; // For backwards compatibility pub mod tree_node { - pub use datafusion_physical_expr_common::tree_node::ExprContext; + pub use datafusion_expr::tree_node::ExprContext; } diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c60a772b9ce2..942b92abc43c 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,10 +17,10 @@ use std::sync::Arc; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_expr::physical_expr::PhysicalExpr; use itertools::izip; -pub use datafusion_physical_expr_common::physical_expr::down_cast_any_ref; +pub use datafusion_expr::physical_expr::down_cast_any_ref; /// Shared [`PhysicalExpr`]. pub type PhysicalExprRef = Arc; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d1152038eb2a..fd2510fbf90d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1195,7 +1195,7 @@ mod tests { use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DFSchema, DFSchemaRef, - DataFusionError, ScalarValue, + DataFusionError, ScalarValue, ToDFSchema, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; @@ -1352,7 +1352,7 @@ mod tests { }; let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) - .schema(Arc::clone(&input_schema)) + .dfschema(Arc::clone(&input_schema).to_dfschema()?) .name("COUNT(1)") .logical_exprs(vec![datafusion_expr::lit(1i8)]) .build()?]; @@ -1497,7 +1497,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) + .dfschema(Arc::clone(&input_schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -1793,7 +1793,7 @@ mod tests { // Median(a) fn test_median_agg_expr(schema: SchemaRef) -> Result> { AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) - .schema(schema) + .dfschema(schema.to_dfschema()?) .name("MEDIAN(a)") .build() } @@ -1824,7 +1824,7 @@ mod tests { let aggregates_v2: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) - .schema(Arc::clone(&input_schema)) + .dfschema(Arc::clone(&input_schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -1884,7 +1884,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(a)") .build()?, ]; @@ -1924,7 +1924,7 @@ mod tests { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -2353,7 +2353,7 @@ mod tests { let aggregates: Vec> = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("1") .build()?]; diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9321fdb2cadf..41f0878595b7 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -586,8 +586,8 @@ mod tests { use arrow_schema::{DataType, SortOptions}; use datafusion_common::ScalarValue; + use datafusion_expr::expressions::column::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_expr_common::expressions::column::col; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index a462430ca381..dbfea253959e 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue, ToDFSchema}; use datafusion_expr::{col, Expr, SortExpr}; use datafusion_expr::{ BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, @@ -145,7 +145,7 @@ pub fn create_window_expr( .collect::>(); let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) + .dfschema(Arc::new(input_schema.clone()).to_dfschema()?) .name(name) .order_by(order_by.to_vec()) .sort_exprs(sort_exprs) @@ -413,6 +413,7 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } } +#[allow(clippy::needless_borrow)] pub(crate) fn calc_requirements< T: Borrow>, S: Borrow, @@ -430,7 +431,7 @@ pub(crate) fn calc_requirements< let PhysicalSortExpr { expr, options } = element.borrow(); if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { sort_reqs.push(PhysicalSortRequirement::new( - Arc::clone(expr), + Arc::clone(&expr), Some(*options), )); } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1f433ff01d12..3c0d6664da17 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -61,7 +61,9 @@ use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, not_impl_err, DataFusionError, Result, ToDFSchema, +}; use datafusion_expr::{AggregateUDF, ScalarUDF}; use crate::common::{byte_to_string, str_to_byte}; @@ -509,7 +511,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. // TODO: `order by` is not supported for UDAF yet - AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() + AggregateExprBuilder::new(agg_udf, input_phy_expr).dfschema(Arc::clone(&physical_schema).to_dfschema()?).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 3ddc122e3de2..caab8f0a77f7 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -79,7 +79,9 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, not_impl_err, DataFusionError, Result, ToDFSchema, +}; use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, @@ -295,7 +297,7 @@ fn roundtrip_window() -> Result<()> { avg_udaf(), vec![cast(col("b", &schema)?, &schema, DataType::Float64)?], ) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("avg(b)") .build()?, &[], @@ -311,7 +313,7 @@ fn roundtrip_window() -> Result<()> { let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") .build()?; @@ -345,17 +347,17 @@ fn rountrip_aggregate() -> Result<()> { vec![(col("a", &schema)?, "unused".to_string())]; let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(b)") .build()?; let nth_expr = AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("NTH_VALUE(b, 1)") .build()?; let str_agg_expr = AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("NTH_VALUE(b, 1)") .build()?; @@ -395,7 +397,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { let aggregates: Vec> = vec![ AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("AVG(b)") .build()?, ]; @@ -462,7 +464,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![ AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("example_agg") .build()?, ]; @@ -957,7 +959,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { vec![Arc::new(Literal::new(ScalarValue::from(42)))]; let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("aggregate_udf") .build()?; @@ -982,7 +984,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { )?); let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone()) - .schema(Arc::clone(&schema)) + .dfschema(Arc::clone(&schema).to_dfschema()?) .name("aggregate_udf") .distinct() .ignore_nulls() diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 6b81e33dfc37..e8698253edb5 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,7 +27,6 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; -#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?;