Skip to content

Commit

Permalink
combine partial changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksonrnewhouse authored and Jackson Newhouse committed Apr 10, 2024
1 parent 4b9d841 commit 2a7d83b
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 9 deletions.
12 changes: 7 additions & 5 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub enum AggregateMode {
/// two operators.
/// This mode requires tha the input is partitioned by group key (like FinalPartitioned)
SinglePartitioned,
/// Combine Partials
CombinePartial,
}

impl AggregateMode {
Expand All @@ -94,7 +96,7 @@ impl AggregateMode {
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned => true,
AggregateMode::Final | AggregateMode::FinalPartitioned => false,
AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => false,
}
}
}
Expand Down Expand Up @@ -651,7 +653,7 @@ impl ExecutionPlan for AggregateExec {

fn required_input_distribution(&self) -> Vec<Distribution> {
match &self.mode {
AggregateMode::Partial => {
AggregateMode::Partial | AggregateMode::CombinePartial => {
vec![Distribution::UnspecifiedDistribution]
}
AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => {
Expand Down Expand Up @@ -781,7 +783,7 @@ fn create_schema(
}

match mode {
AggregateMode::Partial => {
AggregateMode::Partial | AggregateMode::CombinePartial => {
// in partial mode, the fields of the accumulator's state
for expr in aggr_expr {
fields.extend(expr.state_fields()?.iter().cloned())
Expand Down Expand Up @@ -1050,7 +1052,7 @@ fn aggregate_expressions(
})
.collect()),
// In this mode, we build the merge expressions of the aggregation.
AggregateMode::Final | AggregateMode::FinalPartitioned => {
AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => {
let mut col_idx_base = col_idx_base;
aggr_expr
.iter()
Expand Down Expand Up @@ -1099,7 +1101,7 @@ fn finalize_aggregation(
mode: &AggregateMode,
) -> Result<Vec<ArrayRef>> {
match mode {
AggregateMode::Partial => {
AggregateMode::Partial | AggregateMode::CombinePartial => {
// Build the vector of states
accumulators
.iter_mut()
Expand Down
5 changes: 3 additions & 2 deletions datafusion/physical-plan/src/aggregates/no_grouping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ impl AggregateStream {
let filter_expressions = match agg.mode {
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned => agg_filter_expr,
| AggregateMode::SinglePartitioned
| AggregateMode::CombinePartial => agg_filter_expr,
AggregateMode::Final | AggregateMode::FinalPartitioned => {
vec![None; agg.aggr_expr.len()]
}
Expand Down Expand Up @@ -230,7 +231,7 @@ fn aggregate_batch(
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned => accum.update_batch(values),
AggregateMode::Final | AggregateMode::FinalPartitioned => {
AggregateMode::Final | AggregateMode::FinalPartitioned | AggregateMode::CombinePartial => {
accum.merge_batch(values)
}
};
Expand Down
6 changes: 4 additions & 2 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ impl GroupedHashAggregateStream {
let filter_expressions = match agg.mode {
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned => agg_filter_expr,
| AggregateMode::SinglePartitioned
| AggregateMode::CombinePartial => agg_filter_expr,
AggregateMode::Final | AggregateMode::FinalPartitioned => {
vec![None; agg.aggr_expr.len()]
}
Expand Down Expand Up @@ -640,7 +641,8 @@ impl GroupedHashAggregateStream {
// Next output each aggregate value
for acc in self.accumulators.iter_mut() {
match self.mode {
AggregateMode::Partial => output.extend(acc.state(emit_to)?),
AggregateMode::Partial
| AggregateMode::CombinePartial => output.extend(acc.state(emit_to)?),
_ if spilling => {
// If spilling, output partial state because the spilled data will be
// merged and re-evaluated later.
Expand Down
4 changes: 4 additions & 0 deletions datafusion/proto/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1389,8 +1389,12 @@ impl AsExecutionPlan for PhysicalPlanNode {
AggregateMode::Single => protobuf::AggregateMode::Single,
AggregateMode::SinglePartitioned => {
protobuf::AggregateMode::SinglePartitioned
},
AggregateMode::CombinePartial => {
unimplemented!()
}
};

let input_schema = exec.input_schema();
let input = protobuf::PhysicalPlanNode::try_from_physical_plan(
exec.input().to_owned(),
Expand Down

0 comments on commit 2a7d83b

Please sign in to comment.