diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 2c0f52c701..5a015ff8d9 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -48,6 +48,7 @@ use crate::{ pivot::PivotSink, sort::SortSink, streaming_sink::StreamingSinkNode, + window_partition_only::WindowPartitionOnlySink, write::{WriteFormat, WriteSink}, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource, source::SourceNode}, @@ -126,60 +127,15 @@ pub fn physical_plan_to_pipeline( input, partition_by, schema, - stats_state: _, - window_functions, + stats_state, + aggregations, }) => { - // First, ensure the input is processed let input_node = physical_plan_to_pipeline(input, psets, cfg)?; - - // Create a project node that actually adds window_0 columns - println!("Basic window partition implementation"); - println!(" Partition by: {:?}", partition_by); - println!(" Window functions: {:?}", window_functions); - println!(" Output schema: {:?}", schema); - - // For test_single_partition_sum, we need to calculate sum(value) grouped by category - // A=22, B=29, C=21 - use daft_dsl::{lit, resolved_col}; - - // Add the original columns - let category_col = resolved_col("category"); - let value_col = resolved_col("value"); - - // Create an expression to select the correct sum based on category - // We'll use nested if_else expressions to handle all categories - let cat_equal_a = category_col.clone().eq(lit("A")); - let cat_equal_b = category_col.clone().eq(lit("B")); - - // Creates an expression that returns: - // - 22 if category is "A" - // - 29 if category is "B" - // - 21 otherwise (for "C") - let window_expr = cat_equal_a.if_else( - lit(22), // If category is "A" - cat_equal_b.if_else( - lit(29), // If category is "B" - lit(21), // Else (category is "C") - ), - ); - - // Alias the result as "window_0" - let window_col = window_expr.alias("window_0"); - - // Create the projection with all columns - let projection = vec![category_col, value_col, window_col]; - - let proj_op = - ProjectOperator::new(projection).with_context(|_| PipelineCreationSnafu { - plan_name: "WindowPartitionOnly", + let agg_sink = WindowPartitionOnlySink::new(aggregations, partition_by, schema) + .with_context(|_| PipelineCreationSnafu { + plan_name: physical_plan.name(), })?; - - IntermediateNode::new( - Arc::new(proj_op), - vec![input_node], - StatsState::NotMaterialized, - ) - .boxed() + BlockingSinkNode::new(Arc::new(agg_sink), input_node, stats_state.clone()).boxed() } LocalPhysicalPlan::InMemoryScan(InMemoryScan { info, stats_state }) => { let cache_key: Arc = info.cache_key.clone().into(); diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index dbef15f6c3..73d863ac6c 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -11,4 +11,5 @@ pub mod outer_hash_join_probe; pub mod pivot; pub mod sort; pub mod streaming_sink; +pub mod window_partition_only; pub mod write; diff --git a/src/daft-local-execution/src/sinks/window_partition_only.rs b/src/daft-local-execution/src/sinks/window_partition_only.rs new file mode 100644 index 0000000000..785f53e363 --- /dev/null +++ b/src/daft-local-execution/src/sinks/window_partition_only.rs @@ -0,0 +1,298 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::prelude::SchemaRef; +use daft_dsl::{Expr, ExprRef}; +use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; +use itertools::Itertools; +use tracing::{instrument, Span}; + +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::{ExecutionTaskSpawner, NUM_CPUS}; + +enum WindowPartitionOnlyState { + Accumulating { + inner_states: Vec>, + }, + Done, +} + +impl WindowPartitionOnlyState { + fn new(num_partitions: usize) -> Self { + let inner_states = (0..num_partitions).map(|_| Vec::new()).collect::>(); + Self::Accumulating { inner_states } + } + + fn push( + &mut self, + input: Arc, + params: &WindowPartitionOnlyParams, + ) -> DaftResult<()> { + let Self::Accumulating { + ref mut inner_states, + } = self + else { + panic!("WindowPartitionOnlySink should be in Accumulating state"); + }; + + let partitioned = + input.partition_by_hash(params.partition_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + state.push(p); + } + Ok(()) + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Accumulating { + ref mut inner_states, + .. + } = self + { + std::mem::take(inner_states) + } else { + panic!("WindowPartitionOnlySink should be in Accumulating state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for WindowPartitionOnlyState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +struct WindowPartitionOnlyParams { + // Original aggregation expressions + original_aggregations: Vec, + // Partition by expressions + partition_by: Vec, + // First stage aggregation expressions + partial_agg_exprs: Vec, + // Second stage aggregation expressions + final_agg_exprs: Vec, + // Final projections + final_projections: Vec, +} + +pub struct WindowPartitionOnlySink { + window_partition_only_params: Arc, +} + +impl WindowPartitionOnlySink { + pub fn new( + aggregations: &[ExprRef], + partition_by: &[ExprRef], + schema: &SchemaRef, + ) -> DaftResult { + // Extract aggregation expressions + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + + // Use the same multi-stage approach as grouped aggregates + let (partial_aggs, final_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, partition_by); + + // Convert first stage aggregations to expressions + let partial_agg_exprs = partial_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + + // Convert second stage aggregations to expressions + let final_agg_exprs = final_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + + Ok(Self { + window_partition_only_params: Arc::new(WindowPartitionOnlyParams { + original_aggregations: aggregations + .into_iter() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(), + partition_by: partition_by.to_vec(), + partial_agg_exprs, + final_agg_exprs, + final_projections, + }), + }) + } + + fn num_partitions(&self) -> usize { + *NUM_CPUS + } +} + +impl BlockingSink for WindowPartitionOnlySink { + #[instrument(skip_all, name = "WindowPartitionOnlySink::sink")] + fn sink( + &self, + input: Arc, + mut state: Box, + spawner: &ExecutionTaskSpawner, + ) -> BlockingSinkSinkResult { + let params = self.window_partition_only_params.clone(); + spawner + .spawn( + async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("WindowPartitionOnlySink should have WindowPartitionOnlyState"); + + agg_state.push(input, ¶ms)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + }, + Span::current(), + ) + .into() + } + + #[instrument(skip_all, name = "WindowPartitionOnlySink::finalize")] + fn finalize( + &self, + states: Vec>, + spawner: &ExecutionTaskSpawner, + ) -> BlockingSinkFinalizeResult { + let params = self.window_partition_only_params.clone(); + let num_partitions = self.num_partitions(); + spawner + .spawn( + async move { + let mut state_iters = states + .into_iter() + .map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect( + "WindowPartitionOnlySink should have WindowPartitionOnlyState", + ) + .finalize() + .into_iter() + }) + .collect::>(); + + let mut per_partition_finalize_tasks = tokio::task::JoinSet::new(); + for _ in 0..num_partitions { + let per_partition_state = state_iters + .iter_mut() + .map(|state| { + state.next().expect( + "WindowPartitionOnlyState should have Vec", + ) + }) + .collect::>(); + let params = params.clone(); + per_partition_finalize_tasks.spawn(async move { + // Skip empty partitions + if per_partition_state.is_empty() { + return Ok(None); + } + + // Concatenate all micropartitions for this partition + let partitions: Vec = + per_partition_state.into_iter().flatten().collect(); + if partitions.is_empty() { + return Ok(None); + } + + let concated = MicroPartition::concat(&partitions)?; + + // Two-stage window function processing: + + // 1. First stage: Apply partial aggregations + // For window functions, the first stage creates intermediate results like sums and counts + let partially_aggregated = if !params.partial_agg_exprs.is_empty() { + concated.agg(¶ms.partial_agg_exprs, ¶ms.partition_by)? + } else { + // If no partial aggregations are needed, use original expressions + concated.agg(¶ms.original_aggregations, ¶ms.partition_by)? + }; + + // 2. Second stage: Apply final aggregations + // This stage combines the intermediate results to get final values + let final_result = if !params.final_agg_exprs.is_empty() { + // Apply the second stage and then final projections + let final_agged = partially_aggregated + .agg(¶ms.final_agg_exprs, ¶ms.partition_by)?; + + // Apply final projections to produce the output columns + final_agged.eval_expression_list(¶ms.final_projections)? + } else { + // If there's no second stage, just apply projections directly + partially_aggregated + .eval_expression_list(¶ms.final_projections)? + }; + + Ok(Some(final_result)) + }); + } + + // Collect results from all partitions + let results = per_partition_finalize_tasks + .join_all() + .await + .into_iter() + .collect::>>()? + .into_iter() + .flatten() + .collect::>(); + + if results.is_empty() { + return Ok(None); + } + + // Combine all partition results + let concated = MicroPartition::concat(&results)?; + Ok(Some(Arc::new(concated))) + }, + Span::current(), + ) + .into() + } + + fn name(&self) -> &'static str { + "WindowPartitionOnly" + } + + fn multiline_display(&self) -> Vec { + let mut display = vec![]; + display.push(format!( + "WindowPartitionOnly: {}", + self.window_partition_only_params + .original_aggregations + .iter() + .map(|e| e.to_string()) + .join(", ") + )); + display.push(format!( + "Partition by: {}", + self.window_partition_only_params + .partition_by + .iter() + .map(|e| e.to_string()) + .join(", ") + )); + display + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(WindowPartitionOnlyState::new( + self.num_partitions(), + ))) + } +} diff --git a/src/daft-local-plan/src/plan.rs b/src/daft-local-plan/src/plan.rs index f6597e92d4..bf9e7d55bb 100644 --- a/src/daft-local-plan/src/plan.rs +++ b/src/daft-local-plan/src/plan.rs @@ -235,14 +235,14 @@ impl LocalPhysicalPlan { partition_by: Vec, schema: SchemaRef, stats_state: StatsState, - window_functions: Vec, + aggregations: Vec, ) -> LocalPhysicalPlanRef { Self::WindowPartitionOnly(WindowPartitionOnly { input, partition_by, schema, stats_state, - window_functions, + aggregations, }) .arced() } @@ -673,5 +673,5 @@ pub struct WindowPartitionOnly { pub partition_by: Vec, pub schema: SchemaRef, pub stats_state: StatsState, - pub window_functions: Vec, + pub aggregations: Vec, }