Skip to content

Commit

Permalink
Don't allow type vars where they don't belong
Browse files Browse the repository at this point in the history
  • Loading branch information
emk committed Oct 26, 2023
1 parent 527c26a commit 84f51da
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
20 changes: 10 additions & 10 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
errors::{format_err, Result},
scope::{CaseInsensitiveIdent, Scope, ScopeHandle},
tokenizer::{Literal, LiteralValue},
types::{ColumnType, SimpleType, TableType, Type, TypeVar, ValueType},
types::{ColumnType, SimpleType, TableType, Type, ValueType},
};

// TODO: Remember this rather scary example. Verify BigQuery supports it
Expand Down Expand Up @@ -42,7 +42,7 @@ pub trait InferTypes {
}

impl InferTypes for ast::SqlProgram {
type Type = Option<TableType<TypeVar>>;
type Type = Option<TableType>;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let mut ty = None;
Expand Down Expand Up @@ -70,7 +70,7 @@ impl InferTypes for ast::SqlProgram {
}

impl InferTypes for ast::Statement {
type Type = Option<TableType<TypeVar>>;
type Type = Option<TableType>;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
Expand Down Expand Up @@ -118,7 +118,7 @@ impl InferTypes for ast::CreateTableStatement {
let column_decls = columns
.node_iter()
.map(|column| {
let ty = ValueType::<TypeVar>::try_from(&column.data_type)?;
let ty = ValueType::try_from(&column.data_type)?;
Ok(ColumnType {
name: column.name.clone(),
ty,
Expand Down Expand Up @@ -165,7 +165,7 @@ impl InferTypes for ast::DropTableStatement {
}

impl InferTypes for ast::QueryStatement {
type Type = TableType<TypeVar>;
type Type = TableType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let ast::QueryStatement { query_expression } = self;
Expand All @@ -174,7 +174,7 @@ impl InferTypes for ast::QueryStatement {
}

impl InferTypes for ast::QueryExpression {
type Type = TableType<TypeVar>;
type Type = TableType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
Expand All @@ -189,7 +189,7 @@ impl InferTypes for ast::QueryExpression {
}

impl InferTypes for ast::SelectExpression {
type Type = TableType<TypeVar>;
type Type = TableType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
// In order of type inference:
Expand Down Expand Up @@ -247,7 +247,7 @@ impl InferTypes for ast::SelectExpression {
}

impl InferTypes for ast::Expression {
type Type = ValueType<TypeVar>;
type Type = ValueType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
match self {
Expand All @@ -262,7 +262,7 @@ impl InferTypes for ast::Expression {
}

impl InferTypes for LiteralValue {
type Type = ValueType<TypeVar>;
type Type = ValueType;

fn infer_types(&mut self, scope: &ScopeHandle) -> Result<(Self::Type, ScopeHandle)> {
let simple_ty = match self {
Expand All @@ -289,7 +289,7 @@ mod tests {

use super::*;

fn infer(sql: &str) -> Result<(Option<TableType<TypeVar>>, ScopeHandle)> {
fn infer(sql: &str) -> Result<(Option<TableType>, ScopeHandle)> {
let mut program = match parse_sql(Path::new("test.sql"), sql) {
Ok(program) => program,
Err(e) => {
Expand Down
53 changes: 33 additions & 20 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,23 @@ use crate::{

/// Sometimes we want concrete types, and sometimes we want types with type
/// variables. This trait convers both those cases.
pub trait TypeVarSupport: fmt::Display {}
pub trait TypeVarSupport: fmt::Display + Sized {
/// Convert a [`TypeVar`] into a [`SimpleType`], if possible.
fn simple_type_from_type_var(tv: TypeVar) -> Result<SimpleType<Self>, &'static str>;
}

/// This type can never be instantiated. We use this to represent a type variable with
/// all type variables resolved.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ResolvedTypeVarsOnly {}

impl TypeVarSupport for ResolvedTypeVarsOnly {}
impl TypeVarSupport for ResolvedTypeVarsOnly {
fn simple_type_from_type_var(_tv: TypeVar) -> Result<SimpleType<Self>, &'static str> {
// This will be a parser error with `"expected "` prepended. So it's
// hard to word well.
Err("something other than a type variable")
}
}

impl fmt::Display for ResolvedTypeVarsOnly {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -76,7 +85,11 @@ impl TypeVar {
}
}

impl TypeVarSupport for TypeVar {}
impl TypeVarSupport for TypeVar {
fn simple_type_from_type_var(tv: TypeVar) -> Result<SimpleType<Self>, &'static str> {
Ok(SimpleType::Parameter(tv))
}
}

impl fmt::Display for TypeVar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand All @@ -92,7 +105,7 @@ pub enum Type<TV: TypeVarSupport = ResolvedTypeVarsOnly> {
Argument(ArgumentType<TV>),
/// The type of a table, as seen in `CREATE TABLE` statements, or
/// as returned from a sub-`SELECT`, or as passed to `UNNEST`.
Table(TableType<TV>),
Table(TableType),
/// A function type.
Function(FunctionType),
}
Expand Down Expand Up @@ -229,11 +242,11 @@ impl<TV: TypeVarSupport> fmt::Display for StructElementType<TV> {

/// A table type.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TableType<TV: TypeVarSupport = ResolvedTypeVarsOnly> {
pub columns: Vec<ColumnType<TV>>,
pub struct TableType {
pub columns: Vec<ColumnType>,
}

impl<TV: TypeVarSupport> fmt::Display for TableType<TV> {
impl fmt::Display for TableType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TABLE<")?;
for (i, column) in self.columns.iter().enumerate() {
Expand All @@ -249,13 +262,13 @@ impl<TV: TypeVarSupport> fmt::Display for TableType<TV> {

/// A column type.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ColumnType<TV: TypeVarSupport = ResolvedTypeVarsOnly> {
pub struct ColumnType {
pub name: Ident,
pub ty: ValueType<TV>,
pub ty: ValueType,
pub not_null: bool,
}

impl<TV: TypeVarSupport> fmt::Display for ColumnType<TV> {
impl fmt::Display for ColumnType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} {}", BigQueryName(&self.name.name), self.ty)?;
if self.not_null {
Expand Down Expand Up @@ -441,21 +454,21 @@ peg::parser! {
(name, ty)
}

pub rule ty() -> Type<TypeVar>
pub rule ty<TV: TypeVarSupport>() -> Type<TV>
= t:argument_type() { Type::Argument(t) }
/ t:table_type() { Type::Table(t) }
/ t:function_type() { Type::Function(t) }

rule argument_type() -> ArgumentType<TypeVar>
rule argument_type<TV: TypeVarSupport>() -> ArgumentType<TV>
= "Agg" _? "<" _? t:value_type() _? ">" { ArgumentType::Aggregating(t) }
/ t:value_type() { ArgumentType::Value(t) }

rule value_type() -> ValueType<TypeVar>
rule value_type<TV: TypeVarSupport>() -> ValueType<TV>
= t:simple_type() { ValueType::Simple(t) }
/ "ARRAY" _? "<" _? t:simple_type() _? ">" { ValueType::Array(Box::new(t)) }

// Longest match first.
rule simple_type() -> SimpleType<TypeVar>
rule simple_type<TV: TypeVarSupport>() -> SimpleType<TV>
= "BOOL" { SimpleType::Bool }
/ "BYTES" { SimpleType::Bytes }
/ "DATETIME" { SimpleType::Datetime }
Expand All @@ -471,9 +484,9 @@ peg::parser! {
/ "TIME" { SimpleType::Time }
/ "STRUCT" _? "<" _? fields:(struct_field() ** (_? "," _?)) _? ">" {
SimpleType::Struct(StructType { fields }) }
/ type_var:type_var() { SimpleType::Parameter(type_var) }
/ type_var:type_var() {? TV::simple_type_from_type_var(type_var) }

rule struct_field() -> StructElementType<TypeVar>
rule struct_field<TV: TypeVarSupport>() -> StructElementType<TV>
= t:value_type() { StructElementType { name: None, ty: t } }
/ name:ident() _ t:value_type() { StructElementType { name: Some(name), ty: t } }

Expand Down Expand Up @@ -513,14 +526,14 @@ peg::parser! {
/ ".." _? rest_params:value_type() { (Vec::new(), Some(rest_params)) }
/ { (Vec::new(), None) }

rule table_type() -> TableType<TypeVar>
rule table_type() -> TableType
= "TABLE" _? "<" _? columns:(column_type() ** (_? "," _?)) _? ">" {
TableType { columns }
}

rule column_type() -> ColumnType<TypeVar>
= name:ident() _ t:value_type() not_null:not_null() {
ColumnType { name, ty: t, not_null }
rule column_type() -> ColumnType
= name:ident() _ ty:value_type() not_null:not_null() {
ColumnType { name, ty, not_null }
}

rule not_null() -> bool
Expand Down

0 comments on commit 84f51da

Please sign in to comment.