-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix: Don't trigger implicit GROUP BY because of subqueries
The old tree-walker for finding aggregate functions was wrong. We need to do this the hard way.
- Loading branch information
Showing
3 changed files
with
286 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
use crate::{ | ||
ast, | ||
scope::{ColumnSetScope, ScopeGet}, | ||
}; | ||
|
||
/// Interface used to search an AST node for aggregate functions. | ||
pub trait ContainsAggregate { | ||
/// Does this AST node contain an aggregate function? | ||
/// | ||
/// Does not recurse into sub-queries. | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool; | ||
} | ||
|
||
impl<T: ContainsAggregate + ast::Node> ContainsAggregate for Box<T> { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.as_ref().contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl<T: ContainsAggregate + ast::Node> ContainsAggregate for Vec<T> { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.iter().any(|item| item.contains_aggregate(scope)) | ||
} | ||
} | ||
|
||
impl<T: ContainsAggregate + ast::Node> ContainsAggregate for Option<T> { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.as_ref() | ||
.map_or(false, |item| item.contains_aggregate(scope)) | ||
} | ||
} | ||
|
||
impl<T: ContainsAggregate + ast::Node> ContainsAggregate for ast::NodeVec<T> { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.node_iter().any(|item| item.contains_aggregate(scope)) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::SelectList { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.items.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::SelectListItem { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
match self { | ||
ast::SelectListItem::Expression { | ||
expression, | ||
alias: _, | ||
} => expression.contains_aggregate(scope), | ||
ast::SelectListItem::Wildcard { .. } | ||
| ast::SelectListItem::TableNameWildcard { .. } => false, | ||
} | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::Expression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
match self { | ||
ast::Expression::Literal(_) => false, | ||
ast::Expression::BoolValue(_) => false, | ||
ast::Expression::Null(_) => false, | ||
ast::Expression::Interval(interval) => interval.contains_aggregate(scope), | ||
ast::Expression::ColumnName(_) => false, | ||
ast::Expression::Cast(cast) => cast.contains_aggregate(scope), | ||
ast::Expression::Is(is) => is.contains_aggregate(scope), | ||
ast::Expression::In(in_expr) => in_expr.contains_aggregate(scope), | ||
ast::Expression::Between(between) => between.contains_aggregate(scope), | ||
ast::Expression::KeywordBinop(binop) => binop.contains_aggregate(scope), | ||
ast::Expression::Not(not) => not.contains_aggregate(scope), | ||
ast::Expression::If(if_expr) => if_expr.contains_aggregate(scope), | ||
ast::Expression::Case(case) => case.contains_aggregate(scope), | ||
ast::Expression::Binop(binop) => binop.contains_aggregate(scope), | ||
// We never look into sub-queries. | ||
ast::Expression::Query { .. } => false, | ||
ast::Expression::Parens { expression, .. } => expression.contains_aggregate(scope), | ||
ast::Expression::Array(arr) => arr.contains_aggregate(scope), | ||
ast::Expression::Struct(st) => st.contains_aggregate(scope), | ||
// TODO: false if we add `OVER`. | ||
ast::Expression::Count(_) => true, | ||
ast::Expression::CurrentDate(_) => false, | ||
// TODO: false if we add `OVER`. | ||
ast::Expression::ArrayAgg(_) => true, | ||
ast::Expression::SpecialDateFunctionCall(fcall) => fcall.contains_aggregate(scope), | ||
ast::Expression::FunctionCall(fcall) => fcall.contains_aggregate(scope), | ||
ast::Expression::Index(idx) => idx.contains_aggregate(scope), | ||
} | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::IntervalExpression { | ||
fn contains_aggregate(&self, _scope: &ColumnSetScope) -> bool { | ||
// These are currently all literals, so they can't contain aggregates. | ||
false | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::Cast { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.expression.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::IsExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.left.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::InExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.left.contains_aggregate(scope) || self.value_set.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::InValueSet { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
match self { | ||
// We never look into sub-queries. | ||
ast::InValueSet::QueryExpression { .. } => false, | ||
ast::InValueSet::ExpressionList { expressions, .. } => { | ||
expressions.contains_aggregate(scope) | ||
} | ||
// I am doubtful that anything good is happening if we hit this | ||
// case. | ||
ast::InValueSet::Unnest { expression, .. } => expression.contains_aggregate(scope), | ||
} | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::BetweenExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.left.contains_aggregate(scope) | ||
|| self.middle.contains_aggregate(scope) | ||
|| self.right.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::KeywordBinopExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::NotExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.expression.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::IfExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.condition.contains_aggregate(scope) | ||
|| self.then_expression.contains_aggregate(scope) | ||
|| self.else_expression.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::CaseExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.case_expr.contains_aggregate(scope) | ||
|| self.when_clauses.contains_aggregate(scope) | ||
|| self.else_clause.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::CaseWhenClause { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.condition.contains_aggregate(scope) || self.result.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::CaseElseClause { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.result.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::BinopExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.left.contains_aggregate(scope) || self.right.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::ArrayExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.definition.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::ArrayDefinition { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
match self { | ||
// We never look into sub-queries. | ||
ast::ArrayDefinition::Query { .. } => false, | ||
ast::ArrayDefinition::Elements(expressions) => expressions.contains_aggregate(scope), | ||
} | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::StructExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.fields.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::SpecialDateFunctionCall { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
// We don't need to check the function name here. | ||
self.args.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::ExpressionOrDatePart { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
match self { | ||
ast::ExpressionOrDatePart::Expression(expr) => expr.contains_aggregate(scope), | ||
ast::ExpressionOrDatePart::DatePart(_) => false, | ||
} | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::FunctionCall { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
// Check the function we're calling. | ||
if scope | ||
.get_function_type(&self.name) | ||
.map_or(false, |func_ty| func_ty.is_aggregate()) | ||
// If we have an OVER clause, we're not a normal aggregate. | ||
&& self.over_clause.is_none() | ||
{ | ||
return true; | ||
} | ||
|
||
// Check our arguments. | ||
self.args.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::IndexExpression { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
self.expression.contains_aggregate(scope) || self.index.contains_aggregate(scope) | ||
} | ||
} | ||
|
||
impl ContainsAggregate for ast::IndexOffset { | ||
fn contains_aggregate(&self, scope: &ColumnSetScope) -> bool { | ||
match self { | ||
ast::IndexOffset::Simple(expression) => expression.contains_aggregate(scope), | ||
ast::IndexOffset::Offset { expression, .. } => expression.contains_aggregate(scope), | ||
ast::IndexOffset::Ordinal { expression, .. } => expression.contains_aggregate(scope), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
-- Implicit aggregates. | ||
|
||
CREATE TEMP TABLE implicit_agg (val INT64); | ||
INSERT INTO implicit_agg VALUES (1), (2), (3); | ||
|
||
CREATE OR REPLACE TABLE __result1 AS | ||
SELECT SUM(val) AS sum_val | ||
FROM implicit_agg; | ||
|
||
CREATE OR REPLACE TABLE __expected1 ( | ||
sum_val INT64, | ||
); | ||
INSERT INTO __expected1 VALUES | ||
(6); | ||
|
||
-- Now make sure we _don't_ trigger if the aggregate appears in a subquery. | ||
CREATE OR REPLACE TABLE __result2 AS | ||
SELECT x, (SELECT SUM(val) FROM implicit_agg) AS sum_val | ||
FROM (SELECT 1 AS x); | ||
|
||
CREATE OR REPLACE TABLE __expected2 ( | ||
x INT64, | ||
sum_val INT64, | ||
); | ||
INSERT INTO __expected2 VALUES | ||
(1, 6); |