diff --git a/src/binder/create_table.rs b/src/binder/create_table.rs index 7bdb2e31..09c34531 100644 --- a/src/binder/create_table.rs +++ b/src/binder/create_table.rs @@ -60,7 +60,7 @@ impl<'a, T: Transaction> Binder<'a, T> { for column_name in column_names.iter().map(|ident| ident.value.to_lowercase()) { if let Some(column) = columns .iter_mut() - .find(|column| column.name() == column_name.to_string()) + .find(|column| column.name() == column_name) { if *is_primary { column.desc.is_primary = true; diff --git a/src/binder/delete.rs b/src/binder/delete.rs index 841cf289..0b1d7d56 100644 --- a/src/binder/delete.rs +++ b/src/binder/delete.rs @@ -23,7 +23,7 @@ impl<'a, T: Transaction> Binder<'a, T> { .find(|(_, column)| column.desc.is_primary) .map(|(_, column)| Arc::clone(column)) .unwrap(); - let mut plan = ScanOperator::build(table_name.clone(), &table_catalog); + let mut plan = ScanOperator::build(table_name.clone(), table_catalog); if let Some(alias) = alias { self.context diff --git a/src/binder/select.rs b/src/binder/select.rs index 957e7ca6..09c05001 100644 --- a/src/binder/select.rs +++ b/src/binder/select.rs @@ -190,7 +190,7 @@ impl<'a, T: Transaction> Binder<'a, T> { table_name: TableName, ) -> Result<(), DatabaseError> { if !alias_column.is_empty() { - let aliases = alias_column.into_iter().map(lower_ident).collect_vec(); + let aliases = alias_column.iter().map(lower_ident).collect_vec(); let table = self .context .table(table_name.clone()) @@ -222,7 +222,7 @@ impl<'a, T: Transaction> Binder<'a, T> { let table_name = Arc::new(table.to_string()); let table_catalog = self.context.table_and_bind(table_name.clone(), join_type)?; - let scan_op = ScanOperator::build(table_name.clone(), &table_catalog); + let scan_op = ScanOperator::build(table_name.clone(), table_catalog); if let Some(TableAlias { name, columns }) = alias { self.register_alias(columns, name.value.to_lowercase(), table_name.clone())?; diff --git a/src/db.rs b/src/db.rs index 23a41ebd..ae5560ee 100644 --- a/src/db.rs +++ b/src/db.rs @@ -158,6 +158,7 @@ impl Database { HepBatchStrategy::fix_point_topdown(10), vec![ NormalizationRuleImpl::CollapseProject, + NormalizationRuleImpl::CollapseGroupByAgg, NormalizationRuleImpl::CombineFilter, ], ) diff --git a/src/optimizer/rule/normalization/combine_operators.rs b/src/optimizer/rule/normalization/combine_operators.rs index 419122a9..bcc66105 100644 --- a/src/optimizer/rule/normalization/combine_operators.rs +++ b/src/optimizer/rule/normalization/combine_operators.rs @@ -7,6 +7,7 @@ use crate::optimizer::rule::normalization::is_subset_exprs; use crate::planner::operator::Operator; use crate::types::LogicalType; use lazy_static::lazy_static; +use std::collections::HashSet; lazy_static! { static ref COLLAPSE_PROJECT_RULE: Pattern = { @@ -27,6 +28,21 @@ lazy_static! { }]), } }; + static ref COLLAPSE_GROUP_BY_AGG: Pattern = { + Pattern { + predicate: |op| match op { + Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(), + _ => false, + }, + children: PatternChildrenPredicate::Predicate(vec![Pattern { + predicate: |op| match op { + Operator::Aggregate(agg_op) => !agg_op.groupby_exprs.is_empty(), + _ => false, + }, + children: PatternChildrenPredicate::None, + }]), + } + }; } /// Combine two adjacent project operators into one. @@ -87,6 +103,47 @@ impl NormalizationRule for CombineFilter { } } +pub struct CollapseGroupByAgg; + +impl MatchPattern for CollapseGroupByAgg { + fn pattern(&self) -> &Pattern { + &COLLAPSE_GROUP_BY_AGG + } +} + +impl NormalizationRule for CollapseGroupByAgg { + fn apply(&self, node_id: HepNodeId, graph: &mut HepGraph) -> Result<(), DatabaseError> { + if let Operator::Aggregate(op) = graph.operator(node_id).clone() { + // if it is an aggregation operator containing agg_call + if !op.agg_calls.is_empty() { + return Ok(()); + } + + if let Some(Operator::Aggregate(child_op)) = graph + .eldest_child_at(node_id) + .and_then(|child_id| Some(graph.operator_mut(child_id))) + { + if op.groupby_exprs.len() != child_op.groupby_exprs.len() { + return Ok(()); + } + let mut expr_set = HashSet::new(); + + for expr in op.groupby_exprs.iter() { + expr_set.insert(expr); + } + for expr in child_op.groupby_exprs.iter() { + expr_set.remove(expr); + } + if expr_set.len() == 0 { + graph.remove_node(node_id, false); + } + } + } + + Ok(()) + } +} + #[cfg(test)] mod tests { use crate::binder::test::select_sql_run; @@ -181,4 +238,26 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_collapse_group_by_agg() -> Result<(), DatabaseError> { + let plan = select_sql_run("select distinct c1, c2 from t1 group by c1, c2").await?; + + let optimizer = HepOptimizer::new(plan.clone()).batch( + "test_collapse_group_by_agg".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::CollapseGroupByAgg], + ); + + let best_plan = optimizer.find_best::(None)?; + + if let Operator::Aggregate(_) = &best_plan.childrens[0].operator { + if let Operator::Aggregate(_) = &best_plan.childrens[0].childrens[0].operator { + unreachable!("Should not be a agg operator") + } else { + return Ok(()); + } + } + unreachable!("Should be a agg operator") + } } diff --git a/src/optimizer/rule/normalization/mod.rs b/src/optimizer/rule/normalization/mod.rs index e143c93c..1d7b895d 100644 --- a/src/optimizer/rule/normalization/mod.rs +++ b/src/optimizer/rule/normalization/mod.rs @@ -4,7 +4,9 @@ use crate::optimizer::core::pattern::Pattern; use crate::optimizer::core::rule::{MatchPattern, NormalizationRule}; use crate::optimizer::heuristic::graph::{HepGraph, HepNodeId}; use crate::optimizer::rule::normalization::column_pruning::ColumnPruning; -use crate::optimizer::rule::normalization::combine_operators::{CollapseProject, CombineFilter}; +use crate::optimizer::rule::normalization::combine_operators::{ + CollapseGroupByAgg, CollapseProject, CombineFilter, +}; use crate::optimizer::rule::normalization::pushdown_limit::{ EliminateLimits, LimitProjectTranspose, PushLimitIntoScan, PushLimitThroughJoin, }; @@ -24,6 +26,7 @@ pub enum NormalizationRuleImpl { ColumnPruning, // Combine operators CollapseProject, + CollapseGroupByAgg, CombineFilter, // PushDown limit LimitProjectTranspose, @@ -44,6 +47,7 @@ impl MatchPattern for NormalizationRuleImpl { match self { NormalizationRuleImpl::ColumnPruning => ColumnPruning.pattern(), NormalizationRuleImpl::CollapseProject => CollapseProject.pattern(), + NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.pattern(), NormalizationRuleImpl::CombineFilter => CombineFilter.pattern(), NormalizationRuleImpl::LimitProjectTranspose => LimitProjectTranspose.pattern(), NormalizationRuleImpl::EliminateLimits => EliminateLimits.pattern(), @@ -62,6 +66,7 @@ impl NormalizationRule for NormalizationRuleImpl { match self { NormalizationRuleImpl::ColumnPruning => ColumnPruning.apply(node_id, graph), NormalizationRuleImpl::CollapseProject => CollapseProject.apply(node_id, graph), + NormalizationRuleImpl::CollapseGroupByAgg => CollapseGroupByAgg.apply(node_id, graph), NormalizationRuleImpl::CombineFilter => CombineFilter.apply(node_id, graph), NormalizationRuleImpl::LimitProjectTranspose => { LimitProjectTranspose.apply(node_id, graph) diff --git a/tests/slt/group_by.slt b/tests/slt/group_by.slt index 7036f3b7..f022ba56 100644 --- a/tests/slt/group_by.slt +++ b/tests/slt/group_by.slt @@ -5,20 +5,20 @@ statement ok insert into t values (0,1,1), (1,2,1), (2,3,2), (3,4,2), (4,5,3) # TODO: check on binder -# statement error -# select v2 + 1, v1 from t group by v2 + 1 +statement error +select v2 + 1, v1 from t group by v2 + 1 -# statement error -# select v2 + 1 as a, v1 as b from t group by a +statement error +select v2 + 1 as a, v1 as b from t group by a -# statement error -# select v2, v2 + 1, sum(v1) from t group by v2 + 1 +statement error +select v2, v2 + 1, sum(v1) from t group by v2 + 1 -# statement error -# select v2 + 2 + count(*) from t group by v2 + 1 +statement error +select v2 + 2 + count(*) from t group by v2 + 1 -# statement error -# select v2 + count(*) from t group by v2 order by v1; +statement error +select v2 + count(*) from t group by v2 order by v1; query II rowsort select v2 + 1, sum(v1) from t group by v2 + 1