From 0831cdb06fcf1444f79337c5978f87d3766c1bdb Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Thu, 27 Jun 2024 13:14:56 -0400 Subject: [PATCH] feat: add aggregation-related intermediate AST conversions - add `AggregateExpr` which allows expressions such as `sum(a+b)` to become `ProvableExprPlan` - allow `Expression` -> `ProvableExprPlan` for aggregations --- .../src/intermediate_ast.rs | 2 +- .../src/sql/ast/aggregate_expr.rs | 75 +++++++++++++++++++ crates/proof-of-sql/src/sql/ast/mod.rs | 3 + .../src/sql/ast/provable_expr_plan.rs | 25 ++++++- .../src/sql/parse/enriched_expr.rs | 4 +- .../sql/parse/provable_expr_plan_builder.rs | 44 ++++++++++- 6 files changed, 145 insertions(+), 8 deletions(-) create mode 100644 crates/proof-of-sql/src/sql/ast/aggregate_expr.rs diff --git a/crates/proof-of-sql-parser/src/intermediate_ast.rs b/crates/proof-of-sql-parser/src/intermediate_ast.rs index 93df7b9a7..ddda7e1c4 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast.rs @@ -112,7 +112,7 @@ pub enum UnaryOperator { } // Aggregation operators -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Copy)] /// Aggregation operators pub enum AggregationOperator { /// Maximum diff --git a/crates/proof-of-sql/src/sql/ast/aggregate_expr.rs b/crates/proof-of-sql/src/sql/ast/aggregate_expr.rs new file mode 100644 index 000000000..8fb46be11 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/aggregate_expr.rs @@ -0,0 +1,75 @@ +use super::{ProvableExpr, ProvableExprPlan}; +use crate::{ + base::{ + commitment::Commitment, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + proof::ProofError, + }, + sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder}, +}; +use bumpalo::Bump; +use proof_of_sql_parser::intermediate_ast::AggregationOperator; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Provable aggregate expression +/// +/// Currently it doesn't do much since aggregation logic is implemented elsewhere +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct AggregateExpr { + op: AggregationOperator, + expr: Box>, +} + +impl AggregateExpr { + /// Create a new aggregate expression + pub fn new(op: AggregationOperator, expr: Box>) -> Self { + Self { op, expr } + } +} + +impl ProvableExpr for AggregateExpr { + fn count(&self, _builder: &mut CountBuilder) -> Result<(), ProofError> { + Ok(()) + } + + fn data_type(&self) -> ColumnType { + match self.op { + AggregationOperator::Count => ColumnType::BigInt, + AggregationOperator::Sum => self.expr.data_type(), + _ => todo!("Aggregation operator not supported here yet"), + } + } + + #[tracing::instrument(name = "AggregateExpr::result_evaluate", level = "debug", skip_all)] + fn result_evaluate<'a>( + &self, + table_length: usize, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + self.expr.result_evaluate(table_length, alloc, accessor) + } + + #[tracing::instrument(name = "AggregateExpr::prover_evaluate", level = "debug", skip_all)] + fn prover_evaluate<'a>( + &self, + builder: &mut ProofBuilder<'a, C::Scalar>, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + self.expr.prover_evaluate(builder, alloc, accessor) + } + + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &dyn CommitmentAccessor, + ) -> Result { + self.expr.verifier_evaluate(builder, accessor) + } + + fn get_column_references(&self, columns: &mut HashSet) { + self.expr.get_column_references(columns) + } +} diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index e9a44b273..45d75404e 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -10,6 +10,9 @@ pub(crate) use add_subtract_expr::AddSubtractExpr; #[cfg(all(test, feature = "blitzar"))] mod add_subtract_expr_test; +mod aggregate_expr; +pub(crate) use aggregate_expr::AggregateExpr; + mod multiply_expr; use multiply_expr::MultiplyExpr; #[cfg(all(test, feature = "blitzar"))] diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 5f83c4c67..e37a83950 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -1,6 +1,6 @@ use super::{ - AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr, - NotExpr, OrExpr, ProvableExpr, + AddSubtractExpr, AggregateExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, + MultiplyExpr, NotExpr, OrExpr, ProvableExpr, }; use crate::{ base::{ @@ -14,7 +14,7 @@ use crate::{ }, }; use bumpalo::Bump; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; +use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator}; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, fmt::Debug}; @@ -39,6 +39,8 @@ pub enum ProvableExprPlan { AddSubtract(AddSubtractExpr), /// Provable numeric `*` expression Multiply(MultiplyExpr), + /// Provable aggregate expression + Aggregate(AggregateExpr), } impl ProvableExprPlan { /// Create column expression @@ -176,6 +178,11 @@ impl ProvableExprPlan { } } + /// Create a new aggregate expression + pub fn new_aggregate(op: AggregationOperator, expr: ProvableExprPlan) -> Self { + Self::Aggregate(AggregateExpr::new(op, Box::new(expr))) + } + /// Check that the plan has the correct data type fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> { if self.data_type() == data_type { @@ -201,6 +208,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::AddSubtract(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Multiply(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Aggregate(expr) => ProvableExpr::::count(expr, builder), } } @@ -209,6 +217,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Column(expr) => expr.data_type(), ProvableExprPlan::AddSubtract(expr) => expr.data_type(), ProvableExprPlan::Multiply(expr) => expr.data_type(), + ProvableExprPlan::Aggregate(expr) => expr.data_type(), ProvableExprPlan::Literal(expr) => ProvableExpr::::data_type(expr), ProvableExprPlan::And(_) | ProvableExprPlan::Or(_) @@ -252,6 +261,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Multiply(expr) => { ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) } + ProvableExprPlan::Aggregate(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } } } @@ -289,6 +301,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Multiply(expr) => { ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) } + ProvableExprPlan::Aggregate(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } } } @@ -309,6 +324,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::AddSubtract(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Multiply(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Aggregate(expr) => expr.verifier_evaluate(builder, accessor), } } @@ -335,6 +351,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Multiply(expr) => { ProvableExpr::::get_column_references(expr, columns) } + ProvableExprPlan::Aggregate(expr) => { + ProvableExpr::::get_column_references(expr, columns) + } } } } diff --git a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs index d199f3086..c5c49368c 100644 --- a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs @@ -30,8 +30,10 @@ impl EnrichedExpr { expression: AliasedResultExpr, column_mapping: HashMap, ) -> Self { + // TODO: Using new_agg (ironically) disables aggregations in `QueryExpr` for now. + // Re-enable aggregations when we add `GroupByExpr` generalizations. let res_provable_expr_plan = - ProvableExprPlanBuilder::new(&column_mapping).build(&expression.expr); + ProvableExprPlanBuilder::new_agg(&column_mapping).build(&expression.expr); match res_provable_expr_plan { Ok(provable_expr_plan) => { let alias = expression.alias; diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index 271594615..e0e96c1be 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -6,12 +6,12 @@ use crate::{ math::decimal::{try_into_to_scalar, DecimalError::InvalidPrecision, Precision}, }, sql::{ - ast::{ColumnExpr, ProvableExprPlan}, + ast::{ColumnExpr, ProvableExpr, ProvableExprPlan}, parse::ConversionError::DecimalConversionError, }, }; use proof_of_sql_parser::{ - intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, + intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal, UnaryOperator}, Identifier, }; use std::collections::HashMap; @@ -20,12 +20,23 @@ use std::collections::HashMap; /// a `proof_of_sql_parser::intermediate_ast::Expression`. pub struct ProvableExprPlanBuilder<'a> { column_mapping: &'a HashMap, + in_agg_scope: bool, } impl<'a> ProvableExprPlanBuilder<'a> { /// Creates a new `ProvableExprPlanBuilder` with the given column mapping. pub fn new(column_mapping: &'a HashMap) -> Self { - Self { column_mapping } + Self { + column_mapping, + in_agg_scope: false, + } + } + /// Creates a new `ProvableExprPlanBuilder` with the given column mapping and within aggregation scope. + pub(crate) fn new_agg(column_mapping: &'a HashMap) -> Self { + Self { + column_mapping, + in_agg_scope: true, + } } /// Builds a `proofs::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression` pub fn build( @@ -47,6 +58,7 @@ impl ProvableExprPlanBuilder<'_> { Expression::Literal(lit) => self.visit_literal(lit), Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), Expression::Unary { op, expr } => self.visit_unary_expr(*op, expr), + Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr), _ => Err(ConversionError::Unprovable(format!( "Expression {:?} is not supported yet", expr @@ -155,4 +167,30 @@ impl ProvableExprPlanBuilder<'_> { ))), } } + + fn visit_aggregate_expr( + &self, + op: AggregationOperator, + expr: &Expression, + ) -> Result, ConversionError> { + if self.in_agg_scope { + return Err(ConversionError::InvalidExpression( + "nested aggregations are invalid".to_string(), + )); + } + let expr = ProvableExprPlanBuilder::new_agg(self.column_mapping).visit_expr(expr)?; + match (op, expr.data_type().is_numeric()) { + (AggregationOperator::Count, _) | (AggregationOperator::Sum, true) => { + Ok(ProvableExprPlan::new_aggregate(op, expr)) + } + (AggregationOperator::Sum, false) => Err(ConversionError::InvalidExpression(format!( + "Aggregation operator {:?} doesn't work with non-numeric types", + op + ))), + _ => Err(ConversionError::Unprovable(format!( + "Aggregation operator {:?} is not supported at this location", + op + ))), + } + } }