diff --git a/src/analyze.rs b/src/analyze.rs index 4ededbf..e2ea909 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -5,7 +5,11 @@ use std::collections::HashMap; use derive_visitor::{Drive, Visitor}; -use crate::ast::{FunctionCall, SpecialDateFunctionCall, SqlProgram}; +use crate::{ + ast::{FunctionCall, Name, Node, NodeVec, SpecialDateFunctionCall, SqlProgram}, + scope::{Scope, ScopeGet, ScopeHandle}, + tokenizer::Span, +}; /// A `phf` set of functions that are known to take any number of arguments. static KNOWN_VARARG_FUNCTIONS: phf::Set<&'static str> = phf::phf_set! { @@ -13,9 +17,10 @@ static KNOWN_VARARG_FUNCTIONS: phf::Set<&'static str> = phf::phf_set! { }; /// Count all the function calls in a [`SqlProgram`]. -#[derive(Debug, Default, Visitor)] +#[derive(Debug, Visitor)] #[visitor(FunctionCall(enter), SpecialDateFunctionCall(enter))] pub struct FunctionCallCounts { + root_scope: ScopeHandle, counts: HashMap, } @@ -25,6 +30,18 @@ impl FunctionCallCounts { sql_program.drive(self) } + /// Return true if we have at least one signature for a function which + /// could be called with the given number of arguments. + fn is_known_function_and_airty(&self, name: &str, args: &NodeVec) -> bool { + let name = Name::new(name, Span::Unknown); + let ftype = match self.root_scope.get_function_type(&name) { + Ok(ftype) => ftype, + Err(_) => return false, + }; + let arg_count = args.node_iter().count(); + ftype.could_be_called_with_arg_count(arg_count) + } + fn record_call(&mut self, name: String) { let count = self.counts.entry(name).or_default(); *count += 1; @@ -48,6 +65,9 @@ impl FunctionCallCounts { if function_call.over_clause.is_some() { name.push_str(" OVER(..)"); } + if !self.is_known_function_and_airty(&base_name, &function_call.args) { + name.push_str(" (UNKNOWN)"); + } self.record_call(name); } @@ -55,14 +75,12 @@ impl FunctionCallCounts { &mut self, special_date_function_call: &SpecialDateFunctionCall, ) { - let mut name = format!( - "{}(", - special_date_function_call - .function_name - .ident - .name - .to_ascii_uppercase(), - ); + let base_name = special_date_function_call + .function_name + .ident + .name + .to_ascii_uppercase(); + let mut name = format!("{}(", base_name); for (i, _) in special_date_function_call.args.node_iter().enumerate() { if i > 0 { name.push(','); @@ -70,6 +88,9 @@ impl FunctionCallCounts { name.push('_'); } name.push_str(") (special)"); + if !self.is_known_function_and_airty(&base_name, &special_date_function_call.args) { + name.push_str(" (UNKNOWN)"); + } self.record_call(name); } @@ -81,3 +102,12 @@ impl FunctionCallCounts { counts } } + +impl Default for FunctionCallCounts { + fn default() -> Self { + Self { + root_scope: Scope::root(), + counts: HashMap::new(), + } + } +} diff --git a/src/ast.rs b/src/ast.rs index f83debd..4582692 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -644,22 +644,33 @@ pub struct QueryStatement { /// [official grammar]: /// https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#sql_syntax. #[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] -pub enum QueryExpression { - SelectExpression(SelectExpression), +pub struct QueryExpression { + pub with_clause: Option, + pub query: QueryExpressionQuery, + pub order_by: Option, + pub limit: Option, +} + +/// The `WITH` clause of a `QueryExpression`. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] +pub struct QueryExpressionWithClause { + pub with_token: Keyword, + pub ctes: NodeVec, +} + +/// The actual query portion of a `QueryExpression`. +#[derive(Clone, Debug, Drive, DriveMut, Emit, EmitDefault, Spanned, ToTokens)] +pub enum QueryExpressionQuery { + Select(SelectExpression), Nested { paren1: Punct, query: Box, paren2: Punct, }, - With { - with_token: Keyword, - ctes: NodeVec, - query: Box, - }, SetOperation { - left: Box, + left: Box, set_operator: SetOperator, - right: Box, + right: Box, }, } @@ -2139,28 +2150,47 @@ peg::parser! { } } - pub rule query_expression() -> QueryExpression = precedence! { + pub rule query_expression() -> QueryExpression = + with_clause:query_expression_with_clause()? + query:query_expression_query() + order_by:order_by()? + limit:limit()? + { + QueryExpression { + with_clause, + query, + order_by, + limit, + } + } + + rule query_expression_with_clause() -> QueryExpressionWithClause + = with_token:k("WITH") ctes:sep_opt_trailing(, ",") { + QueryExpressionWithClause { + with_token, + ctes, + } + } + + pub rule query_expression_query() -> QueryExpressionQuery = precedence! { left:(@) set_operator:set_operator() right:@ { - QueryExpression::SetOperation { - left: Box::new(left), set_operator, right: Box::new(right) + QueryExpressionQuery::SetOperation { + left: Box::new(left), + set_operator, + right: Box::new(right), } } -- - select_expression:select_expression() { QueryExpression::SelectExpression(select_expression) } + select_expression:select_expression() { + QueryExpressionQuery::Select(select_expression) + } paren1:p("(") query:query_statement() paren2:p(")") { - QueryExpression::Nested { + QueryExpressionQuery::Nested { paren1, query: Box::new(query), paren2, } } - with_token:k("WITH") ctes:sep_opt_trailing(, ",") query:query_statement() { - QueryExpression::With { - with_token, - ctes, - query: Box::new(query), - } - } } rule set_operator() -> SetOperator @@ -2587,6 +2617,7 @@ peg::parser! { rule special_date_function_name() -> PseudoKeyword = pk("DATE_DIFF") / pk("DATE_TRUNC") / pk("DATE_ADD") / pk("DATE_SUB") / pk("DATETIME_DIFF") / pk("DATETIME_TRUNC") / pk("DATETIME_ADD") / pk("DATETIME_SUB") + / pk("GENERATE_DATE_ARRAY") rule special_date_expression() -> SpecialDateExpression = interval:interval_expression() { SpecialDateExpression::Interval(interval) } diff --git a/src/cmd/parse.rs b/src/cmd/parse.rs index 832d8d8..52f141f 100644 --- a/src/cmd/parse.rs +++ b/src/cmd/parse.rs @@ -70,7 +70,7 @@ pub fn cmd_parse(files: &mut KnownFiles, opt: &ParseOpt) -> Result<()> { Ok(sql_program) => { ok_count += 1; ok_line_count += row.query.lines().count(); - println!("OK {}", row.id); + //println!("OK {}", row.id); if opt.count_function_calls { function_call_counts.visit(&sql_program); } diff --git a/src/infer/mod.rs b/src/infer/mod.rs index 0ca98c2..08bcc7f 100644 --- a/src/infer/mod.rs +++ b/src/infer/mod.rs @@ -225,25 +225,70 @@ impl InferTypes for ast::QueryExpression { type Scope = ScopeHandle; type Output = TableType; + fn infer_types(&mut self, scope: &ScopeHandle) -> Result { + let ast::QueryExpression { + with_clause, + query, + order_by, + limit, + } = self; + + let mut scope = scope.clone(); + if let Some(with_clause) = with_clause { + scope = with_clause.infer_types(&scope)?; + } + let (ty, column_set_scope) = query.infer_types(&scope)?; + if let Some(order_by) = order_by { + let column_set_scope = column_set_scope + .unwrap_or_else(|| ColumnSetScope::new_from_table_type(&scope, &ty)); + order_by.infer_types(&column_set_scope)?; + } + if let Some(limit) = limit { + limit.infer_types(&())?; + } + Ok(ty) + } +} + +impl InferTypes for ast::QueryExpressionWithClause { + type Scope = ScopeHandle; + type Output = ScopeHandle; + + fn infer_types(&mut self, scope: &ScopeHandle) -> Result { + let mut scope = scope.clone(); + for cte in self.ctes.node_iter_mut() { + scope = cte.infer_types(&scope)?; + } + Ok(scope) + } +} + +impl InferTypes for ast::QueryExpressionQuery { + type Scope = ScopeHandle; + + /// We return both a `TableType` and _possibly_ a `ColumnSetScope` because + /// `ORDER BY` may need a full `ColumnSetScope` to support things like + /// `ORDER BY table1.col1, table2.col2`. But the `ColumnSetScope` is easily + /// lost in more complicated cases. See the test `order_and_limit.sql` for + /// example code. + type Output = (TableType, Option); + fn infer_types(&mut self, scope: &ScopeHandle) -> Result { match self { - ast::QueryExpression::SelectExpression(expr) => expr.infer_types(scope), - ast::QueryExpression::Nested { query, .. } => query.infer_types(scope), - ast::QueryExpression::With { ctes, query, .. } => { - // Non-recursive CTEs, so each will create a new namespace. - let mut scope = scope.to_owned(); - for cte in ctes.node_iter_mut() { - scope = cte.infer_types(&scope)?; - } - query.infer_types(&scope) + ast::QueryExpressionQuery::Select(expr) => { + let (ty, column_set_scope) = expr.infer_types(scope)?; + Ok((ty, Some(column_set_scope))) + } + ast::QueryExpressionQuery::Nested { query, .. } => { + Ok((query.infer_types(scope)?, None)) } - ast::QueryExpression::SetOperation { + ast::QueryExpressionQuery::SetOperation { left, set_operator, right, } => { - let left_ty = left.infer_types(scope)?; - let right_ty = right.infer_types(scope)?; + let (left_ty, _) = left.infer_types(scope)?; + let (right_ty, _) = right.infer_types(scope)?; let result_ty = left_ty.common_supertype(&right_ty).ok_or_else(|| { Error::annotated( format!("cannot combine {} and {}", left_ty, right_ty), @@ -251,7 +296,7 @@ impl InferTypes for ast::QueryExpression { "incompatible types", ) })?; - Ok(result_ty) + Ok((result_ty, None)) } } } @@ -271,7 +316,11 @@ impl InferTypes for ast::CommonTableExpression { impl InferTypes for ast::SelectExpression { type Scope = ScopeHandle; - type Output = TableType; + + /// We return both a `TableType` and a `ColumnSetScope` because `ORDER BY` + /// may need a full `ColumnSetScope` to support things like `ORDER BY + /// table1.col1, table2.col2`, which is allowed in certain contexts. + type Output = (TableType, ColumnSetScope); fn infer_types(&mut self, scope: &ScopeHandle) -> Result { // In order of type inference: @@ -431,7 +480,7 @@ impl InferTypes for ast::SelectExpression { } *ty = Some(table_type.clone()); - Ok(table_type) + Ok((table_type, column_set_scope)) } } diff --git a/src/scope.rs b/src/scope.rs index 6f164f3..cb7b7a1 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -614,6 +614,16 @@ impl ColumnSetScope { } } + /// Create a column set scope from a [`TableType`]. This is used to + /// implement some of the more complicated rules around things like `(... + /// UNION ALL ...) ORDER BY ...` + pub fn new_from_table_type(parent: &ScopeHandle, table_type: &TableType) -> Self { + Self { + parent: parent.to_owned(), + column_set: ColumnSet::from_table(None, table_type.to_owned()), + } + } + /// Get our [`ColumnSet`]. pub fn column_set(&self) -> &ColumnSet { &self.column_set diff --git a/src/tokenizer.rs b/src/tokenizer.rs index a9953f1..484c6bd 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -616,10 +616,16 @@ impl TokenStream { } /// Try to parse this stream as a [`ast::QueryExpression`]. + #[allow(dead_code)] pub fn try_into_query_expression(self) -> Result { self.try_into_parsed(ast::sql_program::query_expression) } + /// Try to parse this stream as a [`ast::QueryExpressionQuery`]. + pub fn try_into_query_expression_query(self) -> Result { + self.try_into_parsed(ast::sql_program::query_expression_query) + } + /// Try to parse this stream as a [`ast::SelectExpression`]. pub fn try_into_select_expression(self) -> Result { self.try_into_parsed(ast::sql_program::select_expression) @@ -1076,6 +1082,11 @@ peg::parser! { let value = LiteralValue::String(s.into_iter().collect()); Literal { token, value } } } + / quiet! { s_and_token:t(<"\"" s:(([^ '\\' | '\"'] / escape())*) "\"" { s }>) { + let (s, token) = s_and_token; + let value = LiteralValue::String(s.into_iter().collect()); + Literal { token, value } + } } / quiet! { s_and_token:t(<"r'" s:[^ '\'']* "'" { s }>) { let (s, token) = s_and_token; let value = LiteralValue::String(s.into_iter().collect()); diff --git a/src/transforms/wrap_nested_queries.rs b/src/transforms/wrap_nested_queries.rs index 0a66438..d59c630 100644 --- a/src/transforms/wrap_nested_queries.rs +++ b/src/transforms/wrap_nested_queries.rs @@ -2,7 +2,7 @@ use derive_visitor::{DriveMut, VisitorMut}; use joinery_macros::sql_quote; use crate::{ - ast::{self, QueryExpression}, + ast::{self, QueryExpressionQuery}, errors::Result, }; @@ -12,21 +12,21 @@ use super::{Transform, TransformExtra}; /// Needed for SQLite, which doesn't allow using parentheses when working with /// `UNION` or `EXCEPT`. #[derive(VisitorMut)] -#[visitor(QueryExpression(enter))] +#[visitor(QueryExpressionQuery(enter))] pub struct WrapNestedQueries; impl WrapNestedQueries { - fn enter_query_expression(&mut self, query_expression: &mut QueryExpression) { - if let QueryExpression::Nested { + fn enter_query_expression_query(&mut self, query_expression_query: &mut QueryExpressionQuery) { + if let QueryExpressionQuery::Nested { paren1, query, paren2, - } = query_expression + } = query_expression_query { let replacement = sql_quote! { SELECT * FROM #paren1 #query #paren2 } - .try_into_query_expression() + .try_into_query_expression_query() .expect("generated SQL should always parse"); - *query_expression = replacement; + *query_expression_query = replacement; } } } diff --git a/src/types.rs b/src/types.rs index 20cf1ba..42ff8a0 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1125,6 +1125,15 @@ pub struct FunctionType { } impl FunctionType { + /// Could this function be called with the specified number of arguments? + /// We use this to search for unknown functions in SQL code we can't type + /// check. + pub fn could_be_called_with_arg_count(&self, arg_count: usize) -> bool { + self.signatures + .iter() + .any(|sig| sig.could_be_called_with_arg_count(arg_count)) + } + /// Is this an aggregate function? /// /// This is used to detect an implicit `GROUP BY` clause, like in `SELECT @@ -1219,6 +1228,17 @@ pub struct FunctionSignature { } impl FunctionSignature { + /// Could this function be called with the specified number of arguments, + /// ignoring the argument types? We use this to search for unknown functions + /// in SQL code we can't type check. + pub fn could_be_called_with_arg_count(&self, arg_count: usize) -> bool { + if self.rest_params.is_some() { + arg_count >= self.params.len() + } else { + arg_count == self.params.len() + } + } + /// Does this signature match a set of argument types? /// /// TODO: Distinguish between failed matches and errors. diff --git a/tests/sql/queries/set_operations/order_and_limit.sql b/tests/sql/queries/set_operations/order_and_limit.sql new file mode 100644 index 0000000..e311629 --- /dev/null +++ b/tests/sql/queries/set_operations/order_and_limit.sql @@ -0,0 +1,65 @@ +-- ORDER BY and LIMIT can be applied to set operations. + +CREATE OR REPLACE TABLE __result1 AS +( + SELECT 1 AS i + UNION ALL + SELECT 3 + UNION ALL + SELECT 2 +) +ORDER BY i ASC +LIMIT 2; + +CREATE OR REPLACE TABLE __expected1 ( + i INT64, +); +INSERT INTO __expected1 VALUES + (1), + (2); + +-- Now let's try hard mode. This actually works in BigQuery.and Trino, +-- unchanged. +CREATE OR REPLACE TABLE __result2 AS +WITH names1 AS ( + SELECT 1 AS id, 'a' AS name +), +names2 AS ( + SELECT 2 AS id, 'c' AS name +), +streets AS ( + SELECT 1 AS id, 'b' AS street + UNION ALL + SELECT 2 AS id, 'd' AS street +) + +( + ( + SELECT * FROM names1 AS n + JOIN streets USING (id) + -- Can use streets.street here. + ORDER BY streets.street + ) + + UNION ALL + + ( + ( + SELECT * FROM names2 AS n + JOIN streets USING (id) + ) + -- Cannot use streets.street here. + ORDER BY street + ) +) +-- Cannot use streets.street here. +ORDER BY name, street; + +CREATE OR REPLACE TABLE __expected2 ( + id INT64, + name STRING, + street STRING +); +INSERT INTO __expected2 VALUES + (1, 'a', 'b'), + (2, 'c', 'd');