Skip to content

Commit

Permalink
Move min and max to user defined aggregate function, remove `Aggregat…
Browse files Browse the repository at this point in the history
…eFunction` / `AggregateFunctionDefinition::BuiltIn` (apache#11013)

* Moving min and max to new API and removing from protobuf

* Using input_type rather than data_type

* Adding type coercion

* Fixed doctests

* Implementing feedback from code review

* Implementing feedback from code review

* Fixed wrong name

* Fixing name
  • Loading branch information
edmondop authored Aug 3, 2024
1 parent 9e90e17 commit f4e519f
Show file tree
Hide file tree
Showing 56 changed files with 937 additions and 1,813 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;

use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg;
use datafusion::functions_aggregate::min_max::max;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
8 changes: 6 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ use datafusion_common::{
};
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_functions_aggregate::expr_fn::{
avg, count, max, median, min, stddev, sum,
};
use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum};

use async_trait::async_trait;
use datafusion_catalog::Session;
Expand Down Expand Up @@ -144,6 +146,7 @@ impl Default for DataFrameWriteOptions {
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion::functions_aggregate::expr_fn::min;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
Expand Down Expand Up @@ -407,6 +410,7 @@ impl DataFrame {
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # use datafusion::functions_aggregate::expr_fn::min;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use datafusion_common::{
use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator};
use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
use datafusion_physical_plan::metrics::MetricsSet;

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use super::listing::PartitionedFile;
use crate::arrow::datatypes::{Schema, SchemaRef};
use crate::error::Result;
use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator};
use crate::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics};
use arrow_schema::DataType;

Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ where
///
/// ```
/// use datafusion::prelude::*;
/// # use datafusion::functions_aggregate::expr_fn::min;
/// # use datafusion::{error::Result, assert_batches_eq};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
//! ```rust
//! # use datafusion::prelude::*;
//! # use datafusion::error::Result;
//! # use datafusion::functions_aggregate::expr_fn::min;
//! # use datafusion::arrow::record_batch::RecordBatch;
//!
//! # #[tokio::main]
Expand Down
15 changes: 2 additions & 13 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,39 +272,28 @@ fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_min(agg_expr: &dyn AggregateExpr) -> bool {
if agg_expr.as_any().is::<expressions::Min>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "min" {
if agg_expr.fun().name().to_lowercase() == "min" {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_max(agg_expr: &dyn AggregateExpr) -> bool {
if agg_expr.as_any().is::<expressions::Max>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "max" {
if agg_expr.fun().name().to_lowercase() == "max" {
return true;
}
}

false
}

Expand Down
28 changes: 3 additions & 25 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ use crate::physical_plan::unnest::UnnestExec;
use crate::physical_plan::values::ValuesExec;
use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use crate::physical_plan::{
aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan,
ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr,
displayable, udaf, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties,
InputOrderMode, Partitioning, PhysicalExpr, WindowExpr,
};

use arrow::compute::SortOptions;
Expand Down Expand Up @@ -1812,7 +1812,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
e: &Expr,
name: impl Into<String>,
logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
_physical_input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<AggregateExprWithOptionalArgs> {
match e {
Expand Down Expand Up @@ -1840,28 +1840,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
== NullTreatment::IgnoreNulls;

let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
&physical_args,
&ordering_reqs,
physical_input_schema,
name,
ignore_nulls,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::UDF(fun) => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery,
when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum};
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
21 changes: 6 additions & 15 deletions datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ use datafusion::physical_plan::{collect, InputOrderMode};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
use datafusion_expr::{
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
Expand Down Expand Up @@ -361,14 +361,14 @@ fn get_random_function(
window_fn_map.insert(
"min",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
WindowFunctionDefinition::AggregateUDF(min_udaf()),
vec![arg.clone()],
),
);
window_fn_map.insert(
"max",
(
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
WindowFunctionDefinition::AggregateUDF(max_udaf()),
vec![arg.clone()],
),
);
Expand Down Expand Up @@ -465,16 +465,7 @@ fn get_random_function(
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
let mut args = args.clone();
if let WindowFunctionDefinition::AggregateFunction(f) = window_fn {
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let sig = f.signature();
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}
} else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
if !args.is_empty() {
// Do type coercion first argument
let a = args[0].clone();
Expand Down
156 changes: 0 additions & 156 deletions datafusion/expr/src/aggregate_function.rs

This file was deleted.

Loading

0 comments on commit f4e519f

Please sign in to comment.