Skip to content

Commit

Permalink
feat(substrait): AggregateRel grouping_expression support (apache#13173)
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshchiy authored Nov 3, 2024
1 parent b40a298 commit a9d4d52
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 36 deletions.
77 changes: 55 additions & 22 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion::logical_expr::{
expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr,
ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use url::Url;
Expand Down Expand Up @@ -665,39 +666,48 @@ pub async fn from_substrait_rel(
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
let mut group_expr = vec![];
let mut aggr_expr = vec![];
let mut ref_group_exprs = vec![];

for e in &agg.grouping_expressions {
let x =
from_substrait_rex(ctx, e, input.schema(), extensions).await?;
ref_group_exprs.push(x);
}

let mut group_exprs = vec![];
let mut aggr_exprs = vec![];

match agg.groupings.len() {
1 => {
for e in &agg.groupings[0].grouping_expressions {
let x =
from_substrait_rex(ctx, e, input.schema(), extensions)
.await?;
group_expr.push(x);
}
group_exprs.extend_from_slice(
&from_substrait_grouping(
ctx,
&agg.groupings[0],
&ref_group_exprs,
input.schema(),
extensions,
)
.await?,
);
}
_ => {
let mut grouping_sets = vec![];
for grouping in &agg.groupings {
let mut grouping_set = vec![];
for e in &grouping.grouping_expressions {
let x = from_substrait_rex(
ctx,
e,
input.schema(),
extensions,
)
.await?;
grouping_set.push(x);
}
let grouping_set = from_substrait_grouping(
ctx,
grouping,
&ref_group_exprs,
input.schema(),
extensions,
)
.await?;
grouping_sets.push(grouping_set);
}
// Single-element grouping expression of type Expr::GroupingSet.
// Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when
// parsed by the producer and consumer, since Substrait does not have a type dedicated
// to ROLLUP. Only vector of Groupings (grouping sets) is available.
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets(
grouping_sets,
)));
}
Expand Down Expand Up @@ -755,9 +765,9 @@ pub async fn from_substrait_rel(
"Aggregate without aggregate function is not supported"
),
};
aggr_expr.push(agg_func?.as_ref().clone());
aggr_exprs.push(agg_func?.as_ref().clone());
}
input.aggregate(group_expr, aggr_expr)?.build()
input.aggregate(group_exprs, aggr_exprs)?.build()
} else {
not_impl_err!("Aggregate without an input is not valid")
}
Expand Down Expand Up @@ -2762,6 +2772,29 @@ fn from_substrait_null(
}
}

#[allow(deprecated)]
async fn from_substrait_grouping(
ctx: &SessionContext,
grouping: &Grouping,
expressions: &[Expr],
input_schema: &DFSchemaRef,
extensions: &Extensions,
) -> Result<Vec<Expr>> {
let mut group_exprs = vec![];
if !grouping.grouping_expressions.is_empty() {
for e in &grouping.grouping_expressions {
let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?;
group_exprs.push(expr);
}
return Ok(group_exprs);
}
for idx in &grouping.expression_references {
let e = &expressions[*idx as usize];
group_exprs.push(e.clone());
}
Ok(group_exprs)
}

fn from_substrait_field_reference(
field_ref: &FieldReference,
input_schema: &DFSchema,
Expand Down
58 changes: 44 additions & 14 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ pub fn to_substrait_rel(
}
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
let groupings = to_substrait_groupings(
let (grouping_expressions, groupings) = to_substrait_groupings(
ctx,
&agg.group_expr,
agg.input.schema(),
Expand All @@ -377,7 +377,7 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
grouping_expressions,
groupings,
measures,
advanced_extension: None,
Expand Down Expand Up @@ -774,14 +774,20 @@ pub fn parse_flat_grouping_exprs(
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
ref_group_exprs: &mut Vec<Expression>,
) -> Result<Grouping> {
let grouping_expressions = exprs
.iter()
.map(|e| to_substrait_rex(ctx, e, schema, 0, extensions))
.collect::<Result<Vec<_>>>()?;
let mut expression_references = vec![];
let mut grouping_expressions = vec![];

for e in exprs {
let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
grouping_expressions.push(rex.clone());
ref_group_exprs.push(rex);
expression_references.push((ref_group_exprs.len() - 1) as u32);
}
Ok(Grouping {
grouping_expressions,
expression_references: vec![],
expression_references,
})
}

Expand All @@ -790,16 +796,25 @@ pub fn to_substrait_groupings(
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
) -> Result<Vec<Grouping>> {
match exprs.len() {
) -> Result<(Vec<Expression>, Vec<Grouping>)> {
let mut ref_group_exprs = vec![];
let groupings = match exprs.len() {
1 => match &exprs[0] {
Expr::GroupingSet(gs) => match gs {
GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
"GroupingSet CUBE is not yet supported".to_string(),
)),
GroupingSet::GroupingSets(sets) => Ok(sets
.iter()
.map(|set| parse_flat_grouping_exprs(ctx, set, schema, extensions))
.map(|set| {
parse_flat_grouping_exprs(
ctx,
set,
schema,
extensions,
&mut ref_group_exprs,
)
})
.collect::<Result<Vec<_>>>()?),
GroupingSet::Rollup(set) => {
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
Expand All @@ -810,19 +825,34 @@ pub fn to_substrait_groupings(
.iter()
.rev()
.map(|set| {
parse_flat_grouping_exprs(ctx, set, schema, extensions)
parse_flat_grouping_exprs(
ctx,
set,
schema,
extensions,
&mut ref_group_exprs,
)
})
.collect::<Result<Vec<_>>>()?)
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
ctx, exprs, schema, extensions,
ctx,
exprs,
schema,
extensions,
&mut ref_group_exprs,
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
ctx, exprs, schema, extensions,
ctx,
exprs,
schema,
extensions,
&mut ref_group_exprs,
)?]),
}
}?;
Ok((ref_group_exprs, groupings))
}

#[allow(deprecated)]
Expand Down
13 changes: 13 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
.await
}

#[tokio::test]
async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");

assert_expected_plan_substrait(
proto_plan,
"Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
\n TableScan: data projection=[a]",
)
.await
}

#[tokio::test]
async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
let proto_plan =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
{
"extensionUris": [
{
"uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"functionAnchor": 185,
"name": "count:any"
}
}
],
"relations": [
{
"root": {
"input": {
"aggregate": {
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"a"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_NULLABLE"
}
},
"namedTable": {
"names": [
"data"
]
}
}
},
"grouping_expressions": [
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
],
"groupings": [
{
"expression_references": [0]
}
],
"measures": [
{
"measure": {
"functionReference": 185,
"phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
"outputType": {
"i64": {}
},
"arguments": [
{
"value": {
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
}
]
}
}
]
}
},
"names": [
"a",
"countA"
]
}
}
],
"version": {
"minorNumber": 54,
"producer": "subframe"
}
}

0 comments on commit a9d4d52

Please sign in to comment.