Skip to content

Commit

Permalink
feat: generalize count_expr and add res_expr in GroupByExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Jul 1, 2024
1 parent 2dbd966 commit dd767f6
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 121 deletions.
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/ast/aggregate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use std::collections::HashSet;
/// Currently it doesn't do much since aggregation logic is implemented elsewhere
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AggregateExpr<C: Commitment> {
op: AggregationOperator,
expr: Box<ProvableExprPlan<C>>,
pub(crate) op: AggregationOperator,
pub(crate) expr: Box<ProvableExprPlan<C>>,
}

impl<C: Commitment> AggregateExpr<C> {
Expand Down
82 changes: 56 additions & 26 deletions crates/proof-of-sql/src/sql/ast/group_by_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,28 @@ use crate::{
use bumpalo::Bump;
use core::iter::repeat_with;
use num_traits::One;
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable expressions for queries of the form
/// ```ignore
/// SELECT <group_by_expr1>, ..., <group_by_exprM>,
/// SUM(<sum_expr1>.expr) as <sum_expr1>.alias, ..., SUM(<sum_exprN>.expr) as <sum_exprN>.alias,
/// COUNT(*) as count_alias
/// SELECT <result_exprs[1].expr> as <result_expr[1].alias>, ..., <result_exprs[K].expr> as <result_exprs[K].alias>,
/// SUM(<sum_exprs[1].expr>) as <sum_exprs[1].alias>, ..., SUM(<sum_exprs[M].expr>) as <sum_exprs[M].alias>,
/// COUNT(<count_exprs[1].expr>) as <count_exprs[1].alias>, ..., COUNT(<count_exprs[N].expr>) as <count_exprs[N].alias>
/// FROM <table>
/// WHERE <where_clause>
/// GROUP BY <group_by_expr1>, ..., <group_by_exprM>
/// GROUP BY <group_by_exprs[1]>, ..., <group_by_exprs[L]>
/// ```
///
/// Note: if `group_by_exprs` is empty, then the query is equivalent to removing the `GROUP BY` clause.
/// Note:
/// 1. If `group_by_exprs` is empty, then the query is equivalent to removing the `GROUP BY` clause.
/// 2. Result expressions must only contain columns that are in the `group_by_exprs`.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct GroupByExpr<C: Commitment> {
pub(super) group_by_exprs: Vec<ColumnExpr<C>>,
pub(super) sum_expr: Vec<AliasedProvableExprPlan<C>>,
pub(super) count_alias: Identifier,
pub(super) result_exprs: Vec<AliasedProvableExprPlan<C>>,
pub(super) sum_exprs: Vec<AliasedProvableExprPlan<C>>,
pub(super) count_exprs: Vec<AliasedProvableExprPlan<C>>,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
}
Expand All @@ -50,16 +52,18 @@ impl<C: Commitment> GroupByExpr<C> {
/// Creates a new group_by expression.
pub fn new(
group_by_exprs: Vec<ColumnExpr<C>>,
sum_expr: Vec<AliasedProvableExprPlan<C>>,
count_alias: Identifier,
result_exprs: Vec<AliasedProvableExprPlan<C>>,
sum_exprs: Vec<AliasedProvableExprPlan<C>>,
count_exprs: Vec<AliasedProvableExprPlan<C>>,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
) -> Self {
Self {
group_by_exprs,
sum_expr,
result_exprs,
sum_exprs,
table,
count_alias,
count_exprs,
where_clause,
}
}
Expand All @@ -74,13 +78,16 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
self.where_clause.count(builder)?;
for expr in self.group_by_exprs.iter() {
expr.count(builder)?;
}
for aliased_expr in self.result_exprs.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
for aliased_expr in self.sum_expr.iter() {
for aliased_expr in self.sum_exprs.iter() {
aliased_expr.expr.count(builder)?;
builder.count_result_columns(1);
}
builder.count_result_columns(1);
builder.count_result_columns(self.count_exprs.len());
builder.count_intermediate_mles(2);
builder.count_subpolynomials(3);
builder.count_degree(3);
Expand Down Expand Up @@ -112,7 +119,12 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
.map(|expr| expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
let aggregate_evals = self
.sum_expr
.sum_exprs
.iter()
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
let result_evals = self
.result_exprs
.iter()
.map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?;
Expand All @@ -127,7 +139,10 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
repeat_with(|| builder.consume_result_mle()).take(self.group_by_exprs.len()),
);
let sum_result_columns_evals =
Vec::from_iter(repeat_with(|| builder.consume_result_mle()).take(self.sum_expr.len()));
Vec::from_iter(repeat_with(|| builder.consume_result_mle()).take(self.sum_exprs.len()));
let result_columns_evals = Vec::from_iter(
repeat_with(|| builder.consume_result_mle()).take(self.result_exprs.len()),
);
let count_column_eval = builder.consume_result_mle();

let alpha = builder.consume_post_result_challenge();
Expand Down Expand Up @@ -168,26 +183,28 @@ impl<C: Commitment> ProofExpr<C> for GroupByExpr<C> {
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.group_by_exprs
self.result_exprs
.iter()
.map(|col| col.get_column_field())
.chain(self.sum_expr.iter().map(|aliased_expr| {
.map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()))
.chain(self.sum_exprs.iter().map(|aliased_expr| {
ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())
}))
.chain(std::iter::once(ColumnField::new(
self.count_alias,
ColumnType::BigInt,
)))
.chain(
self.count_exprs
.iter()
.map(|aliased_expr| ColumnField::new(aliased_expr.alias, ColumnType::BigInt)),
)
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
let mut columns = HashSet::new();

// No need to add columns from result_exprs since they are already in group_by_exprs
for col in self.group_by_exprs.iter() {
columns.insert(col.get_column_reference());
}
for aliased_expr in self.sum_expr.iter() {
for aliased_expr in self.sum_exprs.iter() {
aliased_expr.expr.get_column_references(&mut columns);
}

Expand Down Expand Up @@ -220,7 +237,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
.iter()
.map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|aliased_expr| {
let sum_columns = Vec::from_iter(self.sum_exprs.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
}));
let result_columns = Vec::from_iter(self.result_exprs.iter().map(|aliased_expr| {
aliased_expr
.expr
.result_evaluate(builder.table_length(), alloc, accessor)
Expand All @@ -241,6 +263,9 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
for col in sum_result_columns {
builder.produce_result_column(col);
}
for col in result_columns {
builder.produce_result_column(col);
}
builder.produce_result_column(count_column);
builder.request_post_result_challenges(2);
}
Expand All @@ -267,7 +292,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExpr<C> {
.map(|expr| expr.prover_evaluate(builder, alloc, accessor)),
);
let sum_columns = Vec::from_iter(
self.sum_expr
self.sum_exprs
.iter()
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
let result_columns = Vec::from_iter(
self.result_exprs
.iter()
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)),
);
Expand Down
31 changes: 1 addition & 30 deletions crates/proof-of-sql/src/sql/ast/group_by_expr_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,9 @@ use crate::{
sql::proof::{exercise_verification, VerifiableQueryResult},
};

/// select a, sum(c) as sum_c, count(*) as __count__ from sxt.t where b = 99 group by a
#[test]
fn we_can_prove_a_simple_group_by_with_bigint_columns() {
let data = owned_table([
bigint("a", [1, 2, 2, 1, 2]),
bigint("b", [99, 99, 99, 99, 0]),
bigint("c", [101, 102, 103, 104, 105]),
]);
let t = "sxt.t".parse().unwrap();
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(t, data, 0);
let expr = group_by(
cols_expr(t, &["a"], &accessor),
vec![sum_expr(column(t, "c", &accessor), "sum_c")],
"__count__",
tab(t),
equal(column(t, "b", &accessor), const_int128(99)),
);
let res = VerifiableQueryResult::new(&expr, &accessor, &());
exercise_verification(&res, &expr, &accessor, t);
let res = res.verify(&expr, &accessor, &()).unwrap().table;
let expected = owned_table([
bigint("a", [1, 2]),
bigint("sum_c", [101 + 104, 102 + 103]),
bigint("__count__", [2, 2]),
]);
assert_eq!(res, expected);
}

/// select a, sum(c * 2 + 1) as sum_c, count(*) as __count__ from sxt.t where b = 99 group by a
#[test]
fn we_can_prove_a_group_by_with_bigint_columns() {
fn we_can_prove_a_simple_group_by_with_bigint_columns() {
let data = owned_table([
bigint("a", [1, 2, 2, 1, 2]),
bigint("b", [99, 99, 99, 99, 0]),
Expand Down
123 changes: 62 additions & 61 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use crate::{
database::{ColumnRef, ColumnType, LiteralValue, TableRef},
},
sql::{
ast::{AliasedProvableExprPlan, ColumnExpr, GroupByExpr, ProvableExprPlan, TableExpr},
ast::{
AggregateExpr, AliasedProvableExprPlan, ColumnExpr, GroupByExpr, ProvableExprPlan,
TableExpr,
},
parse::{ConversionError, ConversionResult, ProvableExprPlanBuilder, WhereExprBuilder},
},
};
Expand Down Expand Up @@ -250,81 +253,79 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
})
.collect::<Result<Vec<ColumnExpr<C>>, ConversionError>>()?;
// For a query to be provable the result columns must be of one of three kinds below:
// 1. Group by columns (it is mandatory to have all of them in the correct order)
// 2. Sum(expr) expressions (it is optional to have any)
// 3. count(*) with an alias (it is mandatory to have one and only one)
let num_group_by_columns = group_by_exprs.len();
let num_result_columns = value.res_aliased_exprs.len();
if num_result_columns < num_group_by_columns + 1 {
// 1. Expressions exclusively consisting of group by columns
// 2. Sum(expr) expressions with an alias
// 3. count(expr) with an alias
let opt_res_expr_plans = value
.res_aliased_exprs
.iter()
.map(|res| {
let res_provable_expr_plan =
ProvableExprPlanBuilder::new(&value.column_mapping).build(&res.expr);
res_provable_expr_plan
.ok()
.map(|provable_expr_plan| AliasedProvableExprPlan {
alias: res.alias,
expr: provable_expr_plan,
})
})
.collect::<Option<Vec<AliasedProvableExprPlan<C>>>>();
if opt_res_expr_plans.is_none() {
return Ok(None);
}
let res_group_by_columns = &value.res_aliased_exprs[..num_group_by_columns].to_vec();
let sum_expr_columns =
&value.res_aliased_exprs[num_group_by_columns..num_result_columns - 1].to_vec();
// Check group by columns
let group_by_compliance = value
.group_by_exprs
let res_expr_plans = opt_res_expr_plans.expect("the none case was just checked");
let sum_exprs = res_expr_plans
.iter()
.zip(res_group_by_columns.iter())
.all(|(ident, res)| {
//TODO: This is due to a workaround related to polars
//Need to remove it when possible (PROOF-850)
if let Expression::Aggregation {
op: AggregationOperator::First,
.filter_map(|res| {
if let ProvableExprPlan::Aggregate(AggregateExpr {
op: AggregationOperator::Sum,
expr,
} = (*res.expr).clone()
}) = &res.expr
{
if let Expression::Column(res_ident) = *expr {
res_ident == *ident
} else {
false
}
Some(AliasedProvableExprPlan {
alias: res.alias.clone(),
expr: *expr.clone(),
})
} else {
false
None
}
});
// Check sums
let sum_expr = sum_expr_columns
})
.collect::<Vec<AliasedProvableExprPlan<C>>>();
let count_exprs = res_expr_plans
.iter()
.map(|res| {
if let Expression::Aggregation {
op: AggregationOperator::Sum,
..
} = (*res.expr).clone()
.filter_map(|res| {
if let ProvableExprPlan::Aggregate(AggregateExpr {
op: AggregationOperator::Count,
expr,
}) = &res.expr
{
let res_provable_expr_plan =
ProvableExprPlanBuilder::new(&value.column_mapping).build(&res.expr);
res_provable_expr_plan
.ok()
.map(|provable_expr_plan| AliasedProvableExprPlan {
alias: res.alias,
expr: provable_expr_plan,
})
Some(AliasedProvableExprPlan {
alias: res.alias.clone(),
expr: *expr.clone(),
})
} else {
None
}
})
.collect::<Option<Vec<AliasedProvableExprPlan<C>>>>();

// Check count(*)
let count_column = &value.res_aliased_exprs[num_result_columns - 1];
let count_column_compliant = if let Expression::Aggregation {
op: AggregationOperator::Count,
expr,
} = (*count_column.expr).clone()
{
//TODO: This is due to a workaround related to polars
matches!(*expr, Expression::Column(_))
} else {
false
};
if !group_by_compliance || sum_expr.is_none() || !count_column_compliant {
return Ok(None);
}
.collect::<Vec<AliasedProvableExprPlan<C>>>();
let res_exprs = res_expr_plans
.iter()
.filter_map(|res| {
if let ProvableExprPlan::Aggregate(_expr) = &res.expr {
None
} else {
Some(AliasedProvableExprPlan {
alias: res.alias.clone(),
expr: res.expr.clone(),
})
}
})
.collect::<Vec<AliasedProvableExprPlan<C>>>();
Ok(Some(GroupByExpr::new(
group_by_exprs,
sum_expr.expect("the none case was just checked"),
count_column.alias,
res_exprs,
sum_exprs,
count_exprs,
table,
where_clause,
)))
Expand Down
3 changes: 1 addition & 2 deletions crates/proof-of-sql/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,8 @@ fn we_can_prove_a_minimal_group_by_query_with_curve25519() {
}

#[test]
#[cfg(feature = "blitzar")]
fn we_can_prove_a_basic_group_by_query_with_curve25519() {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let mut accessor = OwnedTableTestAccessor::<InterProductProof>::new_empty_with_setup(());
accessor.add_table(
"sxt.table".parse().unwrap(),
owned_table([
Expand Down

0 comments on commit dd767f6

Please sign in to comment.