Skip to content

Commit

Permalink
feat: add aggregation-related intermediate AST conversions
Browse files Browse the repository at this point in the history
- add `AggregateExpr` which allows expressions such as `sum(a+b)` to become `ProvableExprPlan`
- allow `Expression` -> `ProvableExprPlan` for aggregations
  • Loading branch information
iajoiner committed Jun 27, 2024
1 parent 026019a commit 22d4c40
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 8 deletions.
2 changes: 1 addition & 1 deletion crates/proof-of-sql-parser/src/intermediate_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub enum UnaryOperator {
}

// Aggregation operators
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Copy)]
/// Aggregation operators
pub enum AggregationOperator {
/// Maximum
Expand Down
75 changes: 75 additions & 0 deletions crates/proof-of-sql/src/sql/ast/aggregate_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::{ProvableExpr, ProvableExprPlan};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use proof_of_sql_parser::intermediate_ast::AggregationOperator;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable aggregate expression
///
/// Currently it doesn't do much since aggregation logic is implemented elsewhere
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AggregateExpr<C: Commitment> {
op: AggregationOperator,
expr: Box<ProvableExprPlan<C>>,
}

impl<C: Commitment> AggregateExpr<C> {
/// Create a new aggregate expression
pub fn new(op: AggregationOperator, expr: Box<ProvableExprPlan<C>>) -> Self {
Self { op, expr }
}
}

impl<C: Commitment> ProvableExpr<C> for AggregateExpr<C> {
fn count(&self, _builder: &mut CountBuilder) -> Result<(), ProofError> {
Ok(())
}

fn data_type(&self) -> ColumnType {
match self.op {
AggregationOperator::Count => ColumnType::BigInt,
AggregationOperator::Sum => self.expr.data_type(),
_ => todo!("Aggregation operator not supported here yet"),
}
}

#[tracing::instrument(name = "AggregateExpr::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
self.expr.result_evaluate(table_length, alloc, accessor)
}

#[tracing::instrument(name = "AggregateExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
self.expr.prover_evaluate(builder, alloc, accessor)
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
self.expr.verifier_evaluate(builder, accessor)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.expr.get_column_references(columns)
}
}
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub(crate) use add_subtract_expr::AddSubtractExpr;
#[cfg(all(test, feature = "blitzar"))]
mod add_subtract_expr_test;

mod aggregate_expr;
pub(crate) use aggregate_expr::AggregateExpr;

mod multiply_expr;
use multiply_expr::MultiplyExpr;
#[cfg(all(test, feature = "blitzar"))]
Expand Down
25 changes: 22 additions & 3 deletions crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr,
NotExpr, OrExpr, ProvableExpr,
AddSubtractExpr, AggregateExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr,
MultiplyExpr, NotExpr, OrExpr, ProvableExpr,
};
use crate::{
base::{
Expand All @@ -14,7 +14,7 @@ use crate::{
},
};
use bumpalo::Bump;
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator};
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt::Debug};

Expand All @@ -39,6 +39,8 @@ pub enum ProvableExprPlan<C: Commitment> {
AddSubtract(AddSubtractExpr<C>),
/// Provable numeric `*` expression
Multiply(MultiplyExpr<C>),
/// Provable aggregate expression
Aggregate(AggregateExpr<C>),
}
impl<C: Commitment> ProvableExprPlan<C> {
/// Create column expression
Expand Down Expand Up @@ -176,6 +178,11 @@ impl<C: Commitment> ProvableExprPlan<C> {
}
}

/// Create a new aggregate expression
pub fn new_aggregate(op: AggregationOperator, expr: ProvableExprPlan<C>) -> Self {
Self::Aggregate(AggregateExpr::new(op, Box::new(expr)))
}

/// Check that the plan has the correct data type
fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> {
if self.data_type() == data_type {
Expand All @@ -201,6 +208,7 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::AddSubtract(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Multiply(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Aggregate(expr) => ProvableExpr::<C>::count(expr, builder),
}
}

Expand All @@ -209,6 +217,7 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Column(expr) => expr.data_type(),
ProvableExprPlan::AddSubtract(expr) => expr.data_type(),
ProvableExprPlan::Multiply(expr) => expr.data_type(),
ProvableExprPlan::Aggregate(expr) => expr.data_type(),
ProvableExprPlan::Literal(expr) => ProvableExpr::<C>::data_type(expr),
ProvableExprPlan::And(_)
| ProvableExprPlan::Or(_)
Expand Down Expand Up @@ -252,6 +261,9 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Multiply(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
ProvableExprPlan::Aggregate(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
}
}

Expand Down Expand Up @@ -289,6 +301,9 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Multiply(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
ProvableExprPlan::Aggregate(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
}
}

Expand All @@ -309,6 +324,7 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::AddSubtract(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Multiply(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Aggregate(expr) => expr.verifier_evaluate(builder, accessor),
}
}

Expand All @@ -335,6 +351,9 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Multiply(expr) => {
ProvableExpr::<C>::get_column_references(expr, columns)
}
ProvableExprPlan::Aggregate(expr) => {
ProvableExpr::<C>::get_column_references(expr, columns)
}
}
}
}
12 changes: 12 additions & 0 deletions crates/proof-of-sql/src/sql/ast/test_utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ pub fn dense_filter<C: Commitment>(
ProofPlan::DenseFilter(DenseFilterExpr::new(results, table, where_clause))
}

pub fn count<C: Commitment>(
expr: ProvableExprPlan<C>,
) -> ProvableExprPlan<C> {
ProvableExprPlan::new_aggregate(AggregationOperator::Count, expr)
}

pub fn sum<C: Commitment>(
expr: ProvableExprPlan<C>,
) -> ProvableExprPlan<C> {
ProvableExprPlan::new_aggregate(AggregationOperator::Sum, expr)
}

pub fn sum_expr<C: Commitment>(
tab: TableRef,
name: &str,
Expand Down
44 changes: 41 additions & 3 deletions crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use crate::{
math::decimal::{try_into_to_scalar, DecimalError::InvalidPrecision, Precision},
},
sql::{
ast::{ColumnExpr, ProvableExprPlan},
ast::{ColumnExpr, ProvableExpr, ProvableExprPlan},
parse::ConversionError::DecimalConversionError,
},
};
use proof_of_sql_parser::{
intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator},
intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal, UnaryOperator},
Identifier,
};
use std::collections::HashMap;
Expand All @@ -20,12 +20,23 @@ use std::collections::HashMap;
/// a `proof_of_sql_parser::intermediate_ast::Expression`.
pub struct ProvableExprPlanBuilder<'a> {
column_mapping: &'a HashMap<Identifier, ColumnRef>,
in_agg_scope: bool,
}

impl<'a> ProvableExprPlanBuilder<'a> {
/// Creates a new `ProvableExprPlanBuilder` with the given column mapping.
pub fn new(column_mapping: &'a HashMap<Identifier, ColumnRef>) -> Self {
Self { column_mapping }
Self {
column_mapping,
in_agg_scope: false,
}
}
/// Creates a new `ProvableExprPlanBuilder` with the given column mapping and within aggregation scope.
fn new_agg(column_mapping: &'a HashMap<Identifier, ColumnRef>) -> Self {
Self {
column_mapping,
in_agg_scope: true,
}
}
/// Builds a `proofs::sql::ast::ProvableExprPlan` from a `proof_of_sql_parser::intermediate_ast::Expression`
pub fn build<C: Commitment>(
Expand All @@ -47,6 +58,7 @@ impl ProvableExprPlanBuilder<'_> {
Expression::Literal(lit) => self.visit_literal(lit),
Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right),
Expression::Unary { op, expr } => self.visit_unary_expr(*op, expr),
Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr),
_ => Err(ConversionError::Unprovable(format!(
"Expression {:?} is not supported yet",
expr
Expand Down Expand Up @@ -155,4 +167,30 @@ impl ProvableExprPlanBuilder<'_> {
))),
}
}

fn visit_aggregate_expr<C: Commitment>(
&self,
op: AggregationOperator,
expr: &Expression,
) -> Result<ProvableExprPlan<C>, ConversionError> {
if self.in_agg_scope {
return Err(ConversionError::InvalidExpression(
"nested aggregations are invalid".to_string(),
));
}
let expr = ProvableExprPlanBuilder::new_agg(self.column_mapping).visit_expr(expr)?;
match (op, expr.data_type().is_numeric()) {
(AggregationOperator::Count, _) | (AggregationOperator::Sum, true) => {
Ok(ProvableExprPlan::new_aggregate(op, expr))
}
(AggregationOperator::Sum, false) => Err(ConversionError::InvalidExpression(format!(
"Aggregation operator {:?} doesn't work with non-numeric types",
op
))),
_ => Err(ConversionError::Unprovable(format!(
"Aggregation operator {:?} is not supported at this location",
op
))),
}
}
}
6 changes: 5 additions & 1 deletion crates/proof-of-sql/src/sql/parse/query_expr_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,11 @@ fn count_aggregate_functions_can_be_used_with_non_numeric_columns() {
);
let expected_ast = QueryExpr::new(
dense_filter(
cols_expr_plan(t, &["bonus", "department"], &accessor),
vec![
col_expr_plan(t, "department", &accessor),
(count(column(t, "bonus", &accessor)), "__count__".parse().unwrap()),
(count(column(t, "department", &accessor)), "dep".parse().unwrap()),
],
tab(t),
const_bool(true),
),
Expand Down

0 comments on commit 22d4c40

Please sign in to comment.