diff --git a/crates/proof-of-sql/src/sql/ast/aliased_provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/aliased_provable_expr_plan.rs new file mode 100644 index 000000000..317c7bbe2 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/aliased_provable_expr_plan.rs @@ -0,0 +1,11 @@ +use super::ProvableExprPlan; +use crate::base::commitment::Commitment; +use proof_of_sql_parser::Identifier; +use serde::{Deserialize, Serialize}; + +/// A `ProvableExprPlan` with an alias. +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct AliasedProvableExprPlan { + pub expr: ProvableExprPlan, + pub alias: Identifier, +} diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs index 6cef2a89a..481983d9a 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs @@ -1,8 +1,6 @@ use super::{ dense_filter_util::{fold_columns, fold_vals}, - filter_columns, - provable_expr_plan::ProvableExprPlan, - ProvableExpr, TableExpr, + filter_columns, AliasedProvableExprPlan, ProvableExpr, ProvableExprPlan, TableExpr, }; use crate::{ base::{ @@ -23,7 +21,6 @@ use crate::{ use bumpalo::Bump; use core::iter::repeat_with; use num_traits::{One, Zero}; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, marker::PhantomData}; @@ -35,7 +32,7 @@ use std::{collections::HashSet, marker::PhantomData}; /// This differs from the [`FilterExpr`] in that the result is not a sparse table. #[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct OstensibleDenseFilterExpr { - pub(super) aliased_results: Vec<(ProvableExprPlan, Identifier)>, + pub(super) aliased_results: Vec>, pub(super) table: TableExpr, pub(super) where_clause: ProvableExprPlan, phantom: PhantomData, @@ -44,7 +41,7 @@ pub struct OstensibleDenseFilterExpr { impl OstensibleDenseFilterExpr { /// Creates a new dense_filter expression. pub fn new( - aliased_results: Vec<(ProvableExprPlan, Identifier)>, + aliased_results: Vec>, table: TableExpr, where_clause: ProvableExprPlan, ) -> Self { @@ -68,7 +65,7 @@ where ) -> Result<(), ProofError> { self.where_clause.count(builder)?; for aliased_expr in self.aliased_results.iter() { - aliased_expr.0.count(builder)?; + aliased_expr.expr.count(builder)?; builder.count_result_columns(1); } builder.count_intermediate_mles(2); @@ -99,7 +96,7 @@ where let columns_evals = Vec::from_iter( self.aliased_results .iter() - .map(|(expr, _)| expr.verifier_evaluate(builder, accessor)) + .map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor)) .collect::, _>>()?, ); // 3. indexes @@ -128,15 +125,15 @@ where fn get_column_result_fields(&self) -> Vec { self.aliased_results .iter() - .map(|(expr, alias)| ColumnField::new(*alias, expr.data_type())) + .map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())) .collect() } fn get_column_references(&self) -> HashSet { let mut columns = HashSet::new(); - for (col, _) in self.aliased_results.iter() { - col.get_column_references(&mut columns); + for aliased_expr in self.aliased_results.iter() { + aliased_expr.expr.get_column_references(&mut columns); } self.where_clause.get_column_references(&mut columns); @@ -165,11 +162,11 @@ impl ProverEvaluate for DenseFilterExpr { .expect("selection is not boolean"); // 2. columns - let columns = Vec::from_iter( - self.aliased_results - .iter() - .map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)), - ); + let columns = Vec::from_iter(self.aliased_results.iter().map(|aliased_expr| { + aliased_expr + .expr + .result_evaluate(builder.table_length(), alloc, accessor) + })); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); // 3. set indexes @@ -200,7 +197,7 @@ impl ProverEvaluate for DenseFilterExpr { let columns = Vec::from_iter( self.aliased_results .iter() - .map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)), + .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), ); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); diff --git a/crates/proof-of-sql/src/sql/ast/group_by_expr.rs b/crates/proof-of-sql/src/sql/ast/group_by_expr.rs index d4b5b92d6..bfa56b890 100644 --- a/crates/proof-of-sql/src/sql/ast/group_by_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/group_by_expr.rs @@ -1,8 +1,7 @@ use super::{ aggregate_columns, fold_columns, fold_vals, group_by_util::{compare_indexes_by_owned_columns, AggregatedColumns}, - provable_expr_plan::ProvableExprPlan, - ColumnExpr, ProvableExpr, TableExpr, + AliasedProvableExprPlan, ProvableExpr, ProvableExprPlan, TableExpr, }; use crate::{ base::{ @@ -29,19 +28,19 @@ use std::collections::HashSet; /// Provable expressions for queries of the form /// ```ignore -/// SELECT , ..., , +/// SELECT .0 as .1, ..., .0 as .1, /// SUM(.0) as .1, ..., SUM(.0) as .1, /// COUNT(*) as count_alias /// FROM /// WHERE -/// GROUP BY , ..., +/// GROUP BY .0, ..., .0 /// ``` /// /// Note: if `group_by_exprs` is empty, then the query is equivalent to removing the `GROUP BY` clause. #[derive(Debug, PartialEq, Serialize, Deserialize)] pub struct GroupByExpr { - pub(super) group_by_exprs: Vec>, - pub(super) sum_expr: Vec<(ColumnExpr, ColumnField)>, + pub(super) group_by_exprs: Vec>, + pub(super) sum_expr: Vec>, pub(super) count_alias: Identifier, pub(super) table: TableExpr, pub(super) where_clause: ProvableExprPlan, @@ -50,8 +49,8 @@ pub struct GroupByExpr { impl GroupByExpr { /// Creates a new group_by expression. pub fn new( - group_by_exprs: Vec>, - sum_expr: Vec<(ColumnExpr, ColumnField)>, + group_by_exprs: Vec>, + sum_expr: Vec>, count_alias: Identifier, table: TableExpr, where_clause: ProvableExprPlan, @@ -73,12 +72,12 @@ impl ProofExpr for GroupByExpr { _accessor: &dyn MetadataAccessor, ) -> Result<(), ProofError> { self.where_clause.count(builder)?; - for expr in self.group_by_exprs.iter() { - expr.count(builder)?; + for aliased_expr in self.group_by_exprs.iter() { + aliased_expr.expr.count(builder)?; builder.count_result_columns(1); } - for expr in self.sum_expr.iter() { - expr.0.count(builder)?; + for aliased_expr in self.sum_expr.iter() { + aliased_expr.expr.count(builder)?; builder.count_result_columns(1); } builder.count_result_columns(1); @@ -110,12 +109,12 @@ impl ProofExpr for GroupByExpr { let group_by_evals = self .group_by_exprs .iter() - .map(|expr| expr.verifier_evaluate(builder, accessor)) + .map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor)) .collect::, _>>()?; let aggregate_evals = self .sum_expr .iter() - .map(|expr| expr.0.verifier_evaluate(builder, accessor)) + .map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor)) .collect::, _>>()?; // 3. indexes let indexes_eval = builder @@ -150,7 +149,7 @@ impl ProofExpr for GroupByExpr { let cols = self .group_by_exprs .iter() - .map(|col| table.inner_table().get(&col.column_id())) + .map(|aliased_expr| table.inner_table().get(&aliased_expr.alias)) .collect::>>() .ok_or(ProofError::VerificationError( "Result does not all correct group by columns.", @@ -169,25 +168,27 @@ impl ProofExpr for GroupByExpr { } fn get_column_result_fields(&self) -> Vec { - let mut fields = Vec::new(); - for col in self.group_by_exprs.iter() { - fields.push(col.get_column_field()); - } - for col in self.sum_expr.iter() { - fields.push(col.1); - } - fields.push(ColumnField::new(self.count_alias, ColumnType::BigInt)); - fields + self.group_by_exprs + .iter() + .map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())) + .chain(self.sum_expr.iter().map(|aliased_expr| { + ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()) + })) + .chain(std::iter::once(ColumnField::new( + self.count_alias, + ColumnType::BigInt, + ))) + .collect() } fn get_column_references(&self) -> HashSet { let mut columns = HashSet::new(); - for col in self.group_by_exprs.iter() { - columns.insert(col.get_column_reference()); + for aliased_expr in self.group_by_exprs.iter() { + aliased_expr.expr.get_column_references(&mut columns); } - for col in self.sum_expr.iter() { - columns.insert(col.0.get_column_reference()); + for aliased_expr in self.sum_expr.iter() { + aliased_expr.expr.get_column_references(&mut columns); } self.where_clause.get_column_references(&mut columns); @@ -214,13 +215,14 @@ impl ProverEvaluate for GroupByExpr { .expect("selection is not boolean"); // 2. columns - let group_by_columns = Vec::from_iter( - self.group_by_exprs - .iter() - .map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)), - ); - let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|expr| { - expr.0 + let group_by_columns = Vec::from_iter(self.group_by_exprs.iter().map(|aliased_expr| { + aliased_expr + .expr + .result_evaluate(builder.table_length(), alloc, accessor) + })); + let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|aliased_expr| { + aliased_expr + .expr .result_evaluate(builder.table_length(), alloc, accessor) })); // Compute filtered_columns and indexes @@ -262,12 +264,12 @@ impl ProverEvaluate for GroupByExpr { let group_by_columns = Vec::from_iter( self.group_by_exprs .iter() - .map(|expr| expr.prover_evaluate(builder, alloc, accessor)), + .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), ); let sum_columns = Vec::from_iter( self.sum_expr .iter() - .map(|expr| expr.0.prover_evaluate(builder, alloc, accessor)), + .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), ); // Compute filtered_columns and indexes let AggregatedColumns { diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index 9dbf0167a..7d7f18600 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -24,6 +24,9 @@ mod bitwise_verification_test; mod provable_expr_plan; pub(crate) use provable_expr_plan::ProvableExprPlan; +mod aliased_provable_expr_plan; +pub(crate) use aliased_provable_expr_plan::AliasedProvableExprPlan; + mod provable_expr; pub(crate) use provable_expr::ProvableExpr; #[cfg(all(test, feature = "blitzar"))] diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index d8bc6ddf7..110a39065 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -145,29 +145,29 @@ pub fn aliased_col_expr_plan( old_name: &str, new_name: &str, accessor: &impl SchemaAccessor, -) -> (ProvableExprPlan, Identifier) { - ( - ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, old_name, accessor))), - new_name.parse().unwrap(), - ) +) -> AliasedProvableExprPlan { + AliasedProvableExprPlan { + expr: ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, old_name, accessor))), + alias: new_name.parse().unwrap(), + } } pub fn col_expr_plan( tab: TableRef, name: &str, accessor: &impl SchemaAccessor, -) -> (ProvableExprPlan, Identifier) { - ( - ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, name, accessor))), - name.parse().unwrap(), - ) +) -> AliasedProvableExprPlan { + AliasedProvableExprPlan { + expr: ProvableExprPlan::Column(ColumnExpr::::new(col_ref(tab, name, accessor))), + alias: name.parse().unwrap(), + } } pub fn aliased_cols_expr_plan( tab: TableRef, names: &[(&str, &str)], accessor: &impl SchemaAccessor, -) -> Vec<(ProvableExprPlan, Identifier)> { +) -> Vec> { names .iter() .map(|(old_name, new_name)| aliased_col_expr_plan(tab, old_name, new_name, accessor)) @@ -178,7 +178,7 @@ pub fn cols_expr_plan( tab: TableRef, names: &[&str], accessor: &impl SchemaAccessor, -) -> Vec<(ProvableExprPlan, Identifier)> { +) -> Vec> { names .iter() .map(|name| col_expr_plan(tab, name, accessor)) @@ -205,43 +205,16 @@ pub fn cols_expr( } pub fn dense_filter( - results: Vec<(ProvableExprPlan, Identifier)>, + results: Vec>, table: TableExpr, where_clause: ProvableExprPlan, ) -> ProofPlan { ProofPlan::DenseFilter(DenseFilterExpr::new(results, table, where_clause)) } -pub fn sum_expr( - tab: TableRef, - name: &str, - alias: &str, - column_type: ColumnType, - accessor: &impl SchemaAccessor, -) -> (ColumnExpr, ColumnField) { - ( - col_expr(tab, name, accessor), - ColumnField::new(alias.parse().unwrap(), column_type), - ) -} - -pub fn sums_expr( - tab: TableRef, - names: &[&str], - aliases: &[&str], - column_types: &[ColumnType], - accessor: &impl SchemaAccessor, -) -> Vec<(ColumnExpr, ColumnField)> { - names - .iter() - .zip(aliases.iter().zip(column_types.iter())) - .map(|(name, (alias, column_type))| sum_expr(tab, name, alias, *column_type, accessor)) - .collect() -} - pub fn group_by( - group_by_exprs: Vec>, - sum_expr: Vec<(ColumnExpr, ColumnField)>, + group_by_exprs: Vec>, + sum_expr: Vec>, count_alias: &str, table: TableExpr, where_clause: ProvableExprPlan, diff --git a/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs index 893251f6a..56f288420 100644 --- a/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/filter_expr_builder.rs @@ -4,7 +4,7 @@ use crate::{ commitment::Commitment, database::{ColumnRef, LiteralValue, TableRef}, }, - sql::ast::{DenseFilterExpr, ProvableExprPlan, TableExpr}, + sql::ast::{AliasedProvableExprPlan, DenseFilterExpr, ProvableExprPlan, TableExpr}, }; use itertools::Itertools; use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; @@ -13,7 +13,7 @@ use std::collections::HashMap; pub struct FilterExprBuilder { table_expr: Option, where_expr: Option>, - filter_result_expr_list: Vec<(ProvableExprPlan, Identifier)>, + filter_result_expr_list: Vec>, column_mapping: HashMap, } @@ -47,8 +47,10 @@ impl FilterExprBuilder { let mut has_nonprovable_column = false; for enriched_expr in columns { if let Some(plan) = &enriched_expr.provable_expr_plan { - self.filter_result_expr_list - .push((plan.clone(), enriched_expr.residue_expression.alias)); + self.filter_result_expr_list.push(AliasedProvableExprPlan { + expr: plan.clone(), + alias: enriched_expr.residue_expression.alias, + }); } else { has_nonprovable_column = true; } @@ -57,8 +59,10 @@ impl FilterExprBuilder { // Has to keep them sorted to have deterministic order for tests for alias in self.column_mapping.keys().sorted() { let column_ref = self.column_mapping.get(alias).unwrap(); - self.filter_result_expr_list - .push((ProvableExprPlan::new_column(*column_ref), *alias)); + self.filter_result_expr_list.push(AliasedProvableExprPlan { + expr: ProvableExprPlan::new_column(*column_ref), + alias: *alias, + }); } } self