From b9d9b8dbbfdcbae51e18fcc53d4b5b488cf152ca Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 18 Dec 2024 09:32:37 -0800 Subject: [PATCH] fix: qualified projection columns --- crates/expr/src/check.rs | 1 + crates/expr/src/statement.rs | 58 ++++++++++++++++++++++++++ crates/sql-parser/src/ast/mod.rs | 32 ++++++++++++++ crates/sql-parser/src/ast/sql.rs | 27 ++++++++++++ crates/sql-parser/src/ast/sub.rs | 12 ++++++ crates/sql-parser/src/parser/errors.rs | 2 + crates/sql-parser/src/parser/sql.rs | 6 ++- crates/sql-parser/src/parser/sub.rs | 4 +- 8 files changed, 140 insertions(+), 2 deletions(-) diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index b0fbb04ddca..5fa2cb0d7d8 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -208,6 +208,7 @@ mod tests { ( "t", ProductType::from([ + ("int", AlgebraicType::U32), ("u32", AlgebraicType::U32), ("f32", AlgebraicType::F32), ("str", AlgebraicType::String), diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index 4fac99cbaa5..b9fe8ca8069 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -286,3 +286,61 @@ pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult< source: StatementSource::Query, }) } + +#[cfg(test)] +mod tests { + use spacetimedb_lib::{AlgebraicType, ProductType}; + use spacetimedb_schema::def::ModuleDef; + + use crate::{ + check::test_utils::{build_module_def, SchemaViewer}, + statement::parse_and_type_sql, + }; + + fn module_def() -> ModuleDef { + build_module_def(vec![ + ( + "t", + ProductType::from([ + ("u32", AlgebraicType::U32), + ("f32", AlgebraicType::F32), + ("str", AlgebraicType::String), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ]), + ), + ( + "s", + ProductType::from([ + ("id", AlgebraicType::identity()), + ("u32", AlgebraicType::U32), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ("bytes", AlgebraicType::bytes()), + ]), + ), + ]) + } + + #[test] + fn valid() { + let tx = SchemaViewer(module_def()); + + for sql in [ + "select str from t", + "select str, arr from t", + "select t.str, arr from t", + ] { + let result = parse_and_type_sql(sql, &tx); + assert!(result.is_ok()); + } + } + + #[test] + fn invalid() { + let tx = SchemaViewer(module_def()); + + // Unqualified columns in a join + let sql = "select id, str from s join t"; + let result = parse_and_type_sql(sql, &tx); + assert!(result.is_err()); + } +} diff --git a/crates/sql-parser/src/ast/mod.rs b/crates/sql-parser/src/ast/mod.rs index 2149b48f51f..2671278a75c 100644 --- a/crates/sql-parser/src/ast/mod.rs +++ b/crates/sql-parser/src/ast/mod.rs @@ -12,6 +12,15 @@ pub enum SqlFrom { Join(SqlIdent, SqlIdent, Vec), } +impl SqlFrom { + pub fn has_unqualified_vars(&self) -> bool { + match self { + Self::Join(_, _, joins) => joins.iter().any(|join| join.has_unqualified_vars()), + _ => false, + } + } +} + /// An inner join in a FROM clause #[derive(Debug)] pub struct SqlJoin { @@ -20,6 +29,12 @@ pub struct SqlJoin { pub on: Option, } +impl SqlJoin { + pub fn has_unqualified_vars(&self) -> bool { + self.on.as_ref().is_some_and(|expr| expr.has_unqualified_vars()) + } +} + /// A projection expression in a SELECT clause #[derive(Debug)] pub struct ProjectElem(pub ProjectExpr, pub SqlIdent); @@ -73,6 +88,15 @@ impl Project { Self::Exprs(elems) => Self::Exprs(elems.into_iter().map(|elem| elem.qualify_vars(with.clone())).collect()), } } + + pub fn has_unqualified_vars(&self) -> bool { + match self { + Self::Exprs(exprs) => exprs + .iter() + .any(|ProjectElem(expr, _)| matches!(expr, ProjectExpr::Var(_))), + _ => false, + } + } } /// A scalar SQL expression @@ -107,6 +131,14 @@ impl SqlExpr { ), } } + + pub fn has_unqualified_vars(&self) -> bool { + match self { + Self::Var(_) => true, + Self::Bin(a, b, _) | Self::Log(a, b, _) => a.has_unqualified_vars() || b.has_unqualified_vars(), + _ => false, + } + } } /// A SQL identifier or named reference. diff --git a/crates/sql-parser/src/ast/sql.rs b/crates/sql-parser/src/ast/sql.rs index 1c9132ce907..0b747080492 100644 --- a/crates/sql-parser/src/ast/sql.rs +++ b/crates/sql-parser/src/ast/sql.rs @@ -1,6 +1,9 @@ +use crate::parser::{errors::SqlUnsupported, SqlParseResult}; + use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral}; /// The AST for the SQL DML and query language +#[derive(Debug)] pub enum SqlAst { /// SELECT ... Select(SqlSelect), @@ -36,9 +39,17 @@ impl SqlAst { _ => self, } } + + pub fn find_unqualified_vars(self) -> SqlParseResult { + match self { + Self::Select(select) => select.find_unqualified_vars().map(Self::Select), + _ => Ok(self), + } + } } /// A SELECT statement in the SQL query language +#[derive(Debug)] pub struct SqlSelect { pub project: Project, pub from: SqlFrom, @@ -56,9 +67,20 @@ impl SqlSelect { SqlFrom::Join(..) => self, } } + + pub fn find_unqualified_vars(self) -> SqlParseResult { + if self.from.has_unqualified_vars() { + return Err(SqlUnsupported::UnqualifiedNames.into()); + } + if self.project.has_unqualified_vars() { + return Err(SqlUnsupported::UnqualifiedNames.into()); + } + Ok(self) + } } /// INSERT INTO table cols VALUES literals +#[derive(Debug)] pub struct SqlInsert { pub table: SqlIdent, pub fields: Vec, @@ -66,9 +88,11 @@ pub struct SqlInsert { } /// VALUES literals +#[derive(Debug)] pub struct SqlValues(pub Vec>); /// UPDATE table SET cols [ WHERE predicate ] +#[derive(Debug)] pub struct SqlUpdate { pub table: SqlIdent, pub assignments: Vec, @@ -76,13 +100,16 @@ pub struct SqlUpdate { } /// DELETE FROM table [ WHERE predicate ] +#[derive(Debug)] pub struct SqlDelete { pub table: SqlIdent, pub filter: Option, } /// SET var '=' literal +#[derive(Debug)] pub struct SqlSet(pub SqlIdent, pub SqlLiteral); /// SHOW var +#[derive(Debug)] pub struct SqlShow(pub SqlIdent); diff --git a/crates/sql-parser/src/ast/sub.rs b/crates/sql-parser/src/ast/sub.rs index 33f32d7cce5..8a793d56eb1 100644 --- a/crates/sql-parser/src/ast/sub.rs +++ b/crates/sql-parser/src/ast/sub.rs @@ -1,3 +1,5 @@ +use crate::parser::{errors::SqlUnsupported, SqlParseResult}; + use super::{Project, SqlExpr, SqlFrom}; /// A SELECT statement in the SQL subscription language @@ -18,4 +20,14 @@ impl SqlSelect { SqlFrom::Join(..) => self, } } + + pub fn find_unqualified_vars(self) -> SqlParseResult { + if self.from.has_unqualified_vars() { + return Err(SqlUnsupported::UnqualifiedNames.into()); + } + if self.project.has_unqualified_vars() { + return Err(SqlUnsupported::UnqualifiedNames.into()); + } + Ok(self) + } } diff --git a/crates/sql-parser/src/parser/errors.rs b/crates/sql-parser/src/parser/errors.rs index 9a783d9a9cd..e49ec40563f 100644 --- a/crates/sql-parser/src/parser/errors.rs +++ b/crates/sql-parser/src/parser/errors.rs @@ -64,6 +64,8 @@ pub enum SqlUnsupported { MultiTableDelete, #[error("Empty SQL query")] Empty, + #[error("Names must be qualified when using joins")] + UnqualifiedNames, } impl SqlUnsupported { diff --git a/crates/sql-parser/src/parser/sql.rs b/crates/sql-parser/src/parser/sql.rs index 7277d32a580..3c93063fd12 100644 --- a/crates/sql-parser/src/parser/sql.rs +++ b/crates/sql-parser/src/parser/sql.rs @@ -155,7 +155,9 @@ pub fn parse_sql(sql: &str) -> SqlParseResult { if stmts.is_empty() { return Err(SqlUnsupported::Empty.into()); } - parse_statement(stmts.swap_remove(0)).map(|ast| ast.qualify_vars()) + parse_statement(stmts.swap_remove(0)) + .map(|ast| ast.qualify_vars()) + .and_then(|ast| ast.find_unqualified_vars()) } /// Parse a SQL statement @@ -416,6 +418,8 @@ mod tests { "update t set a = 1 from s where t.id = s.id and s.b = 2", // Implicit joins "select a.* from t as a, s as b where a.id = b.id and b.c = 1", + // Joins require qualified vars + "select t.* from t join s on int = u32", ] { assert!(parse_sql(sql).is_err()); } diff --git a/crates/sql-parser/src/parser/sub.rs b/crates/sql-parser/src/parser/sub.rs index e4ef09cabf0..828a42b9dae 100644 --- a/crates/sql-parser/src/parser/sub.rs +++ b/crates/sql-parser/src/parser/sub.rs @@ -71,7 +71,9 @@ pub fn parse_subscription(sql: &str) -> SqlParseResult { let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?; match stmts.len() { 0 => Err(SqlUnsupported::Empty.into()), - 1 => parse_statement(stmts.swap_remove(0)).map(|ast| ast.qualify_vars()), + 1 => parse_statement(stmts.swap_remove(0)) + .map(|ast| ast.qualify_vars()) + .and_then(|ast| ast.find_unqualified_vars()), _ => Err(SqlUnsupported::MultiStatement.into()), } }