diff --git a/crates/proof-of-sql/src/sql/ast/and_expr.rs b/crates/proof-of-sql/src/sql/ast/and_expr.rs index bf2d128b1..504a14ce1 100644 --- a/crates/proof-of-sql/src/sql/ast/and_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/and_expr.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable logical AND expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AndExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs index 254336c88..a6ec6c7e8 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs @@ -2,7 +2,7 @@ use super::{ dense_filter_util::{fold_columns, fold_vals}, filter_columns, provable_expr_plan::ProvableExprPlan, - ColumnExpr, ProvableExpr, TableExpr, + ProvableExpr, TableExpr, }; use crate::{ base::{ @@ -22,6 +22,7 @@ use crate::{ use bumpalo::Bump; use core::iter::repeat_with; use num_traits::{One, Zero}; +use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, marker::PhantomData}; @@ -33,7 +34,7 @@ use std::{collections::HashSet, marker::PhantomData}; /// This differs from the [`FilterExpr`] in that the result is not a sparse table. #[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct OstensibleDenseFilterExpr { - pub(super) results: Vec>, + pub(super) results: Vec<(ProvableExprPlan, Identifier)>, pub(super) table: TableExpr, pub(super) where_clause: ProvableExprPlan, phantom: PhantomData, @@ -42,7 +43,7 @@ pub struct OstensibleDenseFilterExpr { impl OstensibleDenseFilterExpr { /// Creates a new dense_filter expression. pub fn new( - results: Vec>, + results: Vec<(ProvableExprPlan, Identifier)>, table: TableExpr, where_clause: ProvableExprPlan, ) -> Self { @@ -65,8 +66,8 @@ where _accessor: &dyn MetadataAccessor, ) -> Result<(), ProofError> { self.where_clause.count(builder)?; - for expr in self.results.iter() { - expr.count(builder)?; + for aliased_expr in self.results.iter() { + aliased_expr.0.count(builder)?; builder.count_result_columns(1); } builder.count_intermediate_mles(2); @@ -96,7 +97,7 @@ where let columns_evals = Vec::from_iter( self.results .iter() - .map(|expr| expr.verifier_evaluate(builder, accessor)) + .map(|(expr, _)| expr.verifier_evaluate(builder, accessor)) .collect::, _>>()?, ); // 3. indexes @@ -122,18 +123,17 @@ where } fn get_column_result_fields(&self) -> Vec { - let mut columns = Vec::with_capacity(self.results.len()); - for col in self.results.iter() { - columns.push(col.get_column_field()); - } - columns + self.results + .iter() + .map(|(expr, alias)| ColumnField::new(*alias, expr.data_type())) + .collect() } fn get_column_references(&self) -> HashSet { let mut columns = HashSet::new(); - for col in self.results.iter() { - columns.insert(col.get_column_reference()); + for (col, _) in self.results.iter() { + col.get_column_references(&mut columns); } self.where_clause.get_column_references(&mut columns); @@ -165,7 +165,7 @@ impl ProverEvaluate for DenseFilterExpr { let columns = Vec::from_iter( self.results .iter() - .map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)), + .map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)), ); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); @@ -197,7 +197,7 @@ impl ProverEvaluate for DenseFilterExpr { let columns = Vec::from_iter( self.results .iter() - .map(|expr| expr.prover_evaluate(builder, alloc, accessor)), + .map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)), ); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs index b0c75000c..e0e1276a9 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs @@ -18,7 +18,9 @@ use crate::{ ast::{ // Making this explicit to ensure that we don't accidentally use the // sparse filter for these tests - test_utility::{cols_expr, column, const_int128, dense_filter, equal, tab}, + test_utility::{ + col_expr_plan, cols_expr_plan, column, const_int128, dense_filter, equal, tab, + }, ColumnExpr, DenseFilterExpr, LiteralExpr, @@ -40,18 +42,26 @@ use std::{collections::HashSet, sync::Arc}; #[test] fn we_can_correctly_fetch_the_query_result_schema() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); + let a = Identifier::try_new("a").unwrap(); + let b = Identifier::try_new("b").unwrap(); let provable_ast = DenseFilterExpr::::new( vec![ - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("a").unwrap(), - ColumnType::BigInt, - )), - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("b").unwrap(), - ColumnType::BigInt, - )), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + a, + ColumnType::BigInt, + ))), + a, + ), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + b, + ColumnType::BigInt, + ))), + b, + ), ], TableExpr { table_ref }, ProvableExprPlan::try_new_equals( @@ -84,18 +94,26 @@ fn we_can_correctly_fetch_the_query_result_schema() { #[test] fn we_can_correctly_fetch_all_the_referenced_columns() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); + let a = Identifier::try_new("a").unwrap(); + let f = Identifier::try_new("f").unwrap(); let provable_ast = DenseFilterExpr::new( vec![ - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("a").unwrap(), - ColumnType::BigInt, - )), - ColumnExpr::new(ColumnRef::new( - table_ref, - Identifier::try_new("f").unwrap(), - ColumnType::BigInt, - )), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + a, + ColumnType::BigInt, + ))), + a, + ), + ( + ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new( + table_ref, + f, + ColumnType::BigInt, + ))), + f, + ), ], TableExpr { table_ref }, not::(and( @@ -170,7 +188,7 @@ fn we_can_prove_and_get_the_correct_result_from_a_basic_dense_filter() { let mut accessor = RecordBatchTestAccessor::new_empty(); accessor.add_table(t, data, 0); let where_clause = equal(column(t, "a", &accessor), const_int128(5_i128)); - let expr = dense_filter(cols_expr(t, &["b"], &accessor), tab(t), where_clause); + let expr = dense_filter(cols_expr_plan(t, &["b"], &accessor), tab(t), where_clause); let res = VerifiableQueryResult::::new(&expr, &accessor, &()); let res = res .verify(&expr, &accessor, &()) @@ -197,7 +215,7 @@ fn we_can_get_an_empty_result_from_a_basic_dense_filter_on_an_empty_table_using_ let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(999)); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), where_clause, ); @@ -242,7 +260,7 @@ fn we_can_get_an_empty_result_from_a_basic_dense_filter_using_result_evaluate() let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(999)); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), where_clause, ); @@ -287,7 +305,7 @@ fn we_can_get_no_columns_from_a_basic_dense_filter_with_no_selected_columns_usin accessor.add_table(t, data, 0); let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(5)); - let expr = dense_filter(cols_expr(t, &[], &accessor), tab(t), where_clause); + let expr = dense_filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause); let alloc = Bump::new(); let mut builder = ResultBuilder::new(5); expr.result_evaluate(&mut builder, &alloc, &accessor); @@ -315,7 +333,7 @@ fn we_can_get_the_correct_result_from_a_basic_dense_filter_using_result_evaluate let where_clause: ProvableExprPlan = equal(column(t, "a", &accessor), const_int128(5)); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), where_clause, ); @@ -357,7 +375,7 @@ fn we_can_prove_a_dense_filter_on_an_empty_table() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(106)), ); @@ -386,7 +404,7 @@ fn we_can_prove_a_dense_filter_with_empty_results() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(106)), ); @@ -406,8 +424,8 @@ fn we_can_prove_a_dense_filter_with_empty_results() { fn we_can_prove_a_dense_filter() { let data = owned_table([ bigint("a", [101, 104, 105, 102, 105]), - bigint("b", [1, 2, 3, 4, 5]), - int128("c", [1, 2, 3, 4, 5]), + bigint("b", [1, 2, 3, 4, 7]), + int128("c", [1, 3, 3, 4, 5]), varchar("d", ["1", "2", "3", "4", "5"]), scalar("e", [1, 2, 3, 4, 5]), ]); @@ -415,7 +433,17 @@ fn we_can_prove_a_dense_filter() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = dense_filter( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + vec![ + col_expr_plan(t, "b", &accessor), + col_expr_plan(t, "c", &accessor), + col_expr_plan(t, "d", &accessor), + col_expr_plan(t, "e", &accessor), + (const_int128(105), "const".parse().unwrap()), + ( + equal(column(t, "b", &accessor), column(t, "c", &accessor)), + "bool".parse().unwrap(), + ), + ], tab(t), equal(column(t, "a", &accessor), const_int128(105)), ); @@ -423,10 +451,12 @@ fn we_can_prove_a_dense_filter() { exercise_verification(&res, &expr, &accessor, t); let res = res.verify(&expr, &accessor, &()).unwrap().table; let expected = owned_table([ - bigint("b", [3, 5]), + bigint("b", [3, 7]), int128("c", [3, 5]), varchar("d", ["3", "5"]), scalar("e", [3, 5]), + int128("const", [105, 105]), + boolean("bool", [true, false]), ]); assert_eq!(res, expected); } diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs index 05172a298..ab91f4231 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test_dishonest_prover.rs @@ -11,7 +11,7 @@ use crate::{ sql::{ // Making this explicit to ensure that we don't accidentally use the // sparse filter for these tests - ast::test_utility::{cols_expr, column, const_int128, equal, tab}, + ast::test_utility::{cols_expr_plan, column, const_int128, equal, tab}, proof::{ Indexes, ProofBuilder, ProverEvaluate, ProverHonestyMarker, QueryError, ResultBuilder, VerifiableQueryResult, @@ -51,7 +51,7 @@ impl ProverEvaluate for DishonestDenseFilterExpr for DishonestDenseFilterExpr::new_empty_with_setup(()); accessor.add_table(t, data, 0); let expr = DishonestDenseFilterExpr::new( - cols_expr(t, &["b", "c", "d", "e"], &accessor), + cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(105_i128)), ); diff --git a/crates/proof-of-sql/src/sql/ast/equals_expr.rs b/crates/proof-of-sql/src/sql/ast/equals_expr.rs index 852990156..34befde36 100644 --- a/crates/proof-of-sql/src/sql/ast/equals_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/equals_expr.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable AST expression for an equals expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct EqualsExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/inequality_expr.rs b/crates/proof-of-sql/src/sql/ast/inequality_expr.rs index 69d21ab65..b6a8afc01 100644 --- a/crates/proof-of-sql/src/sql/ast/inequality_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/inequality_expr.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable AST expression for an inequality expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct InequalityExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/literal_expr.rs b/crates/proof-of-sql/src/sql/ast/literal_expr.rs index 47f8aa224..4f882cc4f 100644 --- a/crates/proof-of-sql/src/sql/ast/literal_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/literal_expr.rs @@ -23,7 +23,7 @@ use std::collections::HashSet; /// While this wouldn't be as efficient as using a new custom expression for /// such queries, it allows us to easily support projects with minimal code /// changes, and the performance is sufficient for present. -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LiteralExpr { value: LiteralValue, } diff --git a/crates/proof-of-sql/src/sql/ast/not_expr.rs b/crates/proof-of-sql/src/sql/ast/not_expr.rs index b5f09d5c8..8b6a8477e 100644 --- a/crates/proof-of-sql/src/sql/ast/not_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/not_expr.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable logical NOT expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct NotExpr { expr: Box>, } diff --git a/crates/proof-of-sql/src/sql/ast/or_expr.rs b/crates/proof-of-sql/src/sql/ast/or_expr.rs index 82f2e041e..b2f405ec7 100644 --- a/crates/proof-of-sql/src/sql/ast/or_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/or_expr.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashSet; /// Provable logical OR expression -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct OrExpr { lhs: Box>, rhs: Box>, diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 73ffcef47..5ffaa6162 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use std::{collections::HashSet, fmt::Debug}; /// Enum of AST column expression types that implement `ProvableExpr`. Is itself a `ProvableExpr`. -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub enum ProvableExprPlan { /// Column Column(ColumnExpr), diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index f7cc315ee..4ed303123 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -6,6 +6,7 @@ use crate::base::{ commitment::Commitment, database::{ColumnField, ColumnRef, ColumnType, LiteralValue, SchemaAccessor, TableRef}, }; +use proof_of_sql_parser::Identifier; pub fn col_ref(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> ColumnRef { let name = name.parse().unwrap(); @@ -112,6 +113,51 @@ pub fn filter( ProofPlan::Filter(FilterExpr::new(results, table, where_clause)) } +pub fn aliased_col_expr_plan( + tab: TableRef, + old_name: &str, + new_name: &str, + accessor: &impl SchemaAccessor, +) -> (ProvableExprPlan, Identifier) { + ( + ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, old_name, accessor))), + new_name.parse().unwrap(), + ) +} + +pub fn col_expr_plan( + tab: TableRef, + name: &str, + accessor: &impl SchemaAccessor, +) -> (ProvableExprPlan, Identifier) { + ( + ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, name, accessor))), + name.parse().unwrap(), + ) +} + +pub fn aliased_cols_expr_plan( + tab: TableRef, + names: &[(&str, &str)], + accessor: &impl SchemaAccessor, +) -> Vec<(ProvableExprPlan, Identifier)> { + names + .iter() + .map(|(old_name, new_name)| aliased_col_expr_plan(tab, old_name, new_name, accessor)) + .collect() +} + +pub fn cols_expr_plan( + tab: TableRef, + names: &[&str], + accessor: &impl SchemaAccessor, +) -> Vec<(ProvableExprPlan, Identifier)> { + names + .iter() + .map(|name| col_expr_plan(tab, name, accessor)) + .collect() +} + pub fn col_expr( tab: TableRef, name: &str, @@ -132,7 +178,7 @@ pub fn cols_expr( } pub fn dense_filter( - results: Vec>, + results: Vec<(ProvableExprPlan, Identifier)>, table: TableExpr, where_clause: ProvableExprPlan, ) -> ProofPlan { diff --git a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs new file mode 100644 index 000000000..80f3eecc5 --- /dev/null +++ b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs @@ -0,0 +1,66 @@ +use super::ProvableExprPlanBuilder; +use crate::{ + base::{commitment::Commitment, database::ColumnRef}, + sql::ast::ProvableExprPlan, +}; +use proof_of_sql_parser::{ + intermediate_ast::{AliasedResultExpr, Expression}, + Identifier, +}; +use std::collections::HashMap; +/// Enriched expression +/// +/// An enriched expression consists of an `proof_of_sql_parser::intermediate_ast::AliasedResultExpr` +/// and an optional `ProvableExprPlan`. +/// If the `ProvableExprPlan` is `None`, the `EnrichedExpr` is not provable. +pub struct EnrichedExpr { + /// The remaining expression after the provable expression plan has been extracted. + pub residue_expression: AliasedResultExpr, + /// The extracted provable expression plan if it exists. + pub provable_expr_plan: Option>, +} + +impl EnrichedExpr { + /// Create a new `EnrichedExpr` with a provable expression. + /// + /// If the expression is not provable, the `provable_expr_plan` will be `None`. + /// Otherwise the `provable_expr_plan` will contain the provable expression plan + /// and the `residue_expression` will contain the remaining expression. + pub fn new( + expression: AliasedResultExpr, + column_mapping: HashMap, + ) -> Self { + let res_provable_expr_plan = + ProvableExprPlanBuilder::new(&column_mapping).build(&expression.expr); + match res_provable_expr_plan { + Ok(provable_expr_plan) => { + let alias = expression.alias.clone(); + Self { + residue_expression: AliasedResultExpr { + expr: Box::new(Expression::Column(alias)), + alias, + }, + provable_expr_plan: Some(provable_expr_plan), + } + } + Err(_) => Self { + residue_expression: expression, + provable_expr_plan: None, + }, + } + } + + /// Get the alias of the `EnrichedExpr`. + /// + /// Since we plan to support unaliased expressions in the future, this method returns an `Option`. + #[allow(dead_code)] + pub fn get_alias(&self) -> Option<&Identifier> { + self.residue_expression.try_as_identifier() + } + + /// Is the `EnrichedExpr` provable + #[allow(dead_code)] + pub fn is_provable(&self) -> bool { + self.provable_expr_plan.is_some() + } +} diff --git a/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs index 762ed8164..2cccee3bf 100644 --- a/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs @@ -1,18 +1,19 @@ -use super::{where_expr_builder::WhereExprBuilder, ConversionError}; +use super::{where_expr_builder::WhereExprBuilder, ConversionError, EnrichedExpr}; use crate::{ base::{ commitment::Commitment, database::{ColumnRef, LiteralValue, TableRef}, }, - sql::ast::{ColumnExpr, DenseFilterExpr, ProvableExprPlan, TableExpr}, + sql::ast::{DenseFilterExpr, ProvableExprPlan, TableExpr}, }; +use itertools::Itertools; use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; pub struct FilterExprBuilder { table_expr: Option, where_expr: Option>, - filter_result_expr_list: Vec>, + filter_result_expr_list: Vec<(ProvableExprPlan, Identifier)>, column_mapping: HashMap, } @@ -40,16 +41,28 @@ impl FilterExprBuilder { Ok(self) } - pub fn add_result_column_set(mut self, columns: HashSet) -> Self { - // Sorting is required to make the relative order of the columns deterministic - let mut columns = columns.into_iter().collect::>(); - columns.sort(); - - columns.into_iter().for_each(|column| { - let column = *self.column_mapping.get(&column).unwrap(); - self.filter_result_expr_list.push(ColumnExpr::new(column)); - }); - + pub fn add_result_columns(mut self, columns: &[EnrichedExpr]) -> Self { + // If a column is provable, add it to the filter result expression list + // If at least one column is non-provable, add all columns from the column mapping to the filter result expression list + let mut has_nonprovable_column = false; + for enriched_expr in columns { + if let Some(plan) = &enriched_expr.provable_expr_plan { + self.filter_result_expr_list + .push((plan.clone(), enriched_expr.residue_expression.alias)); + } else { + has_nonprovable_column = true; + } + } + if has_nonprovable_column { + // Has to keep them sorted to have deterministic order for tests + for alias in self.column_mapping.keys().sorted() { + let column_ref = self.column_mapping.get(alias).unwrap(); + self.filter_result_expr_list.push(( + ProvableExprPlan::new_column(column_ref.clone()), + alias.clone(), + )); + } + } self } diff --git a/crates/proof-of-sql/src/sql/parse/mod.rs b/crates/proof-of-sql/src/sql/parse/mod.rs index d15b7e31a..f79289054 100644 --- a/crates/proof-of-sql/src/sql/parse/mod.rs +++ b/crates/proof-of-sql/src/sql/parse/mod.rs @@ -4,6 +4,9 @@ mod where_expr_builder_tests; pub use error::ConversionError; pub(crate) use error::ConversionResult; +mod enriched_expr; +pub(crate) use enriched_expr::EnrichedExpr; + #[cfg(all(test, feature = "blitzar"))] mod query_expr_tests; @@ -22,5 +25,8 @@ pub(crate) use query_context::QueryContext; mod query_context_builder; pub(crate) use query_context_builder::{type_check_binary_operation, QueryContextBuilder}; +mod provable_expr_plan_builder; +pub(crate) use provable_expr_plan_builder::ProvableExprPlanBuilder; + mod where_expr_builder; pub(crate) use where_expr_builder::WhereExprBuilder; diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs new file mode 100644 index 000000000..0b08ad800 --- /dev/null +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -0,0 +1,142 @@ +use super::ConversionError; +use crate::{ + base::{ + commitment::Commitment, + database::{ColumnRef, LiteralValue}, + math::decimal::{try_into_to_scalar, Precision}, + }, + sql::ast::{ColumnExpr, ProvableExprPlan}, +}; +use proof_of_sql_parser::{ + intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, + Identifier, +}; +use std::collections::HashMap; + +/// Builder that enables building a `proofs::sql::ast::ProvableExprPlan` from +/// a `proof_of_sql_parser::intermediate_ast::Expression`. +pub struct ProvableExprPlanBuilder<'a> { + column_mapping: &'a HashMap, +} + +impl<'a> ProvableExprPlanBuilder<'a> { + /// Creates a new `ProvableExprPlanBuilder` with the given column mapping. + pub fn new(column_mapping: &'a HashMap) -> Self { + Self { column_mapping } + } + /// Builds a `proofs::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` + pub fn build( + &self, + expr: &Expression, + ) -> Result, ConversionError> { + self.visit_expr(expr) + } +} + +// Private interface +impl ProvableExprPlanBuilder<'_> { + fn visit_expr( + &self, + expr: &Expression, + ) -> Result, ConversionError> { + match expr { + Expression::Column(identifier) => self.visit_column(*identifier), + Expression::Literal(lit) => self.visit_literal(lit), + Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), + Expression::Unary { op, expr } => self.visit_unary_expr(*op, expr), + _ => Err(ConversionError::Unprovable(format!( + "Expression {:?} is not supported yet", + expr + ))), + } + } + + fn visit_column( + &self, + identifier: Identifier, + ) -> Result, ConversionError> { + Ok(ProvableExprPlan::Column(ColumnExpr::new( + *self.column_mapping.get(&identifier).ok_or( + ConversionError::MissingColumnWithoutTable(Box::new(identifier)), + )?, + ))) + } + + fn visit_literal( + &self, + lit: &Literal, + ) -> Result, ConversionError> { + match lit { + Literal::Boolean(b) => Ok(ProvableExprPlan::new_literal(LiteralValue::Boolean(*b))), + Literal::BigInt(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::BigInt(*i))), + Literal::Int128(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::Int128(*i))), + Literal::Decimal(d) => { + let scale = d.scale(); + let precision = Precision::new(d.precision()) + .map_err(|_| ConversionError::InvalidPrecision(d.precision()))?; + Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( + precision, + scale, + try_into_to_scalar(d, precision, scale)?, + ))) + } + Literal::VarChar(s) => Ok(ProvableExprPlan::new_literal(LiteralValue::VarChar(( + s.clone(), + s.into(), + )))), + } + } + + fn visit_unary_expr( + &self, + op: UnaryOperator, + expr: &Expression, + ) -> Result, ConversionError> { + let expr = self.visit_expr(expr); + match op { + UnaryOperator::Not => ProvableExprPlan::try_new_not(expr?), + } + } + + fn visit_binary_expr( + &self, + op: BinaryOperator, + left: &Expression, + right: &Expression, + ) -> Result, ConversionError> { + match op { + BinaryOperator::And => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_and(left?, right?) + } + BinaryOperator::Or => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_or(left?, right?) + } + BinaryOperator::Equal => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_equals(left?, right?) + } + BinaryOperator::GreaterThanOrEqual => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_inequality(left?, right?, false) + } + BinaryOperator::LessThanOrEqual => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_inequality(left?, right?, true) + } + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Multiply + | BinaryOperator::Division => Err(ConversionError::Unprovable(format!( + "Binary operator {:?} is not supported yet", + op + ))), + } + } +} diff --git a/crates/proof-of-sql/src/sql/parse/query_expr.rs b/crates/proof-of-sql/src/sql/parse/query_expr.rs index 150c5cd3c..b15868ea1 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr.rs @@ -1,4 +1,4 @@ -use super::{FilterExprBuilder, QueryContextBuilder, ResultExprBuilder}; +use super::{EnrichedExpr, FilterExprBuilder, QueryContextBuilder, ResultExprBuilder}; use crate::{ base::{commitment::Commitment, database::SchemaAccessor}, sql::{ @@ -72,15 +72,23 @@ impl QueryExpr { }); } } - + let column_mapping = context.get_column_mapping(); + let enriched_exprs = result_aliased_exprs + .iter() + .map(|aliased_expr| EnrichedExpr::new(aliased_expr.clone(), column_mapping.clone())) + .collect::>(); + let select_exprs = enriched_exprs + .iter() + .map(|enriched_expr| enriched_expr.residue_expression.clone()) + .collect::>(); let filter = FilterExprBuilder::new(context.get_column_mapping()) .add_table_expr(*context.get_table_ref()) .add_where_expr(context.get_where_expr().clone())? - .add_result_column_set(context.get_result_column_set()) + .add_result_columns(&enriched_exprs) .build(); let result = ResultExprBuilder::default() - .add_group_by_exprs(context.get_group_by_exprs(), result_aliased_exprs) - .add_select_exprs(result_aliased_exprs) + .add_group_by_exprs(context.get_group_by_exprs(), &select_exprs) + .add_select_exprs(&select_exprs) .add_order_by_exprs(context.get_order_by_exprs()?) .add_slice_expr(context.get_slice_expr()) .build(); diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index efd94e59b..85ebfdded 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -126,8 +126,8 @@ macro_rules! expected_query { orderby_macro!($($order_by)?, $($order_dirs)?); macro_rules! filter_macro { - () => {dense_filter(cols_expr(t, &$result_columns, &accessor), tab(t), const_bool(true))}; - ($expr:expr) => { dense_filter(cols_expr(t, &$result_columns, &accessor), tab(t), $expr) }; + () => {dense_filter(cols_expr_plan(t, &$result_columns, &accessor), tab(t), const_bool(true))}; + ($expr:expr) => { dense_filter(cols_expr_plan(t, &$result_columns, &accessor), tab(t), $expr) }; } let filter = filter_macro!($($filter)?); @@ -148,7 +148,7 @@ fn we_can_convert_an_ast_with_one_column() { let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -170,7 +170,7 @@ fn we_can_convert_an_ast_with_one_column_and_i128_data() { let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3_i64)), ), @@ -192,7 +192,7 @@ fn we_can_convert_an_ast_with_one_column_and_a_filter_by_a_string_literal() { let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 'abc'", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_varchar("abc")), ), @@ -243,11 +243,11 @@ fn we_dont_have_duplicate_filter_result_expressions() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + aliased_cols_expr_plan(t, &[("a", "b"), ("a", "c")], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), - result(&[("a", "b"), ("a", "c")]), + result(&[("b", "b"), ("c", "c")]), ); assert_eq!(ast, expected_ast); } @@ -267,7 +267,7 @@ fn we_can_convert_an_ast_with_two_columns() { let ast = query_to_provable_ast(t, "select a, b from sxt_tab where c = 123", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["a", "b"], &accessor), tab(t), equal(column(t, "c", &accessor), const_bigint(123)), ), @@ -290,7 +290,7 @@ fn we_can_parse_all_result_columns_with_select_star() { let ast = query_to_provable_ast(t, "select * from sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["b", "a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -313,7 +313,7 @@ fn we_can_convert_an_ast_with_one_positive_cond() { let ast = query_to_provable_ast(t, "select a from sxt_tab where b = +4", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "b", &accessor), const_bigint(4)), ), @@ -336,7 +336,7 @@ fn we_can_convert_an_ast_with_one_not_equals_cond() { let ast = query_to_provable_ast(t, "select a from sxt_tab where b <> +4", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), not(equal(column(t, "b", &accessor), const_bigint(4))), ), @@ -359,7 +359,7 @@ fn we_can_convert_an_ast_with_one_negative_cond() { let ast = query_to_provable_ast(t, "select a from sxt_tab where b <= -4", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), lte(column(t, "b", &accessor), const_bigint(-4)), ), @@ -387,7 +387,7 @@ fn we_can_convert_an_ast_with_cond_and() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), and( equal(column(t, "b", &accessor), const_bigint(3)), @@ -418,7 +418,7 @@ fn we_can_convert_an_ast_with_cond_or() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), or( equal(column(t, "b", &accessor), const_bigint(3)), @@ -449,7 +449,7 @@ fn we_can_convert_an_ast_with_conds_or_not() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), or( lte(column(t, "b", &accessor), const_bigint(3)), @@ -481,7 +481,7 @@ fn we_can_convert_an_ast_with_conds_not_and_or() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), not(and( or( @@ -513,7 +513,7 @@ fn we_can_convert_an_ast_with_the_min_i128_filter_value() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(i128::MIN)), ), @@ -539,7 +539,7 @@ fn we_can_convert_an_ast_with_the_max_i128_filter_value() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_int128(i128::MAX)), ), @@ -566,11 +566,11 @@ fn we_can_convert_an_ast_using_an_aliased_column() { ); let expected_ast = QueryExpr::new( dense_filter( - vec![col_expr(t, "a", &accessor)], + vec![aliased_col_expr_plan(t, "a", "b_rename", &accessor)], tab(t), gte(column(t, "b", &accessor), const_bigint(4)), ), - result(&[("a", "b_rename")]), + result(&[("b_rename", "b_rename")]), ); assert_eq!(ast, expected_ast); } @@ -614,7 +614,7 @@ fn we_can_convert_an_ast_with_a_schema() { let ast = query_to_provable_ast(t, "select a from eth.sxt_tab where a = 3", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -634,7 +634,11 @@ fn we_can_convert_an_ast_without_any_dense_filter() { 0, ); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), result(&[("a", "a")]), ); let queries = ["select * from eth.sxt_tab", "select a from eth.sxt_tab"]; @@ -661,7 +665,7 @@ fn we_can_parse_order_by_with_a_single_column() { let ast = query_to_provable_ast(t, "select * from sxt_tab where a = 3 order by b", &accessor); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["b", "a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -691,7 +695,7 @@ fn we_can_parse_order_by_with_multiple_columns() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b"], &accessor), + cols_expr_plan(t, &["a", "b"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(3)), ), @@ -723,14 +727,14 @@ fn we_can_parse_order_by_referencing_an_alias_associated_with_column_b_but_with_ let expected_ast = QueryExpr::new( dense_filter( vec![ - col_expr(t, "name", &accessor), - col_expr(t, "salary", &accessor), + aliased_col_expr_plan(t, "salary", "s", &accessor), + aliased_col_expr_plan(t, "name", "salary", &accessor), ], tab(t), equal(column(t, "salary", &accessor), const_bigint(5)), ), composite_result(vec![ - select(&[pc("salary").alias("s"), pc("name").alias("salary")]), + select(&[pc("s").alias("s"), pc("salary").alias("salary")]), orders(&["salary"], &[Desc]), ]), ); @@ -836,17 +840,18 @@ fn we_can_parse_order_by_queries_with_the_same_column_name_appearing_more_than_o let expected_ast = QueryExpr::new( dense_filter( vec![ - col_expr(t, "name", &accessor), - col_expr(t, "salary", &accessor), + aliased_col_expr_plan(t, "salary", "s", &accessor), + col_expr_plan(t, "name", &accessor), + aliased_col_expr_plan(t, "salary", "d", &accessor), ], tab(t), const_bool(true), ), composite_result(vec![ select(&[ - pc("salary").alias("s"), + pc("s").alias("s"), pc("name").alias("name"), - pc("salary").alias("d"), + pc("d").alias("d"), ]), orders(&[order_by], &[Asc]), ]), @@ -872,7 +877,11 @@ fn we_can_parse_a_query_having_a_simple_limit_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab limit 3", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(3, 0)]), ); assert_eq!(ast, expected_ast); @@ -891,7 +900,11 @@ fn no_slice_is_applied_when_limit_is_u64_max_and_offset_is_zero() { let ast = query_to_provable_ast(t, "select a from sxt_tab offset 0", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")])]), ); assert_eq!(ast, expected_ast); @@ -910,7 +923,11 @@ fn we_can_parse_a_query_having_a_simple_positive_offset_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab offset 7", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(u64::MAX, 7)]), ); assert_eq!(ast, expected_ast); @@ -929,7 +946,11 @@ fn we_can_parse_a_query_having_a_negative_offset_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab offset -7", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(u64::MAX, -7)]), ); assert_eq!(ast, expected_ast); @@ -948,7 +969,11 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause() { let ast = query_to_provable_ast(t, "select a from sxt_tab limit 55 offset 3", &accessor); let expected_ast = QueryExpr::new( - dense_filter(cols_expr(t, &["a"], &accessor), tab(t), const_bool(true)), + dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ), composite_result(vec![select(&[pc("a").alias("a")]), slice(55, 3)]), ); assert_eq!(ast, expected_ast); @@ -976,7 +1001,7 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause_preceded_by_wher ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a"], &accessor), + cols_expr_plan(t, &["a"], &accessor), tab(t), equal(column(t, "a", &accessor), const_bigint(-3)), ), @@ -1171,7 +1196,7 @@ fn we_can_group_by_without_using_aggregate_functions() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["department"], &accessor), + cols_expr_plan(t, &["department"], &accessor), tab(t), const_bool(true), ), @@ -1328,7 +1353,7 @@ fn we_can_parse_a_query_having_group_by_with_the_same_name_as_the_aggregation_ex ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["bonus", "department"], &accessor), + cols_expr_plan(t, &["bonus", "department"], &accessor), tab(t), const_bool(true), ), @@ -1362,7 +1387,7 @@ fn count_aggregate_functions_can_be_used_with_non_numeric_columns() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["bonus", "department"], &accessor), + cols_expr_plan(t, &["bonus", "department"], &accessor), tab(t), const_bool(true), ), @@ -1400,7 +1425,7 @@ fn count_all_uses_the_first_group_by_identifier_as_default_result_column() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["department"], &accessor), + cols_expr_plan(t, &["department", "salary"], &accessor), tab(t), equal(column(t, "salary", &accessor), const_bigint(4)), ), @@ -1454,7 +1479,7 @@ fn we_can_use_the_same_result_columns_with_different_aliases_and_associate_it_wi ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["department"], &accessor), + cols_expr_plan(t, &["department"], &accessor), tab(t), const_bool(true), ), @@ -1516,7 +1541,7 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["a", "b", "f", "h"], &accessor), + cols_expr_plan(t, &["a", "b", "f", "h"], &accessor), tab(t), const_bool(true), ), @@ -1554,7 +1579,7 @@ fn we_can_parse_multiple_arithmetic_expression_where_multiplication_has_preceden ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["c", "f", "g", "h"], &accessor), + cols_expr_plan(t, &["c", "f", "g", "h"], &accessor), tab(t), const_bool(true), ), @@ -1587,7 +1612,7 @@ fn we_can_parse_arithmetic_expression_within_aggregations_in_the_result_expr() { ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr(t, &["c", "f"], &accessor), + cols_expr_plan(t, &["c", "f"], &accessor), tab(t), const_bool(true), ), diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs index f40d88f4a..c46a46785 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs @@ -1,27 +1,25 @@ -use super::ConversionError; +use super::{ConversionError, ProvableExprPlanBuilder}; use crate::{ base::{ commitment::Commitment, - database::{ColumnRef, ColumnType, LiteralValue}, - math::decimal::{try_into_to_scalar, Precision}, + database::{ColumnRef, ColumnType}, }, - sql::ast::{ColumnExpr, ProvableExpr, ProvableExprPlan}, -}; -use proof_of_sql_parser::{ - intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, - Identifier, + sql::ast::{ProvableExpr, ProvableExprPlan}, }; +use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; use std::collections::HashMap; /// Builder that enables building a `proof_of_sql::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` that is /// intended to be used as the where clause in a filter expression or group by expression. pub struct WhereExprBuilder<'a> { - column_mapping: &'a HashMap, + builder: ProvableExprPlanBuilder<'a>, } impl<'a> WhereExprBuilder<'a> { /// Creates a new `WhereExprBuilder` with the given column mapping. pub fn new(column_mapping: &'a HashMap) -> Self { - Self { column_mapping } + Self { + builder: ProvableExprPlanBuilder::new(column_mapping), + } } /// Builds a `proof_of_sql::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` that is /// intended to be used as the where clause in a filter expression or group by expression. @@ -31,7 +29,7 @@ impl<'a> WhereExprBuilder<'a> { ) -> Result>, ConversionError> { where_expr .map(|where_expr| { - let expr_plan = self.visit_expr(*where_expr)?; + let expr_plan = self.builder.build(&where_expr)?; // Ensure that the expression is a boolean expression match expr_plan.data_type() { ColumnType::Boolean => Ok(expr_plan), @@ -43,108 +41,3 @@ impl<'a> WhereExprBuilder<'a> { .transpose() } } - -// Private interface -impl WhereExprBuilder<'_> { - fn visit_expr( - &self, - expr: proof_of_sql_parser::intermediate_ast::Expression, - ) -> Result, ConversionError> { - match expr { - Expression::Column(identifier) => self.visit_column(identifier), - Expression::Literal(lit) => self.visit_literal(lit), - Expression::Binary { op, left, right } => self.visit_binary_expr(op, *left, *right), - Expression::Unary { op, expr } => self.visit_unary_expr(op, *expr), - _ => panic!("The parser must ensure that the expression is a boolean expression"), - } - } - - fn visit_column( - &self, - identifier: Identifier, - ) -> Result, ConversionError> { - Ok(ProvableExprPlan::Column(ColumnExpr::new( - *self.column_mapping.get(&identifier).ok_or( - ConversionError::MissingColumnWithoutTable(Box::new(identifier)), - )?, - ))) - } - - fn visit_literal( - &self, - lit: Literal, - ) -> Result, ConversionError> { - match lit { - Literal::Boolean(b) => Ok(ProvableExprPlan::new_literal(LiteralValue::Boolean(b))), - Literal::BigInt(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::BigInt(i))), - Literal::Int128(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::Int128(i))), - Literal::Decimal(d) => { - let scale = d.scale(); - let precision = Precision::new(d.precision()) - .map_err(|_| ConversionError::InvalidPrecision(d.precision()))?; - Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( - precision, - scale, - try_into_to_scalar(&d, precision, scale)?, - ))) - } - Literal::VarChar(s) => Ok(ProvableExprPlan::new_literal(LiteralValue::VarChar(( - s.clone(), - s.into(), - )))), - } - } - - fn visit_unary_expr( - &self, - op: UnaryOperator, - expr: Expression, - ) -> Result, ConversionError> { - let expr = self.visit_expr(expr); - match op { - UnaryOperator::Not => ProvableExprPlan::try_new_not(expr?), - } - } - - fn visit_binary_expr( - &self, - op: BinaryOperator, - left: Expression, - right: Expression, - ) -> Result, ConversionError> { - match op { - BinaryOperator::And => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_and(left?, right?) - } - BinaryOperator::Or => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_or(left?, right?) - } - BinaryOperator::Equal => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_equals(left?, right?) - } - BinaryOperator::GreaterThanOrEqual => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_inequality(left?, right?, false) - } - BinaryOperator::LessThanOrEqual => { - let left = self.visit_expr(left); - let right = self.visit_expr(right); - ProvableExprPlan::try_new_inequality(left?, right?, true) - } - BinaryOperator::Add - | BinaryOperator::Subtract - | BinaryOperator::Multiply - | BinaryOperator::Division => Err(ConversionError::Unprovable(format!( - "Binary operator {:?} is not supported in the where clause", - op - ))), - } - } -}