diff --git a/src/ast.rs b/src/ast.rs index 311458b..ee76f1a 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -471,6 +471,14 @@ impl Name { ident.into() } + /// Create a new name of the form `table.column`. + pub fn new_table_column(table: &str, column: &str, span: Span) -> Name { + let mut node_vec = NodeVec::new("."); + node_vec.push(Ident::new(table, span.clone())); + node_vec.push(Ident::new(column, span)); + node_vec.into() + } + /// Split this name into a table name and a column name. pub fn split_table_and_column(&self) -> (Option, Ident) { // No table part. diff --git a/src/infer.rs b/src/infer.rs index db8ac02..995bbd7 100644 --- a/src/infer.rs +++ b/src/infer.rs @@ -5,7 +5,7 @@ use std::collections::HashSet; use crate::{ ast::{self, Name}, errors::{Error, Result}, - scope::{Scope, ScopeHandle}, + scope::{Scope, ScopeGet, ScopeHandle}, tokenizer::{Ident, Literal, LiteralValue, Spanned}, types::{ArgumentType, ColumnType, SimpleType, TableType, Type, ValueType}, unification::{UnificationTable, Unify}, @@ -162,21 +162,19 @@ impl InferTypes for ast::InsertIntoStatement { type Type = (); fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { - let table_type = scope - .get_or_err(&self.table_name)? - .try_as_table_type(&self.table_name)?; + let table_type = scope.get_table_type(&self.table_name)?; match &mut self.inserted_data { ast::InsertedData::Values { rows, .. } => { for row in rows.node_iter_mut() { let (ty, _scope) = row.infer_types(scope)?; - ty.expect_subtype_ignoring_nullability_of(table_type, row)?; + ty.expect_subtype_ignoring_nullability_of(&table_type, row)?; } Ok(((), scope.clone())) } ast::InsertedData::Select { query, .. } => { let (ty, _scope) = query.infer_types(scope)?; - ty.expect_subtype_ignoring_nullability_of(table_type, query)?; + ty.expect_subtype_ignoring_nullability_of(&table_type, query)?; Ok(((), scope.clone())) } } @@ -353,10 +351,8 @@ impl InferTypes for ast::SelectExpression { ast::SelectListItem::TableNameWildcard { table_name, except, .. } => { - let table_type = scope - .get_or_err(table_name)? - .try_as_table_type(table_name)?; - add_table_cols(&mut cols, table_type, except); + let table_type = scope.get_table_type(table_name)?; + add_table_cols(&mut cols, &table_type, except); } } } @@ -389,9 +385,7 @@ impl InferTypes for ast::FromItem { fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { match self { ast::FromItem::TableName { table_name, alias } => { - let table_type = scope - .get_or_err(table_name)? - .try_as_table_type(table_name)?; + let table_type = scope.get_table_type(table_name)?; let name = match alias { Some(alias) => alias.ident.clone().into(), None => table_name.clone(), @@ -407,7 +401,7 @@ impl InferTypes for ast::FromItem { )?; } } - Ok((table_type.clone(), scope.into_handle())) + Ok((table_type, scope.into_handle())) } ast::FromItem::Subquery { .. } => Err(nyi(self, "from subquery")), ast::FromItem::Unnest { .. } => Err(nyi(self, "from unnest")), @@ -462,8 +456,8 @@ impl InferTypes for Ident { fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { let ident = self.to_owned().into(); - let ty = scope.get_or_err(&ident)?.try_as_argument_type(&ident)?; - Ok((ty.to_owned(), scope.clone())) + let ty = scope.get_argument_type(&ident)?; + Ok((ty, scope.clone())) } } @@ -473,14 +467,12 @@ impl InferTypes for ast::Name { fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> { let (table_name, column_name) = self.split_table_and_column(); if let Some(table_name) = table_name { - let table_type = scope - .get_or_err(&table_name)? - .try_as_table_type(&table_name)?; + let table_type = scope.get_table_type(&table_name)?; let column_type = table_type.column_by_name_or_err(&column_name)?; Ok((column_type.ty.to_owned(), scope.clone())) } else { - let ty = scope.get_or_err(self)?.try_as_argument_type(self)?; - Ok((ty.to_owned(), scope.clone())) + let ty = scope.get_argument_type(self)?; + Ok((ty, scope.clone())) } } } @@ -508,9 +500,7 @@ impl InferTypes for ast::IsExpression { // We need to do this manually because our second argument isn't an // expression. let func_name = &Name::new("%IS", self.is_token.span()); - let func_ty = scope - .get_or_err(func_name)? - .try_as_function_type(func_name)?; + let func_ty = scope.get_function_type(func_name)?; let arg_types = [ self.left.infer_types(scope)?.0, self.predicate.infer_types(scope)?.0, @@ -545,9 +535,7 @@ impl InferTypes for ast::InExpression { // We need to do this manually because our second argument isn't an // expression. let func_name = &Name::new("%IN", self.in_token.span()); - let func_ty = scope - .get_or_err(func_name)? - .try_as_function_type(func_name)?; + let func_ty = scope.get_function_type(func_name)?; let left_ty = self.left.infer_types(scope)?.0; let value_set_ty = self.value_set.infer_types(scope)?.0; let elem_ty = value_set_ty.expect_one_column(&self.value_set)?.ty.clone(); @@ -821,9 +809,7 @@ fn infer_call<'args, ArgExprs>( where ArgExprs: IntoIterator, { - let func_ty = scope - .get_or_err(func_name)? - .try_as_function_type(func_name)?; + let func_ty = scope.get_function_type(func_name)?; let mut arg_types = vec![]; for arg in args { arg_types.push(arg.infer_types(scope)?.0); @@ -876,7 +862,7 @@ mod tests { } fn lookup(scope: &ScopeHandle, name: &str) -> Option { - scope.get(&Name::new(name, Span::Unknown)).cloned() + scope.get(&Name::new(name, Span::Unknown)).unwrap() } macro_rules! assert_defines { diff --git a/src/scope.rs b/src/scope.rs index 74b495d..800428c 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,7 +1,21 @@ -//! Namespace for SQL. +//! Scope (aka namespace) support for SQL. +//! +//! This is surprisingly complicated. We need to support: +//! +//! - Built-in functions. +//! - `CREATE` and `DROP` statements that add or remove items from the +//! top-level scope. +//! - CTEs, which add items to a query's scope. +//! - `FROM` and `JOIN` clauses, which create special scopes containing column +//! names. +//! - `USING` clauses, which remove the table name from some columns. +//! - `GROUP BY` and `PARTITION BY` clauses, which aggregate all columns not +//! mentioned in the clause. +//! - There are also implicit aggregations, like `SELECT COUNT(*) FROM t` +//! and `SUM(x) OVER ()`, but we leave those to our callers. use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashMap}, sync::Arc, }; @@ -10,12 +24,52 @@ use crate::{ errors::{format_err, Error, Result}, known_files::KnownFiles, tokenizer::Spanned, - types::{parse_function_decls, ArgumentType, TableType, Type}, + types::{parse_function_decls, ArgumentType, FunctionType, TableType, Type}, }; /// A value we can store in a scope. Details may change. pub type ScopeValue = Type; +/// Common interface to all things that support a scope-like `get` function. +pub trait ScopeGet { + /// Get the [`ScopeValue`] associated with `name`. + /// + /// Returns `Ok(None)` if `name` is not defined. Returns an error if `name` + /// is ambiguous. Returns an owned because some implementations may need to + /// modify types before returning them, and because trying to use + /// [`std::borrow::Cow`] pays too high a "Rust tax". + fn get(&self, name: &Name) -> Result>; + + /// Get a value, or return an error if it is not defined. + fn get_or_err(&self, name: &Name) -> Result { + self.get(name)?.ok_or_else(|| { + Error::annotated( + format!("unknown name: {}", name.unescaped_bigquery()), + name.span(), + "not defined", + ) + }) + } + + /// Look up `name` as an [`ArgumentType`], if possible. + fn get_argument_type(&self, name: &Name) -> Result { + self.get_or_err(name)?.try_as_argument_type(name).cloned() + } + + /// Look up `name` as a [`TableType`], if possible. + fn get_table_type(&self, name: &Name) -> Result { + self.get_or_err(name)?.try_as_table_type(name).cloned() + } + + /// Look up `name` as a [`FunctionType`], if possible. + fn get_function_type(&self, name: &Name) -> Result { + self.get_or_err(name)?.try_as_function_type(name).cloned() + } +} + +/// A handle to a scope. +pub type ScopeHandle = Arc; + /// We need to both define and hide names in a scope. #[derive(Clone, Debug)] enum ScopeEntry { @@ -25,9 +79,6 @@ enum ScopeEntry { Hidden, } -/// A handle to a scope. -pub type ScopeHandle = Arc; - /// A scope is a namespace for SQL code. We use it to look up names, /// and associate them with types and other information. #[derive(Clone, Debug)] @@ -110,32 +161,22 @@ impl Scope { self.names.insert(name.clone(), ScopeEntry::Hidden); Ok(()) } +} - /// Get a value from the scope. - pub fn get<'scope>(&'scope self, name: &Name) -> Option<&'scope ScopeValue> { +impl ScopeGet for Scope { + fn get(&self, name: &Name) -> Result> { match self.names.get(name) { - Some(ScopeEntry::Defined(value)) => Some(value), - Some(ScopeEntry::Hidden) => None, + Some(ScopeEntry::Defined(value)) => Ok(Some(value.clone())), + Some(ScopeEntry::Hidden) => Ok(None), None => { if let Some(parent) = self.parent.as_ref() { parent.get(name) } else { - None + Ok(None) } } } } - - /// Get a value, or return an error if it is not defined. - pub fn get_or_err(&self, name: &Name) -> Result<&ScopeValue> { - self.get(name).ok_or_else(|| { - Error::annotated( - format!("unknown name: {}", name.unescaped_bigquery()), - name.span(), - "not defined", - ) - }) - } } /// Built-in function declarations in the default scope. @@ -211,13 +252,18 @@ TRIM = Fn(STRING) -> STRING; UPPER = Fn(STRING) -> STRING; "; -#[derive(Clone, Debug)] -pub struct ColumnName { +#[derive(Clone, Debug, PartialEq)] +pub struct ColumnSetColumnName { table: Option, column: Name, } -impl ColumnName { +impl ColumnSetColumnName { + /// Create a new column name. + pub fn new(table: Option, column: Name) -> Self { + Self { table, column } + } + /// Does this column name match `name`? `name` may be either a column name /// or a table name and a column name. Matches follow SQL rules, so `x` /// matches `t.x`. @@ -232,32 +278,47 @@ impl ColumnName { } } -#[derive(Clone, Debug)] -pub struct Column { - column_name: Option, +#[derive(Clone, Debug, PartialEq)] +pub struct ColumnSetColumn { + column_name: Option, ty: ArgumentType, } +impl ColumnSetColumn { + /// Create a new column. + pub fn new(column_name: Option, ty: ArgumentType) -> Self { + Self { column_name, ty } + } +} + /// A set of columns and types. /// /// This is output by a `FROM`, `JOIN`, `GROUP BY` or `PARTITION BY` clause. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct ColumnSet { - columns: Vec, + columns: Vec, } impl ColumnSet { + /// Create a new column set. + pub fn new(columns: Vec) -> Self { + Self { columns } + } + /// Build a column set from a table type. - pub fn from_table(table_name: Name, table_type: TableType) -> Self { + /// + /// The `table_name` may be missing if the the table type was defined by + /// something like `FROM (SELECT 1 AS a)` without an `AS` alias. + pub fn from_table(table_name: Option, table_type: TableType) -> Self { let columns = table_type .columns .into_iter() .map(|col| { - let column_name = col.name.map(|col_name| ColumnName { - table: Some(table_name.clone()), + let column_name = col.name.map(|col_name| ColumnSetColumnName { + table: table_name.clone(), column: col_name.into(), }); - Column { + ColumnSetColumn { column_name, ty: col.ty, } @@ -283,9 +344,9 @@ impl ColumnSet { /// copy of each without the table name. pub fn join_using(&self, other: &Self, using: &[Name]) -> Result { // Create a hash map, indicating which columns we have seen so far. - let mut seen_with_type = HashMap::new(); + let mut seen_at_output_index: HashMap> = HashMap::new(); for name in using { - seen_with_type.insert(name.clone(), None); + seen_at_output_index.insert(name.clone(), None); } // Iterate over all our columns, and add them to the output, being sure @@ -298,23 +359,26 @@ impl ColumnSet { let mut columns = vec![]; for col in columns_iter { if let Some(name) = &col.column_name { - match seen_with_type.get_mut(&name.column) { + match seen_at_output_index.get_mut(&name.column) { Some(None) => { // We have not seen this column yet. Add it to the // output, removing the table name. - seen_with_type.insert(name.column.clone(), Some(col.ty.clone())); - columns.push(Column { - column_name: Some(ColumnName { + seen_at_output_index.insert(name.column.clone(), Some(columns.len())); + columns.push(ColumnSetColumn { + column_name: Some(ColumnSetColumnName { table: None, column: name.column.clone(), }), ty: col.ty, }); } - Some(Some(ty)) => { + Some(Some(idx)) => { // We have already seen this column. Make sure the types - // match. - if col.ty.common_supertype(ty).is_none() { + // are compatible, and update the type if necessary. + let ty = &mut columns[*idx].ty; + if let Some(supertype) = col.ty.common_supertype(ty) { + *ty = supertype; + } else { return Err(Error::annotated( format!( "column {} has type {} in one table and type {} in another", @@ -338,6 +402,18 @@ impl ColumnSet { columns.push(col); } } + + // Make sure we saw all the columns in the `USING` clause. + for name in using { + if let Some(None) = seen_at_output_index.get(name) { + return Err(Error::annotated( + format!("column {} not found", name.unescaped_bigquery()), + name.span(), + "not found", + )); + } + } + Ok(Self { columns }) } @@ -356,14 +432,14 @@ impl ColumnSet { // This column is not mentioned in the `GROUP BY` clause. // Wrap it in `ArgumentType::Aggregating` and add it to the // output. - columns.push(Column { + columns.push(ColumnSetColumn { column_name: col.column_name.clone(), ty: ArgumentType::Aggregating(Box::new(col.ty.clone())), }); } } else { // Not mentioned in the `GROUP BY` clause, so aggregate it. - columns.push(Column { + columns.push(ColumnSetColumn { column_name: col.column_name.clone(), ty: ArgumentType::Aggregating(Box::new(col.ty.clone())), }); @@ -371,9 +447,10 @@ impl ColumnSet { } Ok(Self { columns }) } +} - /// Look up a column type by name. Returns an error if ambiguous. - pub fn get(&self, name: &Name) -> Result<&ArgumentType> { +impl ScopeGet for ColumnSet { + fn get(&self, name: &Name) -> Result> { let mut matches = vec![]; for col in &self.columns { if let Some(column_name) = &col.column_name { @@ -383,12 +460,8 @@ impl ColumnSet { } } match matches.len() { - 0 => Err(Error::annotated( - format!("unknown column: {}", name.unescaped_bigquery()), - name.span(), - "not defined", - )), - 1 => Ok(matches[0]), + 0 => Ok(None), + 1 => Ok(Some(Type::Argument(matches[0].clone()))), _ => Err(Error::annotated( format!("ambiguous column: {}", name.unescaped_bigquery()), name.span(), @@ -398,12 +471,135 @@ impl ColumnSet { } } +/// Wraps a [`ColumnSet`] and allows it to have a [`Scope`] as a parent. +#[derive(Clone, Debug)] +struct ColumnSetScope { + parent: ScopeHandle, + column_set: ColumnSet, +} + +impl ColumnSetScope { + /// Create a new column set scope. + pub fn new(parent: &ScopeHandle, column_set: ColumnSet) -> Self { + Self { + parent: parent.clone(), + column_set, + } + } + + /// Try to transform the underlying [`ColumnSet`] using `f`. + pub fn try_transform(self, f: F) -> Result + where + F: FnOnce(ColumnSet) -> Result, + { + Ok(Self { + parent: self.parent.clone(), + column_set: f(self.column_set)?, + }) + } +} + +impl ScopeGet for ColumnSetScope { + fn get(&self, name: &Name) -> Result> { + match self.column_set.get(name) { + Ok(Some(value)) => Ok(Some(value)), + Ok(None) => self.parent.get(name), + Err(err) => Err(err), + } + } +} + #[cfg(test)] mod tests { use super::*; + use crate::{ + tokenizer::Span, + types::tests::{column_set, ty}, + }; + #[test] fn parse_built_in_functions() { Scope::root(); } + + /// Make a column set for a test. + fn column_set_from(name: &str, table_ty: &str) -> ColumnSet { + let t1_name = Name::new("t1", Span::Unknown); + let table_type = ty(table_ty).try_as_table_type(&t1_name).unwrap().clone(); + ColumnSet::from_table(Some(Name::new(name, Span::Unknown)), table_type) + } + + #[test] + fn join_column_sets() { + let left = column_set_from("t1", "TABLE"); + let right = column_set_from("t2", "TABLE"); + let joined = left.join(&right); + let expected = column_set("t1.a INT64, t1.b STRING, t2.a INT64, t2.c STRING"); + assert_eq!(joined, expected); + } + + #[test] + fn join_using_column_sets() { + let left = column_set_from("t1", "TABLE"); + let right = column_set_from("t2", "TABLE"); + let joined = left + .join_using(&right, &[Name::new("a", Span::Unknown)]) + .unwrap(); + let expected = column_set("a FLOAT64, t1.b STRING, t2.c STRING"); + assert_eq!(joined, expected); + } + + #[test] + fn join_with_overlapping_columns_includes_both() { + let left = column_set_from("t1", "TABLE"); + let right = column_set_from("t2", "TABLE"); + let joined = left.join(&right); + let expected = column_set("t1.a INT64, t1.b STRING, t2.a INT64, t2.b STRING"); + assert_eq!(joined, expected); + } + + #[test] + fn group_by() { + let input = column_set("a INT64, b STRING, c INT64"); + let group_by = vec![Name::new("a", Span::Unknown)]; + let output = input.group_by(&group_by).unwrap(); + let expected = column_set("a INT64, b Agg, c Agg"); + assert_eq!(output, expected); + } + + #[test] + fn group_by_empty() { + let input = column_set("a INT64, b STRING, c INT64"); + let output = input.group_by(&[]).unwrap(); + let expected = column_set("a Agg, b Agg, c Agg"); + assert_eq!(output, expected); + } + + #[test] + fn column_set_get() { + // Check absent, present, and ambiguous columns. + let column_set = column_set("t1.a INT64, t1.b STRING, t2.b FLOAT64"); + assert_eq!( + column_set.get(&Name::new("a", Span::Unknown)).unwrap(), + Some(ty("INT64")) + ); + assert!(column_set.get(&Name::new("b", Span::Unknown)).is_err()); + assert_eq!( + column_set + .get(&Name::new_table_column("t1", "b", Span::Unknown)) + .unwrap(), + Some(ty("STRING")) + ); + assert_eq!( + column_set + .get(&Name::new_table_column("t2", "b", Span::Unknown)) + .unwrap(), + Some(ty("FLOAT64")) + ); + assert_eq!( + column_set.get(&Name::new("c", Span::Unknown)).unwrap(), + None + ); + } } diff --git a/src/types.rs b/src/types.rs index 366d4d1..2f2d98e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -21,6 +21,7 @@ use crate::{ drivers::bigquery::BigQueryName, errors::{format_err, Error, Result}, known_files::{FileId, KnownFiles}, + scope::{ColumnSet, ColumnSetColumn, ColumnSetColumnName}, tokenizer::{Ident, Span, Spanned}, unification::{UnificationTable, Unify}, util::is_c_ident, @@ -118,21 +119,8 @@ impl Type { } } - /// Convert this type into a [`ValueType`], if possible. - #[allow(dead_code)] - pub fn try_as_value_type(&self, spanned: &dyn Spanned) -> Result<&ValueType> { - match self { - Type::Argument(ArgumentType::Value(t)) => Ok(t), - _ => Err(Error::annotated( - format!("expected value type, found {}", self), - spanned.span(), - "type mismatch", - )), - } - } - /// Convert this type into a [`TableType`], if possible. - pub fn try_as_table_type(&self, spanned: &dyn Spanned) -> Result<&TableType> { + pub fn try_as_table_type<'ty>(&'ty self, spanned: &dyn Spanned) -> Result<&'ty TableType> { match self { Type::Table(t) => Ok(t), _ => Err(Error::annotated( @@ -1156,6 +1144,29 @@ pub fn parse_function_decls( peg::parser! { grammar type_grammar() for str { + /// Used by [`tests::column_set`] as a test helper for working with + /// [`ColumnSet`]. + pub rule column_set() -> ColumnSet + = _? columns:(column_set_column() ** (_? "," _?)) _? { + ColumnSet::new(columns) + } + + rule column_set_column() -> ColumnSetColumn + = column_name:column_set_column_name() _? ty:argument_type() { + ColumnSetColumn::new(Some(column_name), ty) + } + / ty:argument_type() { + ColumnSetColumn::new(None, ty) + } + + rule column_set_column_name() -> ColumnSetColumnName + = table_name:ident() _? "." _? column:ident() { + ColumnSetColumnName::new(Some(table_name.into()), column.into()) + } + / column:ident() { + ColumnSetColumnName::new(None, column.into()) + } + pub rule function_decls() -> Vec<(Ident, FunctionType)> = _? decls:(function_decl() ** (_? ";" _?)) _? (";" _?)? { decls @@ -1299,6 +1310,21 @@ pub mod tests { } } + /// Parse a column set declaration. + pub fn column_set(s: &str) -> ColumnSet { + // We use local `KnownFiles` here, because we panic on parse errors, and + // our caller doesn't need to know we parse at all. + let mut files = KnownFiles::new(); + let file_id = files.add_string("column set declaration", s); + match parse_helper(&files, file_id, type_grammar::column_set) { + Ok(column_set) => column_set, + Err(e) => { + e.emit(&files); + panic!("parse error"); + } + } + } + #[test] fn common_supertype() { let examples = &[ diff --git a/tests/sql/functions/windows/README.md b/tests/sql/functions/windows/README.md new file mode 100644 index 0000000..9f961a5 --- /dev/null +++ b/tests/sql/functions/windows/README.md @@ -0,0 +1,6 @@ +# Window function tests + +TODO: + +- [ ] What would `SUM(x) + y OVER ()` do to the type of `y`? This isn't + valid BigQuery SQL, so maybe we can ignore it.