Skip to content

Commit

Permalink
feat: generalize DenseFilterExpr
Browse files Browse the repository at this point in the history
- Move `Expression` -> `ProvableExprPlan` to a separate struct, `ProvableExprPlanBuilder` since the process is no longer for where clauses only.
- Add `EnrichedExpr` to be equivalent to the provable version of Expr in DataFusion
- Allow arbitrary aliased provable expressions as result expressions in dense filters.
  • Loading branch information
iajoiner committed Jun 14, 2024
1 parent 1ed302b commit 656c36d
Show file tree
Hide file tree
Showing 18 changed files with 468 additions and 239 deletions.
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/and_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable logical AND expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AndExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
Expand Down
30 changes: 15 additions & 15 deletions crates/proof-of-sql/src/sql/ast/dense_filter_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{
dense_filter_util::{fold_columns, fold_vals},
filter_columns,
provable_expr_plan::ProvableExprPlan,
ColumnExpr, ProvableExpr, TableExpr,
ProvableExpr, TableExpr,
};
use crate::{
base::{
Expand All @@ -22,6 +22,7 @@ use crate::{
use bumpalo::Bump;
use core::iter::repeat_with;
use num_traits::{One, Zero};
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
use std::{collections::HashSet, marker::PhantomData};

Expand All @@ -33,7 +34,7 @@ use std::{collections::HashSet, marker::PhantomData};
/// This differs from the [`FilterExpr`] in that the result is not a sparse table.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct OstensibleDenseFilterExpr<C: Commitment, H: ProverHonestyMarker> {
pub(super) results: Vec<ColumnExpr<C>>,
pub(super) results: Vec<(ProvableExprPlan<C>, Identifier)>,
pub(super) table: TableExpr,
pub(super) where_clause: ProvableExprPlan<C>,
phantom: PhantomData<H>,
Expand All @@ -42,7 +43,7 @@ pub struct OstensibleDenseFilterExpr<C: Commitment, H: ProverHonestyMarker> {
impl<C: Commitment, H: ProverHonestyMarker> OstensibleDenseFilterExpr<C, H> {
/// Creates a new dense_filter expression.
pub fn new(
results: Vec<ColumnExpr<C>>,
results: Vec<(ProvableExprPlan<C>, Identifier)>,
table: TableExpr,
where_clause: ProvableExprPlan<C>,
) -> Self {
Expand All @@ -65,8 +66,8 @@ where
_accessor: &dyn MetadataAccessor,
) -> Result<(), ProofError> {
self.where_clause.count(builder)?;
for expr in self.results.iter() {
expr.count(builder)?;
for aliased_expr in self.results.iter() {
aliased_expr.0.count(builder)?;
builder.count_result_columns(1);
}
builder.count_intermediate_mles(2);
Expand Down Expand Up @@ -96,7 +97,7 @@ where
let columns_evals = Vec::from_iter(
self.results
.iter()
.map(|expr| expr.verifier_evaluate(builder, accessor))
.map(|(expr, _)| expr.verifier_evaluate(builder, accessor))
.collect::<Result<Vec<_>, _>>()?,
);
// 3. indexes
Expand All @@ -122,18 +123,17 @@ where
}

fn get_column_result_fields(&self) -> Vec<ColumnField> {
let mut columns = Vec::with_capacity(self.results.len());
for col in self.results.iter() {
columns.push(col.get_column_field());
}
columns
self.results
.iter()
.map(|(expr, alias)| ColumnField::new(*alias, expr.data_type()))
.collect()
}

fn get_column_references(&self) -> HashSet<ColumnRef> {
let mut columns = HashSet::new();

for col in self.results.iter() {
columns.insert(col.get_column_reference());
for (col, _) in self.results.iter() {
col.get_column_references(&mut columns);
}

self.where_clause.get_column_references(&mut columns);
Expand Down Expand Up @@ -165,7 +165,7 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for DenseFilterExpr<C> {
let columns = Vec::from_iter(
self.results
.iter()
.map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)),
.map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down Expand Up @@ -197,7 +197,7 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for DenseFilterExpr<C> {
let columns = Vec::from_iter(
self.results
.iter()
.map(|expr| expr.prover_evaluate(builder, alloc, accessor)),
.map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down
94 changes: 62 additions & 32 deletions crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use crate::{
ast::{
// Making this explicit to ensure that we don't accidentally use the
// sparse filter for these tests
test_utility::{cols_expr, column, const_int128, dense_filter, equal, tab},
test_utility::{
col_expr_plan, cols_expr_plan, column, const_int128, dense_filter, equal, tab,
},
ColumnExpr,
DenseFilterExpr,
LiteralExpr,
Expand All @@ -40,18 +42,26 @@ use std::{collections::HashSet, sync::Arc};
#[test]
fn we_can_correctly_fetch_the_query_result_schema() {
let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap());
let a = Identifier::try_new("a").unwrap();
let b = Identifier::try_new("b").unwrap();
let provable_ast = DenseFilterExpr::<RistrettoPoint>::new(
vec![
ColumnExpr::new(ColumnRef::new(
table_ref,
Identifier::try_new("a").unwrap(),
ColumnType::BigInt,
)),
ColumnExpr::new(ColumnRef::new(
table_ref,
Identifier::try_new("b").unwrap(),
ColumnType::BigInt,
)),
(
ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new(
table_ref,
a,
ColumnType::BigInt,
))),
a,
),
(
ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new(
table_ref,
b,
ColumnType::BigInt,
))),
b,
),
],
TableExpr { table_ref },
ProvableExprPlan::try_new_equals(
Expand Down Expand Up @@ -84,18 +94,26 @@ fn we_can_correctly_fetch_the_query_result_schema() {
#[test]
fn we_can_correctly_fetch_all_the_referenced_columns() {
let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap());
let a = Identifier::try_new("a").unwrap();
let f = Identifier::try_new("f").unwrap();
let provable_ast = DenseFilterExpr::new(
vec![
ColumnExpr::new(ColumnRef::new(
table_ref,
Identifier::try_new("a").unwrap(),
ColumnType::BigInt,
)),
ColumnExpr::new(ColumnRef::new(
table_ref,
Identifier::try_new("f").unwrap(),
ColumnType::BigInt,
)),
(
ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new(
table_ref,
a,
ColumnType::BigInt,
))),
a,
),
(
ProvableExprPlan::Column(ColumnExpr::new(ColumnRef::new(
table_ref,
f,
ColumnType::BigInt,
))),
f,
),
],
TableExpr { table_ref },
not::<RistrettoPoint>(and(
Expand Down Expand Up @@ -170,7 +188,7 @@ fn we_can_prove_and_get_the_correct_result_from_a_basic_dense_filter() {
let mut accessor = RecordBatchTestAccessor::new_empty();
accessor.add_table(t, data, 0);
let where_clause = equal(column(t, "a", &accessor), const_int128(5_i128));
let expr = dense_filter(cols_expr(t, &["b"], &accessor), tab(t), where_clause);
let expr = dense_filter(cols_expr_plan(t, &["b"], &accessor), tab(t), where_clause);
let res = VerifiableQueryResult::<InnerProductProof>::new(&expr, &accessor, &());
let res = res
.verify(&expr, &accessor, &())
Expand All @@ -197,7 +215,7 @@ fn we_can_get_an_empty_result_from_a_basic_dense_filter_on_an_empty_table_using_
let where_clause: ProvableExprPlan<RistrettoPoint> =
equal(column(t, "a", &accessor), const_int128(999));
let expr = dense_filter(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
cols_expr_plan(t, &["b", "c", "d", "e"], &accessor),
tab(t),
where_clause,
);
Expand Down Expand Up @@ -242,7 +260,7 @@ fn we_can_get_an_empty_result_from_a_basic_dense_filter_using_result_evaluate()
let where_clause: ProvableExprPlan<RistrettoPoint> =
equal(column(t, "a", &accessor), const_int128(999));
let expr = dense_filter(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
cols_expr_plan(t, &["b", "c", "d", "e"], &accessor),
tab(t),
where_clause,
);
Expand Down Expand Up @@ -287,7 +305,7 @@ fn we_can_get_no_columns_from_a_basic_dense_filter_with_no_selected_columns_usin
accessor.add_table(t, data, 0);
let where_clause: ProvableExprPlan<RistrettoPoint> =
equal(column(t, "a", &accessor), const_int128(5));
let expr = dense_filter(cols_expr(t, &[], &accessor), tab(t), where_clause);
let expr = dense_filter(cols_expr_plan(t, &[], &accessor), tab(t), where_clause);
let alloc = Bump::new();
let mut builder = ResultBuilder::new(5);
expr.result_evaluate(&mut builder, &alloc, &accessor);
Expand Down Expand Up @@ -315,7 +333,7 @@ fn we_can_get_the_correct_result_from_a_basic_dense_filter_using_result_evaluate
let where_clause: ProvableExprPlan<RistrettoPoint> =
equal(column(t, "a", &accessor), const_int128(5));
let expr = dense_filter(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
cols_expr_plan(t, &["b", "c", "d", "e"], &accessor),
tab(t),
where_clause,
);
Expand Down Expand Up @@ -357,7 +375,7 @@ fn we_can_prove_a_dense_filter_on_an_empty_table() {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(t, data, 0);
let expr = dense_filter(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
cols_expr_plan(t, &["b", "c", "d", "e"], &accessor),
tab(t),
equal(column(t, "a", &accessor), const_int128(106)),
);
Expand Down Expand Up @@ -386,7 +404,7 @@ fn we_can_prove_a_dense_filter_with_empty_results() {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(t, data, 0);
let expr = dense_filter(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
cols_expr_plan(t, &["b", "c", "d", "e"], &accessor),
tab(t),
equal(column(t, "a", &accessor), const_int128(106)),
);
Expand All @@ -406,27 +424,39 @@ fn we_can_prove_a_dense_filter_with_empty_results() {
fn we_can_prove_a_dense_filter() {
let data = owned_table([
bigint("a", [101, 104, 105, 102, 105]),
bigint("b", [1, 2, 3, 4, 5]),
int128("c", [1, 2, 3, 4, 5]),
bigint("b", [1, 2, 3, 4, 7]),
int128("c", [1, 3, 3, 4, 5]),
varchar("d", ["1", "2", "3", "4", "5"]),
scalar("e", [1, 2, 3, 4, 5]),
]);
let t = "sxt.t".parse().unwrap();
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(t, data, 0);
let expr = dense_filter(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
vec![
col_expr_plan(t, "b", &accessor),
col_expr_plan(t, "c", &accessor),
col_expr_plan(t, "d", &accessor),
col_expr_plan(t, "e", &accessor),
(const_int128(105), "const".parse().unwrap()),
(
equal(column(t, "b", &accessor), column(t, "c", &accessor)),
"bool".parse().unwrap(),
),
],
tab(t),
equal(column(t, "a", &accessor), const_int128(105)),
);
let res = VerifiableQueryResult::new(&expr, &accessor, &());
exercise_verification(&res, &expr, &accessor, t);
let res = res.verify(&expr, &accessor, &()).unwrap().table;
let expected = owned_table([
bigint("b", [3, 5]),
bigint("b", [3, 7]),
int128("c", [3, 5]),
varchar("d", ["3", "5"]),
scalar("e", [3, 5]),
int128("const", [105, 105]),
boolean("bool", [true, false]),
]);
assert_eq!(res, expected);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
sql::{
// Making this explicit to ensure that we don't accidentally use the
// sparse filter for these tests
ast::test_utility::{cols_expr, column, const_int128, equal, tab},
ast::test_utility::{cols_expr_plan, column, const_int128, equal, tab},
proof::{
Indexes, ProofBuilder, ProverEvaluate, ProverHonestyMarker, QueryError, ResultBuilder,
VerifiableQueryResult,
Expand Down Expand Up @@ -51,7 +51,7 @@ impl ProverEvaluate<Curve25519Scalar> for DishonestDenseFilterExpr<RistrettoPoin
let columns = Vec::from_iter(
self.results
.iter()
.map(|expr| expr.result_evaluate(builder.table_length(), alloc, accessor)),
.map(|(expr, _)| expr.result_evaluate(builder.table_length(), alloc, accessor)),
);
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down Expand Up @@ -87,7 +87,7 @@ impl ProverEvaluate<Curve25519Scalar> for DishonestDenseFilterExpr<RistrettoPoin
let columns = Vec::from_iter(
self.results
.iter()
.map(|expr| expr.prover_evaluate(builder, alloc, accessor)),
.map(|(expr, _)| expr.prover_evaluate(builder, alloc, accessor)),
);
// Compute filtered_columns and indexes
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down Expand Up @@ -141,7 +141,7 @@ fn we_fail_to_verify_a_basic_dense_filter_with_a_dishonest_prover() {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
accessor.add_table(t, data, 0);
let expr = DishonestDenseFilterExpr::new(
cols_expr(t, &["b", "c", "d", "e"], &accessor),
cols_expr_plan(t, &["b", "c", "d", "e"], &accessor),
tab(t),
equal(column(t, "a", &accessor), const_int128(105_i128)),
);
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/equals_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable AST expression for an equals expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct EqualsExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/inequality_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable AST expression for an inequality expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct InequalityExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/literal_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::collections::HashSet;
/// While this wouldn't be as efficient as using a new custom expression for
/// such queries, it allows us to easily support projects with minimal code
/// changes, and the performance is sufficient for present.
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LiteralExpr<S: Scalar> {
value: LiteralValue<S>,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/not_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable logical NOT expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct NotExpr<C: Commitment> {
expr: Box<ProvableExprPlan<C>>,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/or_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable logical OR expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct OrExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize};
use std::{collections::HashSet, fmt::Debug};

/// Enum of AST column expression types that implement `ProvableExpr`. Is itself a `ProvableExpr`.
#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum ProvableExprPlan<C: Commitment> {
/// Column
Column(ColumnExpr<C>),
Expand Down
Loading

0 comments on commit 656c36d

Please sign in to comment.