From ad26df98798b424968d19e2f1f4dc8165d5340f0 Mon Sep 17 00:00:00 2001 From: Micah Wylde Date: Wed, 8 May 2024 15:50:51 -0700 Subject: [PATCH] Revert "Make builtin window function output datatype to be derived from schema (#9686)" This reverts commit 1d0171ab9d33fc7896861dee85804d7daf0a6390. --- datafusion/core/src/physical_planner.rs | 33 ++++++++----- .../core/tests/fuzz_cases/window_fuzz.rs | 39 ++------------- datafusion/physical-plan/src/windows/mod.rs | 47 +++++++++---------- 3 files changed, 45 insertions(+), 74 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 8446da0f55d4..13430b4bdd56 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -743,13 +743,13 @@ impl DefaultPhysicalPlanner { ); } - let logical_schema = logical_plan.schema(); + let logical_input_schema = input.schema(); let window_expr = window_expr .iter() .map(|e| { create_window_expr( e, - logical_schema, + logical_input_schema, session_state.execution_props(), ) }) @@ -1572,11 +1572,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { pub fn create_window_expr_with_name( e: &Expr, name: impl Into, - logical_schema: &DFSchema, + logical_input_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); - let physical_schema: &Schema = &logical_schema.into(); + let physical_input_schema: &Schema = &logical_input_schema.into(); match e { Expr::WindowFunction(WindowFunction { fun, @@ -1586,11 +1586,20 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = create_physical_exprs(args, logical_schema, execution_props)?; - let partition_by = - create_physical_exprs(partition_by, logical_schema, execution_props)?; - let order_by = - create_physical_sort_exprs(order_by, logical_schema, execution_props)?; + let args = args + .iter() + .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .collect::>>()?; + let partition_by = partition_by + .iter() + .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .collect::>>()?; + let order_by = order_by + .iter() + .map(|e| { + create_physical_sort_expr(e, logical_input_schema, execution_props) + }) + .collect::>>()?; if !is_window_frame_bound_valid(window_frame) { return plan_err!( @@ -1610,7 +1619,7 @@ pub fn create_window_expr_with_name( &partition_by, &order_by, window_frame, - physical_schema, + physical_input_schema, ignore_nulls, ) } @@ -1621,7 +1630,7 @@ pub fn create_window_expr_with_name( /// Create a window expression from a logical expression or an alias pub fn create_window_expr( e: &Expr, - logical_schema: &DFSchema, + logical_input_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" @@ -1629,7 +1638,7 @@ pub fn create_window_expr( Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), _ => (e.display_name()?, e), }; - create_window_expr_with_name(e, name, logical_schema, execution_props) + create_window_expr_with_name(e, name, logical_input_schema, execution_props) } type AggregateExprWithOptionalArgs = ( diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 2514324a9541..00c65995a5ff 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,7 +22,6 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ @@ -40,7 +39,6 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -275,9 +273,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { window_frame.is_causal() }; - let extended_schema = - schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; - let window_expr = create_window_expr( &window_fn, fn_name.to_string(), @@ -285,7 +280,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), - &extended_schema, + schema.as_ref(), false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( @@ -683,8 +678,6 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; - let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &window_fn, @@ -693,7 +686,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - &extended_schema, + schema.as_ref(), false, )?], exec1, @@ -711,7 +704,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - &extended_schema, + schema.as_ref(), false, )?], exec2, @@ -754,32 +747,6 @@ async fn run_window_test( Ok(()) } -// The planner has fully updated schema before calling the `create_window_expr` -// Replicate the same for this test -fn schema_add_window_fields( - args: &[Arc], - schema: &Arc, - window_fn: &WindowFunctionDefinition, - fn_name: &str, -) -> Result> { - let data_types = args - .iter() - .map(|e| e.clone().as_ref().data_type(schema)) - .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; - let mut window_fields = schema - .fields() - .iter() - .map(|f| f.as_ref().clone()) - .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - true, - )]); - Ok(Arc::new(Schema::new(window_fields))) -} - /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 21f42f41fb5c..da2b24487d02 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -174,15 +174,20 @@ fn create_built_in_window_expr( name: String, ignore_nulls: bool, ) -> Result> { - // derive the output datatype from incoming schema - let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); + // need to get the types into an owned vec for some reason + let input_types: Vec<_> = args + .iter() + .map(|arg| arg.data_type(input_schema)) + .collect::>()?; + // figure out the output type + let data_type = &fun.return_type(&input_types)?; Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)), - BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)), + BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)), + BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)), + BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { DataFusionError::Execution( @@ -196,13 +201,13 @@ fn create_built_in_window_expr( if n.is_unsigned() { let n: u64 = n.try_into()?; - Arc::new(Ntile::new(name, n, out_data_type)) + Arc::new(Ntile::new(name, n, data_type)) } else { let n: i64 = n.try_into()?; if n <= 0 { return exec_err!("NTILE requires a positive integer"); } - Arc::new(Ntile::new(name, n as u64, out_data_type)) + Arc::new(Ntile::new(name, n as u64, data_type)) } } BuiltInWindowFunction::Lag => { @@ -211,10 +216,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; Arc::new(lag( name, - out_data_type.clone(), + data_type.clone(), arg, shift_offset, default_value, @@ -227,10 +232,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; Arc::new(lead( name, - out_data_type.clone(), + data_type.clone(), arg, shift_offset, default_value, @@ -247,28 +252,18 @@ fn create_built_in_window_expr( Arc::new(NthValue::nth( name, arg, - out_data_type.clone(), + data_type.clone(), n, ignore_nulls, )?) } BuiltInWindowFunction::FirstValue => { let arg = args[0].clone(); - Arc::new(NthValue::first( - name, - arg, - out_data_type.clone(), - ignore_nulls, - )) + Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls)) } BuiltInWindowFunction::LastValue => { let arg = args[0].clone(); - Arc::new(NthValue::last( - name, - arg, - out_data_type.clone(), - ignore_nulls, - )) + Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls)) } }) }