Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Eliminate duplicate aggregations #132

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/binder/create_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
for column_name in column_names.iter().map(|ident| ident.value.to_lowercase()) {
if let Some(column) = columns
.iter_mut()
.find(|column| column.name() == column_name.to_string())
.find(|column| column.name() == column_name)
{
if *is_primary {
column.desc.is_primary = true;
Expand Down
2 changes: 1 addition & 1 deletion src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
.find(|(_, column)| column.desc.is_primary)
.map(|(_, column)| Arc::clone(column))
.unwrap();
let mut plan = ScanOperator::build(table_name.clone(), &table_catalog);
let mut plan = ScanOperator::build(table_name.clone(), table_catalog);

if let Some(alias) = alias {
self.context
Expand Down
4 changes: 2 additions & 2 deletions src/binder/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
table_name: TableName,
) -> Result<(), DatabaseError> {
if !alias_column.is_empty() {
let aliases = alias_column.into_iter().map(lower_ident).collect_vec();
let aliases = alias_column.iter().map(lower_ident).collect_vec();
let table = self
.context
.table(table_name.clone())
Expand Down Expand Up @@ -222,7 +222,7 @@ impl<'a, T: Transaction> Binder<'a, T> {
let table_name = Arc::new(table.to_string());

let table_catalog = self.context.table_and_bind(table_name.clone(), join_type)?;
let scan_op = ScanOperator::build(table_name.clone(), &table_catalog);
let scan_op = ScanOperator::build(table_name.clone(), table_catalog);

if let Some(TableAlias { name, columns }) = alias {
self.register_alias(columns, name.value.to_lowercase(), table_name.clone())?;
Expand Down
1 change: 1 addition & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl<S: Storage> Database<S> {
HepBatchStrategy::fix_point_topdown(10),
vec![
NormalizationRuleImpl::CollapseProject,
NormalizationRuleImpl::CollapseGroupByAgg,
NormalizationRuleImpl::CombineFilter,
],
)
Expand Down
79 changes: 79 additions & 0 deletions src/optimizer/rule/normalization/combine_operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::optimizer::rule::normalization::is_subset_exprs;
use crate::planner::operator::Operator;
use crate::types::LogicalType;
use lazy_static::lazy_static;
use std::collections::HashSet;

lazy_static! {
static ref COLLAPSE_PROJECT_RULE: Pattern = {
Expand All @@ -27,6 +28,21 @@ lazy_static! {
}]),
}
};
static ref COLLAPSE_GROUP_BY_AGG: Pattern = {
Pattern {
predicate: |op| match op {
Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(),
_ => false,
},
children: PatternChildrenPredicate::Predicate(vec![Pattern {
predicate: |op| match op {
Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(),
_ => false,
},
children: PatternChildrenPredicate::None,
}]),
}
};
}

/// Combine two adjacent project operators into one.
Expand Down Expand Up @@ -87,6 +103,47 @@ impl NormalizationRule for CombineFilter {
}
}

pub struct CollapseGroupByAgg;

impl MatchPattern for CollapseGroupByAgg {
fn pattern(&self) -> &Pattern {
&COLLAPSE_GROUP_BY_AGG
}
}

impl NormalizationRule for CollapseGroupByAgg {
fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> {
if let Operator::Aggregate(op) = graph.operator(node_id).clone() {
// if it is an aggregation operator containing agg_call
if !op.agg_calls.is_empty() {
return Ok(());
}

if let Some(Operator::Aggregate(child_op)) = graph
.eldest_child_at(node_id)
.and_then(|child_id| Some(graph.operator_mut(child_id)))
{
if op.groupby_exprs.len() != child_op.groupby_exprs.len() {
return Ok(());
}
let mut expr_set = HashSet::new();

for expr in op.groupby_exprs.iter() {
expr_set.insert(expr);
}
for expr in child_op.groupby_exprs.iter() {
expr_set.remove(expr);
}
if expr_set.len() == 0 {
graph.remove_node(node_id, false);
}
}
}

Ok(())
}
}

#[cfg(test)]
mod tests {
use crate::binder::test::select_sql_run;
Expand Down Expand Up @@ -181,4 +238,26 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_collapse_group_by_agg() -> Result<(), DatabaseError> {
let plan = select_sql_run("select distinct c1, c2 from t1 group by c1, c2").await?;

let optimizer = HepOptimizer::new(plan.clone()).batch(
"test_collapse_group_by_agg".to_string(),
HepBatchStrategy::once_topdown(),
vec![NormalizationRuleImpl::CollapseGroupByAgg],
);

let best_plan = optimizer.find_best::<KipTransaction>(None)?;

if let Operator::Aggregate(_) = &best_plan.childrens[0].operator {
if let Operator::Aggregate(_) = &best_plan.childrens[0].childrens[0].operator {
unreachable!("Should not be a agg operator")
} else {
return Ok(());
}
}
unreachable!("Should be a agg operator")
}
}
7 changes: 6 additions & 1 deletion src/optimizer/rule/normalization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use crate::optimizer::core::pattern::Pattern;
use crate::optimizer::core::rule::{MatchPattern, NormalizationRule};
use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId};
use crate::optimizer::rule::normalization::column_pruning::ColumnPruning;
use crate::optimizer::rule::normalization::combine_operators::{CollapseProject, CombineFilter};
use crate::optimizer::rule::normalization::combine_operators::{
CollapseGroupByAgg, CollapseProject, CombineFilter,
};
use crate::optimizer::rule::normalization::pushdown_limit::{
EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin,
};
Expand All @@ -24,6 +26,7 @@ pub enum NormalizationRuleImpl {
ColumnPruning,
// Combine operators
CollapseProject,
CollapseGroupByAgg,
CombineFilter,
// PushDown limit
LimitProjectTranspose,
Expand All @@ -44,6 +47,7 @@ impl MatchPattern for NormalizationRuleImpl {
match self {
NormalizationRuleImpl::ColumnPruning => ColumnPruning.pattern(),
NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(),
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(),
NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(),
NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(),
NormalizationRuleImpl::EliminateLimits => EliminateLimits.pattern(),
Expand All @@ -62,6 +66,7 @@ impl NormalizationRule for NormalizationRuleImpl {
match self {
NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(node_id, graph),
NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph),
NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph),
NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph),
NormalizationRuleImpl::LimitProjectTranspose => {
LimitProjectTranspose.apply(node_id, graph)
Expand Down
20 changes: 10 additions & 10 deletions tests/slt/group_by.slt
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ statement ok
insert into t values (0,1,1), (1,2,1), (2,3,2), (3,4,2), (4,5,3)

# TODO: check on binder
# statement error
# select v2 + 1, v1 from t group by v2 + 1
statement error
select v2 + 1, v1 from t group by v2 + 1

# statement error
# select v2 + 1 as a, v1 as b from t group by a
statement error
select v2 + 1 as a, v1 as b from t group by a

# statement error
# select v2, v2 + 1, sum(v1) from t group by v2 + 1
statement error
select v2, v2 + 1, sum(v1) from t group by v2 + 1

# statement error
# select v2 + 2 + count(*) from t group by v2 + 1
statement error
select v2 + 2 + count(*) from t group by v2 + 1

# statement error
# select v2 + count(*) from t group by v2 order by v1;
statement error
select v2 + count(*) from t group by v2 order by v1;

query II rowsort
select v2 + 1, sum(v1) from t group by v2 + 1
Expand Down
Loading