Skip to content

Commit

Permalink
feat: generalize GroupByExpr
Browse files Browse the repository at this point in the history
- add `AliasedProvableExprPlan` and `AggregateFunctionExpr`
- generalize `GroupByExpr`
  • Loading branch information
iajoiner committed Jun 27, 2024
1 parent f0b2003 commit 93d5290
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 202 deletions.
78 changes: 78 additions & 0 deletions crates/proof-of-sql/src/sql/ast/aggregate_function_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use super::ProvableExprPlan;
use crate::base::commitment::Commitment;
use proof_of_sql_parser::intermediate_ast::AggregationOperator;

use super::{ProvableExpr, ProvableExprPlan};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable aggregate function expression
///
/// Currently it doesn't do much since aggregation logic is implemented elsewhere
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AggregateFunctionExpr<C: Commitment> {
op: AggregationOperator,
expr: Box<ProvableExprPlan<C>>,
}

impl<C: Commitment> AggregateFunctionExpr<C> {
/// Create a new aggregate function expression
pub fn new(op: AggregationOperator, expr: Box<ProvableExprPlan<C>>) -> Self {
Self { op, exprs }
}
}

impl<C: Commitment> ProvableExpr<C> for NotExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
Ok(())
}

fn data_type(&self) -> ColumnType {
match self.op {
AggregationOperator::Count => ColumnType::BigInt,
AggregationOperator::Sum => self.expr.data_type(),
_ => todo!("Aggregation operator not supported here yet"),
}
}

#[tracing::instrument(name = "AggregateFunctionExpr::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
self.expr.result_evaluate(table_length, alloc, accessor)
}

#[tracing::instrument(name = "AggregateFunctionExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
self.expr.prover_evaluate(builder, alloc, accessor)
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
self.expr.verifier_evaluate(builder, accessor)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.expr.get_column_references(columns)
}
}
11 changes: 11 additions & 0 deletions crates/proof-of-sql/src/sql/ast/aliased_provable_expr_plan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use super::ProvableExprPlan;
use crate::base::commitment::Commitment;
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};

/// A `ProvableExprPlan` with an alias.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AliasedProvableExprPlan<C: Commitment> {
pub expr: ProvableExprPlan<C>,
pub alias: Identifier,
}
31 changes: 14 additions & 17 deletions crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::{
dense_filter_util::{fold_columns, fold_vals},
filter_columns,
provable_expr_plan::ProvableExprPlan,
ProvableExpr, TableExpr,
filter_columns, AliasedProvableExprPlan, ProvableExpr, ProvableExprPlan, TableExpr,
};
use crate::{
base::{
Expand All @@ -23,7 +21,6 @@ use crate::{
use bumpalo::Bump;
use core::iter::repeat_with;
use num_traits::{One, Zero};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, marker::PhantomData};

Expand All @@ -35,7 +32,7 @@ use std::{collections::HashSet, marker::PhantomData};
/// This differs from the [`FilterExpr`] in that the result is not a sparse table.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct OstensibleDenseFilterExpr<C: Commitment, H: ProverHonestyMarker> {
pub(super) aliased_results: Vec<(ProvableExprPlan<C>, Identifier)>,
pub(super) aliased_results: Vec<AliasedProvableExprPlan<C>>,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
phantom: PhantomData<H>,
Expand All @@ -44,7 +41,7 @@ pub struct OstensibleDenseFilterExpr<C: Commitment, H: ProverHonestyMarker> {
impl<C: Commitment, H: ProverHonestyMarker> OstensibleDenseFilterExpr<C, H> {
/// Creates a new dense_filter expression.
pub fn new(
aliased_results: Vec<(ProvableExprPlan<C>, Identifier)>,
aliased_results: Vec<AliasedProvableExprPlan<C>>,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
) -> Self {
Expand All @@ -68,7 +65,7 @@ where
) -> Result<(), ProofError> {
self.where_clause.count(builder)?;
for aliased_expr in self.aliased_results.iter() {
aliased_expr.0.count(builder)?;
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
builder.count_intermediate_mles(2);
Expand Down Expand Up @@ -99,7 +96,7 @@ where
let columns_evals = Vec::from_iter(
self.aliased_results
.iter()
.map(|(expr, _)| expr.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?,
);
// 3. indexes
Expand Down Expand Up @@ -128,15 +125,15 @@ where
fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.aliased_results
.iter()
.map(|(expr, alias)| ColumnField::new(*alias, expr.data_type()))
.map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()))
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
let mut columns = HashSet::new();

for (col, _) in self.aliased_results.iter() {
col.get_column_references(&mut columns);
for aliased_expr in self.aliased_results.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}

self.where_clause.get_column_references(&mut columns);
Expand Down Expand Up @@ -165,11 +162,11 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for DenseFilterExpr<C> {
.expect("selection is not boolean");

// 2. columns
let columns = Vec::from_iter(
self.aliased_results
.iter()
.map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
let columns = Vec::from_iter(self.aliased_results.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
// 3. set indexes
Expand Down Expand Up @@ -200,7 +197,7 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for DenseFilterExpr<C> {
let columns = Vec::from_iter(
self.aliased_results
.iter()
.map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down
76 changes: 39 additions & 37 deletions crates/proof-of-sql/src/sql/ast/group_by_expr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::{
aggregate_columns, fold_columns, fold_vals,
group_by_util::{compare_indexes_by_owned_columns, AggregatedColumns},
provable_expr_plan::ProvableExprPlan,
ColumnExpr, ProvableExpr, TableExpr,
AliasedProvableExprPlan, ProvableExpr, ProvableExprPlan, TableExpr,
};
use crate::{
base::{
Expand All @@ -29,19 +28,19 @@ use std::collections::HashSet;

/// Provable expressions for queries of the form
/// ```ignore
/// SELECT <group_by_expr1>, ..., <group_by_exprM>,
/// SELECT <group_by_expr1>.0 as <group_by_expr1>.1, ..., <group_by_exprM>.0 as <group_by_exprM>.1,
/// SUM(<sum_expr1>.0) as <sum_expr1>.1, ..., SUM(<sum_exprN>.0) as <sum_exprN>.1,
/// COUNT(*) as count_alias
/// FROM <table>
/// WHERE <where_clause>
/// GROUP BY <group_by_expr1>, ..., <group_by_exprM>
/// GROUP BY <group_by_expr1>.0, ..., <group_by_exprM>.0
/// ```
///
/// Note: if `group_by_exprs` is empty, then the query is equivalent to removing the `GROUP BY` clause.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct GroupByExpr<C: Commitment> {
pub(super) group_by_exprs: Vec<ColumnExpr<C>>,
pub(super) sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
pub(super) group_by_exprs: Vec<AliasedProvableExprPlan<C>>,
pub(super) sum_expr: Vec<AliasedProvableExprPlan<C>>,
pub(super) count_alias: Identifier,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
Expand All @@ -50,8 +49,8 @@ pub struct GroupByExpr<C: Commitment> {
impl<C: Commitment> GroupByExpr<C> {
/// Creates a new group_by expression.
pub fn new(
group_by_exprs: Vec<ColumnExpr<C>>,
sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
group_by_exprs: Vec<AliasedProvableExprPlan<C>>,
sum_expr: Vec<AliasedProvableExprPlan<C>>,
count_alias: Identifier,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
Expand All @@ -73,12 +72,12 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
_accessor: &dyn MetadataAccessor,
) -> Result<(), ProofError> {
self.where_clause.count(builder)?;
for expr in self.group_by_exprs.iter() {
expr.count(builder)?;
for aliased_expr in self.group_by_exprs.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
for expr in self.sum_expr.iter() {
expr.0.count(builder)?;
for aliased_expr in self.sum_expr.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
builder.count_result_columns(1);
Expand Down Expand Up @@ -110,12 +109,12 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
let group_by_evals = self
.group_by_exprs
.iter()
.map(|expr| expr.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
let aggregate_evals = self
.sum_expr
.iter()
.map(|expr| expr.0.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
// 3. indexes
let indexes_eval = builder
Expand Down Expand Up @@ -150,7 +149,7 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
let cols = self
.group_by_exprs
.iter()
.map(|col| table.inner_table().get(&col.column_id()))
.map(|aliased_expr| table.inner_table().get(&aliased_expr.alias))
.collect::<Option<Vec<_>>>()
.ok_or(ProofError::VerificationError(
"Result does not all correct group by columns.",
Expand All @@ -169,25 +168,27 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
let mut fields = Vec::new();
for col in self.group_by_exprs.iter() {
fields.push(col.get_column_field());
}
for col in self.sum_expr.iter() {
fields.push(col.1);
}
fields.push(ColumnField::new(self.count_alias, ColumnType::BigInt));
fields
self.group_by_exprs
.iter()
.map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()))
.chain(self.sum_expr.iter().map(|aliased_expr| {
ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())
}))
.chain(std::iter::once(ColumnField::new(
self.count_alias,
ColumnType::BigInt,
)))
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
let mut columns = HashSet::new();

for col in self.group_by_exprs.iter() {
columns.insert(col.get_column_reference());
for aliased_expr in self.group_by_exprs.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}
for col in self.sum_expr.iter() {
columns.insert(col.0.get_column_reference());
for aliased_expr in self.sum_expr.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}

self.where_clause.get_column_references(&mut columns);
Expand All @@ -214,13 +215,14 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
.expect("selection is not boolean");

// 2. columns
let group_by_columns = Vec::from_iter(
self.group_by_exprs
.iter()
.map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|expr| {
expr.0
let group_by_columns = Vec::from_iter(self.group_by_exprs.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
// Compute filtered_columns and indexes
Expand Down Expand Up @@ -262,12 +264,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
let group_by_columns = Vec::from_iter(
self.group_by_exprs
.iter()
.map(|expr| expr.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
let sum_columns = Vec::from_iter(
self.sum_expr
.iter()
.map(|expr| expr.0.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let AggregatedColumns {
Expand Down
5 changes: 5 additions & 0 deletions crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ mod bitwise_verification_test;

mod provable_expr_plan;
pub(crate) use provable_expr_plan::ProvableExprPlan;
mod aggregate_function_expr;
pub(crate) use aggregate_function_expr::AggregateFunctionExpr;

mod aliased_provable_expr_plan;
pub(crate) use aliased_provable_expr_plan::AliasedProvableExprPlan;

mod provable_expr;
pub(crate) use provable_expr::ProvableExpr;
Expand Down
Loading

0 comments on commit 93d5290

Please sign in to comment.