Skip to content

Commit

Permalink
Infer aliased and anonymous columns
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Oct 26, 2023
1 parent f940946 commit 2385edd
Showing 1 changed file with 64 additions and 28 deletions.
92 changes: 64 additions & 28 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
ast::{self, SelectList},
errors::{format_err, Result},
scope::{CaseInsensitiveIdent, Scope, ScopeHandle},
tokenizer::{Literal, LiteralValue},
tokenizer::{Ident, Literal, LiteralValue, Spanned},
types::{ArgumentType, ColumnType, SimpleType, TableType, Type, ValueType},
};

Expand Down Expand Up @@ -74,11 +74,9 @@ impl InferTypes for ast::Statement {

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
// TODO: This can't bind anything into our scope, but we should
// check types anyway.
ast::Statement::Query(stmt) => {
let (ty, scope) = stmt.infer_types(scope)?;
Ok((Some(ty), scope))
let (ty, _scope) = stmt.infer_types(scope)?;
Ok((Some(ty), scope.clone()))
}
ast::Statement::DeleteFrom(_) => todo!("DELETE FROM"),
ast::Statement::InsertInto(_) => todo!("INSERT INTO"),
Expand Down Expand Up @@ -248,36 +246,27 @@ impl InferTypes for ast::SelectExpression {
((), scope) = from_clause.infer_types(&scope)?;
}

let mut next_anon_col_id: u64 = 0;
let mut cols = vec![];
for item in select_list.node_iter_mut() {
match item {
ast::SelectListItem::Expression {
expression: ast::Expression::ColumnName(ident),
alias: None,
} => match scope.get(&ident.to_owned().into()) {
Some(Type::Argument(ArgumentType::Value(ty))) => {
cols.push(ColumnType {
name: ident.clone(),
ty: ty.clone(),
not_null: false,
});
}
Some(ty) => Err(format_err!(
"column {:?} is does not have a value type: {:?}",
ident,
ty
))?,
None => Err(format_err!("column {:?} not found", ident))?,
},
ast::SelectListItem::Expression {
expression,
alias: Some(ast::Alias { ident, .. }),
} => {
ast::SelectListItem::Expression { expression, alias } => {
// BigQuery does not allow select list items to see names
// bound by other select list items.
let (ty, _scope) = expression.infer_types(&scope)?;
let name = alias
.infer_column_name()
.or_else(|| expression.infer_column_name())
.unwrap_or_else(|| {
// We try to predict how BigQuery will name these.
let ident =
Ident::new(&format!("_f{}", next_anon_col_id), expression.span());
next_anon_col_id += 1;
ident
});

cols.push(ColumnType {
name: ident.clone(),
name,
ty,
not_null: false,
});
Expand Down Expand Up @@ -355,6 +344,17 @@ impl InferTypes for ast::Expression {
scope.clone(),
)),
ast::Expression::Literal(Literal { value, .. }) => value.infer_types(scope),
ast::Expression::ColumnName(ident) => {
let ident = ident.to_owned().into();
let ty = scope
.get(&ident)
.ok_or_else(|| format_err!("column {:?} not found in scope", ident))?;
let ty = match ty {
Type::Argument(ArgumentType::Value(ty)) => ty,
_ => Err(format_err!("column {:?} is not a value: {:?}", ident, ty))?,
};
Ok((ty.to_owned(), scope.clone()))
}
_ => todo!("expression"),
}
}
Expand All @@ -373,6 +373,36 @@ impl InferTypes for LiteralValue {
}
}

/// Figure out whether an expression defines an implicit column name.
pub trait InferColumnName {
/// Infer the column name, if any.
fn infer_column_name(&mut self) -> Option<Ident>;
}

impl<T: InferColumnName> InferColumnName for Option<T> {
fn infer_column_name(&mut self) -> Option<Ident> {
match self {
Some(expr) => expr.infer_column_name(),
None => None,
}
}
}

impl InferColumnName for ast::Expression {
fn infer_column_name(&mut self) -> Option<Ident> {
match self {
ast::Expression::ColumnName(ident) => Some(ident.clone()),
_ => None,
}
}
}

impl InferColumnName for ast::Alias {
fn infer_column_name(&mut self) -> Option<Ident> {
Some(self.ident.clone())
}
}

#[cfg(test)]
mod tests {
use std::path::Path;
Expand Down Expand Up @@ -456,6 +486,12 @@ SELECT x FROM t2";
assert_defines!(scope, "foo", "TABLE<x STRING>");
}

#[test]
fn anon_and_aliased_columns() {
let (_, scope) = infer("CREATE TABLE foo AS SELECT 1, 2 AS x, 3").unwrap();
assert_defines!(scope, "foo", "TABLE<_f0 INT64, x INT64, _f1 INT64>");
}

// #[test]
// fn columns_scoped_by_table() {
// let sql = "
Expand Down

0 comments on commit 2385edd

Please sign in to comment.