Skip to content

Commit

Permalink
fix: fix postprocessing projection in provable GROUP BYs & allow `W…
Browse files Browse the repository at this point in the history
…HERE` clause to be omitted (#50)

# Rationale for this change
Recently we found weird postprocessing bugs for group bys. What we found
is that
1. Currently provable `GROUP BY`s without a `WHERE` clause are not
supported even though to support it is trivial.
2. Currently if a `GROUP BY` clause is provable the `ResultExpr` still
incorrectly contains aggregate functions.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked Jira ticket then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->

# What changes are included in this PR?
- fix postprocessing projection with provable `GROUP BY`
- allow provable group by without `WHERE` clause
- enable 3 ignored tests
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

# Are these changes tested?
Yes
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
3. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
  • Loading branch information
iajoiner authored Jul 15, 2024
1 parent ba6e53a commit b52d19a
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 23 deletions.
6 changes: 2 additions & 4 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,6 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
type Error = ConversionError;

fn try_from(value: &QueryContext) -> Result<Option<GroupByExpr<C>>, Self::Error> {
// Currently if there is no where clause, we can't prove the query
if value.where_expr.is_none() {
return Ok(None);
}
let where_clause = WhereExprBuilder::new(&value.column_mapping)
.build(value.where_expr.clone())?
.unwrap_or_else(|| ProvableExprPlan::new_literal(LiteralValue::Boolean(true)));
Expand Down Expand Up @@ -283,6 +279,7 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
false
}
});

// Check sums
let sum_expr = sum_expr_columns
.iter()
Expand Down Expand Up @@ -318,6 +315,7 @@ impl<C: Commitment> TryFrom<&QueryContext> for Option<GroupByExpr<C>> {
} else {
false
};

if !group_by_compliance || sum_expr.is_none() || !count_column_compliant {
return Ok(None);
}
Expand Down
17 changes: 15 additions & 2 deletions crates/proof-of-sql/src/sql/parse/query_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use crate::{
transform::ResultExpr,
},
};
use proof_of_sql_parser::{intermediate_ast::SetExpression, Identifier, SelectStatement};
use proof_of_sql_parser::{
intermediate_ast::{AliasedResultExpr, Expression, SetExpression},
Identifier, SelectStatement,
};
use serde::{Deserialize, Serialize};
use std::fmt;

Expand Down Expand Up @@ -62,10 +65,20 @@ impl<C: Commitment> QueryExpr<C> {
let group_by = context.get_group_by_exprs();
if !group_by.is_empty() {
if let Some(group_by_expr) = Option::<GroupByExpr<C>>::try_from(&context)? {
// If the group by expression is provable the projection step is just identity.
let new_result_aliased_exprs = result_aliased_exprs
.iter()
.map(|aliased_expr| {
AliasedResultExpr::new(
Expression::Column(aliased_expr.alias),
aliased_expr.alias,
)
})
.collect::<Vec<_>>();
return Ok(Self {
proof_expr: ProofPlan::GroupBy(group_by_expr),
result: ResultExprBuilder::default()
.add_select_exprs(result_aliased_exprs)
.add_select_exprs(&new_result_aliased_exprs)
.add_order_by_exprs(context.get_order_by_exprs()?)
.add_slice_expr(context.get_slice_expr())
.build(),
Expand Down
33 changes: 16 additions & 17 deletions crates/proof-of-sql/src/sql/parse/query_expr_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,6 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause_preceded_by_wher
///////////////////////////
// Group By Expressions - Prover
///////////////////////////
#[ignore]
#[test]
fn we_can_do_provable_group_by() {
let t = "sxt.employees".parse().unwrap();
Expand All @@ -1121,14 +1120,14 @@ fn we_can_do_provable_group_by() {
const_bool(true),
),
composite_result(vec![select(&[
pc("department").first().alias("department"),
pc("salary").sum().alias("total_salary"),
pc("department").count().alias("num_employee"),
pc("department").alias("department"),
pc("total_salary").alias("total_salary"),
pc("num_employee").alias("num_employee"),
])]),
);
assert_eq!(ast, expected_ast);
}
#[ignore]

#[test]
fn we_can_do_provable_group_by_without_sum() {
let t = "sxt.employees".parse().unwrap();
Expand All @@ -1155,13 +1154,13 @@ fn we_can_do_provable_group_by_without_sum() {
const_bool(true),
),
composite_result(vec![select(&[
pc("department").first().alias("department"),
pc("department").count().alias("num_employee"),
pc("department").alias("department"),
pc("num_employee").alias("num_employee"),
])]),
);
assert_eq!(ast, expected_ast);
}
#[ignore]

#[test]
fn we_can_do_provable_group_by_with_two_group_by_columns() {
let t = "sxt.employees".parse().unwrap();
Expand Down Expand Up @@ -1189,10 +1188,10 @@ fn we_can_do_provable_group_by_with_two_group_by_columns() {
const_bool(true),
),
composite_result(vec![select(&[
pc("state").first().alias("state"),
pc("department").first().alias("department"),
pc("salary").sum().alias("total_salary"),
pc("department").count().alias("num_employee"),
pc("state").alias("state"),
pc("department").alias("department"),
pc("total_salary").alias("total_salary"),
pc("num_employee").alias("num_employee"),
])]),
);
assert_eq!(ast, expected_ast);
Expand Down Expand Up @@ -1228,17 +1227,17 @@ fn we_can_do_provable_group_by_with_two_sums_and_dense_filter() {
lte(column(t, "tax", &accessor), const_bigint(1)),
),
composite_result(vec![select(&[
pc("department").first().alias("department"),
pc("salary").sum().alias("total_salary"),
pc("tax").sum().alias("total_tax"),
pc("department").count().alias("num_employee"),
pc("department").alias("department"),
pc("total_salary").alias("total_salary"),
pc("total_tax").alias("total_tax"),
pc("num_employee").alias("num_employee"),
])]),
);
assert_eq!(ast, expected_ast);
}

///////////////////////////
// Group By Expressions - Polars
// Group By Expressions - Postprocessing
///////////////////////////
#[test]
fn we_can_group_by_without_using_aggregate_functions() {
Expand Down
125 changes: 125 additions & 0 deletions crates/proof-of-sql/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,131 @@ fn we_can_prove_a_basic_group_by_query_with_curve25519() {
assert_eq!(owned_table_result, expected_result);
}

#[test]
#[cfg(feature = "blitzar")]
fn we_can_prove_a_cat_group_by_query_with_curve25519() {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(
"sxt.cats".parse().unwrap(),
owned_table([
int("id", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
varchar(
"name",
[
"Chloe",
"Margaret",
"Prudence",
"Lucy",
"Ms. Kitty",
"Pepper",
"Rocky",
"Smokey",
"Tiger",
"Whiskers",
],
),
smallint("age", [12_i16, 2, 3, 3, 10, 2, 2, 4, 5, 6]),
varchar(
"human",
[
"Ian", "Ian", "Gretta", "Gretta", "Gretta", "Gretta", "Gretta", "Alice", "Bob",
"Charlie",
],
),
bigint("proof_order", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
]),
0,
);
let query = QueryExpr::try_new(
"select human, sum(age) as total_cat_age, count(*) as num_cats from sxt.cats where age = 2 group by human"
.parse()
.unwrap(),
"sxt".parse().unwrap(),
&accessor,
)
.unwrap();
let (proof, serialized_result) =
QueryProof::<InnerProductProof>::new(query.proof_expr(), &accessor, &());
let owned_table_result = proof
.verify(query.proof_expr(), &accessor, &serialized_result, &())
.unwrap()
.table;
let expected_result = owned_table([
varchar("human", ["Gretta", "Ian"]),
smallint("total_cat_age", [4_i16, 2]),
bigint("num_cats", [2, 1]),
]);
assert_eq!(owned_table_result, expected_result);
}

#[test]
fn we_can_prove_a_cat_group_by_query_with_dory() {
let public_parameters = PublicParameters::rand(4, &mut test_rng());
let prover_setup = ProverSetup::from(&public_parameters);
let verifier_setup = VerifierSetup::from(&public_parameters);
let dory_prover_setup = DoryProverPublicSetup::new(&prover_setup, 3);
let dory_verifier_setup = DoryVerifierPublicSetup::new(&verifier_setup, 3);

let mut accessor =
OwnedTableTestAccessor::<DoryEvaluationProof>::new_empty_with_setup(dory_prover_setup);
accessor.add_table(
"sxt.cats".parse().unwrap(),
owned_table([
int("id", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
varchar(
"name",
[
"Chloe",
"Margaret",
"Prudence",
"Lucy",
"Ms. Kitty",
"Pepper",
"Rocky",
"Smokey",
"Tiger",
"Whiskers",
],
),
smallint("age", [12_i16, 2, 3, 3, 10, 2, 2, 4, 5, 6]),
varchar(
"human",
[
"Ian", "Ian", "Gretta", "Gretta", "Gretta", "Gretta", "Gretta", "Alice", "Bob",
"Charlie",
],
),
bigint("proof_order", [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
]),
0,
);
let query = QueryExpr::try_new(
"select human, sum(age) as total_cat_age, count(*) as num_cats from sxt.cats where age = 2 group by human"
.parse()
.unwrap(),
"sxt".parse().unwrap(),
&accessor,
)
.unwrap();
let (proof, serialized_result) =
QueryProof::<DoryEvaluationProof>::new(query.proof_expr(), &accessor, &dory_prover_setup);
let owned_table_result = proof
.verify(
query.proof_expr(),
&accessor,
&serialized_result,
&dory_verifier_setup,
)
.unwrap()
.table;
let expected_result = owned_table([
varchar("human", ["Gretta", "Ian"]),
smallint("total_cat_age", [4_i16, 2]),
bigint("num_cats", [2, 1]),
]);
assert_eq!(owned_table_result, expected_result);
}

#[test]
fn we_can_prove_a_basic_group_by_query_with_dory() {
let public_parameters = PublicParameters::rand(4, &mut test_rng());
Expand Down

0 comments on commit b52d19a

Please sign in to comment.