diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 2514324a9541..fe0c408dc114 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,11 +22,10 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, }; use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::{collect, InputOrderMode}; @@ -40,7 +39,6 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -276,7 +274,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { }; let extended_schema = - schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + schema_add_window_field(&args, &schema, &window_fn, fn_name)?; let window_expr = create_window_expr( &window_fn, @@ -683,7 +681,7 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( @@ -754,32 +752,6 @@ async fn run_window_test( Ok(()) } -// The planner has fully updated schema before calling the `create_window_expr` -// Replicate the same for this test -fn schema_add_window_fields( - args: &[Arc], - schema: &Arc, - window_fn: &WindowFunctionDefinition, - fn_name: &str, -) -> Result> { - let data_types = args - .iter() - .map(|e| e.clone().as_ref().data_type(schema)) - .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; - let mut window_fields = schema - .fields() - .iter() - .map(|f| f.as_ref().clone()) - .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - true, - )]); - Ok(Arc::new(Schema::new(window_fields))) -} - /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index ff60329ce179..ee17a28108a2 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -42,6 +42,7 @@ use datafusion_physical_expr::{ window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use itertools::Itertools; mod bounded_window_agg_exec; mod window_agg_exec; @@ -52,6 +53,31 @@ pub use datafusion_physical_expr::window::{ }; pub use window_agg_exec::WindowAggExec; +/// Build field from window function and add it into schema +pub fn schema_add_window_field( + args: &[Arc], + schema: &Schema, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + false, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Create a physical expression for window function #[allow(clippy::too_many_arguments)] pub fn create_window_expr( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index c907e991fb86..a290f30586ce 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -40,7 +40,7 @@ use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::windows::create_window_expr; +use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{ ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, }; @@ -155,14 +155,18 @@ pub fn parse_physical_window_expr( ) })?; + let fun: WindowFunctionDefinition = convert_required!(proto.window_function)?; + let name = proto.name.clone(); + let extended_schema = + schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; create_window_expr( - &convert_required!(proto.window_function)?, - proto.name.clone(), + &fun, + name, &window_node_expr, &partition_by, &order_by, Arc::new(window_frame), - input_schema, + &extended_schema, false, ) } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c2018352c7cf..3b452db9e460 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -253,8 +253,7 @@ fn roundtrip_nested_loop_join() -> Result<()> { fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); - let field_c = Field::new("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); let window_frame = WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range,