From 4e1d994815eb3dc8c4dd32a7ffba9cdf5d2df8b0 Mon Sep 17 00:00:00 2001 From: Ian Alexander Joiner <14581281+iajoiner@users.noreply.github.com> Date: Tue, 18 Jun 2024 15:19:44 -0400 Subject: [PATCH] feat: generalize `DenseFilterExpr` (#8) # Rationale for this change It is necessary to move provable projections into the prover so that provable queries do not need postprocessing. This is the first PR of the change which is to move provable projections to the prover for filter queries. # What changes are included in this PR? - Move `Expression` -> `ProvableExprPlan` to a separate struct, `ProvableExprPlanBuilder` since the process is no longer for where clauses only. - Add `EnrichedExpr` to be equivalent to the provable version of [Expr](https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#method.alias) in DataFusion for tamperproof queries. - Allow arbitrary aliased provable expressions as result expressions in dense filters. - Remove constraints in `QueryContext` that prevents constants from being used in select. # Are these changes tested? Existing tests should pass. New tests will be added to cover nontrivial results such as `select a and b as c, d, 4 as f from tab where d = e`. --- crates/proof-of-sql/src/sql/ast/and_expr.rs | 2 +- .../src/sql/ast/dense_filter_expr.rs | 43 ++-- .../src/sql/ast/dense_filter_expr_test.rs | 94 +++++--- ...dense_filter_expr_test_dishonest_prover.rs | 12 +- .../proof-of-sql/src/sql/ast/equals_expr.rs | 2 +- .../src/sql/ast/inequality_expr.rs | 2 +- .../proof-of-sql/src/sql/ast/literal_expr.rs | 2 +- crates/proof-of-sql/src/sql/ast/not_expr.rs | 2 +- crates/proof-of-sql/src/sql/ast/or_expr.rs | 2 +- .../src/sql/ast/provable_expr_plan.rs | 2 +- .../proof-of-sql/src/sql/ast/test_utility.rs | 48 +++- .../src/sql/parse/enriched_expr.rs | 66 ++++++ .../src/sql/parse/filter_expr_builder.rs | 39 ++-- crates/proof-of-sql/src/sql/parse/mod.rs | 6 + .../sql/parse/provable_expr_plan_builder.rs | 142 ++++++++++++ .../src/sql/parse/query_context.rs | 24 +- .../proof-of-sql/src/sql/parse/query_expr.rs | 18 +- .../src/sql/parse/query_expr_tests.rs | 218 +++++++++++------- .../src/sql/parse/where_expr_builder.rs | 125 +--------- .../proof-of-sql/tests/integration_tests.rs | 8 +- 20 files changed, 542 insertions(+), 315 deletions(-) create mode 100644 crates/proof-of-sql/src/sql/parse/enriched_expr.rs create mode 100644 crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs diff --git a/crates/proof-of-sql/src/sql/ast/and_expr.rs b/crates/proof-of-sql/src/sql/ast/and_expr.rs index bf2d128b1..504a14ce1 100644 --- a/crates/proof-of-sql/src/sql/ast/and_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/and_expr.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable logical AND expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AndExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs index 254336c88..cc7e82e51 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs @@ -2,7 +2,7 @@ use super::{ dense_filter_util::{fold_columns, fold_vals}, filter_columns, provable_expr_plan::ProvableExprPlan, - ColumnExpr, ProvableExpr, TableExpr, + ProvableExpr, TableExpr, }; use crate::{ base::{ @@ -22,6 +22,7 @@ use crate::{ use bumpalo::Bump; use core::iter::repeat_with; use num_traits::{One, Zero}; +use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, marker::PhantomData}; @@ -33,7 +34,7 @@ use std::{collections::HashSet, marker::PhantomData}; /// This differs from the [`FilterExpr`] in that the result is not a sparse table. #[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct OstensibleDenseFilterExpr { - pub(super) results: Vec>, + pub(super) aliased_results: Vec<(ProvableExprPlan, Identifier)>, pub(super) table: TableExpr, pub(super) where_clause: ProvableExprPlan, phantom: PhantomData, @@ -42,12 +43,12 @@ pub struct OstensibleDenseFilterExpr { impl OstensibleDenseFilterExpr { /// Creates a new dense_filter expression. pub fn new( - results: Vec>, + aliased_results: Vec<(ProvableExprPlan, Identifier)>, table: TableExpr, where_clause: ProvableExprPlan, ) -> Self { Self { - results, + aliased_results, table, where_clause, phantom: PhantomData, @@ -65,8 +66,8 @@ where _accessor: &dyn MetadataAccessor, ) -> Result<(), ProofError> { self.where_clause.count(builder)?; - for expr in self.results.iter() { - expr.count(builder)?; + for aliased_expr in self.aliased_results.iter() { + aliased_expr.0.count(builder)?; builder.count_result_columns(1); } builder.count_intermediate_mles(2); @@ -94,9 +95,9 @@ where let selection_eval = self.where_clause.verifier_evaluate(builder, accessor)?; // 2. columns let columns_evals = Vec::from_iter( - self.results + self.aliased_results .iter() - .map(|expr| expr.verifier_evaluate(builder, accessor)) + .map(|(expr, _)| expr.verifier_evaluate(builder, accessor)) .collect::, _>>()?, ); // 3. indexes @@ -105,8 +106,9 @@ where .result_indexes_evaluation .ok_or(ProofError::VerificationError("invalid indexes"))?; // 4. filtered_columns - let filtered_columns_evals = - Vec::from_iter(repeat_with(|| builder.consume_result_mle()).take(self.results.len())); + let filtered_columns_evals = Vec::from_iter( + repeat_with(|| builder.consume_result_mle()).take(self.aliased_results.len()), + ); let alpha = builder.consume_post_result_challenge(); let beta = builder.consume_post_result_challenge(); @@ -122,18 +124,17 @@ where } fn get_column_result_fields(&self) -> Vec { - let mut columns = Vec::with_capacity(self.results.len()); - for col in self.results.iter() { - columns.push(col.get_column_field()); - } - columns + self.aliased_results + .iter() + .map(|(expr, alias)| ColumnField::new(*alias, expr.data_type())) + .collect() } fn get_column_references(&self) -> HashSet { let mut columns = HashSet::new(); - for col in self.results.iter() { - columns.insert(col.get_column_reference()); + for (col, _) in self.aliased_results.iter() { + col.get_column_references(&mut columns); } self.where_clause.get_column_references(&mut columns); @@ -163,9 +164,9 @@ impl ProverEvaluate for DenseFilterExpr { // 2. columns let columns = Vec::from_iter( - self.results + self.aliased_results .iter() - .map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)), + .map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)), ); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); @@ -195,9 +196,9 @@ impl ProverEvaluate for DenseFilterExpr { // 2. columns let columns = Vec::from_iter( - self.results + self.aliased_results .iter() - .map(|expr| expr.prover_evaluate(builder, alloc, accessor)), + .map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)), ); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs index b0c75000c..e0e1276a9 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs @@ -18,7 +18,9 @@ use crate::{ ast::{ // Making this explicit to ensure that we don't accidentally use the // sparse filter for these tests - test_utility::{cols_expr, column, const_int128, dense_filter, equal, tab}, + test_utility::{ + col_expr_plan, cols_expr_plan, column, const_int128, dense_filter, equal, tab, + }, ColumnExpr, DenseFilterExpr, LiteralExpr, @@ -40,18 +42,26 @@ use std::{collections::HashSet, sync::Arc}; #[test] fn we_can_correctly_fetch_the_query_result_schema() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); + let a = Identifier::try_new("a").unwrap(); + let b = Identifier::try_new("b").unwrap(); let provable_ast = DenseFilterExpr::::new( vec![ - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("a").unwrap(), - ColumnType::BigInt, - )), - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("b").unwrap(), - ColumnType::BigInt, - )), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + a, + ColumnType::BigInt, + ))), + a, + ), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + b, + ColumnType::BigInt, + ))), + b, + ), ], TableExpr { table_ref }, ProvableExprPlan::try_new_equals( @@ -84,18 +94,26 @@ fn we_can_correctly_fetch_the_query_result_schema() { #[test] fn we_can_correctly_fetch_all_the_referenced_columns() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); + let a = Identifier::try_new("a").unwrap(); + let f = Identifier::try_new("f").unwrap(); let provable_ast = DenseFilterExpr::new( vec![ - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("a").unwrap(), - ColumnType::BigInt, - )), - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("f").unwrap(), - ColumnType::BigInt, - )), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + a, + ColumnType::BigInt, + ))), + a, + ), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + f, + ColumnType::BigInt, + ))), + f, + ), ], TableExpr { table_ref }, not::(and( @@ -170,7 +188,7 @@ fn we_can_prove_and_get_the_correct_result_from_a_basic_dense_filter() { let mut accessor = RecordBatchTestAccessor::new_empty(); accessor.add_table(t, data, 0); let where_clause = equal(column(t, "a", &accessor), const_int128(5_i128)); - let expr = dense_filter(cols_expr(t, &["b"], &accessor), tab(t), where_clause); + let expr = dense_filter(cols_expr_plan(t, &["b"], &accessor), tab(t), where_clause); let res = VerifiableQueryResult::::new(&expr, &accessor, &()); let res = res .verify(&expr, &accessor, &()) @@ -197,7 +215,7 @@ fn we_can_get_an_empty_result_from_a_basic_dense_filter_on_an_empty_table_using_ let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(999)); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), where_clause, ); @@ -242,7 +260,7 @@ fn we_can_get_an_empty_result_from_a_basic_dense_filter_using_result_evaluate() let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(999)); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), where_clause, ); @@ -287,7 +305,7 @@ fn we_can_get_no_columns_from_a_basic_dense_filter_with_no_selected_columns_usin accessor.add_table(t, data, 0); let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(5)); - let expr = dense_filter(cols_expr(t, &[], &accessor), tab(t), where_clause); + let expr = dense_filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause); let alloc = Bump::new(); let mut builder = ResultBuilder::new(5); expr.result_evaluate(&mut builder, &alloc, &accessor); @@ -315,7 +333,7 @@ fn we_can_get_the_correct_result_from_a_basic_dense_filter_using_result_evaluate let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(5)); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), where_clause, ); @@ -357,7 +375,7 @@ fn we_can_prove_a_dense_filter_on_an_empty_table() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(106)), ); @@ -386,7 +404,7 @@ fn we_can_prove_a_dense_filter_with_empty_results() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(106)), ); @@ -406,8 +424,8 @@ fn we_can_prove_a_dense_filter_with_empty_results() { fn we_can_prove_a_dense_filter() { let data = owned_table([ bigint("a", [101, 104, 105, 102, 105]), - bigint("b", [1, 2, 3, 4, 5]), - int128("c", [1, 2, 3, 4, 5]), + bigint("b", [1, 2, 3, 4, 7]), + int128("c", [1, 3, 3, 4, 5]), varchar("d", ["1", "2", "3", "4", "5"]), scalar("e", [1, 2, 3, 4, 5]), ]); @@ -415,7 +433,17 @@ fn we_can_prove_a_dense_filter() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + vec![ + col_expr_plan(t, "b", &accessor), + col_expr_plan(t, "c", &accessor), + col_expr_plan(t, "d", &accessor), + col_expr_plan(t, "e", &accessor), + (const_int128(105), "const".parse().unwrap()), + ( + equal(column(t, "b", &accessor), column(t, "c", &accessor)), + "bool".parse().unwrap(), + ), + ], tab(t), equal(column(t, "a", &accessor), const_int128(105)), ); @@ -423,10 +451,12 @@ fn we_can_prove_a_dense_filter() { exercise_verification(&res, &expr, &accessor, t); let res = res.verify(&expr, &accessor, &()).unwrap().table; let expected = owned_table([ - bigint("b", [3, 5]), + bigint("b", [3, 7]), int128("c", [3, 5]), varchar("d", ["3", "5"]), scalar("e", [3, 5]), + int128("const", [105, 105]), + boolean("bool", [true, false]), ]); assert_eq!(res, expected); } diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs index 05172a298..6550f7383 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs @@ -11,7 +11,7 @@ use crate::{ sql::{ // Making this explicit to ensure that we don't accidentally use the // sparse filter for these tests - ast::test_utility::{cols_expr, column, const_int128, equal, tab}, + ast::test_utility::{cols_expr_plan, column, const_int128, equal, tab}, proof::{ Indexes, ProofBuilder, ProverEvaluate, ProverHonestyMarker, QueryError, ResultBuilder, VerifiableQueryResult, @@ -49,9 +49,9 @@ impl ProverEvaluate for DishonestDenseFilterExpr for DishonestDenseFilterExpr::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = DishonestDenseFilterExpr::new( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(105_i128)), ); diff --git a/crates/proof-of-sql/src/sql/ast/equals_expr.rs b/crates/proof-of-sql/src/sql/ast/equals_expr.rs index 852990156..34befde36 100644 --- a/crates/proof-of-sql/src/sql/ast/equals_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/equals_expr.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable AST expression for an equals expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct EqualsExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/inequality_expr.rs b/crates/proof-of-sql/src/sql/ast/inequality_expr.rs index 69d21ab65..b6a8afc01 100644 --- a/crates/proof-of-sql/src/sql/ast/inequality_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/inequality_expr.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable AST expression for an inequality expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct InequalityExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/literal_expr.rs b/crates/proof-of-sql/src/sql/ast/literal_expr.rs index 47f8aa224..4f882cc4f 100644 --- a/crates/proof-of-sql/src/sql/ast/literal_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/literal_expr.rs @@ -23,7 +23,7 @@ use std::collections::HashSet; /// While this wouldn't be as efficient as using a new custom expression for /// such queries, it allows us to easily support projects with minimal code /// changes, and the performance is sufficient for present. -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LiteralExpr { value: LiteralValue, } diff --git a/crates/proof-of-sql/src/sql/ast/not_expr.rs b/crates/proof-of-sql/src/sql/ast/not_expr.rs index b5f09d5c8..8b6a8477e 100644 --- a/crates/proof-of-sql/src/sql/ast/not_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/not_expr.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable logical NOT expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct NotExpr { expr: Box>, } diff --git a/crates/proof-of-sql/src/sql/ast/or_expr.rs b/crates/proof-of-sql/src/sql/ast/or_expr.rs index 82f2e041e..b2f405ec7 100644 --- a/crates/proof-of-sql/src/sql/ast/or_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/or_expr.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable logical OR expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct OrExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 73ffcef47..5ffaa6162 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use std::{collections::HashSet, fmt::Debug}; /// Enum of AST column expression types that implement `ProvableExpr`. Is itself a `ProvableExpr`. -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum ProvableExprPlan { /// Column Column(ColumnExpr), diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index f7cc315ee..4ed303123 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -6,6 +6,7 @@ use crate::base::{ commitment::Commitment, database::{ColumnField, ColumnRef, ColumnType, LiteralValue, SchemaAccessor, TableRef}, }; +use proof_of_sql_parser::Identifier; pub fn col_ref(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> ColumnRef { let name = name.parse().unwrap(); @@ -112,6 +113,51 @@ pub fn filter( ProofPlan::Filter(FilterExpr::new(results, table, where_clause)) } +pub fn aliased_col_expr_plan( + tab: TableRef, + old_name: &str, + new_name: &str, + accessor: &impl SchemaAccessor, +) -> (ProvableExprPlan, Identifier) { + ( + ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, old_name, accessor))), + new_name.parse().unwrap(), + ) +} + +pub fn col_expr_plan( + tab: TableRef, + name: &str, + accessor: &impl SchemaAccessor, +) -> (ProvableExprPlan, Identifier) { + ( + ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, name, accessor))), + name.parse().unwrap(), + ) +} + +pub fn aliased_cols_expr_plan( + tab: TableRef, + names: &[(&str, &str)], + accessor: &impl SchemaAccessor, +) -> Vec<(ProvableExprPlan, Identifier)> { + names + .iter() + .map(|(old_name, new_name)| aliased_col_expr_plan(tab, old_name, new_name, accessor)) + .collect() +} + +pub fn cols_expr_plan( + tab: TableRef, + names: &[&str], + accessor: &impl SchemaAccessor, +) -> Vec<(ProvableExprPlan, Identifier)> { + names + .iter() + .map(|name| col_expr_plan(tab, name, accessor)) + .collect() +} + pub fn col_expr( tab: TableRef, name: &str, @@ -132,7 +178,7 @@ pub fn cols_expr( } pub fn dense_filter( - results: Vec>, + results: Vec<(ProvableExprPlan, Identifier)>, table: TableExpr, where_clause: ProvableExprPlan, ) -> ProofPlan { diff --git a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs new file mode 100644 index 000000000..d199f3086 --- /dev/null +++ b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs @@ -0,0 +1,66 @@ +use super::ProvableExprPlanBuilder; +use crate::{ + base::{commitment::Commitment, database::ColumnRef}, + sql::ast::ProvableExprPlan, +}; +use proof_of_sql_parser::{ + intermediate_ast::{AliasedResultExpr, Expression}, + Identifier, +}; +use std::collections::HashMap; +/// Enriched expression +/// +/// An enriched expression consists of an `proof_of_sql_parser::intermediate_ast::AliasedResultExpr` +/// and an optional `ProvableExprPlan`. +/// If the `ProvableExprPlan` is `None`, the `EnrichedExpr` is not provable. +pub struct EnrichedExpr { + /// The remaining expression after the provable expression plan has been extracted. + pub residue_expression: AliasedResultExpr, + /// The extracted provable expression plan if it exists. + pub provable_expr_plan: Option>, +} + +impl EnrichedExpr { + /// Create a new `EnrichedExpr` with a provable expression. + /// + /// If the expression is not provable, the `provable_expr_plan` will be `None`. + /// Otherwise the `provable_expr_plan` will contain the provable expression plan + /// and the `residue_expression` will contain the remaining expression. + pub fn new( + expression: AliasedResultExpr, + column_mapping: HashMap, + ) -> Self { + let res_provable_expr_plan = + ProvableExprPlanBuilder::new(&column_mapping).build(&expression.expr); + match res_provable_expr_plan { + Ok(provable_expr_plan) => { + let alias = expression.alias; + Self { + residue_expression: AliasedResultExpr { + expr: Box::new(Expression::Column(alias)), + alias, + }, + provable_expr_plan: Some(provable_expr_plan), + } + } + Err(_) => Self { + residue_expression: expression, + provable_expr_plan: None, + }, + } + } + + /// Get the alias of the `EnrichedExpr`. + /// + /// Since we plan to support unaliased expressions in the future, this method returns an `Option`. + #[allow(dead_code)] + pub fn get_alias(&self) -> Option<&Identifier> { + self.residue_expression.try_as_identifier() + } + + /// Is the `EnrichedExpr` provable + #[allow(dead_code)] + pub fn is_provable(&self) -> bool { + self.provable_expr_plan.is_some() + } +} diff --git a/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs index 762ed8164..893251f6a 100644 --- a/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs @@ -1,18 +1,19 @@ -use super::{where_expr_builder::WhereExprBuilder, ConversionError}; +use super::{where_expr_builder::WhereExprBuilder, ConversionError, EnrichedExpr}; use crate::{ base::{ commitment::Commitment, database::{ColumnRef, LiteralValue, TableRef}, }, - sql::ast::{ColumnExpr, DenseFilterExpr, ProvableExprPlan, TableExpr}, + sql::ast::{DenseFilterExpr, ProvableExprPlan, TableExpr}, }; +use itertools::Itertools; use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; pub struct FilterExprBuilder { table_expr: Option, where_expr: Option>, - filter_result_expr_list: Vec>, + filter_result_expr_list: Vec<(ProvableExprPlan, Identifier)>, column_mapping: HashMap, } @@ -40,16 +41,26 @@ impl FilterExprBuilder { Ok(self) } - pub fn add_result_column_set(mut self, columns: HashSet) -> Self { - // Sorting is required to make the relative order of the columns deterministic - let mut columns = columns.into_iter().collect::>(); - columns.sort(); - - columns.into_iter().for_each(|column| { - let column = *self.column_mapping.get(&column).unwrap(); - self.filter_result_expr_list.push(ColumnExpr::new(column)); - }); - + pub fn add_result_columns(mut self, columns: &[EnrichedExpr]) -> Self { + // If a column is provable, add it to the filter result expression list + // If at least one column is non-provable, add all columns from the column mapping to the filter result expression list + let mut has_nonprovable_column = false; + for enriched_expr in columns { + if let Some(plan) = &enriched_expr.provable_expr_plan { + self.filter_result_expr_list + .push((plan.clone(), enriched_expr.residue_expression.alias)); + } else { + has_nonprovable_column = true; + } + } + if has_nonprovable_column { + // Has to keep them sorted to have deterministic order for tests + for alias in self.column_mapping.keys().sorted() { + let column_ref = self.column_mapping.get(alias).unwrap(); + self.filter_result_expr_list + .push((ProvableExprPlan::new_column(*column_ref), *alias)); + } + } self } diff --git a/crates/proof-of-sql/src/sql/parse/mod.rs b/crates/proof-of-sql/src/sql/parse/mod.rs index d15b7e31a..f79289054 100644 --- a/crates/proof-of-sql/src/sql/parse/mod.rs +++ b/crates/proof-of-sql/src/sql/parse/mod.rs @@ -4,6 +4,9 @@ mod where_expr_builder_tests; pub use error::ConversionError; pub(crate) use error::ConversionResult; +mod enriched_expr; +pub(crate) use enriched_expr::EnrichedExpr; + #[cfg(all(test, feature = "blitzar"))] mod query_expr_tests; @@ -22,5 +25,8 @@ pub(crate) use query_context::QueryContext; mod query_context_builder; pub(crate) use query_context_builder::{type_check_binary_operation, QueryContextBuilder}; +mod provable_expr_plan_builder; +pub(crate) use provable_expr_plan_builder::ProvableExprPlanBuilder; + mod where_expr_builder; pub(crate) use where_expr_builder::WhereExprBuilder; diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs new file mode 100644 index 000000000..0b08ad800 --- /dev/null +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -0,0 +1,142 @@ +use super::ConversionError; +use crate::{ + base::{ + commitment::Commitment, + database::{ColumnRef, LiteralValue}, + math::decimal::{try_into_to_scalar, Precision}, + }, + sql::ast::{ColumnExpr, ProvableExprPlan}, +}; +use proof_of_sql_parser::{ + intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, + Identifier, +}; +use std::collections::HashMap; + +/// Builder that enables building a `proofs::sql::ast::ProvableExprPlan` from +/// a `proof_of_sql_parser::intermediate_ast::Expression`. +pub struct ProvableExprPlanBuilder<'a> { + column_mapping: &'a HashMap, +} + +impl<'a> ProvableExprPlanBuilder<'a> { + /// Creates a new `ProvableExprPlanBuilder` with the given column mapping. + pub fn new(column_mapping: &'a HashMap) -> Self { + Self { column_mapping } + } + /// Builds a `proofs::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` + pub fn build( + &self, + expr: &Expression, + ) -> Result, ConversionError> { + self.visit_expr(expr) + } +} + +// Private interface +impl ProvableExprPlanBuilder<'_> { + fn visit_expr( + &self, + expr: &Expression, + ) -> Result, ConversionError> { + match expr { + Expression::Column(identifier) => self.visit_column(*identifier), + Expression::Literal(lit) => self.visit_literal(lit), + Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), + Expression::Unary { op, expr } => self.visit_unary_expr(*op, expr), + _ => Err(ConversionError::Unprovable(format!( + "Expression {:?} is not supported yet", + expr + ))), + } + } + + fn visit_column( + &self, + identifier: Identifier, + ) -> Result, ConversionError> { + Ok(ProvableExprPlan::Column(ColumnExpr::new( + *self.column_mapping.get(&identifier).ok_or( + ConversionError::MissingColumnWithoutTable(Box::new(identifier)), + )?, + ))) + } + + fn visit_literal( + &self, + lit: &Literal, + ) -> Result, ConversionError> { + match lit { + Literal::Boolean(b) => Ok(ProvableExprPlan::new_literal(LiteralValue::Boolean(*b))), + Literal::BigInt(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::BigInt(*i))), + Literal::Int128(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::Int128(*i))), + Literal::Decimal(d) => { + let scale = d.scale(); + let precision = Precision::new(d.precision()) + .map_err(|_| ConversionError::InvalidPrecision(d.precision()))?; + Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( + precision, + scale, + try_into_to_scalar(d, precision, scale)?, + ))) + } + Literal::VarChar(s) => Ok(ProvableExprPlan::new_literal(LiteralValue::VarChar(( + s.clone(), + s.into(), + )))), + } + } + + fn visit_unary_expr( + &self, + op: UnaryOperator, + expr: &Expression, + ) -> Result, ConversionError> { + let expr = self.visit_expr(expr); + match op { + UnaryOperator::Not => ProvableExprPlan::try_new_not(expr?), + } + } + + fn visit_binary_expr( + &self, + op: BinaryOperator, + left: &Expression, + right: &Expression, + ) -> Result, ConversionError> { + match op { + BinaryOperator::And => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_and(left?, right?) + } + BinaryOperator::Or => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_or(left?, right?) + } + BinaryOperator::Equal => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_equals(left?, right?) + } + BinaryOperator::GreaterThanOrEqual => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_inequality(left?, right?, false) + } + BinaryOperator::LessThanOrEqual => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_inequality(left?, right?, true) + } + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Multiply + | BinaryOperator::Division => Err(ConversionError::Unprovable(format!( + "Binary operator {:?} is not supported yet", + op + ))), + } + } +} diff --git a/crates/proof-of-sql/src/sql/parse/query_context.rs b/crates/proof-of-sql/src/sql/parse/query_context.rs index 716c8c395..2d927110a 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context.rs @@ -24,7 +24,6 @@ pub struct QueryContext { in_result_scope: bool, has_visited_group_by: bool, order_by_exprs: Vec, - fixed_col_ref_counter: usize, group_by_exprs: Vec, where_expr: Option>, result_column_set: HashSet, @@ -72,10 +71,11 @@ impl QueryContext { "aggregation context needs to be set before exiting" ); self.in_agg_scope = false; - return self.check_col_ref_counter(); + return Ok(()); } if self.in_agg_scope { + // TODO: Disable this once we support nested aggregations return Err(ConversionError::InvalidExpression( "nested aggregations are not supported".to_string(), )); @@ -84,10 +84,6 @@ impl QueryContext { self.agg_counter += 1; self.in_agg_scope = true; - // Resetting the counter to ensure that the - // aggregation expression references at least one column. - self.fixed_col_ref_counter = self.col_ref_counter; - Ok(()) } @@ -111,26 +107,10 @@ impl QueryContext { } } - fn check_col_ref_counter(&mut self) -> ConversionResult<()> { - if self.col_ref_counter == self.fixed_col_ref_counter { - return Err(ConversionError::InvalidExpression( - "at least one column must be referenced in the result expression".to_string(), - )); - } - - Ok(()) - } - pub fn push_aliased_result_expr(&mut self, expr: AliasedResultExpr) -> ConversionResult<()> { assert!(&self.has_visited_group_by, "Group by must be visited first"); - - self.check_col_ref_counter()?; self.res_aliased_exprs.push(expr); - // Resetting the counter to ensure consecutive aliased - // expression references include at least one column. - self.fixed_col_ref_counter = self.col_ref_counter; - Ok(()) } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr.rs b/crates/proof-of-sql/src/sql/parse/query_expr.rs index 150c5cd3c..b15868ea1 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr.rs @@ -1,4 +1,4 @@ -use super::{FilterExprBuilder, QueryContextBuilder, ResultExprBuilder}; +use super::{EnrichedExpr, FilterExprBuilder, QueryContextBuilder, ResultExprBuilder}; use crate::{ base::{commitment::Commitment, database::SchemaAccessor}, sql::{ @@ -72,15 +72,23 @@ impl QueryExpr { }); } } - + let column_mapping = context.get_column_mapping(); + let enriched_exprs = result_aliased_exprs + .iter() + .map(|aliased_expr| EnrichedExpr::new(aliased_expr.clone(), column_mapping.clone())) + .collect::>(); + let select_exprs = enriched_exprs + .iter() + .map(|enriched_expr| enriched_expr.residue_expression.clone()) + .collect::>(); let filter = FilterExprBuilder::new(context.get_column_mapping()) .add_table_expr(*context.get_table_ref()) .add_where_expr(context.get_where_expr().clone())? - .add_result_column_set(context.get_result_column_set()) + .add_result_columns(&enriched_exprs) .build(); let result = ResultExprBuilder::default() - .add_group_by_exprs(context.get_group_by_exprs(), result_aliased_exprs) - .add_select_exprs(result_aliased_exprs) + .add_group_by_exprs(context.get_group_by_exprs(), &select_exprs) + .add_select_exprs(&select_exprs) .add_order_by_exprs(context.get_order_by_exprs()?) .add_slice_expr(context.get_slice_expr()) .build(); diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index efd94e59b..e83615da4 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -126,8 +126,8 @@ macro_rules! expected_query { orderby_macro!($($order_by)?, $($order_dirs)?); macro_rules! filter_macro { - () => {dense_filter(cols_expr(t, &$result_columns, &accessor), tab(t), const_bool(true))}; - ($expr:expr) => { dense_filter(cols_expr(t, &$result_columns, &accessor), tab(t), $expr) }; + () => {dense_filter(cols_expr_plan(t, &$result_columns, &accessor), tab(t), const_bool(true))}; + ($expr:expr) => { dense_filter(cols_expr_plan(t, &$result_columns, &accessor), tab(t), $expr) }; } let filter = filter_macro!($($filter)?); @@ -148,7 +148,7 @@ fn we_can_convert_an_ast_with_one_column() { let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -170,7 +170,7 @@ fn we_can_convert_an_ast_with_one_column_and_i128_data() { let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3_i64)), ), @@ -192,7 +192,7 @@ fn we_can_convert_an_ast_with_one_column_and_a_filter_by_a_string_literal() { let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 'abc'", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_varchar("abc")), ), @@ -243,11 +243,11 @@ fn we_dont_have_duplicate_filter_result_expressions() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + aliased_cols_expr_plan(t, &[("a", "b"), ("a", "c")], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), - result(&[("a", "b"), ("a", "c")]), + result(&[("b", "b"), ("c", "c")]), ); assert_eq!(ast, expected_ast); } @@ -267,7 +267,7 @@ fn we_can_convert_an_ast_with_two_columns() { let ast = query_to_provable_ast(t, "select a, b from sxt_tab where c = 123", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["a", "b"], &accessor), tab(t), equal(column(t, "c", &accessor), const_bigint(123)), ), @@ -290,7 +290,7 @@ fn we_can_parse_all_result_columns_with_select_star() { let ast = query_to_provable_ast(t, "select * from sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["b", "a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -313,7 +313,7 @@ fn we_can_convert_an_ast_with_one_positive_cond() { let ast = query_to_provable_ast(t, "select a from sxt_tab where b = +4", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "b", &accessor), const_bigint(4)), ), @@ -336,7 +336,7 @@ fn we_can_convert_an_ast_with_one_not_equals_cond() { let ast = query_to_provable_ast(t, "select a from sxt_tab where b <> +4", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), not(equal(column(t, "b", &accessor), const_bigint(4))), ), @@ -359,7 +359,7 @@ fn we_can_convert_an_ast_with_one_negative_cond() { let ast = query_to_provable_ast(t, "select a from sxt_tab where b <= -4", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), lte(column(t, "b", &accessor), const_bigint(-4)), ), @@ -387,7 +387,7 @@ fn we_can_convert_an_ast_with_cond_and() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), and( equal(column(t, "b", &accessor), const_bigint(3)), @@ -418,7 +418,7 @@ fn we_can_convert_an_ast_with_cond_or() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), or( equal(column(t, "b", &accessor), const_bigint(3)), @@ -449,7 +449,7 @@ fn we_can_convert_an_ast_with_conds_or_not() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), or( lte(column(t, "b", &accessor), const_bigint(3)), @@ -476,12 +476,21 @@ fn we_can_convert_an_ast_with_conds_not_and_or() { ); let ast = query_to_provable_ast( t, - "select a from sxt_tab where not (((f >= 45) or (c <= -2)) and (b = 3))", + "select a, not (a = b or c = f) as boolean from sxt_tab where not (((f >= 45) or (c <= -2)) and (b = 3))", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + vec![ + col_expr_plan(t, "a", &accessor), + ( + not(or( + equal(column(t, "a", &accessor), column(t, "b", &accessor)), + equal(column(t, "c", &accessor), column(t, "f", &accessor)), + )), + "boolean".parse().unwrap(), + ), + ], tab(t), not(and( or( @@ -491,13 +500,13 @@ fn we_can_convert_an_ast_with_conds_not_and_or() { equal(column(t, "b", &accessor), const_bigint(3)), )), ), - result(&[("a", "a")]), + result(&[("a", "a"), ("boolean", "boolean")]), ); assert_eq!(ast, expected_ast); } #[test] -fn we_can_convert_an_ast_with_the_min_i128_filter_value() { +fn we_can_convert_an_ast_with_the_min_i128_filter_value_and_const() { let t = "sxt.sxt_tab".parse().unwrap(); let accessor = record_batch_to_accessor( t, @@ -508,22 +517,25 @@ fn we_can_convert_an_ast_with_the_min_i128_filter_value() { ); let ast = query_to_provable_ast( t, - "select a from sxt_tab where a = -170141183460469231731687303715884105728", + "select a, -170141183460469231731687303715884105728 as b from sxt_tab where a = -170141183460469231731687303715884105728", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + vec![ + col_expr_plan(t, "a", &accessor), + (const_int128(i128::MIN), "b".parse().unwrap()), + ], tab(t), equal(column(t, "a", &accessor), const_int128(i128::MIN)), ), - result(&[("a", "a")]), + result(&[("a", "a"), ("b", "b")]), ); assert_eq!(ast, expected_ast); } #[test] -fn we_can_convert_an_ast_with_the_max_i128_filter_value() { +fn we_can_convert_an_ast_with_the_max_i128_filter_value_and_const() { let t = "sxt.sxt_tab".parse().unwrap(); let accessor = record_batch_to_accessor( t, @@ -534,16 +546,19 @@ fn we_can_convert_an_ast_with_the_max_i128_filter_value() { ); let ast = query_to_provable_ast( t, - "select a from sxt_tab where a = 170141183460469231731687303715884105727", + "select a, 170141183460469231731687303715884105727 as ma from sxt_tab where a = 170141183460469231731687303715884105727", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + vec![ + col_expr_plan(t, "a", &accessor), + (const_int128(i128::MAX), "ma".parse().unwrap()), + ], tab(t), equal(column(t, "a", &accessor), const_int128(i128::MAX)), ), - result(&[("a", "a")]), + result(&[("a", "a"), ("ma", "ma")]), ); assert_eq!(ast, expected_ast); } @@ -561,16 +576,22 @@ fn we_can_convert_an_ast_using_an_aliased_column() { ); let ast = query_to_provable_ast( t, - "select a as b_rename from sxt_tab where b >= +4", + "select a as b_rename, a = b as boolean from sxt_tab where b >= +4", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( - vec![col_expr(t, "a", &accessor)], + vec![ + aliased_col_expr_plan(t, "a", "b_rename", &accessor), + ( + equal(column(t, "a", &accessor), column(t, "b", &accessor)), + "boolean".parse().unwrap(), + ), + ], tab(t), gte(column(t, "b", &accessor), const_bigint(4)), ), - result(&[("a", "b_rename")]), + result(&[("b_rename", "b_rename"), ("boolean", "boolean")]), ); assert_eq!(ast, expected_ast); } @@ -614,7 +635,7 @@ fn we_can_convert_an_ast_with_a_schema() { let ast = query_to_provable_ast(t, "select a from eth.sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -634,10 +655,20 @@ fn we_can_convert_an_ast_without_any_dense_filter() { 0, ); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), - result(&[("a", "a")]), + dense_filter( + vec![ + col_expr_plan(t, "a", &accessor), + (const_bigint(3), "b".parse().unwrap()), + ], + tab(t), + const_bool(true), + ), + result(&[("a", "a"), ("b", "b")]), ); - let queries = ["select * from eth.sxt_tab", "select a from eth.sxt_tab"]; + let queries = [ + "select *, 3 as b from eth.sxt_tab", + "select a, 3 as b from eth.sxt_tab", + ]; for query in queries { let ast = query_to_provable_ast(t, query, &accessor); assert_eq!(ast, expected_ast); @@ -661,7 +692,7 @@ fn we_can_parse_order_by_with_a_single_column() { let ast = query_to_provable_ast(t, "select * from sxt_tab where a = 3 order by b", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["b", "a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -691,7 +722,7 @@ fn we_can_parse_order_by_with_multiple_columns() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["a", "b"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -723,14 +754,14 @@ fn we_can_parse_order_by_referencing_an_alias_associated_with_column_b_but_with_ let expected_ast = QueryExpr::new( dense_filter( vec![ - col_expr(t, "name", &accessor), - col_expr(t, "salary", &accessor), + aliased_col_expr_plan(t, "salary", "s", &accessor), + aliased_col_expr_plan(t, "name", "salary", &accessor), ], tab(t), equal(column(t, "salary", &accessor), const_bigint(5)), ), composite_result(vec![ - select(&[pc("salary").alias("s"), pc("name").alias("salary")]), + select(&[pc("s").alias("s"), pc("salary").alias("salary")]), orders(&["salary"], &[Desc]), ]), ); @@ -836,17 +867,18 @@ fn we_can_parse_order_by_queries_with_the_same_column_name_appearing_more_than_o let expected_ast = QueryExpr::new( dense_filter( vec![ - col_expr(t, "name", &accessor), - col_expr(t, "salary", &accessor), + aliased_col_expr_plan(t, "salary", "s", &accessor), + col_expr_plan(t, "name", &accessor), + aliased_col_expr_plan(t, "salary", "d", &accessor), ], tab(t), const_bool(true), ), composite_result(vec![ select(&[ - pc("salary").alias("s"), + pc("s").alias("s"), pc("name").alias("name"), - pc("salary").alias("d"), + pc("d").alias("d"), ]), orders(&[order_by], &[Asc]), ]), @@ -872,7 +904,11 @@ fn we_can_parse_a_query_having_a_simple_limit_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab limit 3", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(3, 0)]), ); assert_eq!(ast, expected_ast); @@ -891,7 +927,11 @@ fn no_slice_is_applied_when_limit_is_u64_max_and_offset_is_zero() { let ast = query_to_provable_ast(t, "select a from sxt_tab offset 0", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")])]), ); assert_eq!(ast, expected_ast); @@ -910,7 +950,11 @@ fn we_can_parse_a_query_having_a_simple_positive_offset_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab offset 7", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(u64::MAX, 7)]), ); assert_eq!(ast, expected_ast); @@ -929,7 +973,11 @@ fn we_can_parse_a_query_having_a_negative_offset_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab offset -7", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(u64::MAX, -7)]), ); assert_eq!(ast, expected_ast); @@ -948,7 +996,11 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab limit 55 offset 3", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(55, 3)]), ); assert_eq!(ast, expected_ast); @@ -965,23 +1017,33 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause_preceded_by_wher t, record_batch!( "a" => [5_i64], + "boolean" => [true], ), 0, ); let ast = query_to_provable_ast( t, - "select a from sxt_tab where a = -3 order by a desc limit 55 offset 3", + "select a, boolean and a >= 4 as res from sxt_tab where a = -3 order by a desc limit 55 offset 3", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + vec![ + col_expr_plan(t, "a", &accessor), + ( + and( + column(t, "boolean", &accessor), + gte(column(t, "a", &accessor), const_bigint(4)), + ), + "res".parse().unwrap(), + ), + ], tab(t), equal(column(t, "a", &accessor), const_bigint(-3)), ), composite_result(vec![ - select(&[pc("a").alias("a")]), + select(&[pc("a").alias("a"), pc("res").alias("res")]), orders(&["a"], &[Desc]), slice(55, 3), ]), @@ -1166,21 +1228,27 @@ fn we_can_group_by_without_using_aggregate_functions() { let ast = query_to_provable_ast( t, - "select department from employees group by department", + "select department, true as is_remote from employees group by department", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["department"], &accessor), + vec![ + (const_bool(true), "is_remote".parse().unwrap()), + col_expr_plan(t, "department", &accessor), + ], tab(t), const_bool(true), ), composite_result(vec![ groupby( vec![pc("department")], - vec![pc("department").first().alias("department")], + vec![ + pc("department").first().alias("department"), + pc("is_remote").alias("is_remote"), + ], ), - select(&[pc("department")]), + select(&[pc("department"), pc("is_remote")]), ]), ); assert_eq!(ast, expected_ast); @@ -1328,7 +1396,7 @@ fn we_can_parse_a_query_having_group_by_with_the_same_name_as_the_aggregation_ex ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["bonus", "department"], &accessor), + cols_expr_plan(t, &["bonus", "department"], &accessor), tab(t), const_bool(true), ), @@ -1362,7 +1430,7 @@ fn count_aggregate_functions_can_be_used_with_non_numeric_columns() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["bonus", "department"], &accessor), + cols_expr_plan(t, &["bonus", "department"], &accessor), tab(t), const_bool(true), ), @@ -1400,7 +1468,7 @@ fn count_all_uses_the_first_group_by_identifier_as_default_result_column() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["department"], &accessor), + cols_expr_plan(t, &["department", "salary"], &accessor), tab(t), equal(column(t, "salary", &accessor), const_bigint(4)), ), @@ -1454,7 +1522,7 @@ fn we_can_use_the_same_result_columns_with_different_aliases_and_associate_it_wi ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["department"], &accessor), + cols_expr_plan(t, &["department"], &accessor), tab(t), const_bool(true), ), @@ -1516,7 +1584,7 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b", "f", "h"], &accessor), + cols_expr_plan(t, &["a", "b", "f", "h"], &accessor), tab(t), const_bool(true), ), @@ -1554,7 +1622,7 @@ fn we_can_parse_multiple_arithmetic_expression_where_multiplication_has_preceden ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["c", "f", "g", "h"], &accessor), + cols_expr_plan(t, &["c", "f", "g", "h"], &accessor), tab(t), const_bool(true), ), @@ -1587,7 +1655,7 @@ fn we_can_parse_arithmetic_expression_within_aggregations_in_the_result_expr() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["c", "f"], &accessor), + cols_expr_plan(t, &["c", "f"], &accessor), tab(t), const_bool(true), ), @@ -1605,34 +1673,6 @@ fn we_can_parse_arithmetic_expression_within_aggregations_in_the_result_expr() { assert_eq!(ast, expected_ast); } -#[test] -fn we_need_to_reference_at_least_one_column_in_the_result_expr() { - assert_eq!( - query!(select: ["i", "-123 "], should_err: true), - ConversionError::InvalidExpression( - "at least one column must be referenced in the result expression".to_string() - ) - ); - assert_eq!( - query!(select: ["sum(-123)"], should_err: true), - ConversionError::InvalidExpression( - "at least one column must be referenced in the result expression".to_string() - ) - ); - assert_eq!( - query!(select: ["i + sum(-123)"], group: ["i"], should_err: true), - ConversionError::InvalidExpression( - "at least one column must be referenced in the result expression".to_string() - ) - ); - assert_eq!( - query!(select: ["sum(-123) + i"], group: ["i"], should_err: true), - ConversionError::InvalidExpression( - "at least one column must be referenced in the result expression".to_string() - ) - ); -} - #[test] fn we_cannot_use_non_grouped_columns_outside_agg() { assert_eq!( diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs index f40d88f4a..c46a46785 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs @@ -1,27 +1,25 @@ -use super::ConversionError; +use super::{ConversionError, ProvableExprPlanBuilder}; use crate::{ base::{ commitment::Commitment, - database::{ColumnRef, ColumnType, LiteralValue}, - math::decimal::{try_into_to_scalar, Precision}, + database::{ColumnRef, ColumnType}, }, - sql::ast::{ColumnExpr, ProvableExpr, ProvableExprPlan}, -}; -use proof_of_sql_parser::{ - intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, - Identifier, + sql::ast::{ProvableExpr, ProvableExprPlan}, }; +use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; use std::collections::HashMap; /// Builder that enables building a `proof_of_sql::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` that is /// intended to be used as the where clause in a filter expression or group by expression. pub struct WhereExprBuilder<'a> { - column_mapping: &'a HashMap, + builder: ProvableExprPlanBuilder<'a>, } impl<'a> WhereExprBuilder<'a> { /// Creates a new `WhereExprBuilder` with the given column mapping. pub fn new(column_mapping: &'a HashMap) -> Self { - Self { column_mapping } + Self { + builder: ProvableExprPlanBuilder::new(column_mapping), + } } /// Builds a `proof_of_sql::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` that is /// intended to be used as the where clause in a filter expression or group by expression. @@ -31,7 +29,7 @@ impl<'a> WhereExprBuilder<'a> { ) -> Result>, ConversionError> { where_expr .map(|where_expr| { - let expr_plan = self.visit_expr(*where_expr)?; + let expr_plan = self.builder.build(&where_expr)?; // Ensure that the expression is a boolean expression match expr_plan.data_type() { ColumnType::Boolean => Ok(expr_plan), @@ -43,108 +41,3 @@ impl<'a> WhereExprBuilder<'a> { .transpose() } } - -// Private interface -impl WhereExprBuilder<'_> { - fn visit_expr( - &self, - expr: proof_of_sql_parser::intermediate_ast::Expression, - ) -> Result, ConversionError> { - match expr { - Expression::Column(identifier) => self.visit_column(identifier), - Expression::Literal(lit) => self.visit_literal(lit), - Expression::Binary { op, left, right } => self.visit_binary_expr(op, *left, *right), - Expression::Unary { op, expr } => self.visit_unary_expr(op, *expr), - _ => panic!("The parser must ensure that the expression is a boolean expression"), - } - } - - fn visit_column( - &self, - identifier: Identifier, - ) -> Result, ConversionError> { - Ok(ProvableExprPlan::Column(ColumnExpr::new( - *self.column_mapping.get(&identifier).ok_or( - ConversionError::MissingColumnWithoutTable(Box::new(identifier)), - )?, - ))) - } - - fn visit_literal( - &self, - lit: Literal, - ) -> Result, ConversionError> { - match lit { - Literal::Boolean(b) => Ok(ProvableExprPlan::new_literal(LiteralValue::Boolean(b))), - Literal::BigInt(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::BigInt(i))), - Literal::Int128(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::Int128(i))), - Literal::Decimal(d) => { - let scale = d.scale(); - let precision = Precision::new(d.precision()) - .map_err(|_| ConversionError::InvalidPrecision(d.precision()))?; - Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( - precision, - scale, - try_into_to_scalar(&d, precision, scale)?, - ))) - } - Literal::VarChar(s) => Ok(ProvableExprPlan::new_literal(LiteralValue::VarChar(( - s.clone(), - s.into(), - )))), - } - } - - fn visit_unary_expr( - &self, - op: UnaryOperator, - expr: Expression, - ) -> Result, ConversionError> { - let expr = self.visit_expr(expr); - match op { - UnaryOperator::Not => ProvableExprPlan::try_new_not(expr?), - } - } - - fn visit_binary_expr( - &self, - op: BinaryOperator, - left: Expression, - right: Expression, - ) -> Result, ConversionError> { - match op { - BinaryOperator::And => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_and(left?, right?) - } - BinaryOperator::Or => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_or(left?, right?) - } - BinaryOperator::Equal => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_equals(left?, right?) - } - BinaryOperator::GreaterThanOrEqual => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_inequality(left?, right?, false) - } - BinaryOperator::LessThanOrEqual => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_inequality(left?, right?, true) - } - BinaryOperator::Add - | BinaryOperator::Subtract - | BinaryOperator::Multiply - | BinaryOperator::Division => Err(ConversionError::Unprovable(format!( - "Binary operator {:?} is not supported in the where clause", - op - ))), - } - } -} diff --git a/crates/proof-of-sql/tests/integration_tests.rs b/crates/proof-of-sql/tests/integration_tests.rs index 6c6f69261..1d56894b3 100644 --- a/crates/proof-of-sql/tests/integration_tests.rs +++ b/crates/proof-of-sql/tests/integration_tests.rs @@ -284,7 +284,7 @@ fn we_can_prove_a_complex_query_with_curve25519() { 0, ); let query = QueryExpr::try_new( - "SELECT * FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f" + "SELECT *, 45 as g, (a = b) or f as h FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f" .parse() .unwrap(), "sxt".parse().unwrap(), @@ -304,6 +304,8 @@ fn we_can_prove_a_complex_query_with_curve25519() { bigint("d", [3]), varchar("e", ["f"]), boolean("f", [false]), + bigint("g", [45]), + boolean("h", [false]), ]); assert_eq!(owned_table_result, expected_result); } @@ -329,7 +331,7 @@ fn we_can_prove_a_complex_query_with_dory() { 0, ); let query = QueryExpr::try_new( - "SELECT * FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f" + "SELECT *, 32 as g, (c >= d) and f as h FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f" .parse() .unwrap(), "sxt".parse().unwrap(), @@ -354,6 +356,8 @@ fn we_can_prove_a_complex_query_with_dory() { bigint("d", [1]), varchar("e", ["d"]), boolean("f", [true]), + bigint("g", [32]), + boolean("h", [true]), ]); assert_eq!(owned_table_result, expected_result); }