Skip to content

Commit

Permalink
feat: generalize GroupByExpr
Browse files Browse the repository at this point in the history
- add `AliasedGroupByExpr`
- generalize `GroupByExpr`
  • Loading branch information
iajoiner committed Jun 25, 2024
1 parent 650ac2a commit e364713
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 102 deletions.
11 changes: 11 additions & 0 deletions crates/proof-of-sql/src/sql/ast/aliased_provable_expr_plan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use super::ProvableExprPlan;
use crate::base::commitment::Commitment;
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};

/// A `ProvableExprPlan` with an alias.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AliasedProvableExprPlan<C: Commitment> {
pub expr: ProvableExprPlan<C>,
pub alias: Identifier,
}
31 changes: 14 additions & 17 deletions crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::{
dense_filter_util::{fold_columns, fold_vals},
filter_columns,
provable_expr_plan::ProvableExprPlan,
ProvableExpr, TableExpr,
filter_columns, AliasedProvableExprPlan, ProvableExpr, ProvableExprPlan, TableExpr,
};
use crate::{
base::{
Expand All @@ -23,7 +21,6 @@ use crate::{
use bumpalo::Bump;
use core::iter::repeat_with;
use num_traits::{One, Zero};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, marker::PhantomData};

Expand All @@ -35,7 +32,7 @@ use std::{collections::HashSet, marker::PhantomData};
/// This differs from the [`FilterExpr`] in that the result is not a sparse table.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct OstensibleDenseFilterExpr<C: Commitment, H: ProverHonestyMarker> {
pub(super) aliased_results: Vec<(ProvableExprPlan<C>, Identifier)>,
pub(super) aliased_results: Vec<AliasedProvableExprPlan<C>>,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
phantom: PhantomData<H>,
Expand All @@ -44,7 +41,7 @@ pub struct OstensibleDenseFilterExpr<C: Commitment, H: ProverHonestyMarker> {
impl<C: Commitment, H: ProverHonestyMarker> OstensibleDenseFilterExpr<C, H> {
/// Creates a new dense_filter expression.
pub fn new(
aliased_results: Vec<(ProvableExprPlan<C>, Identifier)>,
aliased_results: Vec<AliasedProvableExprPlan<C>>,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
) -> Self {
Expand All @@ -68,7 +65,7 @@ where
) -> Result<(), ProofError> {
self.where_clause.count(builder)?;
for aliased_expr in self.aliased_results.iter() {
aliased_expr.0.count(builder)?;
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
builder.count_intermediate_mles(2);
Expand Down Expand Up @@ -99,7 +96,7 @@ where
let columns_evals = Vec::from_iter(
self.aliased_results
.iter()
.map(|(expr, _)| expr.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?,
);
// 3. indexes
Expand Down Expand Up @@ -128,15 +125,15 @@ where
fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.aliased_results
.iter()
.map(|(expr, alias)| ColumnField::new(*alias, expr.data_type()))
.map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()))
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
let mut columns = HashSet::new();

for (col, _) in self.aliased_results.iter() {
col.get_column_references(&mut columns);
for aliased_expr in self.aliased_results.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}

self.where_clause.get_column_references(&mut columns);
Expand Down Expand Up @@ -165,11 +162,11 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for DenseFilterExpr<C> {
.expect("selection is not boolean");

// 2. columns
let columns = Vec::from_iter(
self.aliased_results
.iter()
.map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
let columns = Vec::from_iter(self.aliased_results.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
// 3. set indexes
Expand Down Expand Up @@ -200,7 +197,7 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for DenseFilterExpr<C> {
let columns = Vec::from_iter(
self.aliased_results
.iter()
.map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down
76 changes: 39 additions & 37 deletions crates/proof-of-sql/src/sql/ast/group_by_expr.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::{
aggregate_columns, fold_columns, fold_vals,
group_by_util::{compare_indexes_by_owned_columns, AggregatedColumns},
provable_expr_plan::ProvableExprPlan,
ColumnExpr, ProvableExpr, TableExpr,
AliasedProvableExprPlan, ProvableExpr, ProvableExprPlan, TableExpr,
};
use crate::{
base::{
Expand All @@ -29,19 +28,19 @@ use std::collections::HashSet;

/// Provable expressions for queries of the form
/// ```ignore
/// SELECT <group_by_expr1>, ..., <group_by_exprM>,
/// SELECT <group_by_expr1>.0 as <group_by_expr1>.1, ..., <group_by_exprM>.0 as <group_by_exprM>.1,
/// SUM(<sum_expr1>.0) as <sum_expr1>.1, ..., SUM(<sum_exprN>.0) as <sum_exprN>.1,
/// COUNT(*) as count_alias
/// FROM <table>
/// WHERE <where_clause>
/// GROUP BY <group_by_expr1>, ..., <group_by_exprM>
/// GROUP BY <group_by_expr1>.0, ..., <group_by_exprM>.0
/// ```
///
/// Note: if `group_by_exprs` is empty, then the query is equivalent to removing the `GROUP BY` clause.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct GroupByExpr<C: Commitment> {
pub(super) group_by_exprs: Vec<ColumnExpr<C>>,
pub(super) sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
pub(super) group_by_exprs: Vec<AliasedProvableExprPlan<C>>,
pub(super) sum_expr: Vec<AliasedProvableExprPlan<C>>,
pub(super) count_alias: Identifier,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
Expand All @@ -50,8 +49,8 @@ pub struct GroupByExpr<C: Commitment> {
impl<C: Commitment> GroupByExpr<C> {
/// Creates a new group_by expression.
pub fn new(
group_by_exprs: Vec<ColumnExpr<C>>,
sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
group_by_exprs: Vec<AliasedProvableExprPlan<C>>,
sum_expr: Vec<AliasedProvableExprPlan<C>>,
count_alias: Identifier,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
Expand All @@ -73,12 +72,12 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
_accessor: &dyn MetadataAccessor,
) -> Result<(), ProofError> {
self.where_clause.count(builder)?;
for expr in self.group_by_exprs.iter() {
expr.count(builder)?;
for aliased_expr in self.group_by_exprs.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
for expr in self.sum_expr.iter() {
expr.0.count(builder)?;
for aliased_expr in self.sum_expr.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
builder.count_result_columns(1);
Expand Down Expand Up @@ -110,12 +109,12 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
let group_by_evals = self
.group_by_exprs
.iter()
.map(|expr| expr.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
let aggregate_evals = self
.sum_expr
.iter()
.map(|expr| expr.0.verifier_evaluate(builder, accessor))
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
// 3. indexes
let indexes_eval = builder
Expand Down Expand Up @@ -150,7 +149,7 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
let cols = self
.group_by_exprs
.iter()
.map(|col| table.inner_table().get(&col.column_id()))
.map(|aliased_expr| table.inner_table().get(&aliased_expr.alias))
.collect::<Option<Vec<_>>>()
.ok_or(ProofError::VerificationError(
"Result does not all correct group by columns.",
Expand All @@ -169,25 +168,27 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
let mut fields = Vec::new();
for col in self.group_by_exprs.iter() {
fields.push(col.get_column_field());
}
for col in self.sum_expr.iter() {
fields.push(col.1);
}
fields.push(ColumnField::new(self.count_alias, ColumnType::BigInt));
fields
self.group_by_exprs
.iter()
.map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()))
.chain(self.sum_expr.iter().map(|aliased_expr| {
ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())
}))
.chain(std::iter::once(ColumnField::new(
self.count_alias,
ColumnType::BigInt,
)))
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
let mut columns = HashSet::new();

for col in self.group_by_exprs.iter() {
columns.insert(col.get_column_reference());
for aliased_expr in self.group_by_exprs.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}
for col in self.sum_expr.iter() {
columns.insert(col.0.get_column_reference());
for aliased_expr in self.sum_expr.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}

self.where_clause.get_column_references(&mut columns);
Expand All @@ -214,13 +215,14 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
.expect("selection is not boolean");

// 2. columns
let group_by_columns = Vec::from_iter(
self.group_by_exprs
.iter()
.map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|expr| {
expr.0
let group_by_columns = Vec::from_iter(self.group_by_exprs.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
// Compute filtered_columns and indexes
Expand Down Expand Up @@ -262,12 +264,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
let group_by_columns = Vec::from_iter(
self.group_by_exprs
.iter()
.map(|expr| expr.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
let sum_columns = Vec::from_iter(
self.sum_expr
.iter()
.map(|expr| expr.0.prover_evaluate(builder, alloc, accessor)),
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let AggregatedColumns {
Expand Down
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ mod bitwise_verification_test;
mod provable_expr_plan;
pub(crate) use provable_expr_plan::ProvableExprPlan;

mod aliased_provable_expr_plan;
pub(crate) use aliased_provable_expr_plan::AliasedProvableExprPlan;

mod provable_expr;
pub(crate) use provable_expr::ProvableExpr;
#[cfg(all(test, feature = "blitzar"))]
Expand Down
57 changes: 15 additions & 42 deletions crates/proof-of-sql/src/sql/ast/test_utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,29 +145,29 @@ pub fn aliased_col_expr_plan<C: Commitment>(
old_name: &str,
new_name: &str,
accessor: &impl SchemaAccessor,
) -> (ProvableExprPlan<C>, Identifier) {
(
ProvableExprPlan::Column(ColumnExpr::<C>::new(col_ref(tab, old_name, accessor))),
new_name.parse().unwrap(),
)
) -> AliasedProvableExprPlan<C> {
AliasedProvableExprPlan {
expr: ProvableExprPlan::Column(ColumnExpr::<C>::new(col_ref(tab, old_name, accessor))),
alias: new_name.parse().unwrap(),
}
}

pub fn col_expr_plan<C: Commitment>(
tab: TableRef,
name: &str,
accessor: &impl SchemaAccessor,
) -> (ProvableExprPlan<C>, Identifier) {
(
ProvableExprPlan::Column(ColumnExpr::<C>::new(col_ref(tab, name, accessor))),
name.parse().unwrap(),
)
) -> AliasedProvableExprPlan<C> {
AliasedProvableExprPlan {
expr: ProvableExprPlan::Column(ColumnExpr::<C>::new(col_ref(tab, name, accessor))),
alias: name.parse().unwrap(),
}
}

pub fn aliased_cols_expr_plan<C: Commitment>(
tab: TableRef,
names: &[(&str, &str)],
accessor: &impl SchemaAccessor,
) -> Vec<(ProvableExprPlan<C>, Identifier)> {
) -> Vec<AliasedProvableExprPlan<C>> {
names
.iter()
.map(|(old_name, new_name)| aliased_col_expr_plan(tab, old_name, new_name, accessor))
Expand All @@ -178,7 +178,7 @@ pub fn cols_expr_plan<C: Commitment>(
tab: TableRef,
names: &[&str],
accessor: &impl SchemaAccessor,
) -> Vec<(ProvableExprPlan<C>, Identifier)> {
) -> Vec<AliasedProvableExprPlan<C>> {
names
.iter()
.map(|name| col_expr_plan(tab, name, accessor))
Expand All @@ -205,43 +205,16 @@ pub fn cols_expr<C: Commitment>(
}

pub fn dense_filter<C: Commitment>(
results: Vec<(ProvableExprPlan<C>, Identifier)>,
results: Vec<AliasedProvableExprPlan<C>>,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
) -> ProofPlan<C> {
ProofPlan::DenseFilter(DenseFilterExpr::new(results, table, where_clause))
}

pub fn sum_expr<C: Commitment>(
tab: TableRef,
name: &str,
alias: &str,
column_type: ColumnType,
accessor: &impl SchemaAccessor,
) -> (ColumnExpr<C>, ColumnField) {
(
col_expr(tab, name, accessor),
ColumnField::new(alias.parse().unwrap(), column_type),
)
}

pub fn sums_expr<C: Commitment>(
tab: TableRef,
names: &[&str],
aliases: &[&str],
column_types: &[ColumnType],
accessor: &impl SchemaAccessor,
) -> Vec<(ColumnExpr<C>, ColumnField)> {
names
.iter()
.zip(aliases.iter().zip(column_types.iter()))
.map(|(name, (alias, column_type))| sum_expr(tab, name, alias, *column_type, accessor))
.collect()
}

pub fn group_by<C: Commitment>(
group_by_exprs: Vec<ColumnExpr<C>>,
sum_expr: Vec<(ColumnExpr<C>, ColumnField)>,
group_by_exprs: Vec<AliasedProvableExprPlan<C>>,
sum_expr: Vec<AliasedProvableExprPlan<C>>,
count_alias: &str,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
Expand Down
Loading

0 comments on commit e364713

Please sign in to comment.