diff --git a/Cargo.toml b/Cargo.toml index 3431c4673e0c..02b1f1ccd92a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,13 +23,16 @@ members = [ "datafusion/catalog", "datafusion/core", "datafusion/expr", + "datafusion/expr-common", "datafusion/execution", - "datafusion/functions-aggregate", "datafusion/functions", + "datafusion/functions-aggregate", + "datafusion/functions-aggregate-common", "datafusion/functions-nested", "datafusion/optimizer", - "datafusion/physical-expr-common", "datafusion/physical-expr", + "datafusion/physical-expr-common", + "datafusion/physical-expr-functions-aggregate", "datafusion/physical-optimizer", "datafusion/physical-plan", "datafusion/proto", @@ -94,12 +97,15 @@ datafusion-common = { path = "datafusion/common", version = "41.0.0", default-fe datafusion-common-runtime = { path = "datafusion/common-runtime", version = "41.0.0" } datafusion-execution = { path = "datafusion/execution", version = "41.0.0" } datafusion-expr = { path = "datafusion/expr", version = "41.0.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "41.0.0" } datafusion-functions = { path = "datafusion/functions", version = "41.0.0" } datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "41.0.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "41.0.0" } datafusion-functions-nested = { path = "datafusion/functions-nested", version = "41.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "41.0.0", default-features = false } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "41.0.0", default-features = false } datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "41.0.0", default-features = false } +datafusion-physical-expr-functions-aggregate = { path = "datafusion/physical-expr-functions-aggregate", version = "41.0.0" } datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "41.0.0" } datafusion-physical-plan = { path = "datafusion/physical-plan", version = "41.0.0" } datafusion-proto = { path = "datafusion/proto", version = "41.0.0" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2eb93da7c020..134cde8976d6 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1153,6 +1153,7 @@ dependencies = [ "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", + "datafusion-physical-expr-functions-aggregate", "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", @@ -1278,6 +1279,9 @@ dependencies = [ "arrow-buffer", "chrono", "datafusion-common", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", + "datafusion-physical-expr-common", "paste", "serde_json", "sqlparser", @@ -1285,6 +1289,15 @@ dependencies = [ "strum_macros 0.26.4", ] +[[package]] +name = "datafusion-expr-common" +version = "41.0.0" +dependencies = [ + "arrow", + "datafusion-common", + "paste", +] + [[package]] name = "datafusion-functions" version = "41.0.0" @@ -1320,12 +1333,26 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate-common", + "datafusion-physical-expr", "datafusion-physical-expr-common", "log", "paste", "sqlparser", ] +[[package]] +name = "datafusion-functions-aggregate-common" +version = "41.0.0" +dependencies = [ + "ahash", + "arrow", + "datafusion-common", + "datafusion-expr-common", + "datafusion-physical-expr-common", + "rand", +] + [[package]] name = "datafusion-functions-nested" version = "41.0.0" @@ -1380,6 +1407,8 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", @@ -1399,11 +1428,25 @@ dependencies = [ "ahash", "arrow", "datafusion-common", - "datafusion-expr", + "datafusion-expr-common", "hashbrown 0.14.5", "rand", ] +[[package]] +name = "datafusion-physical-expr-functions-aggregate" +version = "41.0.0" +dependencies = [ + "ahash", + "arrow", + "datafusion-common", + "datafusion-expr", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", + "datafusion-physical-expr-common", + "rand", +] + [[package]] name = "datafusion-physical-optimizer" version = "41.0.0" @@ -1431,8 +1474,10 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", + "datafusion-physical-expr-functions-aggregate", "futures", "half", "hashbrown 0.14.5", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0714c3e94a85..e678c93ede8b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -109,6 +109,7 @@ datafusion-functions-nested = { workspace = true, optional = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +datafusion-physical-expr-functions-aggregate = { workspace = true } datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index d4b82f288bdd..6b3773e4f6d5 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -556,6 +556,11 @@ pub mod physical_expr_common { pub use datafusion_physical_expr_common::*; } +/// re-export of [`datafusion_physical_expr_functions_aggregate`] crate +pub mod physical_expr_functions_aggregate { + pub use datafusion_physical_expr_functions_aggregate::*; +} + /// re-export of [`datafusion_physical_expr`] crate pub mod physical_expr { pub use datafusion_physical_expr::*; diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 843efcc7b0d2..f65a4c837a60 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -177,7 +177,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; - use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; + use datafusion_physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected macro_rules! assert_optimized { diff --git a/datafusion/core/src/physical_optimizer/limit_pushdown.rs b/datafusion/core/src/physical_optimizer/limit_pushdown.rs index 4379a34a9426..d02737ff0959 100644 --- a/datafusion/core/src/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/limit_pushdown.rs @@ -258,9 +258,8 @@ mod tests { use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::BinaryExpr; + use datafusion_physical_expr::expressions::{col, lit}; use datafusion_physical_expr::Partitioning; - use datafusion_physical_expr_common::expressions::column::col; - use datafusion_physical_expr_common::expressions::lit; use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::empty::EmptyExec; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ab0765ac0deb..7eb468f56eeb 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -58,7 +58,7 @@ use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - displayable, udaf, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties, + displayable, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, }; @@ -73,7 +73,8 @@ use datafusion_common::{ }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ - self, physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + self, create_function_physical_name, physical_name, AggregateFunction, Alias, + GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -83,6 +84,7 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -1559,6 +1561,17 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { + let name = if let Some(name) = name { + name + } else { + create_function_physical_name( + func.name(), + *distinct, + args, + order_by.as_ref(), + )? + }; + let physical_args = create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { @@ -1575,7 +1588,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = { - let sort_exprs = order_by.clone().unwrap_or(vec![]); let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( exprs, @@ -1588,18 +1600,15 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ordering_reqs: Vec = physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = udaf::create_aggregate_expr_with_dfschema( - func, - &physical_args, - args, - &sort_exprs, - &ordering_reqs, - logical_input_schema, - name, - ignore_nulls, - *distinct, - false, - )?; + let schema: Schema = logical_input_schema.clone().into(); + let agg_expr = + AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) + .order_by(ordering_reqs.to_vec()) + .schema(Arc::new(schema)) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .with_distinct(*distinct) + .build()?; (agg_expr, filter, physical_sort_exprs) }; diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 937344ef5e4e..ca8376fdec0a 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -54,7 +54,7 @@ use datafusion_physical_expr::{ use async_trait::async_trait; use datafusion_catalog::Session; -use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; +use datafusion_physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use futures::Stream; use tempfile::TempDir; // backwards compatibility diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 4cecb0b69335..138e5bda7f39 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -25,6 +25,7 @@ use arrow::util::pretty::pretty_format_batches; use arrow_array::types::Int64Type; use datafusion::common::Result; use datafusion::datasource::MemTable; +use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -35,7 +36,6 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor} use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; -use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml new file mode 100644 index 000000000000..7e477efc4ebc --- /dev/null +++ b/datafusion/expr-common/Cargo.toml @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-expr-common" +description = "Logical plan and expression representation for DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_expr_common" +path = "src/lib.rs" + +[features] + +[dependencies] +arrow = { workspace = true } +datafusion-common = { workspace = true } +paste = "^1.0" diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs similarity index 100% rename from datafusion/expr/src/accumulator.rs rename to datafusion/expr-common/src/accumulator.rs diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs similarity index 100% rename from datafusion/expr/src/columnar_value.rs rename to datafusion/expr-common/src/columnar_value.rs diff --git a/datafusion/expr/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs similarity index 97% rename from datafusion/expr/src/groups_accumulator.rs rename to datafusion/expr-common/src/groups_accumulator.rs index 886bd8443e4d..e66b27d073d1 100644 --- a/datafusion/expr/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -17,7 +17,7 @@ //! Vectorized [`GroupsAccumulator`] -use arrow_array::{ArrayRef, BooleanArray}; +use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. @@ -75,7 +75,7 @@ impl EmitTo { /// expected that each `GroupAccumulator` will use something like `Vec<..>` /// to store the group states. /// -/// [`Accumulator`]: crate::Accumulator +/// [`Accumulator`]: crate::accumulator::Accumulator /// [Aggregating Millions of Groups Fast blog]: https://arrow.apache.org/blog/2023/08/05/datafusion_fast_grouping/ pub trait GroupsAccumulator: Send { /// Updates the accumulator's state from its arguments, encoded as @@ -140,7 +140,7 @@ pub trait GroupsAccumulator: Send { /// See [`Self::evaluate`] for details on the required output /// order and `emit_to`. /// - /// [`Accumulator::state`]: crate::Accumulator::state + /// [`Accumulator::state`]: crate::accumulator::Accumulator::state fn state(&mut self, emit_to: EmitTo) -> Result>; /// Merges intermediate state (the output from [`Self::state`]) @@ -197,7 +197,7 @@ pub trait GroupsAccumulator: Send { /// state directly to the next aggregation phase with minimal processing /// using this method. /// - /// [`Accumulator::state`]: crate::Accumulator::state + /// [`Accumulator::state`]: crate::accumulator::Accumulator::state fn convert_to_state( &self, _values: &[ArrayRef], diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs similarity index 99% rename from datafusion/expr/src/interval_arithmetic.rs rename to datafusion/expr-common/src/interval_arithmetic.rs index 553cdd8c8709..e3ff412e785b 100644 --- a/datafusion/expr/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -17,16 +17,16 @@ //! Interval arithmetic library +use crate::operator::Operator; use crate::type_coercion::binary::get_result_type; -use crate::Operator; -use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use std::borrow::Borrow; use std::fmt::{self, Display, Formatter}; use std::ops::{AddAssign, SubAssign}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; -use arrow::datatypes::{IntervalUnit, TimeUnit}; +use arrow::datatypes::{ + DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, TimeUnit, +}; use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; use datafusion_common::{internal_err, Result, ScalarValue}; @@ -120,12 +120,12 @@ macro_rules! value_transition { IntervalYearMonth(None) } IntervalDayTime(Some(value)) - if value == arrow_buffer::IntervalDayTime::$bound => + if value == arrow::datatypes::IntervalDayTime::$bound => { IntervalDayTime(None) } IntervalMonthDayNano(Some(value)) - if value == arrow_buffer::IntervalMonthDayNano::$bound => + if value == arrow::datatypes::IntervalMonthDayNano::$bound => { IntervalMonthDayNano(None) } @@ -1135,12 +1135,12 @@ fn next_value_helper(value: ScalarValue) -> ScalarValue { } IntervalDayTime(Some(val)) => IntervalDayTime(Some(increment_decrement::< INC, - arrow_buffer::IntervalDayTime, + arrow::datatypes::IntervalDayTime, >(val))), IntervalMonthDayNano(Some(val)) => { IntervalMonthDayNano(Some(increment_decrement::< INC, - arrow_buffer::IntervalMonthDayNano, + arrow::datatypes::IntervalMonthDayNano, >(val))) } _ => value, // Unbounded values return without change. @@ -1177,7 +1177,7 @@ fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { /// Example usage: /// ``` /// use datafusion_common::DataFusionError; -/// use datafusion_expr::interval_arithmetic::{satisfy_greater, Interval}; +/// use datafusion_expr_common::interval_arithmetic::{satisfy_greater, Interval}; /// /// let left = Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?; /// let right = Interval::make(Some(500.0_f32), Some(2000.0_f32))?; @@ -1552,8 +1552,8 @@ fn cast_scalar_value( /// ``` /// use arrow::datatypes::DataType; /// use datafusion_common::ScalarValue; -/// use datafusion_expr::interval_arithmetic::Interval; -/// use datafusion_expr::interval_arithmetic::NullableInterval; +/// use datafusion_expr_common::interval_arithmetic::Interval; +/// use datafusion_expr_common::interval_arithmetic::NullableInterval; /// /// // [1, 2) U {NULL} /// let maybe_null = NullableInterval::MaybeNull { @@ -1674,9 +1674,9 @@ impl NullableInterval { /// /// ``` /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_expr::interval_arithmetic::Interval; - /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// use datafusion_expr_common::operator::Operator; + /// use datafusion_expr_common::interval_arithmetic::Interval; + /// use datafusion_expr_common::interval_arithmetic::NullableInterval; /// /// // 4 > 3 -> true /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); @@ -1798,8 +1798,8 @@ impl NullableInterval { /// /// ``` /// use datafusion_common::ScalarValue; - /// use datafusion_expr::interval_arithmetic::Interval; - /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// use datafusion_expr_common::interval_arithmetic::Interval; + /// use datafusion_expr_common::interval_arithmetic::NullableInterval; /// /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/expr-common/src/lib.rs similarity index 56% rename from datafusion/physical-expr/src/aggregate/mod.rs rename to datafusion/expr-common/src/lib.rs index b477a815bf80..179dd75ace85 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/expr-common/src/lib.rs @@ -15,14 +15,22 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod groups_accumulator; -pub(crate) mod stats; +//! Logical Expr types and traits for [DataFusion] +//! +//! This crate contains types and traits that are used by both Logical and Physical expressions. +//! They are kept in their own crate to avoid physical expressions depending on logical expressions. +//! +//! +//! [DataFusion]: -pub mod utils { - pub use datafusion_physical_expr_common::aggregate::utils::{ - adjust_output_array, down_cast_any_ref, get_accum_scalar_values_as_arrays, - get_sort_options, ordering_fields, DecimalAverager, Hashable, - }; -} +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] -pub use datafusion_physical_expr_common::aggregate::AggregateExpr; +pub mod accumulator; +pub mod columnar_value; +pub mod groups_accumulator; +pub mod interval_arithmetic; +pub mod operator; +pub mod signature; +pub mod sort_properties; +pub mod type_coercion; diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr-common/src/operator.rs similarity index 67% rename from datafusion/expr/src/operator.rs rename to datafusion/expr-common/src/operator.rs index 9bb8c48d6c71..e013b6fafa22 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -15,14 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Operator module contains foundational types that are used to represent operators in DataFusion. - -use crate::expr_fn::binary_expr; -use crate::Expr; -use crate::Like; use std::fmt; -use std::ops; -use std::ops::Not; /// Operators applied to expressions #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -287,202 +280,3 @@ impl fmt::Display for Operator { write!(f, "{display}") } } - -/// Support ` + ` fluent style -impl ops::Add for Expr { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - binary_expr(self, Operator::Plus, rhs) - } -} - -/// Support ` - ` fluent style -impl ops::Sub for Expr { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - binary_expr(self, Operator::Minus, rhs) - } -} - -/// Support ` * ` fluent style -impl ops::Mul for Expr { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - binary_expr(self, Operator::Multiply, rhs) - } -} - -/// Support ` / ` fluent style -impl ops::Div for Expr { - type Output = Self; - - fn div(self, rhs: Self) -> Self { - binary_expr(self, Operator::Divide, rhs) - } -} - -/// Support ` % ` fluent style -impl ops::Rem for Expr { - type Output = Self; - - fn rem(self, rhs: Self) -> Self { - binary_expr(self, Operator::Modulo, rhs) - } -} - -/// Support ` & ` fluent style -impl ops::BitAnd for Expr { - type Output = Self; - - fn bitand(self, rhs: Self) -> Self { - binary_expr(self, Operator::BitwiseAnd, rhs) - } -} - -/// Support ` | ` fluent style -impl ops::BitOr for Expr { - type Output = Self; - - fn bitor(self, rhs: Self) -> Self { - binary_expr(self, Operator::BitwiseOr, rhs) - } -} - -/// Support ` ^ ` fluent style -impl ops::BitXor for Expr { - type Output = Self; - - fn bitxor(self, rhs: Self) -> Self { - binary_expr(self, Operator::BitwiseXor, rhs) - } -} - -/// Support ` << ` fluent style -impl ops::Shl for Expr { - type Output = Self; - - fn shl(self, rhs: Self) -> Self::Output { - binary_expr(self, Operator::BitwiseShiftLeft, rhs) - } -} - -/// Support ` >> ` fluent style -impl ops::Shr for Expr { - type Output = Self; - - fn shr(self, rhs: Self) -> Self::Output { - binary_expr(self, Operator::BitwiseShiftRight, rhs) - } -} - -/// Support `- ` fluent style -impl ops::Neg for Expr { - type Output = Self; - - fn neg(self) -> Self::Output { - Expr::Negative(Box::new(self)) - } -} - -/// Support `NOT ` fluent style -impl Not for Expr { - type Output = Self; - - fn not(self) -> Self::Output { - match self { - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::Like(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::SimilarTo(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - _ => Expr::Not(Box::new(self)), - } - } -} - -#[cfg(test)] -mod tests { - use crate::lit; - - #[test] - fn test_operators() { - // Add - assert_eq!( - format!("{}", lit(1u32) + lit(2u32)), - "UInt32(1) + UInt32(2)" - ); - // Sub - assert_eq!( - format!("{}", lit(1u32) - lit(2u32)), - "UInt32(1) - UInt32(2)" - ); - // Mul - assert_eq!( - format!("{}", lit(1u32) * lit(2u32)), - "UInt32(1) * UInt32(2)" - ); - // Div - assert_eq!( - format!("{}", lit(1u32) / lit(2u32)), - "UInt32(1) / UInt32(2)" - ); - // Rem - assert_eq!( - format!("{}", lit(1u32) % lit(2u32)), - "UInt32(1) % UInt32(2)" - ); - // BitAnd - assert_eq!( - format!("{}", lit(1u32) & lit(2u32)), - "UInt32(1) & UInt32(2)" - ); - // BitOr - assert_eq!( - format!("{}", lit(1u32) | lit(2u32)), - "UInt32(1) | UInt32(2)" - ); - // BitXor - assert_eq!( - format!("{}", lit(1u32) ^ lit(2u32)), - "UInt32(1) BIT_XOR UInt32(2)" - ); - // Shl - assert_eq!( - format!("{}", lit(1u32) << lit(2u32)), - "UInt32(1) << UInt32(2)" - ); - // Shr - assert_eq!( - format!("{}", lit(1u32) >> lit(2u32)), - "UInt32(1) >> UInt32(2)" - ); - // Neg - assert_eq!(format!("{}", -lit(1u32)), "(- UInt32(1))"); - // Not - assert_eq!(format!("{}", !lit(1u32)), "NOT UInt32(1)"); - } -} diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr-common/src/signature.rs similarity index 97% rename from datafusion/expr/src/signature.rs rename to datafusion/expr-common/src/signature.rs index 577c663142a1..4dcfa423e371 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -75,7 +75,7 @@ pub enum Volatility { /// /// ``` /// # use arrow::datatypes::{DataType, TimeUnit}; -/// # use datafusion_expr::{TIMEZONE_WILDCARD, TypeSignature}; +/// # use datafusion_expr_common::signature::{TIMEZONE_WILDCARD, TypeSignature}; /// let type_signature = TypeSignature::Exact(vec![ /// // A nanosecond precision timestamp with ANY timezone /// // matches Timestamp(Nanosecond, Some("+0:00")) @@ -93,9 +93,7 @@ pub enum TypeSignature { Variadic(Vec), /// The acceptable signature and coercions rules to coerce arguments to this /// signature are special for this function. If this signature is specified, - /// DataFusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types. - /// - /// [`ScalarUDFImpl::coerce_types`]: crate::udf::ScalarUDFImpl::coerce_types + /// DataFusion will call `ScalarUDFImpl::coerce_types` to prepare argument types. UserDefined, /// One or more arguments with arbitrary types VariadicAny, @@ -176,7 +174,7 @@ impl std::fmt::Display for ArrayFunctionSignature { } impl TypeSignature { - pub(crate) fn to_string_repr(&self) -> Vec { + pub fn to_string_repr(&self) -> Vec { match self { TypeSignature::Variadic(types) => { vec![format!("{}, ..", Self::join_types(types, "/"))] @@ -213,10 +211,7 @@ impl TypeSignature { } /// Helper function to join types with specified delimiter. - pub(crate) fn join_types( - types: &[T], - delimiter: &str, - ) -> String { + pub fn join_types(types: &[T], delimiter: &str) -> String { types .iter() .map(|t| t.to_string()) diff --git a/datafusion/expr/src/sort_properties.rs b/datafusion/expr-common/src/sort_properties.rs similarity index 100% rename from datafusion/expr/src/sort_properties.rs rename to datafusion/expr-common/src/sort_properties.rs diff --git a/datafusion/physical-expr/src/aggregate/stats.rs b/datafusion/expr-common/src/type_coercion.rs similarity index 91% rename from datafusion/physical-expr/src/aggregate/stats.rs rename to datafusion/expr-common/src/type_coercion.rs index d9338f5a962f..e934c6eaf35b 100644 --- a/datafusion/physical-expr/src/aggregate/stats.rs +++ b/datafusion/expr-common/src/type_coercion.rs @@ -15,4 +15,5 @@ // specific language governing permissions and limitations // under the License. -pub use datafusion_physical_expr_common::aggregate::stats::StatsType; +pub mod aggregates; +pub mod binary; diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs similarity index 99% rename from datafusion/expr/src/type_coercion/aggregates.rs rename to datafusion/expr-common/src/type_coercion/aggregates.rs index e7e58bf84362..40ee596eee05 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::TypeSignature; +use crate::signature::TypeSignature; use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs similarity index 99% rename from datafusion/expr/src/type_coercion/binary.rs rename to datafusion/expr-common/src/type_coercion/binary.rs index 6de0118f6bae..05e365a0b988 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -20,7 +20,7 @@ use std::collections::HashSet; use std::sync::Arc; -use crate::Operator; +use crate::operator::Operator; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; @@ -569,7 +569,7 @@ fn string_temporal_coercion( } /// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric -pub(crate) fn binary_numeric_coercion( +pub fn binary_numeric_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 1b6878b6f49e..b5d34d9a3834 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -43,7 +43,10 @@ arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } chrono = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } paste = "^1.0" serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index d8be2b434732..cd7a0c8aa918 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -18,11 +18,15 @@ //! Function module contains typing and signature for built-in and user defined functions. use crate::ColumnarValue; -use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{DFSchema, Result}; +use crate::{Expr, PartitionEvaluator}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::sync::Arc; +pub use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs, +}; + #[derive(Debug, Clone, Copy)] pub enum Hint { /// Indicates the argument needs to be padded if it is scalar @@ -46,86 +50,6 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// [`AccumulatorArgs`] contains information about how an aggregate -/// function was called, including the types of its arguments and any optional -/// ordering expressions. -#[derive(Debug)] -pub struct AccumulatorArgs<'a> { - /// The return type of the aggregate function. - pub data_type: &'a DataType, - - /// The schema of the input arguments - pub schema: &'a Schema, - - /// The schema of the input arguments - pub dfschema: &'a DFSchema, - - /// Whether to ignore nulls. - /// - /// SQL allows the user to specify `IGNORE NULLS`, for example: - /// - /// ```sql - /// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; - /// ``` - pub ignore_nulls: bool, - - /// The expressions in the `ORDER BY` clause passed to this aggregator. - /// - /// SQL allows the user to specify the ordering of arguments to the - /// aggregate using an `ORDER BY`. For example: - /// - /// ```sql - /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; - /// ``` - /// - /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. - pub sort_exprs: &'a [Expr], - - /// Whether the aggregation is running in reverse order - pub is_reversed: bool, - - /// The name of the aggregate expression - pub name: &'a str, - - /// Whether the aggregate function is distinct. - /// - /// ```sql - /// SELECT COUNT(DISTINCT column1) FROM t; - /// ``` - pub is_distinct: bool, - - /// The input types of the aggregate function. - pub input_types: &'a [DataType], - - /// The logical expression of arguments the aggregate function takes. - pub input_exprs: &'a [Expr], -} - -/// [`StateFieldsArgs`] contains information about the fields that an -/// aggregate function's accumulator should have. Used for [`AggregateUDFImpl::state_fields`]. -/// -/// [`AggregateUDFImpl::state_fields`]: crate::udaf::AggregateUDFImpl::state_fields -pub struct StateFieldsArgs<'a> { - /// The name of the aggregate function. - pub name: &'a str, - - /// The input types of the aggregate function. - pub input_types: &'a [DataType], - - /// The return type of the aggregate function. - pub return_type: &'a DataType, - - /// The ordering fields of the aggregate function. - pub ordering_fields: &'a [Field], - - /// Whether the aggregate function is distinct. - pub is_distinct: bool, -} - -/// Factory that returns an accumulator for the given aggregate function. -pub type AccumulatorFactoryFunction = - Arc Result> + Send + Sync>; - /// Factory that creates a PartitionEvaluator for the given window /// function pub type PartitionEvaluatorFactory = diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index f5460918fa70..260065f69af9 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -27,13 +27,10 @@ //! //! The [expr_fn] module contains functions for creating expressions. -mod accumulator; mod built_in_window_function; -mod columnar_value; mod literal; -mod operator; +mod operation; mod partition_evaluator; -mod signature; mod table_source; mod udaf; mod udf; @@ -46,13 +43,20 @@ pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; pub mod function; -pub mod groups_accumulator; -pub mod interval_arithmetic; +pub mod groups_accumulator { + pub use datafusion_expr_common::groups_accumulator::*; +} + +pub mod interval_arithmetic { + pub use datafusion_expr_common::interval_arithmetic::*; +} pub mod logical_plan; pub mod planner; pub mod registry; pub mod simplify; -pub mod sort_properties; +pub mod sort_properties { + pub use datafusion_expr_common::sort_properties::*; +} pub mod test; pub mod tree_node; pub mod type_coercion; @@ -62,9 +66,15 @@ pub mod window_frame; pub mod window_function; pub mod window_state; -pub use accumulator::Accumulator; pub use built_in_window_function::BuiltInWindowFunction; -pub use columnar_value::ColumnarValue; +pub use datafusion_expr_common::accumulator::Accumulator; +pub use datafusion_expr_common::columnar_value::ColumnarValue; +pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +pub use datafusion_expr_common::operator::Operator; +pub use datafusion_expr_common::signature::{ + ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, +}; +pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GroupingSet, Like, Sort as SortExpr, TryCast, WindowFunctionDefinition, @@ -75,14 +85,9 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; -pub use groups_accumulator::{EmitTo, GroupsAccumulator}; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; -pub use operator::Operator; pub use partition_evaluator::PartitionEvaluator; -pub use signature::{ - ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, -}; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; diff --git a/datafusion/expr/src/operation.rs b/datafusion/expr/src/operation.rs new file mode 100644 index 000000000000..6b79a8248b29 --- /dev/null +++ b/datafusion/expr/src/operation.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains implementations of operations (unary, binary etc.) for DataFusion expressions. + +use crate::expr_fn::binary_expr; +use crate::{Expr, Like}; +use datafusion_expr_common::operator::Operator; +use std::ops::{self, Not}; + +/// Support ` + ` fluent style +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +/// Support ` - ` fluent style +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +/// Support ` * ` fluent style +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiply, rhs) + } +} + +/// Support ` / ` fluent style +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Divide, rhs) + } +} + +/// Support ` % ` fluent style +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} + +/// Support ` & ` fluent style +impl ops::BitAnd for Expr { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + binary_expr(self, Operator::BitwiseAnd, rhs) + } +} + +/// Support ` | ` fluent style +impl ops::BitOr for Expr { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + binary_expr(self, Operator::BitwiseOr, rhs) + } +} + +/// Support ` ^ ` fluent style +impl ops::BitXor for Expr { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self { + binary_expr(self, Operator::BitwiseXor, rhs) + } +} + +/// Support ` << ` fluent style +impl ops::Shl for Expr { + type Output = Self; + + fn shl(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::BitwiseShiftLeft, rhs) + } +} + +/// Support ` >> ` fluent style +impl ops::Shr for Expr { + type Output = Self; + + fn shr(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::BitwiseShiftRight, rhs) + } +} + +/// Support `- ` fluent style +impl ops::Neg for Expr { + type Output = Self; + + fn neg(self) -> Self::Output { + Expr::Negative(Box::new(self)) + } +} + +/// Support `NOT ` fluent style +impl Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + match self { + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => Expr::Like(Like::new( + !negated, + expr, + pattern, + escape_char, + case_insensitive, + )), + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => Expr::SimilarTo(Like::new( + !negated, + expr, + pattern, + escape_char, + case_insensitive, + )), + _ => Expr::Not(Box::new(self)), + } + } +} + +#[cfg(test)] +mod tests { + use crate::lit; + + #[test] + fn test_operators() { + // Add + assert_eq!( + format!("{}", lit(1u32) + lit(2u32)), + "UInt32(1) + UInt32(2)" + ); + // Sub + assert_eq!( + format!("{}", lit(1u32) - lit(2u32)), + "UInt32(1) - UInt32(2)" + ); + // Mul + assert_eq!( + format!("{}", lit(1u32) * lit(2u32)), + "UInt32(1) * UInt32(2)" + ); + // Div + assert_eq!( + format!("{}", lit(1u32) / lit(2u32)), + "UInt32(1) / UInt32(2)" + ); + // Rem + assert_eq!( + format!("{}", lit(1u32) % lit(2u32)), + "UInt32(1) % UInt32(2)" + ); + // BitAnd + assert_eq!( + format!("{}", lit(1u32) & lit(2u32)), + "UInt32(1) & UInt32(2)" + ); + // BitOr + assert_eq!( + format!("{}", lit(1u32) | lit(2u32)), + "UInt32(1) | UInt32(2)" + ); + // BitXor + assert_eq!( + format!("{}", lit(1u32) ^ lit(2u32)), + "UInt32(1) BIT_XOR UInt32(2)" + ); + // Shl + assert_eq!( + format!("{}", lit(1u32) << lit(2u32)), + "UInt32(1) << UInt32(2)" + ); + // Shr + assert_eq!( + format!("{}", lit(1u32) >> lit(2u32)), + "UInt32(1) >> UInt32(2)" + ); + // Neg + assert_eq!(format!("{}", -lit(1u32)), "(- UInt32(1))"); + // Not + assert_eq!(format!("{}", !lit(1u32)), "NOT UInt32(1)"); + } +} diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 4f2776516d3e..190374b01dd2 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -17,9 +17,6 @@ use std::sync::Arc; -use crate::signature::{ - ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, -}; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature}; use arrow::{ compute::can_cast_types, @@ -29,6 +26,9 @@ use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, Result, }; +use datafusion_expr_common::signature::{ + ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, +}; use super::binary::{binary_numeric_coercion, comparison_coercion}; diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index e0d1236aac2d..3a5c65fb46ee 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -31,11 +31,14 @@ //! i64. However, i64 -> i32 is never performed as there are i64 //! values which can not be represented by i32 values. -pub mod aggregates; -pub mod binary; +pub mod aggregates { + pub use datafusion_expr_common::type_coercion::aggregates::*; +} pub mod functions; pub mod other; +pub use datafusion_expr_common::type_coercion::binary; + use arrow::datatypes::DataType; /// Determine whether the given data type `dt` represents signed numeric values. pub fn is_signed_numeric(dt: &DataType) -> bool { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3a292b2b49bf..d136aeaf0908 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -422,7 +422,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// See [retract_batch] for more details. /// - /// [retract_batch]: crate::accumulator::Accumulator::retract_batch + /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch fn create_sliding_accumulator( &self, args: AccumulatorArgs, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 34b5909f0a5a..f5434726e23d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,22 +17,19 @@ //! [`ScalarUDF`]: Scalar User Defined Functions -use std::any::Any; -use std::fmt::{self, Debug, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::Arc; - -use arrow::datatypes::DataType; - -use datafusion_common::{not_impl_err, ExprSchema, Result}; - use crate::expr::schema_name_from_exprs_comma_seperated_without_space; -use crate::interval_arithmetic::Interval; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, ExprSchema, Result}; +use datafusion_expr_common::interval_arithmetic::Interval; +use std::any::Any; +use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. /// diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c3e4505ed19c..7b650d1ab448 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -23,10 +23,10 @@ use std::sync::Arc; use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; -use crate::signature::{Signature, TypeSignature}; use crate::{ and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, }; +use datafusion_expr_common::signature::{Signature, TypeSignature}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion_common::tree_node::{ @@ -40,6 +40,8 @@ use datafusion_common::{ use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; +pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; + /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; @@ -1219,37 +1221,6 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } -/// Represents the sensitivity of an aggregate expression to ordering. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum AggregateOrderSensitivity { - /// Indicates that the aggregate expression is insensitive to ordering. - /// Ordering at the input is not important for the result of the aggregator. - Insensitive, - /// Indicates that the aggregate expression has a hard requirement on ordering. - /// The aggregator can not produce a correct result unless its ordering - /// requirement is satisfied. - HardRequirement, - /// Indicates that ordering is beneficial for the aggregate expression in terms - /// of evaluation efficiency. The aggregator can produce its result efficiently - /// when its required ordering is satisfied; however, it can still produce the - /// correct result (albeit less efficiently) when its required ordering is not met. - Beneficial, -} - -impl AggregateOrderSensitivity { - pub fn is_insensitive(&self) -> bool { - self.eq(&AggregateOrderSensitivity::Insensitive) - } - - pub fn is_beneficial(&self) -> bool { - self.eq(&AggregateOrderSensitivity::Beneficial) - } - - pub fn hard_requires(&self) -> bool { - self.eq(&AggregateOrderSensitivity::HardRequirement) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml new file mode 100644 index 000000000000..a8296ce11f30 --- /dev/null +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-aggregate-common" +description = "Utility functions for implementing aggregate functions for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_functions_aggregate_common" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ahash = { workspace = true } +arrow = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +rand = { workspace = true } diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs new file mode 100644 index 000000000000..ddf0085b9de4 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::Result; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::{ + physical_expr::PhysicalExpr, sort_expr::PhysicalSortExpr, +}; +use std::sync::Arc; + +/// [`AccumulatorArgs`] contains information about how an aggregate +/// function was called, including the types of its arguments and any optional +/// ordering expressions. +#[derive(Debug)] +pub struct AccumulatorArgs<'a> { + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The schema of the input arguments + pub schema: &'a Schema, + + /// Whether to ignore nulls. + /// + /// SQL allows the user to specify `IGNORE NULLS`, for example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; + /// ``` + pub ignore_nulls: bool, + + /// The expressions in the `ORDER BY` clause passed to this aggregator. + /// + /// SQL allows the user to specify the ordering of arguments to the + /// aggregate using an `ORDER BY`. For example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; + /// ``` + /// + /// If no `ORDER BY` is specified, `ordering_req` will be empty. + pub ordering_req: &'a [PhysicalSortExpr], + + /// Whether the aggregation is running in reverse order + pub is_reversed: bool, + + /// The name of the aggregate expression + pub name: &'a str, + + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + + /// The physical expression of arguments the aggregate function takes. + pub exprs: &'a [Arc], +} + +/// Factory that returns an accumulator for the given aggregate function. +pub type AccumulatorFactoryFunction = + Arc Result> + Send + Sync>; + +/// [`StateFieldsArgs`] contains information about the fields that an +/// aggregate function's accumulator should have. Used for `AggregateUDFImpl::state_fields`. +pub struct StateFieldsArgs<'a> { + /// The name of the aggregate function. + pub name: &'a str, + + /// The input types of the aggregate function. + pub input_types: &'a [DataType], + + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The ordering fields of the aggregate function. + pub ordering_fields: &'a [Field], + + /// Whether the aggregate function is distinct. + pub is_distinct: bool, +} diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs new file mode 100644 index 000000000000..016e54e68835 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`AggregateExpr`] which defines the interface all aggregate expressions +//! (built-in and custom) need to satisfy. + +use crate::order::AggregateOrderSensitivity; +use arrow::datatypes::Field; +use datafusion_common::exec_err; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::groups_accumulator::GroupsAccumulator; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use std::fmt::Debug; +use std::{any::Any, sync::Arc}; + +pub mod count_distinct; +pub mod groups_accumulator; + +/// An aggregate expression that: +/// * knows its resulting field +/// * knows how to create its accumulator +/// * knows its accumulator's state's field +/// * knows the expressions from whose its accumulator will receive values +/// +/// Any implementation of this trait also needs to implement the +/// `PartialEq` to allows comparing equality between the +/// trait objects. +pub trait AggregateExpr: Send + Sync + Debug + PartialEq { + /// Returns the aggregate expression as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// the field of the final result of this aggregation. + fn field(&self) -> Result; + + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + fn create_accumulator(&self) -> Result>; + + /// the fields that encapsulate the Accumulator's state + /// the number of fields here equals the number of states that the accumulator contains + fn state_fields(&self) -> Result>; + + /// expressions that are passed to the Accumulator. + /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. + fn expressions(&self) -> Vec>; + + /// Order by requirements for the aggregate function + /// By default it is `None` (there is no requirement) + /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + /// Indicates whether aggregator can produce the correct result with any + /// arbitrary input ordering. By default, we assume that aggregate expressions + /// are order insensitive. + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + /// Sets the indicator whether ordering requirements of the aggregator is + /// satisfied by its input. If this is not the case, aggregators with order + /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce + /// the correct result with possibly more work internally. + /// + /// # Returns + /// + /// Returns `Ok(Some(updated_expr))` if the process completes successfully. + /// If the expression can benefit from existing input ordering, but does + /// not implement the method, returns an error. Order insensitive and hard + /// requirement aggregators return `Ok(None)`. + fn with_beneficial_ordering( + self: Arc, + _requirement_satisfied: bool, + ) -> Result>> { + if self.order_bys().is_some() && self.order_sensitivity().is_beneficial() { + return exec_err!( + "Should implement with satisfied for aggregator :{:?}", + self.name() + ); + } + Ok(None) + } + + /// Human readable name such as `"MIN(c2)"`. The default + /// implementation returns placeholder text. + fn name(&self) -> &str { + "AggregateExpr: default name" + } + + /// If the aggregate expression has a specialized + /// [`GroupsAccumulator`] implementation. If this returns true, + /// `[Self::create_groups_accumulator`] will be called. + fn groups_accumulator_supported(&self) -> bool { + false + } + + /// Return a specialized [`GroupsAccumulator`] that manages state + /// for all groups. + /// + /// For maximum performance, a [`GroupsAccumulator`] should be + /// implemented in addition to [`Accumulator`]. + fn create_groups_accumulator(&self) -> Result> { + not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") + } + + /// Construct an expression that calculates the aggregate in reverse. + /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). + /// For aggregates that do not support calculation in reverse, + /// returns None (which is the default value). + fn reverse_expr(&self) -> Option> { + None + } + + /// Creates accumulator implementation that supports retract + fn create_sliding_accumulator(&self) -> Result> { + not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") + } + + /// Returns all expressions used in the [`AggregateExpr`]. + /// These expressions are (1)function arguments, (2) order by expressions. + fn all_expressions(&self) -> AggregatePhysicalExpressions { + let args = self.expressions(); + let order_bys = self.order_bys().unwrap_or(&[]); + let order_by_exprs = order_bys + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + AggregatePhysicalExpressions { + args, + order_by_exprs, + } + } + + /// Rewrites [`AggregateExpr`], with new expressions given. The argument should be consistent + /// with the return value of the [`AggregateExpr::all_expressions`] method. + /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. + fn with_new_expressions( + &self, + _args: Vec>, + _order_by_exprs: Vec>, + ) -> Option> { + None + } + + /// If this function is max, return (output_field, true) + /// if the function is min, return (output_field, false) + /// otherwise return None (the default) + /// + /// output_field is the name of the column produced by this aggregate + /// + /// Note: this is used to use special aggregate implementations in certain conditions + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + None + } +} + +/// Stores the physical expressions used inside the `AggregateExpr`. +pub struct AggregatePhysicalExpressions { + /// Aggregate function arguments + pub args: Vec>, + /// Order by expressions + pub order_by_exprs: Vec>, +} diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs similarity index 100% rename from datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs rename to datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs similarity index 95% rename from datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs rename to datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs index 360d64ce0141..ee61128979e1 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs @@ -17,13 +17,13 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values -use crate::binary_map::{ArrowBytesSet, OutputType}; -use crate::binary_view_map::ArrowBytesViewSet; use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet; use std::fmt::Debug; use std::sync::Arc; diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs similarity index 98% rename from datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs rename to datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs index e525118b9a17..d128a8af58ee 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs @@ -35,9 +35,9 @@ use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr_common::accumulator::Accumulator; -use crate::aggregate::utils::Hashable; +use crate::utils::Hashable; #[derive(Debug)] pub struct PrimitiveDistinctCountAccumulator diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs similarity index 97% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 592c130b69d8..644221edd04d 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -15,19 +15,24 @@ // specific language governing permissions and limitations // under the License. +//! Utilities for implementing GroupsAccumulator //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] +pub mod accumulate; +pub mod bool_op; +pub mod prim_op; + use arrow::{ - array::{AsArray, UInt32Builder}, + array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray, UInt32Builder}, compute, datatypes::UInt32Type, }; -use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; use datafusion_common::{ arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs similarity index 99% rename from datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 3fcd570f514e..455fc5fec450 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -17,13 +17,13 @@ //! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] //! -//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator +//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::datatypes::ArrowPrimitiveType; -use datafusion_expr::EmitTo; +use datafusion_expr_common::groups_accumulator::EmitTo; /// Track the accumulator null state per row: if any values for that /// group were null and if any values have been seen at all for that group. /// @@ -48,7 +48,7 @@ use datafusion_expr::EmitTo; /// had at least one value to accumulate so they do not need to track /// if they have seen values for a particular group. /// -/// [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator +/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator #[derive(Debug)] pub struct NullState { /// Have we seen any non-filtered input values for `group_index`? diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs similarity index 98% rename from datafusion/physical-expr-common/src/aggregate/groups_accumulator/bool_op.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index 8498d69dd333..be2b5e48a8db 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder}; use arrow::buffer::BooleanBuffer; use datafusion_common::Result; -use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use super::accumulate::NullState; diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs similarity index 98% rename from datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index acf1ae525c79..b5c6171af37c 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -23,7 +23,7 @@ use arrow::compute; use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::DataType; use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; -use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use super::accumulate::NullState; diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/functions-aggregate-common/src/lib.rs similarity index 59% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs rename to datafusion/functions-aggregate-common/src/lib.rs index 3c0f3a28fedb..cc50ff70913b 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -15,13 +15,20 @@ // specific language governing permissions and limitations // under the License. -mod adapter; -pub use adapter::GroupsAccumulatorAdapter; +//! Common Aggregate functionality for [DataFusion] +//! +//! This crate contains traits and utilities commonly used to implement aggregate functions +//! They are kept in their own crate to avoid physical expressions depending on logical expressions. +//! +//! [DataFusion]: -// Backward compatibility -#[allow(unused_imports)] -pub(crate) mod accumulate { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; -} +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] -pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; +pub mod accumulator; +pub mod aggregate; +pub mod merge_arrays; +pub mod order; +pub mod stats; +pub mod tdigest; +pub mod utils; diff --git a/datafusion/physical-expr-common/src/aggregate/merge_arrays.rs b/datafusion/functions-aggregate-common/src/merge_arrays.rs similarity index 100% rename from datafusion/physical-expr-common/src/aggregate/merge_arrays.rs rename to datafusion/functions-aggregate-common/src/merge_arrays.rs diff --git a/datafusion/functions-aggregate-common/src/order.rs b/datafusion/functions-aggregate-common/src/order.rs new file mode 100644 index 000000000000..bfa6e39138f9 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/order.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Represents the sensitivity of an aggregate expression to ordering. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum AggregateOrderSensitivity { + /// Indicates that the aggregate expression is insensitive to ordering. + /// Ordering at the input is not important for the result of the aggregator. + Insensitive, + /// Indicates that the aggregate expression has a hard requirement on ordering. + /// The aggregator can not produce a correct result unless its ordering + /// requirement is satisfied. + HardRequirement, + /// Indicates that ordering is beneficial for the aggregate expression in terms + /// of evaluation efficiency. The aggregator can produce its result efficiently + /// when its required ordering is satisfied; however, it can still produce the + /// correct result (albeit less efficiently) when its required ordering is not met. + Beneficial, +} + +impl AggregateOrderSensitivity { + pub fn is_insensitive(&self) -> bool { + self.eq(&AggregateOrderSensitivity::Insensitive) + } + + pub fn is_beneficial(&self) -> bool { + self.eq(&AggregateOrderSensitivity::Beneficial) + } + + pub fn hard_requires(&self) -> bool { + self.eq(&AggregateOrderSensitivity::HardRequirement) + } +} diff --git a/datafusion/physical-expr-common/src/aggregate/stats.rs b/datafusion/functions-aggregate-common/src/stats.rs similarity index 100% rename from datafusion/physical-expr-common/src/aggregate/stats.rs rename to datafusion/functions-aggregate-common/src/stats.rs diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs similarity index 100% rename from datafusion/physical-expr-common/src/aggregate/tdigest.rs rename to datafusion/functions-aggregate-common/src/tdigest.rs diff --git a/datafusion/physical-expr-common/src/aggregate/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs similarity index 98% rename from datafusion/physical-expr-common/src/aggregate/utils.rs rename to datafusion/functions-aggregate-common/src/utils.rs index 9e380bd820ff..7b8ce0397af8 100644 --- a/datafusion/physical-expr-common/src/aggregate/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -29,11 +29,10 @@ use arrow::{ }, }; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; -use crate::sort_expr::PhysicalSortExpr; - -use super::AggregateExpr; +use crate::aggregate::AggregateExpr; /// Downcast a `Box` or `Arc` /// and return the inner trait object as [`Any`] so diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 4f2bd864832e..636b2e42d236 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "datafusion-functions-aggregate" -description = "Aggregate function packages for the DataFusion query engine" +description = "Traits and types for logical plans and expressions for DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] readme = "README.md" version = { workspace = true } @@ -44,6 +44,8 @@ arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 875112ca8d47..65956cb8a1de 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -20,25 +20,22 @@ use arrow::datatypes::Int32Type; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; use arrow_schema::{DataType, Field, Schema}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::DFSchema; use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::count::Count; +use datafusion_physical_expr::expressions::col; use std::sync::Arc; fn prepare_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); - let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let accumulator_args = AccumulatorArgs { - data_type: &DataType::Int64, + return_type: &DataType::Int64, schema: &schema, - dfschema: &df_schema, ignore_nulls: false, - sort_exprs: &[], + ordering_req: &[], is_reversed: false, name: "COUNT(f)", is_distinct: false, - input_types: &[DataType::Int32], - input_exprs: &[datafusion_expr::col("f")], + exprs: &[col("f", &schema).unwrap()], }; let count_fn = Count::new(); diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index dfaa93cdeff7..652d447129dc 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -20,25 +20,22 @@ use arrow::datatypes::Int64Type; use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; use arrow_schema::{DataType, Field, Schema}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_common::DFSchema; use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; use datafusion_functions_aggregate::sum::Sum; +use datafusion_physical_expr::expressions::col; use std::sync::Arc; fn prepare_accumulator(data_type: &DataType) -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)])); - let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let accumulator_args = AccumulatorArgs { - data_type, + return_type: data_type, schema: &schema, - dfschema: &df_schema, ignore_nulls: false, - sort_exprs: &[], + ordering_req: &[], is_reversed: false, name: "SUM(f)", is_distinct: false, - input_types: &[data_type.clone()], - input_exprs: &[datafusion_expr::col("f")], + exprs: &[col("f", &schema).unwrap()], }; let sum_fn = Sum::new(); diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 56ef32e7ebe0..cf8217fe981d 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -277,7 +277,9 @@ impl AggregateUDFImpl for ApproxDistinct { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let accumulator: Box = match &acc_args.input_types[0] { + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + + let accumulator: Box = match data_type { // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL // TODO support for boolean (trivial case) // https://github.com/apache/datafusion/issues/1109 diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index c386ad89f0fb..7a7b12432544 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_types[0].clone(), + acc_args.exprs[0].data_type(acc_args.schema)?, ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index ffa623c13b0b..89d827e86859 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -31,7 +31,7 @@ use arrow::{ use arrow_schema::{Field, Schema}; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, + downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; @@ -41,10 +41,10 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, Volatility, }; -use datafusion_physical_expr_common::aggregate::tdigest::{ +use datafusion_functions_aggregate_common::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; -use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); @@ -105,15 +105,16 @@ impl ApproxPercentileCont { pub(crate) fn create_accumulator( &self, args: AccumulatorArgs, - ) -> datafusion_common::Result { - let percentile = validate_input_percentile_expr(&args.input_exprs[1])?; - let tdigest_max_size = if args.input_exprs.len() == 3 { - Some(validate_input_max_size_expr(&args.input_exprs[2])?) + ) -> Result { + let percentile = validate_input_percentile_expr(&args.exprs[1])?; + let tdigest_max_size = if args.exprs.len() == 3 { + Some(validate_input_max_size_expr(&args.exprs[2])?) } else { None }; - let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] { + let data_type = args.exprs[0].data_type(args.schema)?; + let accumulator: ApproxPercentileAccumulator = match data_type { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 @@ -142,31 +143,30 @@ impl ApproxPercentileCont { } } -fn get_lit_value(expr: &Expr) -> datafusion_common::Result { +fn get_scalar_value(expr: &Arc) -> Result { let empty_schema = Arc::new(Schema::empty()); - let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); - let dfschema = DFSchema::empty(); - let expr = - limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, &dfschema)?; - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), - ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); + if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { + Ok(s) + } else { + internal_err!("Didn't expect ColumnarValue::Array") } } -fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result { - let lit = get_lit_value(expr)?; - let percentile = match &lit { - ScalarValue::Float32(Some(q)) => *q as f64, - ScalarValue::Float64(Some(q)) => *q, - got => return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.data_type() - ) +fn validate_input_percentile_expr(expr: &Arc) -> Result { + let percentile = match get_scalar_value(expr)? { + ScalarValue::Float32(Some(value)) => { + value as f64 + } + ScalarValue::Float64(Some(value)) => { + value + } + sv => { + return not_impl_err!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + sv.data_type() + ) + } }; // Ensure the percentile is between 0 and 1. @@ -178,22 +178,24 @@ fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result Ok(percentile) } -fn validate_input_max_size_expr(expr: &Expr) -> datafusion_common::Result { - let lit = get_lit_value(expr)?; - let max_size = match &lit { - ScalarValue::UInt8(Some(q)) => *q as usize, - ScalarValue::UInt16(Some(q)) => *q as usize, - ScalarValue::UInt32(Some(q)) => *q as usize, - ScalarValue::UInt64(Some(q)) => *q as usize, - ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, - got => return not_impl_err!( - "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", - got.data_type() - ) +fn validate_input_max_size_expr(expr: &Arc) -> Result { + let max_size = match get_scalar_value(expr)? { + ScalarValue::UInt8(Some(q)) => q as usize, + ScalarValue::UInt16(Some(q)) => q as usize, + ScalarValue::UInt32(Some(q)) => q as usize, + ScalarValue::UInt64(Some(q)) => q as usize, + ScalarValue::Int32(Some(q)) if q > 0 => q as usize, + ScalarValue::Int64(Some(q)) if q > 0 => q as usize, + ScalarValue::Int16(Some(q)) if q > 0 => q as usize, + ScalarValue::Int8(Some(q)) if q > 0 => q as usize, + sv => { + return not_impl_err!( + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", + sv.data_type() + ) + } }; + Ok(max_size) } @@ -205,10 +207,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialised /// state. - fn state_fields( - &self, - args: StateFieldsArgs, - ) -> datafusion_common::Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( format_state_name(args.name, "max_size"), @@ -252,14 +251,11 @@ impl AggregateUDFImpl for ApproxPercentileCont { } #[inline] - fn accumulator( - &self, - acc_args: AccumulatorArgs, - ) -> datafusion_common::Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(self.create_accumulator(acc_args)?)) } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { if !arg_types[0].is_numeric() { return plan_err!("approx_percentile_cont requires numeric input types"); } @@ -307,7 +303,7 @@ impl ApproxPercentileAccumulator { } // public for approx_percentile_cont_with_weight - pub fn convert_to_float(values: &ArrayRef) -> datafusion_common::Result> { + pub fn convert_to_float(values: &ArrayRef) -> Result> { match values.data_type() { DataType::Float64 => { let array = downcast_value!(values, Float64Array); @@ -315,7 +311,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::Float32 => { let array = downcast_value!(values, Float32Array); @@ -323,7 +319,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::Int64 => { let array = downcast_value!(values, Int64Array); @@ -331,7 +327,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::Int32 => { let array = downcast_value!(values, Int32Array); @@ -339,7 +335,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::Int16 => { let array = downcast_value!(values, Int16Array); @@ -347,7 +343,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::Int8 => { let array = downcast_value!(values, Int8Array); @@ -355,7 +351,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::UInt64 => { let array = downcast_value!(values, UInt64Array); @@ -363,7 +359,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::UInt32 => { let array = downcast_value!(values, UInt32Array); @@ -371,7 +367,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::UInt16 => { let array = downcast_value!(values, UInt16Array); @@ -379,7 +375,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } DataType::UInt8 => { let array = downcast_value!(values, UInt8Array); @@ -387,7 +383,7 @@ impl ApproxPercentileAccumulator { .values() .iter() .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) + .collect::>>()?) } e => internal_err!( "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" @@ -397,11 +393,11 @@ impl ApproxPercentileAccumulator { } impl Accumulator for ApproxPercentileAccumulator { - fn state(&mut self) -> datafusion_common::Result> { + fn state(&mut self) -> Result> { Ok(self.digest.to_scalar_state().into_iter().collect()) } - fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // Remove any nulls before computing the percentile let mut values = Arc::clone(&values[0]); if values.nulls().is_some() { @@ -413,7 +409,7 @@ impl Accumulator for ApproxPercentileAccumulator { Ok(()) } - fn evaluate(&mut self) -> datafusion_common::Result { + fn evaluate(&mut self) -> Result { if self.digest.count() == 0 { return ScalarValue::try_from(self.return_type.clone()); } @@ -436,7 +432,7 @@ impl Accumulator for ApproxPercentileAccumulator { }) } - fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } @@ -446,10 +442,10 @@ impl Accumulator for ApproxPercentileAccumulator { states .iter() .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>() + .collect::>>() .map(|state| TDigest::from_scalar_state(&state)) }) - .collect::>>()?; + .collect::>>()?; self.merge_digests(&states); @@ -472,7 +468,7 @@ impl Accumulator for ApproxPercentileAccumulator { mod tests { use arrow_schema::DataType; - use datafusion_physical_expr_common::aggregate::tdigest::TDigest; + use datafusion_functions_aggregate_common::tdigest::TDigest; use crate::approx_percentile_cont::ApproxPercentileAccumulator; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 0dbea1fb1ff7..fee67ba1623d 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,6 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; +use std::sync::Arc; use arrow::{ array::ArrayRef, @@ -29,7 +30,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; -use datafusion_physical_expr_common::aggregate::tdigest::{ +use datafusion_functions_aggregate_common::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; @@ -123,16 +124,16 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } - if acc_args.input_exprs.len() != 3 { + if acc_args.exprs.len() != 3 { return plan_err!( "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" ); } let sub_args = AccumulatorArgs { - input_exprs: &[ - acc_args.input_exprs[0].clone(), - acc_args.input_exprs[2].clone(), + exprs: &[ + Arc::clone(&acc_args.exprs[0]), + Arc::clone(&acc_args.exprs[2]), ], ..acc_args }; diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 36c9d6a0d7c8..b641d388a7c5 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -29,12 +29,9 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::AggregateUDFImpl; use datafusion_expr::{Accumulator, Signature, Volatility}; -use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; -use datafusion_physical_expr_common::aggregate::utils::ordering_fields; -use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, - PhysicalSortExpr, -}; +use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use std::collections::{HashSet, VecDeque}; use std::sync::Arc; @@ -117,32 +114,26 @@ impl AggregateUDFImpl for ArrayAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + if acc_args.is_distinct { - return Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &acc_args.input_types[0], - )?)); + return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); } - if acc_args.sort_exprs.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new( - &acc_args.input_types[0], - )?)); + if acc_args.ordering_req.is_empty() { + return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); } - let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( - acc_args.sort_exprs, - acc_args.dfschema, - )?; - - let ordering_dtypes = ordering_req + let ordering_dtypes = acc_args + .ordering_req .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( - &acc_args.input_types[0], + &data_type, &ordering_dtypes, - ordering_req, + acc_args.ordering_req.to_vec(), acc_args.is_reversed, ) .map(|acc| Box::new(acc) as _) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 288e0b09f809..1be3cd6b0714 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -34,8 +34,8 @@ use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; -use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; -use datafusion_physical_expr_common::aggregate::utils::DecimalAverager; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_functions_aggregate_common::utils::DecimalAverager; use log::debug; use std::any::Any; use std::fmt::Debug; @@ -92,8 +92,10 @@ impl AggregateUDFImpl for Avg { return exec_err!("avg(DISTINCT) aggregations are not available"); } use DataType::*; + + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; // instantiate specialized accumulator based for the type - match (&acc_args.input_types[0], acc_args.data_type) { + match (&data_type, acc_args.return_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -120,8 +122,8 @@ impl AggregateUDFImpl for Avg { })), _ => exec_err!( "AvgAccumulator for ({} --> {})", - &acc_args.input_types[0], - acc_args.data_type + &data_type, + acc_args.return_type ), } } @@ -143,7 +145,7 @@ impl AggregateUDFImpl for Avg { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( - args.data_type, + args.return_type, DataType::Float64 | DataType::Decimal128(_, _) ) } @@ -153,12 +155,14 @@ impl AggregateUDFImpl for Avg { args: AccumulatorArgs, ) -> Result> { use DataType::*; + + let data_type = args.exprs[0].data_type(args.schema)?; // instantiate specialized accumulator based for the type - match (&args.input_types[0], args.data_type) { + match (&data_type, args.return_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], - args.data_type, + &data_type, + args.return_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) } @@ -176,8 +180,8 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], - args.data_type, + &data_type, + args.return_type, avg_fn, ))) } @@ -197,16 +201,16 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - &args.input_types[0], - args.data_type, + &data_type, + args.return_type, avg_fn, ))) } _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - &args.input_types[0], - args.data_type + &data_type, + args.return_type ), } } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index f6dd0bc20a83..aa65062e3330 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -38,7 +38,7 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign}; /// This macro helps create group accumulators based on bitwise operations typically used internally @@ -84,7 +84,7 @@ macro_rules! accumulator_helper { /// `is_distinct` is boolean value indicating whether the operation is distinct or not. macro_rules! downcast_bitwise_accumulator { ($args:ident, $opr:expr, $is_distinct: expr) => { - match $args.data_type { + match $args.return_type { DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), @@ -98,7 +98,7 @@ macro_rules! downcast_bitwise_accumulator { "{} not supported for {}: {}", stringify!($opr), $args.name, - $args.data_type + $args.return_type ) } } @@ -224,7 +224,7 @@ impl AggregateUDFImpl for BitwiseOperation { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = args.data_type; + let data_type = args.return_type; let operation = &self.operation; downcast_integer! { data_type => (group_accumulator_helper, data_type, operation), diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index b91fbb9ff709..b993b2a4979c 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -35,7 +35,7 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; // returns the new value after bool_and/bool_or with the new values, taking nullability into account macro_rules! typed_bool_and_or_batch { @@ -149,14 +149,14 @@ impl AggregateUDFImpl for BoolAnd { &self, args: AccumulatorArgs, ) -> Result> { - match args.data_type { + match args.return_type { DataType::Boolean => { Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y))) } _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.data_type + args.return_type ), } } @@ -269,14 +269,14 @@ impl AggregateUDFImpl for BoolOr { &self, args: AccumulatorArgs, ) -> Result> { - match args.data_type { + match args.return_type { DataType::Boolean => { Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x || y))) } _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.data_type + args.return_type ), } } diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index c2d7a89081d6..88f01b06d2d9 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -36,7 +36,7 @@ use datafusion_expr::{ utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::stats::StatsType; +use datafusion_functions_aggregate_common::stats::StatsType; make_udaf_expr_and_func!( Correlation, diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index aea05442536e..04b1921c7b9e 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,7 +16,7 @@ // under the License. use ahash::RandomState; -use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use std::collections::HashSet; use std::ops::BitAnd; use std::{fmt::Debug, sync::Arc}; @@ -47,14 +47,12 @@ use datafusion_expr::{ EmitTo, GroupsAccumulator, Signature, Volatility, }; use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; -use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; -use datafusion_physical_expr_common::{ - aggregate::count_distinct::{ - BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, - PrimitiveDistinctCountAccumulator, - }, - binary_map::OutputType, +use datafusion_functions_aggregate_common::aggregate::count_distinct::{ + BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, + PrimitiveDistinctCountAccumulator, }; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_physical_expr_common::binary_map::OutputType; make_udaf_expr_and_func!( Count, @@ -145,11 +143,11 @@ impl AggregateUDFImpl for Count { return Ok(Box::new(CountAccumulator::new())); } - if acc_args.input_exprs.len() > 1 { + if acc_args.exprs.len() > 1 { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = &acc_args.input_types[0]; + let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( @@ -271,7 +269,7 @@ impl AggregateUDFImpl for Count { if args.is_distinct { return false; } - args.input_exprs.len() == 1 + args.exprs.len() == 1 } fn create_groups_accumulator( diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs index 6f03b256fd9f..d0abb079ef15 100644 --- a/datafusion/functions-aggregate/src/covariance.rs +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -35,7 +35,7 @@ use datafusion_expr::{ utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::stats::StatsType; +use datafusion_functions_aggregate_common::stats::StatsType; make_udaf_expr_and_func!( CovarianceSample, diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 587767b8e356..2162442f054e 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -34,11 +34,8 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, Signature, TypeSignature, Volatility, }; -use datafusion_physical_expr_common::aggregate::utils::get_sort_options; -use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, - PhysicalSortExpr, -}; +use datafusion_functions_aggregate_common::utils::get_sort_options; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; create_func!(FirstValue, first_value_udaf); @@ -117,24 +114,21 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( - acc_args.sort_exprs, - acc_args.dfschema, - )?; - - let ordering_dtypes = ordering_req + let ordering_dtypes = acc_args + .ordering_req .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; // When requirement is empty, or it is signalled by outside caller that // the ordering requirement is/will be satisfied. - let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied; + let requirement_satisfied = + acc_args.ordering_req.is_empty() || self.requirement_satisfied; FirstValueAccumulator::try_new( - acc_args.data_type, + acc_args.return_type, &ordering_dtypes, - ordering_req, + acc_args.ordering_req.to_vec(), acc_args.ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) @@ -416,22 +410,19 @@ impl AggregateUDFImpl for LastValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( - acc_args.sort_exprs, - acc_args.dfschema, - )?; - - let ordering_dtypes = ordering_req + let ordering_dtypes = acc_args + .ordering_req .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied; + let requirement_satisfied = + acc_args.ordering_req.is_empty() || self.requirement_satisfied; LastValueAccumulator::try_new( - acc_args.data_type, + acc_args.return_type, &ordering_dtypes, - ordering_req, + acc_args.ordering_req.to_vec(), acc_args.ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index febf1fcd2fef..7dd0de14c3c0 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -38,7 +38,7 @@ use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::utils::Hashable; +use datafusion_functions_aggregate_common::utils::Hashable; make_udaf_expr_and_func!( Median, @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = &acc_args.input_types[0]; + let dt = acc_args.exprs[0].data_type(acc_args.schema)?; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index f19d6d767ba1..f9a08631bfb9 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -49,7 +49,7 @@ use arrow::datatypes::{ }; use arrow_schema::IntervalUnit; use datafusion_common::{downcast_value, internal_err, DataFusionError, Result}; -use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use std::fmt::Debug; use arrow::datatypes::i256; @@ -156,7 +156,7 @@ impl AggregateUDFImpl for Max { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(acc_args.data_type)?)) + Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) } fn aliases(&self) -> &[String] { @@ -166,7 +166,7 @@ impl AggregateUDFImpl for Max { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.data_type, + args.return_type, Int8 | Int16 | Int32 | Int64 @@ -192,7 +192,7 @@ impl AggregateUDFImpl for Max { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.data_type; + let data_type = args.return_type; match data_type { Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), @@ -253,7 +253,7 @@ impl AggregateUDFImpl for Max { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) + Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) } fn is_descending(&self) -> Option { @@ -925,7 +925,7 @@ impl AggregateUDFImpl for Min { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MinAccumulator::try_new(acc_args.data_type)?)) + Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) } fn aliases(&self) -> &[String] { @@ -935,7 +935,7 @@ impl AggregateUDFImpl for Min { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.data_type, + args.return_type, Int8 | Int16 | Int32 | Int64 @@ -961,7 +961,7 @@ impl AggregateUDFImpl for Min { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.data_type; + let data_type = args.return_type; match data_type { Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), @@ -1022,7 +1022,7 @@ impl AggregateUDFImpl for Min { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) + Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) } fn is_descending(&self) -> Option { diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index dc7c6c86f213..cb1ddd4738c4 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -30,14 +30,12 @@ use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValu use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Expr, ReversedUDAF, Signature, Volatility, -}; -use datafusion_physical_expr_common::aggregate::merge_arrays::merge_ordered_arrays; -use datafusion_physical_expr_common::aggregate::utils::ordering_fields; -use datafusion_physical_expr_common::sort_expr::{ - limited_convert_logical_sort_exprs_to_physical_with_dfschema, LexOrdering, - PhysicalSortExpr, + Accumulator, AggregateUDFImpl, ReversedUDAF, Signature, Volatility, }; +use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; make_udaf_expr_and_func!( NthValueAgg, @@ -87,36 +85,39 @@ impl AggregateUDFImpl for NthValueAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let n = match acc_args.input_exprs[1] { - Expr::Literal(ScalarValue::Int64(Some(value))) => { + let n = match acc_args.exprs[1] + .as_any() + .downcast_ref::() + .map(|lit| lit.value()) + { + Some(ScalarValue::Int64(Some(value))) => { if acc_args.is_reversed { - Ok(-value) + -*value } else { - Ok(value) + *value } } - _ => not_impl_err!( - "{} not supported for n: {}", - self.name(), - &acc_args.input_exprs[1] - ), - }?; - - let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( - acc_args.sort_exprs, - acc_args.dfschema, - )?; + _ => { + return not_impl_err!( + "{} not supported for n: {}", + self.name(), + &acc_args.exprs[1] + ) + } + }; - let ordering_dtypes = ordering_req + let ordering_dtypes = acc_args + .ordering_req .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; NthValueAccumulator::try_new( n, - &acc_args.input_types[0], + &data_type, &ordering_dtypes, - ordering_req, + acc_args.ordering_req.to_vec(), ) .map(|acc| Box::new(acc) as _) } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index df757ddc0422..180f4ad3cf37 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -27,7 +27,7 @@ use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; -use datafusion_physical_expr_common::aggregate::stats::StatsType; +use datafusion_functions_aggregate_common::stats::StatsType; use crate::variance::VarianceAccumulator; @@ -269,16 +269,12 @@ impl Accumulator for StddevAccumulator { #[cfg(test)] mod tests { - use std::sync::Arc; - + use super::*; use arrow::{array::*, datatypes::*}; - - use datafusion_common::DFSchema; use datafusion_expr::AggregateUDF; - use datafusion_physical_expr_common::aggregate::utils::get_accum_scalar_values_as_arrays; - use datafusion_physical_expr_common::expressions::column::col; - - use super::*; + use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; + use datafusion_physical_expr::expressions::col; + use std::sync::Arc; #[test] fn stddev_f64_merge_1() -> Result<()> { @@ -325,31 +321,26 @@ mod tests { agg2: Arc, schema: &Schema, ) -> Result { - let dfschema = DFSchema::empty(); let args1 = AccumulatorArgs { - data_type: &DataType::Float64, + return_type: &DataType::Float64, schema, - dfschema: &dfschema, ignore_nulls: false, - sort_exprs: &[], + ordering_req: &[], name: "a", is_distinct: false, is_reversed: false, - input_types: &[DataType::Float64], - input_exprs: &[datafusion_expr::col("a")], + exprs: &[col("a", schema)?], }; let args2 = AccumulatorArgs { - data_type: &DataType::Float64, + return_type: &DataType::Float64, schema, - dfschema: &dfschema, ignore_nulls: false, - sort_exprs: &[], + ordering_req: &[], name: "a", is_distinct: false, is_reversed: false, - input_types: &[DataType::Float64], - input_exprs: &[datafusion_expr::col("a")], + exprs: &[col("a", schema)?], }; let mut accum1 = agg1.accumulator(args1)?; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 5d91a52bc4c6..a7e9a37e23ad 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -24,8 +24,9 @@ use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, }; +use datafusion_physical_expr::expressions::Literal; use std::any::Any; make_udaf_expr_and_func!( @@ -82,21 +83,20 @@ impl AggregateUDFImpl for StringAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - match &acc_args.input_exprs[1] { - Expr::Literal(ScalarValue::Utf8(Some(delimiter))) - | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { - Ok(Box::new(StringAggAccumulator::new(delimiter))) - } - Expr::Literal(ScalarValue::Utf8(None)) - | Expr::Literal(ScalarValue::LargeUtf8(None)) - | Expr::Literal(ScalarValue::Null) => { - Ok(Box::new(StringAggAccumulator::new(""))) - } - _ => not_impl_err!( - "StringAgg not supported for delimiter {}", - &acc_args.input_exprs[1] - ), + if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { + return match lit.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) + } + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), + e => not_impl_err!("StringAgg not supported for delimiter {}", e), + }; } + + not_impl_err!("expect literal") } } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 08e3908a5829..7e40c1bd17a8 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -39,8 +39,8 @@ use datafusion_expr::utils::format_state_name; use datafusion_expr::{ Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use datafusion_physical_expr_common::aggregate::utils::Hashable; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::utils::Hashable; make_udaf_expr_and_func!( Sum, @@ -58,14 +58,18 @@ make_udaf_expr_and_func!( /// `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($args:ident, $helper:ident) => { - match $args.data_type { - DataType::UInt64 => $helper!(UInt64Type, $args.data_type), - DataType::Int64 => $helper!(Int64Type, $args.data_type), - DataType::Float64 => $helper!(Float64Type, $args.data_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.data_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.data_type), + match $args.return_type { + DataType::UInt64 => $helper!(UInt64Type, $args.return_type), + DataType::Int64 => $helper!(Int64Type, $args.return_type), + DataType::Float64 => $helper!(Float64Type, $args.return_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), _ => { - not_impl_err!("Sum not supported for {}: {}", $args.name, $args.data_type) + not_impl_err!( + "Sum not supported for {}: {}", + $args.name, + $args.return_type + ) } } }; diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index c772608cb376..4c78a42ea494 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, }; -use datafusion_physical_expr_common::aggregate::stats::StatsType; +use datafusion_functions_aggregate_common::stats::StatsType; make_udaf_expr_and_func!( VarianceSample, diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 3ef2d5345533..45ccb08e52e9 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -39,6 +39,6 @@ path = "src/lib.rs" ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } hashbrown = { workspace = true } rand = { workspace = true } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs deleted file mode 100644 index 350023352b12..000000000000 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ /dev/null @@ -1,807 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::fmt::Debug; -use std::{any::Any, sync::Arc}; - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - -use datafusion_common::exec_err; -use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; -use datafusion_expr::expr::create_function_physical_name; -use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::type_coercion::aggregates::check_arg_count; -use datafusion_expr::utils::AggregateOrderSensitivity; -use datafusion_expr::ReversedUDAF; -use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, -}; - -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; -use crate::utils::reverse_order_bys; - -use self::utils::down_cast_any_ref; - -pub mod count_distinct; -pub mod groups_accumulator; -pub mod merge_arrays; -pub mod stats; -pub mod tdigest; -pub mod utils; - -/// Creates a physical expression of the UDAF, that includes all necessary type coercion. -/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. -/// -/// `input_exprs` and `sort_exprs` are used for customizing Accumulator -/// whose behavior depends on arguments such as the `ORDER BY`. -/// -/// For example to call `ARRAY_AGG(x ORDER BY y)` would pass `y` to `sort_exprs`, `x` to `input_exprs` -/// -/// `input_exprs` and `sort_exprs` are used for customizing Accumulator as the arguments in `AccumulatorArgs`, -/// if you don't need them it is fine to pass empty slice `&[]`. -/// -/// `is_reversed` is used to indicate whether the aggregation is running in reverse order, -/// it could be used to hint Accumulator to accumulate in the reversed order, -/// you can just set to false if you are not reversing expression -/// -/// You can also create expression by [`AggregateExprBuilder`] -#[allow(clippy::too_many_arguments)] -pub fn create_aggregate_expr( - fun: &AggregateUDF, - input_phy_exprs: &[Arc], - input_exprs: &[Expr], - sort_exprs: &[Expr], - ordering_req: &[PhysicalSortExpr], - schema: &Schema, - name: Option, - ignore_nulls: bool, - is_distinct: bool, -) -> Result> { - let mut builder = - AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec()); - builder = builder.sort_exprs(sort_exprs.to_vec()); - builder = builder.order_by(ordering_req.to_vec()); - builder = builder.logical_exprs(input_exprs.to_vec()); - builder = builder.schema(Arc::new(schema.clone())); - if let Some(name) = name { - builder = builder.alias(name); - } - - if ignore_nulls { - builder = builder.ignore_nulls(); - } - if is_distinct { - builder = builder.distinct(); - } - - builder.build() -} - -#[allow(clippy::too_many_arguments)] -// This is not for external usage, consider creating with `create_aggregate_expr` instead. -pub fn create_aggregate_expr_with_dfschema( - fun: &AggregateUDF, - input_phy_exprs: &[Arc], - input_exprs: &[Expr], - sort_exprs: &[Expr], - ordering_req: &[PhysicalSortExpr], - dfschema: &DFSchema, - alias: Option, - ignore_nulls: bool, - is_distinct: bool, - is_reversed: bool, -) -> Result> { - let mut builder = - AggregateExprBuilder::new(Arc::new(fun.clone()), input_phy_exprs.to_vec()); - builder = builder.sort_exprs(sort_exprs.to_vec()); - builder = builder.order_by(ordering_req.to_vec()); - builder = builder.logical_exprs(input_exprs.to_vec()); - builder = builder.dfschema(dfschema.clone()); - let schema: Schema = dfschema.into(); - builder = builder.schema(Arc::new(schema)); - if let Some(alias) = alias { - builder = builder.alias(alias); - } - - if ignore_nulls { - builder = builder.ignore_nulls(); - } - if is_distinct { - builder = builder.distinct(); - } - if is_reversed { - builder = builder.reversed(); - } - - builder.build() -} - -/// Builder for physical [`AggregateExpr`] -/// -/// `AggregateExpr` contains the information necessary to call -/// an aggregate expression. -#[derive(Debug, Clone)] -pub struct AggregateExprBuilder { - fun: Arc, - /// Physical expressions of the aggregate function - args: Vec>, - /// Logical expressions of the aggregate function, it will be deprecated in - logical_args: Vec, - alias: Option, - /// Arrow Schema for the aggregate function - schema: SchemaRef, - /// Datafusion Schema for the aggregate function - dfschema: DFSchema, - /// The logical order by expressions, it will be deprecated in - sort_exprs: Vec, - /// The physical order by expressions - ordering_req: LexOrdering, - /// Whether to ignore null values - ignore_nulls: bool, - /// Whether is distinct aggregate function - is_distinct: bool, - /// Whether the expression is reversed - is_reversed: bool, -} - -impl AggregateExprBuilder { - pub fn new(fun: Arc, args: Vec>) -> Self { - Self { - fun, - args, - logical_args: vec![], - alias: None, - schema: Arc::new(Schema::empty()), - dfschema: DFSchema::empty(), - sort_exprs: vec![], - ordering_req: vec![], - ignore_nulls: false, - is_distinct: false, - is_reversed: false, - } - } - - pub fn build(self) -> Result> { - let Self { - fun, - args, - logical_args, - alias, - schema, - dfschema, - sort_exprs, - ordering_req, - ignore_nulls, - is_distinct, - is_reversed, - } = self; - if args.is_empty() { - return internal_err!("args should not be empty"); - } - - let mut ordering_fields = vec![]; - - debug_assert_eq!(sort_exprs.len(), ordering_req.len()); - if !ordering_req.is_empty() { - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(&schema)) - .collect::>>()?; - - ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); - } - - let input_exprs_types = args - .iter() - .map(|arg| arg.data_type(&schema)) - .collect::>>()?; - - check_arg_count( - fun.name(), - &input_exprs_types, - &fun.signature().type_signature, - )?; - - let data_type = fun.return_type(&input_exprs_types)?; - let name = match alias { - None => create_function_physical_name( - fun.name(), - is_distinct, - &logical_args, - if sort_exprs.is_empty() { - None - } else { - Some(&sort_exprs) - }, - )?, - Some(alias) => alias, - }; - - Ok(Arc::new(AggregateFunctionExpr { - fun: Arc::unwrap_or_clone(fun), - args, - logical_args, - data_type, - name, - schema: Arc::unwrap_or_clone(schema), - dfschema, - sort_exprs, - ordering_req, - ignore_nulls, - ordering_fields, - is_distinct, - input_types: input_exprs_types, - is_reversed, - })) - } - - pub fn alias(mut self, alias: impl Into) -> Self { - self.alias = Some(alias.into()); - self - } - - pub fn schema(mut self, schema: SchemaRef) -> Self { - self.schema = schema; - self - } - - pub fn dfschema(mut self, dfschema: DFSchema) -> Self { - self.dfschema = dfschema; - self - } - - pub fn order_by(mut self, order_by: LexOrdering) -> Self { - self.ordering_req = order_by; - self - } - - pub fn reversed(mut self) -> Self { - self.is_reversed = true; - self - } - - pub fn with_reversed(mut self, is_reversed: bool) -> Self { - self.is_reversed = is_reversed; - self - } - - pub fn distinct(mut self) -> Self { - self.is_distinct = true; - self - } - - pub fn with_distinct(mut self, is_distinct: bool) -> Self { - self.is_distinct = is_distinct; - self - } - - pub fn ignore_nulls(mut self) -> Self { - self.ignore_nulls = true; - self - } - - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; - self - } - - /// This method will be deprecated in - pub fn sort_exprs(mut self, sort_exprs: Vec) -> Self { - self.sort_exprs = sort_exprs; - self - } - - /// This method will be deprecated in - pub fn logical_exprs(mut self, logical_args: Vec) -> Self { - self.logical_args = logical_args; - self - } -} - -/// An aggregate expression that: -/// * knows its resulting field -/// * knows how to create its accumulator -/// * knows its accumulator's state's field -/// * knows the expressions from whose its accumulator will receive values -/// -/// Any implementation of this trait also needs to implement the -/// `PartialEq` to allows comparing equality between the -/// trait objects. -pub trait AggregateExpr: Send + Sync + Debug + PartialEq { - /// Returns the aggregate expression as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// the field of the final result of this aggregation. - fn field(&self) -> Result; - - /// the accumulator used to accumulate values from the expressions. - /// the accumulator expects the same number of arguments as `expressions` and must - /// return states with the same description as `state_fields` - fn create_accumulator(&self) -> Result>; - - /// the fields that encapsulate the Accumulator's state - /// the number of fields here equals the number of states that the accumulator contains - fn state_fields(&self) -> Result>; - - /// expressions that are passed to the Accumulator. - /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. - fn expressions(&self) -> Vec>; - - /// Order by requirements for the aggregate function - /// By default it is `None` (there is no requirement) - /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - /// Indicates whether aggregator can produce the correct result with any - /// arbitrary input ordering. By default, we assume that aggregate expressions - /// are order insensitive. - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - AggregateOrderSensitivity::Insensitive - } - - /// Sets the indicator whether ordering requirements of the aggregator is - /// satisfied by its input. If this is not the case, aggregators with order - /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce - /// the correct result with possibly more work internally. - /// - /// # Returns - /// - /// Returns `Ok(Some(updated_expr))` if the process completes successfully. - /// If the expression can benefit from existing input ordering, but does - /// not implement the method, returns an error. Order insensitive and hard - /// requirement aggregators return `Ok(None)`. - fn with_beneficial_ordering( - self: Arc, - _requirement_satisfied: bool, - ) -> Result>> { - if self.order_bys().is_some() && self.order_sensitivity().is_beneficial() { - return exec_err!( - "Should implement with satisfied for aggregator :{:?}", - self.name() - ); - } - Ok(None) - } - - /// Human readable name such as `"MIN(c2)"`. The default - /// implementation returns placeholder text. - fn name(&self) -> &str { - "AggregateExpr: default name" - } - - /// If the aggregate expression has a specialized - /// [`GroupsAccumulator`] implementation. If this returns true, - /// `[Self::create_groups_accumulator`] will be called. - fn groups_accumulator_supported(&self) -> bool { - false - } - - /// Return a specialized [`GroupsAccumulator`] that manages state - /// for all groups. - /// - /// For maximum performance, a [`GroupsAccumulator`] should be - /// implemented in addition to [`Accumulator`]. - fn create_groups_accumulator(&self) -> Result> { - not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") - } - - /// Construct an expression that calculates the aggregate in reverse. - /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). - /// For aggregates that do not support calculation in reverse, - /// returns None (which is the default value). - fn reverse_expr(&self) -> Option> { - None - } - - /// Creates accumulator implementation that supports retract - fn create_sliding_accumulator(&self) -> Result> { - not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") - } - - /// Returns all expressions used in the [`AggregateExpr`]. - /// These expressions are (1)function arguments, (2) order by expressions. - fn all_expressions(&self) -> AggregatePhysicalExpressions { - let args = self.expressions(); - let order_bys = self.order_bys().unwrap_or(&[]); - let order_by_exprs = order_bys - .iter() - .map(|sort_expr| sort_expr.expr.clone()) - .collect::>(); - AggregatePhysicalExpressions { - args, - order_by_exprs, - } - } - - /// Rewrites [`AggregateExpr`], with new expressions given. The argument should be consistent - /// with the return value of the [`AggregateExpr::all_expressions`] method. - /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. - /// TODO: This method only rewrites the [`PhysicalExpr`]s and does not handle [`Expr`]s. - /// This can cause silent bugs and should be fixed in the future (possibly with physical-to-logical - /// conversions). - fn with_new_expressions( - &self, - _args: Vec>, - _order_by_exprs: Vec>, - ) -> Option> { - None - } - - /// If this function is max, return (output_field, true) - /// if the function is min, return (output_field, false) - /// otherwise return None (the default) - /// - /// output_field is the name of the column produced by this aggregate - /// - /// Note: this is used to use special aggregate implementations in certain conditions - fn get_minmax_desc(&self) -> Option<(Field, bool)> { - None - } -} - -/// Stores the physical expressions used inside the `AggregateExpr`. -pub struct AggregatePhysicalExpressions { - /// Aggregate function arguments - pub args: Vec>, - /// Order by expressions - pub order_by_exprs: Vec>, -} - -/// Physical aggregate expression of a UDAF. -#[derive(Debug, Clone)] -pub struct AggregateFunctionExpr { - fun: AggregateUDF, - args: Vec>, - logical_args: Vec, - /// Output / return type of this aggregate - data_type: DataType, - name: String, - schema: Schema, - dfschema: DFSchema, - // The logical order by expressions - sort_exprs: Vec, - // The physical order by expressions - ordering_req: LexOrdering, - // Whether to ignore null values - ignore_nulls: bool, - // fields used for order sensitive aggregation functions - ordering_fields: Vec, - is_distinct: bool, - is_reversed: bool, - input_types: Vec, -} - -impl AggregateFunctionExpr { - /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` - pub fn fun(&self) -> &AggregateUDF { - &self.fun - } - - /// Return if the aggregation is distinct - pub fn is_distinct(&self) -> bool { - self.is_distinct - } - - /// Return if the aggregation ignores nulls - pub fn ignore_nulls(&self) -> bool { - self.ignore_nulls - } - - /// Return if the aggregation is reversed - pub fn is_reversed(&self) -> bool { - self.is_reversed - } -} - -impl AggregateExpr for AggregateFunctionExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn expressions(&self) -> Vec> { - self.args.clone() - } - - fn state_fields(&self) -> Result> { - let args = StateFieldsArgs { - name: &self.name, - input_types: &self.input_types, - return_type: &self.data_type, - ordering_fields: &self.ordering_fields, - is_distinct: self.is_distinct, - }; - - self.fun.state_fields(args) - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs { - data_type: &self.data_type, - schema: &self.schema, - dfschema: &self.dfschema, - ignore_nulls: self.ignore_nulls, - sort_exprs: &self.sort_exprs, - is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, - name: &self.name, - is_reversed: self.is_reversed, - }; - - self.fun.accumulator(acc_args) - } - - fn create_sliding_accumulator(&self) -> Result> { - let args = AccumulatorArgs { - data_type: &self.data_type, - schema: &self.schema, - dfschema: &self.dfschema, - ignore_nulls: self.ignore_nulls, - sort_exprs: &self.sort_exprs, - is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, - name: &self.name, - is_reversed: self.is_reversed, - }; - - let accumulator = self.fun.create_sliding_accumulator(args)?; - - // Accumulators that have window frame startings different - // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to - // implement retract_batch method in order to run correctly - // currently in DataFusion. - // - // If this `retract_batches` is not present, there is no way - // to calculate result correctly. For example, the query - // - // ```sql - // SELECT - // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a - // FROM - // t - // ``` - // - // 1. First sum value will be the sum of rows between `[0, 1)`, - // - // 2. Second sum value will be the sum of rows between `[0, 2)` - // - // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. - // - // Since the accumulator keeps the running sum: - // - // 1. First sum we add to the state sum value between `[0, 1)` - // - // 2. Second sum we add to the state sum value between `[1, 2)` - // (`[0, 1)` is already in the state sum, hence running sum will - // cover `[0, 2)` range) - // - // 3. Third sum we add to the state sum value between `[2, 3)` - // (`[0, 2)` is already in the state sum). Also we need to - // retract values between `[0, 1)` by this way we can obtain sum - // between [1, 3) which is indeed the appropriate range. - // - // When we use `UNBOUNDED PRECEDING` in the query starting - // index will always be 0 for the desired range, and hence the - // `retract_batch` method will not be called. In this case - // having retract_batch is not a requirement. - // - // This approach is a a bit different than window function - // approach. In window function (when they use a window frame) - // they get all the desired range during evaluation. - if !accumulator.supports_retract_batch() { - return not_impl_err!( - "Aggregate can not be used as a sliding accumulator because \ - `retract_batch` is not implemented: {}", - self.name - ); - } - Ok(accumulator) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - let args = AccumulatorArgs { - data_type: &self.data_type, - schema: &self.schema, - dfschema: &self.dfschema, - ignore_nulls: self.ignore_nulls, - sort_exprs: &self.sort_exprs, - is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, - name: &self.name, - is_reversed: self.is_reversed, - }; - self.fun.groups_accumulator_supported(args) - } - - fn create_groups_accumulator(&self) -> Result> { - let args = AccumulatorArgs { - data_type: &self.data_type, - schema: &self.schema, - dfschema: &self.dfschema, - ignore_nulls: self.ignore_nulls, - sort_exprs: &self.sort_exprs, - is_distinct: self.is_distinct, - input_types: &self.input_types, - input_exprs: &self.logical_args, - name: &self.name, - is_reversed: self.is_reversed, - }; - self.fun.create_groups_accumulator(args) - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - if self.ordering_req.is_empty() { - return None; - } - - if !self.order_sensitivity().is_insensitive() { - return Some(&self.ordering_req); - } - - None - } - - fn order_sensitivity(&self) -> AggregateOrderSensitivity { - if !self.ordering_req.is_empty() { - // If there is requirement, use the sensitivity of the implementation - self.fun.order_sensitivity() - } else { - // If no requirement, aggregator is order insensitive - AggregateOrderSensitivity::Insensitive - } - } - - fn with_beneficial_ordering( - self: Arc, - beneficial_ordering: bool, - ) -> Result>> { - let Some(updated_fn) = self - .fun - .clone() - .with_beneficial_ordering(beneficial_ordering)? - else { - return Ok(None); - }; - create_aggregate_expr_with_dfschema( - &updated_fn, - &self.args, - &self.logical_args, - &self.sort_exprs, - &self.ordering_req, - &self.dfschema, - Some(self.name().to_string()), - self.ignore_nulls, - self.is_distinct, - self.is_reversed, - ) - .map(Some) - } - - fn reverse_expr(&self) -> Option> { - match self.fun.reverse_udf() { - ReversedUDAF::NotSupported => None, - ReversedUDAF::Identical => Some(Arc::new(self.clone())), - ReversedUDAF::Reversed(reverse_udf) => { - let reverse_ordering_req = reverse_order_bys(&self.ordering_req); - let reverse_sort_exprs = self - .sort_exprs - .iter() - .map(|e| { - if let Expr::Sort(s) = e { - Expr::Sort(s.reverse()) - } else { - // Expects to receive `Expr::Sort`. - unreachable!() - } - }) - .collect::>(); - let mut name = self.name().to_string(); - // If the function is changed, we need to reverse order_by clause as well - // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) - if self.fun().name() == reverse_udf.name() { - } else { - replace_order_by_clause(&mut name); - } - replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); - let reverse_aggr = create_aggregate_expr_with_dfschema( - &reverse_udf, - &self.args, - &self.logical_args, - &reverse_sort_exprs, - &reverse_ordering_req, - &self.dfschema, - Some(name), - self.ignore_nulls, - self.is_distinct, - !self.is_reversed, - ) - .unwrap(); - - Some(reverse_aggr) - } - } - } - - fn get_minmax_desc(&self) -> Option<(Field, bool)> { - self.fun - .is_descending() - .and_then(|flag| self.field().ok().map(|f| (f, flag))) - } -} - -impl PartialEq for AggregateFunctionExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.fun == x.fun - && self.args.len() == x.args.len() - && self - .args - .iter() - .zip(x.args.iter()) - .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) - }) - .unwrap_or(false) - } -} - -fn replace_order_by_clause(order_by: &mut String) { - let suffixes = [ - (" DESC NULLS FIRST]", " ASC NULLS LAST]"), - (" ASC NULLS FIRST]", " DESC NULLS LAST]"), - (" DESC NULLS LAST]", " ASC NULLS FIRST]"), - (" ASC NULLS LAST]", " DESC NULLS FIRST]"), - ]; - - if let Some(start) = order_by.find("ORDER BY [") { - if let Some(end) = order_by[start..].find(']') { - let order_by_start = start + 9; - let order_by_end = start + end; - - let column_order = &order_by[order_by_start..=order_by_end]; - for (suffix, replacement) in suffixes { - if column_order.ends_with(suffix) { - let new_order = column_order.replace(suffix, replacement); - order_by.replace_range(order_by_start..=order_by_end, &new_order); - break; - } - } - } - } -} - -fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) { - *aggr_name = aggr_name.replace(fn_name_old, fn_name_new); -} diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index edf608a2054f..d21bdb3434c4 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -60,7 +60,7 @@ impl ArrowBytesSet { /// Return the contents of this set and replace it with a new empty /// set with the same output type - pub(super) fn take(&mut self) -> Self { + pub fn take(&mut self) -> Self { Self(self.0.take()) } diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index d0ba5f113b6f..96c08d0d3a5b 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -22,7 +22,8 @@ use arrow::compute::SortOptions; use arrow::error::ArrowError; use datafusion_common::internal_err; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Operator}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::operator::Operator; use std::sync::Arc; /// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs deleted file mode 100644 index dd534cc07d20..000000000000 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ /dev/null @@ -1,23 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod cast; -pub mod column; -pub mod literal; - -pub use cast::{cast, cast_with_options, CastExpr}; -pub use literal::{lit, Literal}; diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f03eedd4cf65..7e2ea0c49397 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -pub mod aggregate; +//! Physical Expr Common packages for [DataFusion] +//! This package contains high level PhysicalExpr trait +//! +//! [DataFusion]: + pub mod binary_map; pub mod binary_view_map; pub mod datum; -pub mod expressions; pub mod physical_expr; pub mod sort_expr; pub mod tree_node; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index e62606a42e6f..75d300dd0107 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -20,18 +20,16 @@ use std::fmt::{Debug, Display}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::column::Column; use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{internal_err, not_impl_err, plan_err, Result}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::ColumnarValue; +use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::ExprProperties; /// See [create_physical_expr](https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html) /// for examples of creating `PhysicalExpr` from `Expr` @@ -193,33 +191,6 @@ pub fn with_new_children_if_necessary( } } -/// Rewrites an expression according to new schema; i.e. changes the columns it -/// refers to with the column at corresponding index in the new schema. Returns -/// an error if the given schema has fewer columns than the original schema. -/// Note that the resulting expression may not be valid if data types in the -/// new schema is incompatible with expression nodes. -pub fn with_new_schema( - expr: Arc, - schema: &SchemaRef, -) -> Result> { - Ok(expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let idx = col.index(); - let Some(field) = schema.fields().get(idx) else { - return plan_err!( - "New schema has fewer columns than original schema" - ); - }; - let new_col = Column::new(field.name(), idx); - Ok(Transformed::yes(Arc::new(new_col) as _)) - } else { - Ok(Transformed::no(expr)) - } - })? - .data) -} - pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { if any.is::>() { any.downcast_ref::>() diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 2b506b74216f..9dc54d2eb2d0 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -22,13 +22,12 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use crate::utils::limited_convert_logical_expr_to_physical_expr_with_dfschema; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; -use datafusion_common::{exec_err, DFSchema, Result}; -use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_common::Result; +use datafusion_expr_common::columnar_value::ColumnarValue; /// Represents Sort operation for a column in a RecordBatch #[derive(Clone, Debug)] @@ -272,29 +271,3 @@ pub type LexRequirement = Vec; ///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which /// represents a reference to a lexicographical ordering requirement. pub type LexRequirementRef<'a> = &'a [PhysicalSortRequirement]; - -/// Converts each [`Expr::Sort`] into a corresponding [`PhysicalSortExpr`]. -/// Returns an error if the given logical expression is not a [`Expr::Sort`]. -pub fn limited_convert_logical_sort_exprs_to_physical_with_dfschema( - exprs: &[Expr], - dfschema: &DFSchema, -) -> Result> { - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in exprs { - let Expr::Sort(sort) = expr else { - return exec_err!("Expects to receive sort expression"); - }; - sort_exprs.push(PhysicalSortExpr::new( - limited_convert_logical_expr_to_physical_expr_with_dfschema( - sort.expr.as_ref(), - dfschema, - )?, - SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - )) - } - Ok(sort_exprs) -} diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 0978a906a5dc..d2c9bf1a2408 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -20,14 +20,9 @@ use std::sync::Arc; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; -use datafusion_common::{exec_err, DFSchema, Result}; -use datafusion_expr::expr::Alias; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::Expr; - -use crate::expressions::column::Column; -use crate::expressions::literal::Literal; -use crate::expressions::CastExpr; +use datafusion_common::Result; +use datafusion_expr_common::sort_properties::ExprProperties; + use crate::physical_expr::PhysicalExpr; use crate::sort_expr::PhysicalSortExpr; use crate::tree_node::ExprContext; @@ -108,35 +103,6 @@ pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec`. -/// If conversion is not supported yet, returns Error. -pub fn limited_convert_logical_expr_to_physical_expr_with_dfschema( - expr: &Expr, - dfschema: &DFSchema, -) -> Result> { - match expr { - Expr::Alias(Alias { expr, .. }) => Ok( - limited_convert_logical_expr_to_physical_expr_with_dfschema(expr, dfschema)?, - ), - Expr::Column(col) => { - let idx = dfschema.index_of_column(col)?; - Ok(Arc::new(Column::new(&col.name, idx))) - } - Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( - limited_convert_logical_expr_to_physical_expr_with_dfschema( - cast_expr.expr.as_ref(), - dfschema, - )?, - cast_expr.data_type.clone(), - None, - ))), - Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), - _ => exec_err!( - "Unsupported expression: {expr} for conversion to Arc" - ), - } -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/datafusion/physical-expr-functions-aggregate/Cargo.toml b/datafusion/physical-expr-functions-aggregate/Cargo.toml new file mode 100644 index 000000000000..6eed89614c53 --- /dev/null +++ b/datafusion/physical-expr-functions-aggregate/Cargo.toml @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-physical-expr-functions-aggregate" +description = "Logical plan and expression representation for DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_physical_expr_functions_aggregate" +path = "src/lib.rs" + +[features] + +[dependencies] +ahash = { workspace = true } +arrow = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +rand = { workspace = true } diff --git a/datafusion/physical-expr-functions-aggregate/src/aggregate.rs b/datafusion/physical-expr-functions-aggregate/src/aggregate.rs new file mode 100644 index 000000000000..8185f0fdd51f --- /dev/null +++ b/datafusion/physical-expr-functions-aggregate/src/aggregate.rs @@ -0,0 +1,486 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_expr::expr::create_function_physical_name; +use datafusion_expr::AggregateUDF; +use datafusion_expr::ReversedUDAF; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::groups_accumulator::GroupsAccumulator; +use datafusion_expr_common::type_coercion::aggregates::check_arg_count; +use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs; +use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; +use datafusion_functions_aggregate_common::aggregate::AggregateExpr; +use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; +use datafusion_functions_aggregate_common::utils::{self, down_cast_any_ref}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_expr_common::utils::reverse_order_bys; + +use std::fmt::Debug; +use std::{any::Any, sync::Arc}; + +/// Builder for physical [`AggregateExpr`] +/// +/// `AggregateExpr` contains the information necessary to call +/// an aggregate expression. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilder { + fun: Arc, + /// Physical expressions of the aggregate function + args: Vec>, + alias: Option, + /// Arrow Schema for the aggregate function + schema: SchemaRef, + /// The physical order by expressions + ordering_req: LexOrdering, + /// Whether to ignore null values + ignore_nulls: bool, + /// Whether is distinct aggregate function + is_distinct: bool, + /// Whether the expression is reversed + is_reversed: bool, +} + +impl AggregateExprBuilder { + pub fn new(fun: Arc, args: Vec>) -> Self { + Self { + fun, + args, + alias: None, + schema: Arc::new(Schema::empty()), + ordering_req: vec![], + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + } + } + + pub fn build(self) -> Result> { + let Self { + fun, + args, + alias, + schema, + ordering_req, + ignore_nulls, + is_distinct, + is_reversed, + } = self; + if args.is_empty() { + return internal_err!("args should not be empty"); + } + + let mut ordering_fields = vec![]; + + if !ordering_req.is_empty() { + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; + + ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types); + } + + let input_exprs_types = args + .iter() + .map(|arg| arg.data_type(&schema)) + .collect::>>()?; + + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; + + let data_type = fun.return_type(&input_exprs_types)?; + let name = match alias { + // TODO: Ideally, we should build the name from physical expressions + None => create_function_physical_name(fun.name(), is_distinct, &[], None)?, + Some(alias) => alias, + }; + + Ok(Arc::new(AggregateFunctionExpr { + fun: Arc::unwrap_or_clone(fun), + args, + data_type, + name, + schema: Arc::unwrap_or_clone(schema), + ordering_req, + ignore_nulls, + ordering_fields, + is_distinct, + input_types: input_exprs_types, + is_reversed, + })) + } + + pub fn alias(mut self, alias: impl Into) -> Self { + self.alias = Some(alias.into()); + self + } + + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = schema; + self + } + + pub fn order_by(mut self, order_by: LexOrdering) -> Self { + self.ordering_req = order_by; + self + } + + pub fn reversed(mut self) -> Self { + self.is_reversed = true; + self + } + + pub fn with_reversed(mut self, is_reversed: bool) -> Self { + self.is_reversed = is_reversed; + self + } + + pub fn distinct(mut self) -> Self { + self.is_distinct = true; + self + } + + pub fn with_distinct(mut self, is_distinct: bool) -> Self { + self.is_distinct = is_distinct; + self + } + + pub fn ignore_nulls(mut self) -> Self { + self.ignore_nulls = true; + self + } + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } +} + +/// Physical aggregate expression of a UDAF. +#[derive(Debug, Clone)] +pub struct AggregateFunctionExpr { + fun: AggregateUDF, + args: Vec>, + /// Output / return type of this aggregate + data_type: DataType, + name: String, + schema: Schema, + // The physical order by expressions + ordering_req: LexOrdering, + // Whether to ignore null values + ignore_nulls: bool, + // fields used for order sensitive aggregation functions + ordering_fields: Vec, + is_distinct: bool, + is_reversed: bool, + input_types: Vec, +} + +impl AggregateFunctionExpr { + /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` + pub fn fun(&self) -> &AggregateUDF { + &self.fun + } + + /// Return if the aggregation is distinct + pub fn is_distinct(&self) -> bool { + self.is_distinct + } + + /// Return if the aggregation ignores nulls + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } + + /// Return if the aggregation is reversed + pub fn is_reversed(&self) -> bool { + self.is_reversed + } +} + +impl AggregateExpr for AggregateFunctionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn state_fields(&self) -> Result> { + let args = StateFieldsArgs { + name: &self.name, + input_types: &self.input_types, + return_type: &self.data_type, + ordering_fields: &self.ordering_fields, + is_distinct: self.is_distinct, + }; + + self.fun.state_fields(args) + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + let acc_args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: &self.ordering_req, + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + + self.fun.accumulator(acc_args) + } + + fn create_sliding_accumulator(&self) -> Result> { + let args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: &self.ordering_req, + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + + let accumulator = self.fun.create_sliding_accumulator(args)?; + + // Accumulators that have window frame startings different + // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to + // implement retract_batch method in order to run correctly + // currently in DataFusion. + // + // If this `retract_batches` is not present, there is no way + // to calculate result correctly. For example, the query + // + // ```sql + // SELECT + // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a + // FROM + // t + // ``` + // + // 1. First sum value will be the sum of rows between `[0, 1)`, + // + // 2. Second sum value will be the sum of rows between `[0, 2)` + // + // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. + // + // Since the accumulator keeps the running sum: + // + // 1. First sum we add to the state sum value between `[0, 1)` + // + // 2. Second sum we add to the state sum value between `[1, 2)` + // (`[0, 1)` is already in the state sum, hence running sum will + // cover `[0, 2)` range) + // + // 3. Third sum we add to the state sum value between `[2, 3)` + // (`[0, 2)` is already in the state sum). Also we need to + // retract values between `[0, 1)` by this way we can obtain sum + // between [1, 3) which is indeed the appropriate range. + // + // When we use `UNBOUNDED PRECEDING` in the query starting + // index will always be 0 for the desired range, and hence the + // `retract_batch` method will not be called. In this case + // having retract_batch is not a requirement. + // + // This approach is a a bit different than window function + // approach. In window function (when they use a window frame) + // they get all the desired range during evaluation. + if !accumulator.supports_retract_batch() { + return not_impl_err!( + "Aggregate can not be used as a sliding accumulator because \ + `retract_batch` is not implemented: {}", + self.name + ); + } + Ok(accumulator) + } + + fn name(&self) -> &str { + &self.name + } + + fn groups_accumulator_supported(&self) -> bool { + let args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: &self.ordering_req, + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + self.fun.groups_accumulator_supported(args) + } + + fn create_groups_accumulator(&self) -> Result> { + let args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: &self.ordering_req, + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + self.fun.create_groups_accumulator(args) + } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + if self.ordering_req.is_empty() { + return None; + } + + if !self.order_sensitivity().is_insensitive() { + return Some(&self.ordering_req); + } + + None + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + if !self.ordering_req.is_empty() { + // If there is requirement, use the sensitivity of the implementation + self.fun.order_sensitivity() + } else { + // If no requirement, aggregator is order insensitive + AggregateOrderSensitivity::Insensitive + } + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + let Some(updated_fn) = self + .fun + .clone() + .with_beneficial_ordering(beneficial_ordering)? + else { + return Ok(None); + }; + + AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) + .order_by(self.ordering_req.to_vec()) + .schema(Arc::new(self.schema.clone())) + .alias(self.name().to_string()) + .with_ignore_nulls(self.ignore_nulls) + .with_distinct(self.is_distinct) + .with_reversed(self.is_reversed) + .build() + .map(Some) + } + + fn reverse_expr(&self) -> Option> { + match self.fun.reverse_udf() { + ReversedUDAF::NotSupported => None, + ReversedUDAF::Identical => Some(Arc::new(self.clone())), + ReversedUDAF::Reversed(reverse_udf) => { + let reverse_ordering_req = reverse_order_bys(&self.ordering_req); + let mut name = self.name().to_string(); + // If the function is changed, we need to reverse order_by clause as well + // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) + if self.fun().name() == reverse_udf.name() { + } else { + replace_order_by_clause(&mut name); + } + replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); + + AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) + .order_by(reverse_ordering_req.to_vec()) + .schema(Arc::new(self.schema.clone())) + .alias(name) + .with_ignore_nulls(self.ignore_nulls) + .with_distinct(self.is_distinct) + .with_reversed(!self.is_reversed) + .build() + .ok() + } + } + } + + fn get_minmax_desc(&self) -> Option<(Field, bool)> { + self.fun + .is_descending() + .and_then(|flag| self.field().ok().map(|f| (f, flag))) + } +} + +impl PartialEq for AggregateFunctionExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.fun == x.fun + && self.args.len() == x.args.len() + && self + .args + .iter() + .zip(x.args.iter()) + .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) + }) + .unwrap_or(false) + } +} + +fn replace_order_by_clause(order_by: &mut String) { + let suffixes = [ + (" DESC NULLS FIRST]", " ASC NULLS LAST]"), + (" ASC NULLS FIRST]", " DESC NULLS LAST]"), + (" DESC NULLS LAST]", " ASC NULLS FIRST]"), + (" ASC NULLS LAST]", " DESC NULLS FIRST]"), + ]; + + if let Some(start) = order_by.find("ORDER BY [") { + if let Some(end) = order_by[start..].find(']') { + let order_by_start = start + 9; + let order_by_end = start + end; + + let column_order = &order_by[order_by_start..=order_by_end]; + for (suffix, replacement) in suffixes { + if column_order.ends_with(suffix) { + let new_order = column_order.replace(suffix, replacement); + order_by.replace_range(order_by_start..=order_by_end, &new_order); + break; + } + } + } + } +} + +fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) { + *aggr_name = aggr_name.replace(fn_name_old, fn_name_new); +} diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr-functions-aggregate/src/lib.rs similarity index 87% rename from datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs rename to datafusion/physical-expr-functions-aggregate/src/lib.rs index 5b0182c5db8a..2ff7ff5777ec 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr-functions-aggregate/src/lib.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Utilities for implementing GroupsAccumulator +//! Technically, all aggregate functions that depend on `expr` crate should be included here. -pub mod accumulate; -pub mod bool_op; -pub mod prim_op; +pub mod aggregate; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 8436b5279bd7..c53f7a6c4771 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -56,6 +56,8 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs index 862edd9c1fac..8a34f34a82db 100644 --- a/datafusion/physical-expr/benches/case_when.rs +++ b/datafusion/physical-expr/benches/case_when.rs @@ -22,9 +22,7 @@ use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; -use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; -use datafusion_physical_expr_common::expressions::column::Column; -use datafusion_physical_expr_common::expressions::Literal; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column, Literal}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs index 3dad8e9b456a..7d26557afb1b 100644 --- a/datafusion/physical-expr/benches/is_null.rs +++ b/datafusion/physical-expr/benches/is_null.rs @@ -20,8 +20,7 @@ use arrow::record_batch::RecordBatch; use arrow_array::builder::Int32Builder; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_physical_expr::expressions::{IsNotNullExpr, IsNullExpr}; -use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr::expressions::{Column, IsNotNullExpr, IsNullExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index ffa58e385322..0296b7a247d6 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -47,7 +47,7 @@ use datafusion_common::JoinType; /// /// ```rust /// # use datafusion_physical_expr::ConstExpr; -/// # use datafusion_physical_expr_common::expressions::lit; +/// # use datafusion_physical_expr::expressions::lit; /// let col = lit(5); /// // Create a constant expression from a physical expression ref /// let const_expr = ConstExpr::from(&col); diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a6e9fba28167..a5d54ee56cff 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -24,7 +24,7 @@ use crate::equivalence::{ collapse_lex_req, EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::Literal; +use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; use crate::{ physical_exprs_contains, ConstExpr, LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, @@ -36,9 +36,6 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, JoinSide, JoinType, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_physical_expr_common::expressions::column::Column; -use datafusion_physical_expr_common::expressions::CastExpr; -use datafusion_physical_expr_common::physical_expr::with_new_schema; use datafusion_physical_expr_common::utils::ExprPropertiesNode; use indexmap::{IndexMap, IndexSet}; diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c34dcdfb7598..347a5d82dbec 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -681,11 +681,9 @@ pub fn binary( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, try_cast, Literal}; - + use crate::expressions::{col, lit, try_cast, Column, Literal}; use datafusion_common::plan_datafusion_err; use datafusion_expr::type_coercion::binary::get_input_types; - use datafusion_physical_expr_common::expressions::column::Column; /// Performs a binary operation, applying any type coercion necessary fn binary_op( diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b428d562bd1b..583a4ef32542 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -32,8 +32,7 @@ use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::expressions::column::Column; -use datafusion_physical_expr_common::expressions::Literal; +use super::{Column, Literal}; use itertools::Itertools; type WhenThen = (Arc, Arc); @@ -548,8 +547,8 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit, BinaryExpr}; + use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; @@ -558,7 +557,6 @@ mod tests { use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; - use datafusion_physical_expr_common::expressions::Literal; #[test] fn case_with_expr() -> Result<()> { diff --git a/datafusion/physical-expr-common/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs similarity index 98% rename from datafusion/physical-expr-common/src/expressions/cast.rs rename to datafusion/physical-expr/src/expressions/cast.rs index dd6131ad65c3..5621473c4fdb 100644 --- a/datafusion/physical-expr-common/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -27,9 +27,9 @@ use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::sort_properties::ExprProperties; -use datafusion_expr::ColumnarValue; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::ExprProperties; const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { safe: false, @@ -136,7 +136,7 @@ impl PhysicalExpr for CastExpr { children: Vec>, ) -> Result> { Ok(Arc::new(CastExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.cast_type.clone(), Some(self.cast_options.clone()), ))) @@ -211,7 +211,7 @@ pub fn cast_with_options( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - Ok(expr.clone()) + Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs similarity index 82% rename from datafusion/physical-expr-common/src/expressions/column.rs rename to datafusion/physical-expr/src/expressions/column.rs index 5397599ea2dc..79d15fdb02e8 100644 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -25,7 +25,9 @@ use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{internal_err, Result}; +use arrow_schema::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; @@ -89,7 +91,7 @@ impl PhysicalExpr for Column { /// Evaluate the expression fn evaluate(&self, batch: &RecordBatch) -> Result { self.bounds_check(batch.schema().as_ref())?; - Ok(ColumnarValue::Array(batch.column(self.index).clone())) + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } fn children(&self) -> Vec<&Arc> { @@ -136,6 +138,33 @@ pub fn col(name: &str, schema: &Schema) -> Result> { Ok(Arc::new(Column::new_with_schema(name, schema)?)) } +/// Rewrites an expression according to new schema; i.e. changes the columns it +/// refers to with the column at corresponding index in the new schema. Returns +/// an error if the given schema has fewer columns than the original schema. +/// Note that the resulting expression may not be valid if data types in the +/// new schema is incompatible with expression nodes. +pub fn with_new_schema( + expr: Arc, + schema: &SchemaRef, +) -> Result> { + Ok(expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let idx = col.index(); + let Some(field) = schema.fields().get(idx) else { + return plan_err!( + "New schema has fewer columns than original schema" + ); + }; + let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) + } else { + Ok(Transformed::no(expr)) + } + })? + .data) +} + #[cfg(test)] mod test { use super::Column; diff --git a/datafusion/physical-expr-common/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs similarity index 95% rename from datafusion/physical-expr-common/src/expressions/literal.rs rename to datafusion/physical-expr/src/expressions/literal.rs index b3cff1ef69ba..ed24e9028153 100644 --- a/datafusion/physical-expr-common/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -28,9 +28,10 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; -use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_expr::Expr; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value #[derive(Debug, PartialEq, Eq, Hash)] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index cbb697b5f304..9e65889d8758 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -20,10 +20,13 @@ #[macro_use] mod binary; mod case; +mod cast; +mod column; mod in_list; mod is_not_null; mod is_null; mod like; +mod literal; mod negative; mod no_op; mod not; @@ -42,14 +45,14 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; +pub use cast::{cast, CastExpr}; +pub use column::{col, with_new_schema, Column}; pub use datafusion_expr::utils::format_state_name; -pub use datafusion_physical_expr_common::expressions::column::{col, Column}; -pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; -pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; pub use like::{like, LikeExpr}; +pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 2e78119eba46..c4255172d680 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -14,10 +14,32 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + // Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 #![deny(clippy::clone_on_ref_ptr)] -pub mod aggregate; +// Backward compatibility +pub mod aggregate { + pub(crate) mod groups_accumulator { + #[allow(unused_imports)] + pub(crate) mod accumulate { + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; + } + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ + accumulate::NullState, GroupsAccumulatorAdapter, + }; + } + pub(crate) mod stats { + pub use datafusion_functions_aggregate_common::stats::StatsType; + } + pub mod utils { + pub use datafusion_functions_aggregate_common::utils::{ + adjust_output_array, down_cast_any_ref, get_accum_scalar_values_as_arrays, + get_sort_options, ordering_fields, DecimalAverager, Hashable, + }; + } + pub use datafusion_functions_aggregate_common::aggregate::AggregateExpr; +} pub mod analysis; pub mod binary_map { pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; @@ -45,7 +67,7 @@ pub mod execution_props { pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; -pub use datafusion_physical_expr_common::aggregate::{ +pub use datafusion_functions_aggregate_common::aggregate::{ AggregateExpr, AggregatePhysicalExpressions, }; pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c60a772b9ce2..c718e6b054ef 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use itertools::izip; pub use datafusion_physical_expr_common::physical_expr::down_cast_any_ref; diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index d3f66bdea93d..78da4dc9c53f 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -52,8 +52,10 @@ datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } +datafusion-physical-expr-functions-aggregate = { workspace = true } futures = { workspace = true } half = { workspace = true } hashbrown = { workspace = true } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index d72da9b30049..4d39eff42b5f 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1203,26 +1203,23 @@ mod tests { use arrow::datatypes::DataType; use arrow_array::{Float32Array, Int32Array}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, internal_err, DFSchema, DFSchemaRef, - DataFusionError, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, + ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use datafusion_expr::expr::Sort; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; + use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::lit; use datafusion_physical_expr::PhysicalSortExpr; use crate::common::collect; - use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr_with_dfschema, AggregateExprBuilder, - }; - use datafusion_physical_expr_common::expressions::Literal; + use datafusion_physical_expr::expressions::Literal; + use datafusion_physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1270,22 +1267,19 @@ mod tests { } /// Generates some mock data for aggregate tests. - fn some_data_v2() -> (Arc, DFSchemaRef, Vec) { + fn some_data_v2() -> (Arc, Vec) { // Define a schema: let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, false), Field::new("b", DataType::Float64, false), ])); - let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); - // Generate data so that first and last value results are at 2nd and // 3rd partitions. With this construction, we guarantee we don't receive // the expected result by accident, but merging actually works properly; // i.e. it doesn't depend on the data insertion order. ( Arc::clone(&schema), - Arc::new(df_schema), vec![ RecordBatch::try_new( Arc::clone(&schema), @@ -1363,7 +1357,6 @@ mod tests { let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) .schema(Arc::clone(&input_schema)) .alias("COUNT(1)") - .logical_exprs(vec![datafusion_expr::lit(1i8)]) .build()?]; let task_ctx = if spill { @@ -1980,65 +1973,36 @@ mod tests { // FIRST_VALUE(b ORDER BY b ) fn test_first_value_agg_expr( schema: &Schema, - dfschema: &DFSchema, sort_options: SortOptions, ) -> Result> { - let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { - expr: Box::new(datafusion_expr::col("b")), - asc: !sort_options.descending, - nulls_first: sort_options.nulls_first, - })]; - let ordering_req = vec![PhysicalSortExpr { + let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; - let args = vec![col("b", schema)?]; - let logical_args = vec![datafusion_expr::col("b")]; - let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); - datafusion_physical_expr_common::aggregate::create_aggregate_expr_with_dfschema( - &func, - &args, - &logical_args, - &sort_exprs, - &ordering_req, - dfschema, - None, - false, - false, - false, - ) + let args = [col("b", schema)?]; + + AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) + .order_by(ordering_req.to_vec()) + .schema(Arc::new(schema.clone())) + .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) + .build() } // LAST_VALUE(b ORDER BY b ) fn test_last_value_agg_expr( schema: &Schema, - dfschema: &DFSchema, sort_options: SortOptions, ) -> Result> { - let sort_exprs = vec![datafusion_expr::Expr::Sort(Sort { - expr: Box::new(datafusion_expr::col("b")), - asc: !sort_options.descending, - nulls_first: sort_options.nulls_first, - })]; - let ordering_req = vec![PhysicalSortExpr { + let ordering_req = [PhysicalSortExpr { expr: col("b", schema)?, options: sort_options, }]; - let args = vec![col("b", schema)?]; - let logical_args = vec![datafusion_expr::col("b")]; - let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - create_aggregate_expr_with_dfschema( - &func, - &args, - &logical_args, - &sort_exprs, - &ordering_req, - dfschema, - None, - false, - false, - false, - ) + let args = [col("b", schema)?]; + AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) + .order_by(ordering_req.to_vec()) + .schema(Arc::new(schema.clone())) + .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) + .build() } // This function either constructs the physical plan below, @@ -2070,7 +2034,7 @@ mod tests { Arc::new(TaskContext::default()) }; - let (schema, df_schema, data) = some_data_v2(); + let (schema, data) = some_data_v2(); let partition1 = data[0].clone(); let partition2 = data[1].clone(); let partition3 = data[2].clone(); @@ -2084,13 +2048,9 @@ mod tests { nulls_first: false, }; let aggregates: Vec> = if is_first_acc { - vec![test_first_value_agg_expr( - &schema, - &df_schema, - sort_options, - )?] + vec![test_first_value_agg_expr(&schema, sort_options)?] } else { - vec![test_last_value_agg_expr(&schema, &df_schema, sort_options)?] + vec![test_last_value_agg_expr(&schema, sort_options)?] }; let memory_exec = Arc::new(MemoryExec::try_new( @@ -2157,7 +2117,6 @@ mod tests { #[tokio::test] async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; - let test_df_schema = DFSchema::try_from(Arc::clone(&test_schema)).unwrap(); // Assume column a and b are aliases // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). @@ -2204,46 +2163,7 @@ mod tests { }, ]), ]; - let col_expr_a = Box::new(datafusion_expr::col("a")); - let col_expr_b = Box::new(datafusion_expr::col("b")); - let col_expr_c = Box::new(datafusion_expr::col("c")); - let sort_exprs = vec![ - None, - Some(vec![datafusion_expr::Expr::Sort(Sort::new( - col_expr_a.clone(), - options1.descending, - options1.nulls_first, - ))]), - Some(vec![ - datafusion_expr::Expr::Sort(Sort::new( - col_expr_a.clone(), - options1.descending, - options1.nulls_first, - )), - datafusion_expr::Expr::Sort(Sort::new( - col_expr_b.clone(), - options1.descending, - options1.nulls_first, - )), - datafusion_expr::Expr::Sort(Sort::new( - col_expr_c, - options1.descending, - options1.nulls_first, - )), - ]), - Some(vec![ - datafusion_expr::Expr::Sort(Sort::new( - col_expr_a, - options1.descending, - options1.nulls_first, - )), - datafusion_expr::Expr::Sort(Sort::new( - col_expr_b, - options1.descending, - options1.nulls_first, - )), - ]), - ]; + let common_requirement = vec![ PhysicalSortExpr { expr: Arc::clone(col_a), @@ -2256,23 +2176,13 @@ mod tests { ]; let mut aggr_exprs = order_by_exprs .into_iter() - .zip(sort_exprs.into_iter()) - .map(|(order_by_expr, sort_exprs)| { + .map(|order_by_expr| { let ordering_req = order_by_expr.unwrap_or_default(); - let sort_exprs = sort_exprs.unwrap_or_default(); - create_aggregate_expr_with_dfschema( - &array_agg_udaf(), - &[Arc::clone(col_a)], - &[], - &sort_exprs, - &ordering_req, - &test_df_schema, - None, - false, - false, - false, - ) - .unwrap() + AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) + .order_by(ordering_req.to_vec()) + .schema(Arc::clone(&test_schema)) + .build() + .unwrap() }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); @@ -2293,7 +2203,6 @@ mod tests { Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), ])); - let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); let col_a = col("a", &schema)?; let option_desc = SortOptions { @@ -2303,8 +2212,8 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); let aggregates: Vec> = vec![ - test_first_value_agg_expr(&schema, &df_schema, option_desc)?, - test_last_value_agg_expr(&schema, &df_schema, option_desc)?, + test_first_value_agg_expr(&schema, option_desc)?, + test_last_value_agg_expr(&schema, option_desc)?, ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2414,24 +2323,17 @@ mod tests { Field::new("key", DataType::Int32, true), Field::new("val", DataType::Int32, true), ])); - let df_schema = DFSchema::try_from(Arc::clone(&schema))?; let group_by = PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); - let aggr_expr: Vec> = - vec![create_aggregate_expr_with_dfschema( - &count_udaf(), - &[col("val", &schema)?], - &[datafusion_expr::col("val")], - &[], - &[], - &df_schema, - Some("COUNT(val)".to_string()), - false, - false, - false, - )?]; + let aggr_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build()?, + ]; let input_data = vec![ RecordBatch::try_new( @@ -2502,24 +2404,17 @@ mod tests { Field::new("key", DataType::Int32, true), Field::new("val", DataType::Int32, true), ])); - let df_schema = DFSchema::try_from(Arc::clone(&schema))?; let group_by = PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); - let aggr_expr: Vec> = - vec![create_aggregate_expr_with_dfschema( - &count_udaf(), - &[col("val", &schema)?], - &[datafusion_expr::col("val")], - &[], - &[], - &df_schema, - Some("COUNT(val)".to_string()), - false, - false, - false, - )?]; + let aggr_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build()?, + ]; let input_data = vec![ RecordBatch::try_new( diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index eeecc017c2af..59c5da6b6fb2 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -82,9 +82,7 @@ pub mod windows; pub mod work_table; pub mod udaf { - pub use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr, create_aggregate_expr_with_dfschema, AggregateFunctionExpr, - }; + pub use datafusion_physical_expr_functions_aggregate::aggregate::AggregateFunctionExpr; } #[cfg(test)] diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 9321fdb2cadf..9ef29c833dcc 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -586,8 +586,8 @@ mod tests { use arrow_schema::{DataType, SortOptions}; use datafusion_common::ScalarValue; + use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_expr_common::expressions::column::col; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index b41f3ad71bb8..2e6ad4e1a14f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -31,10 +31,9 @@ use crate::{ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{col, Expr, SortExpr}; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, - WindowUDF, + BuiltInWindowFunction, Expr, PartitionEvaluator, WindowFrame, + WindowFunctionDefinition, WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -43,7 +42,7 @@ use datafusion_physical_expr::{ AggregateExpr, ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; -use datafusion_physical_expr_common::aggregate::AggregateExprBuilder; +use datafusion_physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use itertools::Itertools; mod bounded_window_agg_exec; @@ -112,25 +111,10 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - // Convert `Vec` into `Vec` - let sort_exprs = order_by - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let field_name = expr.to_string(); - let field_name = field_name.split('@').next().unwrap_or(&field_name); - Expr::Sort(SortExpr { - expr: Box::new(col(field_name)), - asc: !options.descending, - nulls_first: options.nulls_first, - }) - }) - .collect::>(); - let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) .schema(Arc::new(input_schema.clone())) .alias(name) .order_by(order_by.to_vec()) - .sort_exprs(sort_exprs) .with_ignore_nulls(ignore_nulls) .build()?; window_expr_from_aggregate_expr( diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 59db791c7595..b5d28f40a68f 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -18,7 +18,7 @@ use std::fmt::Debug; use std::sync::Arc; -use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; +use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use prost::bytes::BufMut; use prost::Message; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 712182791b0b..1a9c6d40ebe6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -24,7 +24,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; -use datafusion::physical_expr_common::aggregate::AggregateExprBuilder; +use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message;