Skip to content

Commit

Permalink
feat(window): implement window partition only sink
Browse files Browse the repository at this point in the history
  • Loading branch information
f4t4nt committed Mar 11, 2025
1 parent bd44148 commit 407c7ec
Show file tree
Hide file tree
Showing 4 changed files with 309 additions and 54 deletions.
58 changes: 7 additions & 51 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use crate::{
pivot::PivotSink,
sort::SortSink,
streaming_sink::StreamingSinkNode,
window_partition_only::WindowPartitionOnlySink,
write::{WriteFormat, WriteSink},
},
sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource, source::SourceNode},
Expand Down Expand Up @@ -126,60 +127,15 @@ pub fn physical_plan_to_pipeline(
input,
partition_by,
schema,
stats_state: _,
window_functions,
stats_state,
aggregations,

Check warning on line 131 in src/daft-local-execution/src/pipeline.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/pipeline.rs#L127-L131

Added lines #L127 - L131 were not covered by tests
}) => {
// First, ensure the input is processed
let input_node = physical_plan_to_pipeline(input, psets, cfg)?;

// Create a project node that actually adds window_0 columns
println!("Basic window partition implementation");
println!(" Partition by: {:?}", partition_by);
println!(" Window functions: {:?}", window_functions);
println!(" Output schema: {:?}", schema);

// For test_single_partition_sum, we need to calculate sum(value) grouped by category
// A=22, B=29, C=21
use daft_dsl::{lit, resolved_col};

// Add the original columns
let category_col = resolved_col("category");
let value_col = resolved_col("value");

// Create an expression to select the correct sum based on category
// We'll use nested if_else expressions to handle all categories
let cat_equal_a = category_col.clone().eq(lit("A"));
let cat_equal_b = category_col.clone().eq(lit("B"));

// Creates an expression that returns:
// - 22 if category is "A"
// - 29 if category is "B"
// - 21 otherwise (for "C")
let window_expr = cat_equal_a.if_else(
lit(22), // If category is "A"
cat_equal_b.if_else(
lit(29), // If category is "B"
lit(21), // Else (category is "C")
),
);

// Alias the result as "window_0"
let window_col = window_expr.alias("window_0");

// Create the projection with all columns
let projection = vec![category_col, value_col, window_col];

let proj_op =
ProjectOperator::new(projection).with_context(|_| PipelineCreationSnafu {
plan_name: "WindowPartitionOnly",
let agg_sink = WindowPartitionOnlySink::new(aggregations, partition_by, schema)
.with_context(|_| PipelineCreationSnafu {
plan_name: physical_plan.name(),
})?;

IntermediateNode::new(
Arc::new(proj_op),
vec![input_node],
StatsState::NotMaterialized,
)
.boxed()
BlockingSinkNode::new(Arc::new(agg_sink), input_node, stats_state.clone()).boxed()

Check warning on line 138 in src/daft-local-execution/src/pipeline.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/pipeline.rs#L133-L138

Added lines #L133 - L138 were not covered by tests
}
LocalPhysicalPlan::InMemoryScan(InMemoryScan { info, stats_state }) => {
let cache_key: Arc<str> = info.cache_key.clone().into();
Expand Down
1 change: 1 addition & 0 deletions src/daft-local-execution/src/sinks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pub mod outer_hash_join_probe;
pub mod pivot;
pub mod sort;
pub mod streaming_sink;
pub mod window_partition_only;
pub mod write;
298 changes: 298 additions & 0 deletions src/daft-local-execution/src/sinks/window_partition_only.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::prelude::SchemaRef;
use daft_dsl::{Expr, ExprRef};
use daft_micropartition::MicroPartition;
use daft_physical_plan::extract_agg_expr;
use itertools::Itertools;
use tracing::{instrument, Span};

use super::blocking_sink::{
BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState,
BlockingSinkStatus,
};
use crate::{ExecutionTaskSpawner, NUM_CPUS};

enum WindowPartitionOnlyState {
Accumulating {
inner_states: Vec<Vec<MicroPartition>>,
},
Done,
}

impl WindowPartitionOnlyState {
fn new(num_partitions: usize) -> Self {
let inner_states = (0..num_partitions).map(|_| Vec::new()).collect::<Vec<_>>();
Self::Accumulating { inner_states }
}

Check warning on line 28 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L25-L28

Added lines #L25 - L28 were not covered by tests

fn push(
&mut self,
input: Arc<MicroPartition>,
params: &WindowPartitionOnlyParams,
) -> DaftResult<()> {

Check warning on line 34 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L30-L34

Added lines #L30 - L34 were not covered by tests
let Self::Accumulating {
ref mut inner_states,
} = self

Check warning on line 37 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L36-L37

Added lines #L36 - L37 were not covered by tests
else {
panic!("WindowPartitionOnlySink should be in Accumulating state");

Check warning on line 39 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L39

Added line #L39 was not covered by tests
};

let partitioned =
input.partition_by_hash(params.partition_by.as_slice(), inner_states.len())?;
for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) {
state.push(p);
}
Ok(())
}

Check warning on line 48 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L42-L48

Added lines #L42 - L48 were not covered by tests

fn finalize(&mut self) -> Vec<Vec<MicroPartition>> {
let res = if let Self::Accumulating {
ref mut inner_states,

Check warning on line 52 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L50-L52

Added lines #L50 - L52 were not covered by tests
..
} = self

Check warning on line 54 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L54

Added line #L54 was not covered by tests
{
std::mem::take(inner_states)

Check warning on line 56 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L56

Added line #L56 was not covered by tests
} else {
panic!("WindowPartitionOnlySink should be in Accumulating state");

Check warning on line 58 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L58

Added line #L58 was not covered by tests
};
*self = Self::Done;
res
}

Check warning on line 62 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L60-L62

Added lines #L60 - L62 were not covered by tests
}

impl BlockingSinkState for WindowPartitionOnlyState {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}

Check warning on line 68 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L66-L68

Added lines #L66 - L68 were not covered by tests
}

struct WindowPartitionOnlyParams {
// Original aggregation expressions
original_aggregations: Vec<ExprRef>,
// Partition by expressions
partition_by: Vec<ExprRef>,
// First stage aggregation expressions
partial_agg_exprs: Vec<ExprRef>,
// Second stage aggregation expressions
final_agg_exprs: Vec<ExprRef>,
// Final projections
final_projections: Vec<ExprRef>,
}

pub struct WindowPartitionOnlySink {
window_partition_only_params: Arc<WindowPartitionOnlyParams>,
}

impl WindowPartitionOnlySink {
pub fn new(
aggregations: &[ExprRef],
partition_by: &[ExprRef],
schema: &SchemaRef,
) -> DaftResult<Self> {

Check warning on line 93 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L89-L93

Added lines #L89 - L93 were not covered by tests
// Extract aggregation expressions
let aggregations = aggregations
.iter()
.map(extract_agg_expr)
.collect::<DaftResult<Vec<_>>>()?;

Check warning on line 98 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L95-L98

Added lines #L95 - L98 were not covered by tests

// Use the same multi-stage approach as grouped aggregates
let (partial_aggs, final_aggs, final_projections) =
daft_physical_plan::populate_aggregation_stages(&aggregations, schema, partition_by);

// Convert first stage aggregations to expressions
let partial_agg_exprs = partial_aggs
.into_values()
.map(|e| Arc::new(Expr::Agg(e)))
.collect::<Vec<_>>();

// Convert second stage aggregations to expressions
let final_agg_exprs = final_aggs
.into_values()
.map(|e| Arc::new(Expr::Agg(e)))
.collect::<Vec<_>>();

Ok(Self {
window_partition_only_params: Arc::new(WindowPartitionOnlyParams {
original_aggregations: aggregations
.into_iter()
.map(|e| Arc::new(Expr::Agg(e)))
.collect(),
partition_by: partition_by.to_vec(),
partial_agg_exprs,
final_agg_exprs,
final_projections,
}),
})
}

Check warning on line 128 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L101-L128

Added lines #L101 - L128 were not covered by tests

fn num_partitions(&self) -> usize {
*NUM_CPUS
}

Check warning on line 132 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L130-L132

Added lines #L130 - L132 were not covered by tests
}

impl BlockingSink for WindowPartitionOnlySink {
#[instrument(skip_all, name = "WindowPartitionOnlySink::sink")]
fn sink(
&self,
input: Arc<MicroPartition>,
mut state: Box<dyn BlockingSinkState>,
spawner: &ExecutionTaskSpawner,
) -> BlockingSinkSinkResult {
let params = self.window_partition_only_params.clone();
spawner
.spawn(
async move {
let agg_state = state
.as_any_mut()
.downcast_mut::<WindowPartitionOnlyState>()
.expect("WindowPartitionOnlySink should have WindowPartitionOnlyState");

agg_state.push(input, &params)?;
Ok(BlockingSinkStatus::NeedMoreInput(state))
},

Check warning on line 154 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L146-L154

Added lines #L146 - L154 were not covered by tests
Span::current(),
)
.into()
}

#[instrument(skip_all, name = "WindowPartitionOnlySink::finalize")]
fn finalize(
&self,
states: Vec<Box<dyn BlockingSinkState>>,
spawner: &ExecutionTaskSpawner,
) -> BlockingSinkFinalizeResult {
let params = self.window_partition_only_params.clone();
let num_partitions = self.num_partitions();
spawner
.spawn(
async move {
let mut state_iters = states
.into_iter()
.map(|mut state| {
state
.as_any_mut()
.downcast_mut::<WindowPartitionOnlyState>()
.expect(
"WindowPartitionOnlySink should have WindowPartitionOnlyState",
)
.finalize()
.into_iter()
})
.collect::<Vec<_>>();

let mut per_partition_finalize_tasks = tokio::task::JoinSet::new();
for _ in 0..num_partitions {
let per_partition_state = state_iters
.iter_mut()
.map(|state| {
state.next().expect(
"WindowPartitionOnlyState should have Vec<MicroPartition>",
)
})
.collect::<Vec<_>>();
let params = params.clone();
per_partition_finalize_tasks.spawn(async move {
// Skip empty partitions
if per_partition_state.is_empty() {
return Ok(None);
}

// Concatenate all micropartitions for this partition
let partitions: Vec<MicroPartition> =
per_partition_state.into_iter().flatten().collect();
if partitions.is_empty() {
return Ok(None);
}

Check warning on line 207 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L170-L207

Added lines #L170 - L207 were not covered by tests

let concated = MicroPartition::concat(&partitions)?;

Check warning on line 209 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L209

Added line #L209 was not covered by tests

// Two-stage window function processing:

// 1. First stage: Apply partial aggregations
// For window functions, the first stage creates intermediate results like sums and counts
let partially_aggregated = if !params.partial_agg_exprs.is_empty() {
concated.agg(&params.partial_agg_exprs, &params.partition_by)?

Check warning on line 216 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L215-L216

Added lines #L215 - L216 were not covered by tests
} else {
// If no partial aggregations are needed, use original expressions
concated.agg(&params.original_aggregations, &params.partition_by)?

Check warning on line 219 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L219

Added line #L219 was not covered by tests
};

// 2. Second stage: Apply final aggregations
// This stage combines the intermediate results to get final values
let final_result = if !params.final_agg_exprs.is_empty() {

Check warning on line 224 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L224

Added line #L224 was not covered by tests
// Apply the second stage and then final projections
let final_agged = partially_aggregated
.agg(&params.final_agg_exprs, &params.partition_by)?;

Check warning on line 227 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L226-L227

Added lines #L226 - L227 were not covered by tests

// Apply final projections to produce the output columns
final_agged.eval_expression_list(&params.final_projections)?

Check warning on line 230 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L230

Added line #L230 was not covered by tests
} else {
// If there's no second stage, just apply projections directly
partially_aggregated
.eval_expression_list(&params.final_projections)?

Check warning on line 234 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L233-L234

Added lines #L233 - L234 were not covered by tests
};

Ok(Some(final_result))
});
}

Check warning on line 239 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L237-L239

Added lines #L237 - L239 were not covered by tests

// Collect results from all partitions
let results = per_partition_finalize_tasks
.join_all()
.await
.into_iter()
.collect::<DaftResult<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>();

if results.is_empty() {
return Ok(None);
}

Check warning on line 253 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L242-L253

Added lines #L242 - L253 were not covered by tests

// Combine all partition results
let concated = MicroPartition::concat(&results)?;
Ok(Some(Arc::new(concated)))
},

Check warning on line 258 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L256-L258

Added lines #L256 - L258 were not covered by tests
Span::current(),
)
.into()
}

fn name(&self) -> &'static str {
"WindowPartitionOnly"
}

Check warning on line 266 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L264-L266

Added lines #L264 - L266 were not covered by tests

fn multiline_display(&self) -> Vec<String> {
let mut display = vec![];
display.push(format!(
"WindowPartitionOnly: {}",
self.window_partition_only_params
.original_aggregations
.iter()
.map(|e| e.to_string())
.join(", ")
));
display.push(format!(
"Partition by: {}",
self.window_partition_only_params
.partition_by
.iter()
.map(|e| e.to_string())
.join(", ")
));
display
}

Check warning on line 287 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L268-L287

Added lines #L268 - L287 were not covered by tests

fn max_concurrency(&self) -> usize {
*NUM_CPUS
}

Check warning on line 291 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L289-L291

Added lines #L289 - L291 were not covered by tests

fn make_state(&self) -> DaftResult<Box<dyn BlockingSinkState>> {
Ok(Box::new(WindowPartitionOnlyState::new(
self.num_partitions(),
)))
}

Check warning on line 297 in src/daft-local-execution/src/sinks/window_partition_only.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-local-execution/src/sinks/window_partition_only.rs#L293-L297

Added lines #L293 - L297 were not covered by tests
}
Loading

0 comments on commit 407c7ec

Please sign in to comment.