diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs index 54edcedd3c3..b5b39028c75 100644 --- a/crates/physical-plan/src/compile.rs +++ b/crates/physical-plan/src/compile.rs @@ -4,34 +4,34 @@ use crate::plan; use crate::plan::{CrossJoin, Filter, PhysicalCtx, PhysicalExpr, PhysicalPlan}; use spacetimedb_expr::expr::{Expr, Let, LetCtx, Project, RelExpr, Select}; use spacetimedb_expr::statement::Statement; -use spacetimedb_expr::ty::TyId; +use spacetimedb_expr::ty::{TyCtx, TyId}; use spacetimedb_expr::StatementCtx; use spacetimedb_sql_parser::ast::BinOp; -fn compile_expr(ctx: &LetCtx, expr: Expr) -> PhysicalExpr { +fn compile_expr(_ctx: &TyCtx, vars: &LetCtx, expr: Expr) -> PhysicalExpr { match expr { Expr::Bin(op, lhs, rhs) => { - let lhs = compile_expr(ctx, *lhs); - let rhs = compile_expr(ctx, *rhs); + let lhs = compile_expr(_ctx, vars, *lhs); + let rhs = compile_expr(_ctx, vars, *rhs); PhysicalExpr::BinOp(op, Box::new(lhs), Box::new(rhs)) } Expr::Var(sym, _ty) => { - let var = ctx.get_var(sym).cloned().unwrap(); - compile_expr(ctx, var) + let var = vars.get_var(sym).cloned().unwrap(); + compile_expr(_ctx, vars, var) } Expr::Row(row, ty) => { PhysicalExpr::Tuple( row.into_vec() .into_iter() // The `sym` is inline in `expr` - .map(|(_sym, expr)| compile_expr(ctx, expr)) + .map(|(_sym, expr)| compile_expr(_ctx, vars, expr)) .collect(), ty, ) } Expr::Lit(value, ty) => PhysicalExpr::Value(value, ty), Expr::Field(expr, pos, ty) => { - let expr = compile_expr(ctx, *expr); + let expr = compile_expr(_ctx, vars, *expr); PhysicalExpr::Field(Box::new(expr), pos, ty) } Expr::Input(ty) => PhysicalExpr::Input(ty), @@ -44,15 +44,16 @@ fn join_exprs(exprs: Vec) -> Option { .reduce(|lhs, rhs| PhysicalExpr::BinOp(BinOp::And, Box::new(lhs), Box::new(rhs))) } -fn compile_let(expr: Let) -> Vec { - let ctx = LetCtx { vars: &expr.vars }; - - expr.exprs.into_iter().map(|expr| compile_expr(&ctx, expr)).collect() +fn compile_let(ctx: &TyCtx, Let { vars, exprs }: Let) -> Vec { + exprs + .into_iter() + .map(|expr| compile_expr(ctx, &LetCtx { vars: &vars }, expr)) + .collect() } -fn compile_filter(select: Select) -> PhysicalPlan { - let input = compile_rel_expr(select.input); - if let Some(op) = join_exprs(compile_let(select.expr)) { +fn compile_filter(ctx: &TyCtx, select: Select) -> PhysicalPlan { + let input = compile_rel_expr(ctx, select.input); + if let Some(op) = join_exprs(compile_let(ctx, select.expr)) { PhysicalPlan::Filter(Filter { input: Box::new(input), op, @@ -62,19 +63,19 @@ fn compile_filter(select: Select) -> PhysicalPlan { } } -fn compile_project(expr: Project) -> PhysicalPlan { +fn compile_project(ctx: &TyCtx, expr: Project) -> PhysicalPlan { let proj = plan::Project { - input: Box::new(compile_rel_expr(expr.input)), - op: join_exprs(compile_let(expr.expr)).unwrap(), + input: Box::new(compile_rel_expr(ctx, expr.input)), + op: join_exprs(compile_let(ctx, expr.expr)).unwrap(), }; PhysicalPlan::Project(proj) } -fn compile_cross_joins(joins: Vec, ty: TyId) -> PhysicalPlan { +fn compile_cross_joins(ctx: &TyCtx, joins: Vec, ty: TyId) -> PhysicalPlan { joins .into_iter() - .map(compile_rel_expr) + .map(|expr| compile_rel_expr(ctx, expr)) .reduce(|lhs, rhs| { PhysicalPlan::CrossJoin(CrossJoin { lhs: Box::new(lhs), @@ -85,12 +86,12 @@ fn compile_cross_joins(joins: Vec, ty: TyId) -> PhysicalPlan { .unwrap() } -fn compile_rel_expr(ast: RelExpr) -> PhysicalPlan { +fn compile_rel_expr(ctx: &TyCtx, ast: RelExpr) -> PhysicalPlan { match ast { RelExpr::RelVar(table, _ty) => PhysicalPlan::TableScan(table), - RelExpr::Select(select) => compile_filter(*select), - RelExpr::Proj(proj) => compile_project(*proj), - RelExpr::Join(joins, ty) => compile_cross_joins(joins.into_vec(), ty), + RelExpr::Select(select) => compile_filter(ctx, *select), + RelExpr::Proj(proj) => compile_project(ctx, *proj), + RelExpr::Join(joins, ty) => compile_cross_joins(ctx, joins.into_vec(), ty), RelExpr::Union(_, _) | RelExpr::Minus(_, _) | RelExpr::Dedup(_) => { unreachable!("DISTINCT is not implemented") } @@ -102,9 +103,9 @@ fn compile_rel_expr(ast: RelExpr) -> PhysicalPlan { /// The input [Statement] is assumed to be valid so the lowering is not expected to fail. /// /// **NOTE:** It does not optimize the plan. -pub fn compile(ast: StatementCtx) -> PhysicalCtx { +pub fn compile<'a>(ctx: &TyCtx, ast: StatementCtx<'a>) -> PhysicalCtx<'a> { let plan = match ast.statement { - Statement::Select(expr) => compile_rel_expr(expr), + Statement::Select(expr) => compile_rel_expr(ctx, expr), _ => { unreachable!("Only `SELECT` is implemented") } @@ -151,10 +152,11 @@ mod tests { ]) } - fn compile_sql_sub_test(sql: &str) -> ResultTest { + fn compile_sql_sub_test(sql: &str) -> ResultTest<(StatementCtx, TyCtx)> { let tx = SchemaViewer(module_def()); - let expr = compile_sql_sub(&mut TyCtx::default(), sql, &tx)?; - Ok(expr) + let mut ctx = TyCtx::default(); + let expr = compile_sql_sub(&mut ctx, sql, &tx)?; + Ok((expr, ctx)) } fn compile_sql_stmt_test(sql: &str) -> ResultTest { @@ -165,29 +167,30 @@ mod tests { #[test] fn test_project() -> ResultTest<()> { - let ast = compile_sql_sub_test("SELECT * FROM t")?; - assert!(matches!(compile(ast).plan, PhysicalPlan::TableScan(_))); + let (ast, ctx) = compile_sql_sub_test("SELECT * FROM t")?; + assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::TableScan(_))); let ast = compile_sql_stmt_test("SELECT u32 FROM t")?; - assert!(matches!(compile(ast).plan, PhysicalPlan::Project(_))); + assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Project(_))); Ok(()) } #[test] fn test_select() -> ResultTest<()> { - let ast = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1")?; - assert!(matches!(compile(ast).plan, PhysicalPlan::Filter(_))); + let (ast, ctx) = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1")?; + assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Filter(_))); - let ast = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1 AND f32 = f32")?; - assert!(matches!(compile(ast).plan, PhysicalPlan::Filter(_))); + let (ast, ctx) = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1 AND f32 = f32")?; + assert!(matches!(compile(&ctx, ast).plan, PhysicalPlan::Filter(_))); Ok(()) } #[test] fn test_joins() -> ResultTest<()> { // Check we can do a cross join - let ast = compile(compile_sql_sub_test("SELECT t.* FROM t JOIN u")?).plan; + let (ast, ctx) = compile_sql_sub_test("SELECT t.* FROM t JOIN u")?; + let ast = compile(&ctx, ast).plan; let plan::Project { input, op } = ast.as_project().unwrap(); let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap(); @@ -196,7 +199,8 @@ mod tests { assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); // Check we can do multiple joins - let ast = compile(compile_sql_sub_test("SELECT t.* FROM t JOIN u JOIN x")?).plan; + let (ast, ctx) = compile_sql_sub_test("SELECT t.* FROM t JOIN u JOIN x")?; + let ast = compile(&ctx, ast).plan; let plan::Project { input, op: _ } = ast.as_project().unwrap(); let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap(); assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); @@ -206,7 +210,8 @@ mod tests { assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); // Check we can do a join with a filter - let ast = compile(compile_sql_stmt_test("SELECT t.* FROM t JOIN u ON t.u32 = u.u32")?).plan; + let (ast, ctx) = compile_sql_sub_test("SELECT t.* FROM t JOIN u ON t.u32 = u.u32")?; + let ast = compile(&ctx, ast).plan; let plan::Project { input, op: _ } = ast.as_project().unwrap(); let Filter { input, op } = input.as_filter().unwrap();