Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: qualified projection columns #2070

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/expr/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ mod tests {
(
"t",
ProductType::from([
("int", AlgebraicType::U32),
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
Expand Down
58 changes: 58 additions & 0 deletions crates/expr/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,61 @@ pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<
source: StatementSource::Query,
})
}

#[cfg(test)]
mod tests {
use spacetimedb_lib::{AlgebraicType, ProductType};
use spacetimedb_schema::def::ModuleDef;

use crate::{
check::test_utils::{build_module_def, SchemaViewer},
statement::parse_and_type_sql,
};

fn module_def() -> ModuleDef {
build_module_def(vec![
(
"t",
ProductType::from([
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
("arr", AlgebraicType::array(AlgebraicType::String)),
]),
),
(
"s",
ProductType::from([
("id", AlgebraicType::identity()),
("u32", AlgebraicType::U32),
("arr", AlgebraicType::array(AlgebraicType::String)),
("bytes", AlgebraicType::bytes()),
]),
),
])
}

#[test]
fn valid() {
let tx = SchemaViewer(module_def());

for sql in [
"select str from t",
"select str, arr from t",
"select t.str, arr from t",
] {
let result = parse_and_type_sql(sql, &tx);
assert!(result.is_ok());
}
}

#[test]
fn invalid() {
let tx = SchemaViewer(module_def());

// Unqualified columns in a join
let sql = "select id, str from s join t";
let result = parse_and_type_sql(sql, &tx);
assert!(result.is_err());
}
}
32 changes: 32 additions & 0 deletions crates/sql-parser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@ pub enum SqlFrom {
Join(SqlIdent, SqlIdent, Vec<SqlJoin>),
}

impl SqlFrom {
pub fn has_unqualified_vars(&self) -> bool {
match self {
Self::Join(_, _, joins) => joins.iter().any(|join| join.has_unqualified_vars()),
_ => false,
}
}
}

/// An inner join in a FROM clause
#[derive(Debug)]
pub struct SqlJoin {
Expand All @@ -20,6 +29,12 @@ pub struct SqlJoin {
pub on: Option<SqlExpr>,
}

impl SqlJoin {
pub fn has_unqualified_vars(&self) -> bool {
self.on.as_ref().is_some_and(|expr| expr.has_unqualified_vars())
}
}

/// A projection expression in a SELECT clause
#[derive(Debug)]
pub struct ProjectElem(pub ProjectExpr, pub SqlIdent);
Expand Down Expand Up @@ -73,6 +88,15 @@ impl Project {
Self::Exprs(elems) => Self::Exprs(elems.into_iter().map(|elem| elem.qualify_vars(with.clone())).collect()),
}
}

pub fn has_unqualified_vars(&self) -> bool {
match self {
Self::Exprs(exprs) => exprs
.iter()
.any(|ProjectElem(expr, _)| matches!(expr, ProjectExpr::Var(_))),
_ => false,
}
}
}

/// A scalar SQL expression
Expand Down Expand Up @@ -107,6 +131,14 @@ impl SqlExpr {
),
}
}

pub fn has_unqualified_vars(&self) -> bool {
match self {
Self::Var(_) => true,
Self::Bin(a, b, _) | Self::Log(a, b, _) => a.has_unqualified_vars() || b.has_unqualified_vars(),
_ => false,
}
}
}

/// A SQL identifier or named reference.
Expand Down
27 changes: 27 additions & 0 deletions crates/sql-parser/src/ast/sql.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::parser::{errors::SqlUnsupported, SqlParseResult};

use super::{Project, SqlExpr, SqlFrom, SqlIdent, SqlLiteral};

/// The AST for the SQL DML and query language
#[derive(Debug)]
pub enum SqlAst {
/// SELECT ...
Select(SqlSelect),
Expand Down Expand Up @@ -36,9 +39,17 @@ impl SqlAst {
_ => self,
}
}

pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
match self {
Self::Select(select) => select.find_unqualified_vars().map(Self::Select),
_ => Ok(self),
}
}
}

/// A SELECT statement in the SQL query language
#[derive(Debug)]
pub struct SqlSelect {
pub project: Project,
pub from: SqlFrom,
Expand All @@ -56,33 +67,49 @@ impl SqlSelect {
SqlFrom::Join(..) => self,
}
}

pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
if self.from.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
if self.project.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
Ok(self)
}
}

/// INSERT INTO table cols VALUES literals
#[derive(Debug)]
pub struct SqlInsert {
pub table: SqlIdent,
pub fields: Vec<SqlIdent>,
pub values: SqlValues,
}

/// VALUES literals
#[derive(Debug)]
pub struct SqlValues(pub Vec<Vec<SqlLiteral>>);

/// UPDATE table SET cols [ WHERE predicate ]
#[derive(Debug)]
pub struct SqlUpdate {
pub table: SqlIdent,
pub assignments: Vec<SqlSet>,
pub filter: Option<SqlExpr>,
}

/// DELETE FROM table [ WHERE predicate ]
#[derive(Debug)]
pub struct SqlDelete {
pub table: SqlIdent,
pub filter: Option<SqlExpr>,
}

/// SET var '=' literal
#[derive(Debug)]
pub struct SqlSet(pub SqlIdent, pub SqlLiteral);

/// SHOW var
#[derive(Debug)]
pub struct SqlShow(pub SqlIdent);
12 changes: 12 additions & 0 deletions crates/sql-parser/src/ast/sub.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::parser::{errors::SqlUnsupported, SqlParseResult};

use super::{Project, SqlExpr, SqlFrom};

/// A SELECT statement in the SQL subscription language
Expand All @@ -18,4 +20,14 @@ impl SqlSelect {
SqlFrom::Join(..) => self,
}
}

pub fn find_unqualified_vars(self) -> SqlParseResult<Self> {
if self.from.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
if self.project.has_unqualified_vars() {
return Err(SqlUnsupported::UnqualifiedNames.into());
}
Ok(self)
}
}
2 changes: 2 additions & 0 deletions crates/sql-parser/src/parser/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub enum SqlUnsupported {
MultiTableDelete,
#[error("Empty SQL query")]
Empty,
#[error("Names must be qualified when using joins")]
UnqualifiedNames,
}

impl SqlUnsupported {
Expand Down
6 changes: 5 additions & 1 deletion crates/sql-parser/src/parser/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ pub fn parse_sql(sql: &str) -> SqlParseResult<SqlAst> {
if stmts.is_empty() {
return Err(SqlUnsupported::Empty.into());
}
parse_statement(stmts.swap_remove(0)).map(|ast| ast.qualify_vars())
parse_statement(stmts.swap_remove(0))
.map(|ast| ast.qualify_vars())
.and_then(|ast| ast.find_unqualified_vars())
}

/// Parse a SQL statement
Expand Down Expand Up @@ -416,6 +418,8 @@ mod tests {
"update t set a = 1 from s where t.id = s.id and s.b = 2",
// Implicit joins
"select a.* from t as a, s as b where a.id = b.id and b.c = 1",
// Joins require qualified vars
"select t.* from t join s on int = u32",
] {
assert!(parse_sql(sql).is_err());
}
Expand Down
4 changes: 3 additions & 1 deletion crates/sql-parser/src/parser/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ pub fn parse_subscription(sql: &str) -> SqlParseResult<SqlSelect> {
let mut stmts = Parser::parse_sql(&PostgreSqlDialect {}, sql)?;
match stmts.len() {
0 => Err(SqlUnsupported::Empty.into()),
1 => parse_statement(stmts.swap_remove(0)).map(|ast| ast.qualify_vars()),
1 => parse_statement(stmts.swap_remove(0))
.map(|ast| ast.qualify_vars())
.and_then(|ast| ast.find_unqualified_vars()),
_ => Err(SqlUnsupported::MultiStatement.into()),
}
}
Expand Down
Loading