Skip to content

Commit

Permalink
revert schema change v1
Browse files Browse the repository at this point in the history
  • Loading branch information
xinlifoobar committed Aug 4, 2024
1 parent 52cf9ee commit 5d97345
Show file tree
Hide file tree
Showing 18 changed files with 58 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ pub(crate) mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::cast::as_int64_array;
use datafusion_common::ToDFSchema;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::expressions::cast;
use datafusion_physical_expr::PhysicalExpr;
Expand Down Expand Up @@ -411,7 +410,7 @@ pub(crate) mod tests {
// Return appropriate expr depending if COUNT is for col or table (*)
pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn AggregateExpr> {
AggregateExprBuilder::new(count_udaf(), vec![self.column()])
.dfschema(schema.clone().to_dfschema().unwrap())
.schema(Arc::new(schema.clone()))
.name(self.column_name())
.build()
.unwrap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ mod tests {
use crate::physical_plan::{displayable, Partitioning};

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::ToDFSchema;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::col;
Expand Down Expand Up @@ -280,7 +279,7 @@ mod tests {
schema: &Schema,
) -> Arc<dyn AggregateExpr> {
AggregateExprBuilder::new(count_udaf(), vec![expr])
.dfschema(schema.clone().to_dfschema().unwrap())
.schema(schema.clone())
.name(name)
.build()
.unwrap()
Expand Down
3 changes: 1 addition & 2 deletions datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::{collect, displayable, ExecutionPlan};
use datafusion::prelude::{DataFrame, SessionConfig, SessionContext};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::ToDFSchema;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::PhysicalSortExpr;
Expand Down Expand Up @@ -107,7 +106,7 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>, group_by_columns: Vec<&str
let aggregate_expr =
vec![
AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()])
.dfschema(Arc::clone(&schema).to_dfschema().unwrap())
.schema(Arc::clone(&schema))
.name("sum1")
.build()
.unwrap(),
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
use crate::physical_expr::PhysicalExpr;
use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{DFSchema, Result};
use std::sync::Arc;

Expand Down Expand Up @@ -55,6 +55,9 @@ pub struct AccumulatorArgs<'a> {
/// The return type of the aggregate function.
pub data_type: &'a DataType,

/// The schema of the input arguments
pub schema: &'a Schema,

/// The schema of the input arguments
pub dfschema: &'a DFSchema,

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ impl AggregateUDFImpl for ApproxDistinct {

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let accumulator: Box<dyn Accumulator> =
match &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())? {
match &acc_args.input_exprs[0].data_type(acc_args.schema)? {
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
// TODO support for boolean (trivial case)
// https://github.com/apache/datafusion/issues/1109
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian {

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?,
acc_args.input_exprs[0].data_type(acc_args.schema)?,
)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl ApproxPercentileCont {
None
};

let accumulator: ApproxPercentileAccumulator = match &args.input_exprs[0].data_type(args.dfschema.as_arrow())? {
let accumulator: ApproxPercentileAccumulator = match &args.input_exprs[0].data_type(args.schema)? {
t @ (DataType::UInt8
| DataType::UInt16
| DataType::UInt32
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl AggregateUDFImpl for ArrayAgg {

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type =
acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?;
acc_args.input_exprs[0].data_type(acc_args.schema)?;

if acc_args.is_distinct {
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?));
Expand All @@ -135,7 +135,7 @@ impl AggregateUDFImpl for ArrayAgg {

let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.dfschema.as_arrow()))
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;

OrderSensitiveArrayAggAccumulator::try_new(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl AggregateUDFImpl for Avg {
use DataType::*;
// instantiate specialized accumulator based for the type
let input_type =
acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?;
acc_args.input_exprs[0].data_type(acc_args.schema)?;

match (&input_type, acc_args.data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
Expand Down Expand Up @@ -157,7 +157,7 @@ impl AggregateUDFImpl for Avg {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;
// instantiate specialized accumulator based for the type
let sum_data_type = &args.input_exprs[0].data_type(args.dfschema.as_arrow())?;
let sum_data_type = &args.input_exprs[0].data_type(args.schema)?;

match (sum_data_type, args.data_type) {
(Float64, Float64) => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl AggregateUDFImpl for Count {
}

let data_type =
&acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?;
&acc_args.input_exprs[0].data_type(acc_args.schema)?;
Ok(match data_type {
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
DataType::Int8 => Box::new(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl AggregateUDFImpl for FirstValue {

let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.dfschema.as_arrow()))
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;

// When requirement is empty, or it is signalled by outside caller that
Expand Down Expand Up @@ -423,7 +423,7 @@ impl AggregateUDFImpl for LastValue {

let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.dfschema.as_arrow()))
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;

let requirement_satisfied = ordering_req.is_empty() || self.requirement_satisfied;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median {
};
}

let dt = &acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?;
let dt = &acc_args.input_exprs[0].data_type(acc_args.schema)?;
downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ impl AggregateUDFImpl for NthValueAgg {

let ordering_dtypes = ordering_req
.iter()
.map(|e| e.expr.data_type(acc_args.dfschema.as_arrow()))
.map(|e| e.expr.data_type(acc_args.schema))
.collect::<Result<Vec<_>>>()?;

NthValueAccumulator::try_new(
n,
&acc_args.input_exprs[0].data_type(acc_args.dfschema.as_arrow())?,
&acc_args.input_exprs[0].data_type(acc_args.schema)?,
&ordering_dtypes,
ordering_req,
)
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ mod tests {
let dfschema = DFSchema::empty();
let args1 = AccumulatorArgs {
data_type: &DataType::Float64,
schema: &schema,
dfschema: &dfschema,
ignore_nulls: false,
sort_exprs: &[],
Expand All @@ -339,6 +340,7 @@ mod tests {

let args2 = AccumulatorArgs {
data_type: &DataType::Float64,
schema: &schema,
dfschema: &dfschema,
ignore_nulls: false,
sort_exprs: &[],
Expand Down
28 changes: 22 additions & 6 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

use arrow::datatypes::{DataType, Field, Schema};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};

use datafusion_common::{exec_err, ToDFSchema};
use datafusion_common::exec_err;
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
Expand Down Expand Up @@ -76,7 +76,7 @@ pub fn create_aggregate_expr(
builder = builder.sort_exprs(sort_exprs.to_vec());
builder = builder.order_by(ordering_req.to_vec());
builder = builder.logical_exprs(input_exprs.to_vec());
builder = builder.dfschema(Arc::new(schema.clone()).to_dfschema()?);
builder = builder.schema(Arc::new(schema.clone()));
builder = builder.name(name);

if ignore_nulls {
Expand Down Expand Up @@ -108,7 +108,8 @@ pub fn create_aggregate_expr_with_dfschema(
builder = builder.sort_exprs(sort_exprs.to_vec());
builder = builder.order_by(ordering_req.to_vec());
builder = builder.logical_exprs(input_exprs.to_vec());
builder = builder.dfschema(dfschema.clone());
let schema: Schema = dfschema.into();
builder = builder.schema(Arc::new(schema));
builder = builder.name(name);

if ignore_nulls {
Expand Down Expand Up @@ -136,6 +137,8 @@ pub struct AggregateExprBuilder {
/// Logical expressions of the aggregate function, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
logical_args: Vec<Expr>,
name: String,
/// Arrow Schema for the aggregate function
schema: SchemaRef,
/// Datafusion Schema for the aggregate function
dfschema: DFSchema,
/// The logical order by expressions, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
Expand All @@ -157,6 +160,7 @@ impl AggregateExprBuilder {
args,
logical_args: vec![],
name: String::new(),
schema: Arc::new(Schema::empty()),
dfschema: DFSchema::empty(),
sort_exprs: vec![],
ordering_req: vec![],
Expand All @@ -172,6 +176,7 @@ impl AggregateExprBuilder {
args,
logical_args,
name,
schema,
dfschema,
sort_exprs,
ordering_req,
Expand All @@ -189,15 +194,15 @@ impl AggregateExprBuilder {
if !ordering_req.is_empty() {
let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(dfschema.as_arrow()))
.map(|e| e.expr.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

ordering_fields = utils::ordering_fields(&ordering_req, &ordering_types);
}

let input_exprs_types = args
.iter()
.map(|arg| arg.data_type(dfschema.as_arrow()))
.map(|arg| arg.data_type(&schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
Expand All @@ -214,6 +219,7 @@ impl AggregateExprBuilder {
logical_args,
data_type,
name,
schema: Arc::unwrap_or_clone(schema),
dfschema,
sort_exprs,
ordering_req,
Expand All @@ -230,6 +236,11 @@ impl AggregateExprBuilder {
self
}

pub fn schema(mut self, schema: SchemaRef) -> Self {
self.schema = schema;
self
}

pub fn dfschema(mut self, dfschema: DFSchema) -> Self {
self.dfschema = dfschema;
self
Expand Down Expand Up @@ -444,6 +455,7 @@ pub struct AggregateFunctionExpr {
/// Output / return type of this aggregate
data_type: DataType,
name: String,
schema: Schema,
dfschema: DFSchema,
// The logical order by expressions
sort_exprs: Vec<Expr>,
Expand Down Expand Up @@ -509,6 +521,7 @@ impl AggregateExpr for AggregateFunctionExpr {
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let acc_args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
dfschema: &self.dfschema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
Expand All @@ -524,6 +537,7 @@ impl AggregateExpr for AggregateFunctionExpr {
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
dfschema: &self.dfschema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
Expand Down Expand Up @@ -594,6 +608,7 @@ impl AggregateExpr for AggregateFunctionExpr {
fn groups_accumulator_supported(&self) -> bool {
let args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
dfschema: &self.dfschema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
Expand All @@ -608,6 +623,7 @@ impl AggregateExpr for AggregateFunctionExpr {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
let args = AccumulatorArgs {
data_type: &self.data_type,
schema: &self.schema,
dfschema: &self.dfschema,
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-plan/src/windows/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::{

use arrow::datatypes::Schema;
use arrow_schema::{DataType, Field, SchemaRef};
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue, ToDFSchema};
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::{col, Expr, SortExpr};
use datafusion_expr::{
BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition,
Expand Down Expand Up @@ -127,7 +127,7 @@ pub fn create_window_expr(
.collect::<Vec<_>>();

let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec())
.dfschema(Arc::new(input_schema.clone()).to_dfschema()?)
.schema(Arc::new(input_schema.clone()))
.name(name)
.order_by(order_by.to_vec())
.sort_exprs(sort_exprs)
Expand Down
7 changes: 2 additions & 5 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
use datafusion::physical_plan::{
AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr,
};
use datafusion_common::{
internal_err, not_impl_err, DataFusionError, Result, ToDFSchema,
};
use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
use datafusion_expr::{AggregateUDF, ScalarUDF};

use crate::common::{byte_to_string, str_to_byte};
Expand Down Expand Up @@ -491,8 +489,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {

// TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
// TODO: `order by` is not supported for UDAF yet
AggregateExprBuilder::new(agg_udf, input_phy_expr).dfschema(Arc::clone(&physical_schema).to_dfschema()?).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build()
}
AggregateExprBuilder::new(agg_udf, input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build() }
}
}).transpose()?.ok_or_else(|| {
proto_error("Invalid AggregateExpr, missing aggregate_function")
Expand Down
Loading

0 comments on commit 5d97345

Please sign in to comment.