Skip to content

Commit

Permalink
fix: qualified projection columns
Browse files Browse the repository at this point in the history
  • Loading branch information
joshua-spacetime committed Dec 18, 2024
1 parent 1b31209 commit b9d9b8d
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 2 deletions.
1 change: 1 addition & 0 deletions crates/expr/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ mod tests {
(
"t",
ProductType::from([
("int", AlgebraicType::U32),
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
Expand Down
58 changes: 58 additions & 0 deletions crates/expr/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,61 @@ pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<
source: StatementSource::Query,
})
}

#[cfg(test)]
mod tests {
use spacetimedb_lib::{AlgebraicType, ProductType};
use spacetimedb_schema::def::ModuleDef;

use crate::{
check::test_utils::{build_module_def, SchemaViewer},
statement::parse_and_type_sql,
};

fn module_def() -> ModuleDef {
build_module_def(vec![
(
"t",
ProductType::from([
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
("arr", AlgebraicType::array(AlgebraicType::String)),
]),
),
(
"s",
ProductType::from([
("id", AlgebraicType::identity()),
("u32", AlgebraicType::U32),
("arr", AlgebraicType::array(AlgebraicType::String)),
("bytes", AlgebraicType::bytes()),
]),
),
])
}

#[test]
fn valid() {
let tx = SchemaViewer(module_def());

for sql in [
"select str from t",
"select str, arr from t",
"select t.str, arr from t",
] {
let result = parse_and_type_sql(sql, &tx);
assert!(result.is_ok());
}
}

#[test]
fn invalid() {
let tx = SchemaViewer(module_def());

// Unqualified columns in a join
let sql = "select id, str from s join t";
let result = parse_and_type_sql(sql, &tx);
assert!(result.is_err());
}
}
32 changes: 32 additions & 0 deletions crates/sql-parser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ pub enum SqlFrom {
Join(SqlIdent, SqlIdent, Vec<SqlJoin>),
}

impl SqlFrom {
pub fn has_unqualified_vars(&self) -> bool {
match self {
Self::Join(_, _, joins) => joins.iter().any(|join| join.has_unqualified_vars()),
_ => false,
}
}
}

/// An inner join in a FROM clause
#[derive(Debug)]
pub struct SqlJoin {
Expand All @@ -20,6 +29,12 @@ pub struct SqlJoin {
pub on: Option<SqlExpr>,
}

impl SqlJoin {
pub fn has_unqualified_vars(&self) -> bool {
self.on.as_ref().is_some_and(|expr| expr.has_unqualified_vars())
}
}

/// A projection expression in a SELECT clause
#[derive(Debug)]
pub struct ProjectElem(pub ProjectExpr, pub SqlIdent);
Expand Down Expand Up @@ -73,6 +88,15 @@ impl Project {
Self::Exprs(elems) => Self::Exprs(elems.into_iter().map(|elem| elem.qualify_vars(with.clone())).collect()),
}
}

pub fn has_unqualified_vars(&self) -> bool {
match self {
Self::Exprs(exprs) => exprs
.iter()
.any(|ProjectElem(expr, _)| matches!(expr, ProjectExpr::Var(_))),
_ => false,
}
}
}

/// A scalar SQL expression
Expand Down Expand Up @@ -107,6 +131,14 @@ impl SqlExpr {
),
}
}

pub fn has_unqualified_vars(&self) -> bool {
match self {
Self::Var(_) => true,
Self::Bin(a, b, _) | Self::Log(a, b, _) => a.has_unqualified_vars() || b.has_unqualified_vars(),
_ => false,
}
}
}

/// A SQL identifier or named reference.
Expand Down
27 changes: 27 additions & 0 deletions crates/sql-parser/src/ast/sql.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::parser::{errors::SqlUnsupported, SqlParseResult};

use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral};

/// The AST for the SQL DML and query language
#[derive(Debug)]
pub enum SqlAst {
/// SELECT ...
Select(SqlSelect),
Expand Down Expand Up @@ -36,9 +39,17 @@ impl SqlAst {
_ => self,
}
}

pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
match self {
Self::Select(select) => select.find_unqualified_vars().map(Self::Select),
_ => Ok(self),
}
}
}

/// A SELECT statement in the SQL query language
#[derive(Debug)]
pub struct SqlSelect {
pub project: Project,
pub from: SqlFrom,
Expand All @@ -56,33 +67,49 @@ impl SqlSelect {
SqlFrom::Join(..) => self,
}
}

pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
if self.from.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
if self.project.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
Ok(self)
}
}

/// INSERT INTO table cols VALUES literals
#[derive(Debug)]
pub struct SqlInsert {
pub table: SqlIdent,
pub fields: Vec<SqlIdent>,
pub values: SqlValues,
}

/// VALUES literals
#[derive(Debug)]
pub struct SqlValues(pub Vec<Vec<SqlLiteral>>);

/// UPDATE table SET cols [ WHERE predicate ]
#[derive(Debug)]
pub struct SqlUpdate {
pub table: SqlIdent,
pub assignments: Vec<SqlSet>,
pub filter: Option<SqlExpr>,
}

/// DELETE FROM table [ WHERE predicate ]
#[derive(Debug)]
pub struct SqlDelete {
pub table: SqlIdent,
pub filter: Option<SqlExpr>,
}

/// SET var '=' literal
#[derive(Debug)]
pub struct SqlSet(pub SqlIdent, pub SqlLiteral);

/// SHOW var
#[derive(Debug)]
pub struct SqlShow(pub SqlIdent);
12 changes: 12 additions & 0 deletions crates/sql-parser/src/ast/sub.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::parser::{errors::SqlUnsupported, SqlParseResult};

use super::{Project, SqlExpr, SqlFrom};

/// A SELECT statement in the SQL subscription language
Expand All @@ -18,4 +20,14 @@ impl SqlSelect {
SqlFrom::Join(..) => self,
}
}

pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
if self.from.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
if self.project.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
Ok(self)
}
}
2 changes: 2 additions & 0 deletions crates/sql-parser/src/parser/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub enum SqlUnsupported {
MultiTableDelete,
#[error("Empty SQL query")]
Empty,
#[error("Names must be qualified when using joins")]
UnqualifiedNames,
}

impl SqlUnsupported {
Expand Down
6 changes: 5 additions & 1 deletion crates/sql-parser/src/parser/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
if stmts.is_empty() {
return Err(SqlUnsupported::Empty.into());
}
parse_statement(stmts.swap_remove(0)).map(|ast| ast.qualify_vars())
parse_statement(stmts.swap_remove(0))
.map(|ast| ast.qualify_vars())
.and_then(|ast| ast.find_unqualified_vars())
}

/// Parse a SQL statement
Expand Down Expand Up @@ -416,6 +418,8 @@ mod tests {
"update t set a = 1 from s where t.id = s.id and s.b = 2",
// Implicit joins
"select a.* from t as a, s as b where a.id = b.id and b.c = 1",
// Joins require qualified vars
"select t.* from t join s on int = u32",
] {
assert!(parse_sql(sql).is_err());
}
Expand Down
4 changes: 3 additions & 1 deletion crates/sql-parser/src/parser/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ pub fn parse_subscription(sql: &str) -> SqlParseResult<SqlSelect> {
let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
match stmts.len() {
0 => Err(SqlUnsupported::Empty.into()),
1 => parse_statement(stmts.swap_remove(0)).map(|ast| ast.qualify_vars()),
1 => parse_statement(stmts.swap_remove(0))
.map(|ast| ast.qualify_vars())
.and_then(|ast| ast.find_unqualified_vars()),
_ => Err(SqlUnsupported::MultiStatement.into()),
}
}
Expand Down

0 comments on commit b9d9b8d

Please sign in to comment.