Skip to content

Commit

Permalink
refactor!: generalize ProverEvaluate::result_evaluate and `final_ro…
Browse files Browse the repository at this point in the history
…und_evaluate` to multiple tables
  • Loading branch information
iajoiner committed Oct 15, 2024
1 parent 4211dea commit 035e77e
Show file tree
Hide file tree
Showing 17 changed files with 99 additions and 40 deletions.
3 changes: 2 additions & 1 deletion crates/proof-of-sql/src/sql/proof/proof_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub trait ProverEvaluate<S: Scalar> {
/// Evaluate the query and modify `FirstRoundBuilder` to track the result of the query.
fn result_evaluate<'a>(
&self,
input_length: usize,
input_lengths: &[usize],
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<S>,
) -> Vec<Column<'a, S>>;
Expand All @@ -68,6 +68,7 @@ pub trait ProverEvaluate<S: Scalar> {
/// will be bulk deallocated once the proof is formed.
fn final_round_evaluate<'a>(
&self,
input_lengths: &[usize],
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<S>,
Expand Down
9 changes: 7 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,17 @@ impl<C: Commitment> ProofExpr<C> for AddSubtractExpr<C> {
)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_column: Column<'a, C::Scalar> =
self.lhs
.prover_evaluate(table_length, builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs
.prover_evaluate(table_length, builder, alloc, accessor);
Column::Scalar(add_subtract_columns(
lhs_column,
rhs_column,
Expand Down
4 changes: 3 additions & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/aggregate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ impl<C: Commitment> ProofExpr<C> for AggregateExpr<C> {
#[tracing::instrument(name = "AggregateExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
self.expr.prover_evaluate(builder, alloc, accessor)
self.expr
.prover_evaluate(table_length, builder, alloc, accessor)
}

fn verifier_evaluate(
Expand Down
9 changes: 7 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/and_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,17 @@ impl<C: Commitment> ProofExpr<C> for AndExpr<C> {
#[tracing::instrument(name = "AndExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_column: Column<'a, C::Scalar> =
self.lhs
.prover_evaluate(table_length, builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs
.prover_evaluate(table_length, builder, alloc, accessor);
let lhs = lhs_column.as_boolean().expect("lhs is not boolean");
let rhs = rhs_column.as_boolean().expect("rhs is not boolean");
let n = lhs.len();
Expand Down
2 changes: 2 additions & 0 deletions crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ impl<C: Commitment> ProofExpr<C> for ColumnExpr<C> {
/// add the components needed to prove the result
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
_alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let column = accessor.get_column(self.column_ref);
assert!(column.len() == table_length);
builder.produce_anchored_mle(column);
column
}
Expand Down
21 changes: 11 additions & 10 deletions crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,40 +253,41 @@ impl<C: Commitment> ProofExpr<C> for DynProofExpr<C> {

fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
match self {
DynProofExpr::Column(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::And(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Or(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Not(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Literal(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Equals(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Inequality(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::AddSubtract(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Multiply(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
DynProofExpr::Aggregate(expr) => {
ProofExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
ProofExpr::<C>::prover_evaluate(table_length, expr, builder, alloc, accessor)
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/equals_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ impl<C: Commitment> ProofExpr<C> for EqualsExpr<C> {
#[tracing::instrument(name = "EqualsExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_column = self
.lhs
.prover_evaluate(table_length, builder, alloc, accessor);
let rhs_column = self
.rhs
.prover_evaluate(table_length, builder, alloc, accessor);
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res = scale_and_subtract(alloc, lhs_column, rhs_column, lhs_scale, rhs_scale, true)
Expand Down
9 changes: 7 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/inequality_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ impl<C: Commitment> ProofExpr<C> for InequalityExpr<C> {
#[tracing::instrument(name = "InequalityExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_column = self
.lhs
.prover_evaluate(table_length, builder, alloc, accessor);
let rhs_column = self
.rhs
.prover_evaluate(table_length, builder, alloc, accessor);
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let diff = if self.is_lte {
Expand Down
4 changes: 2 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ impl<C: Commitment> ProofExpr<C> for LiteralExpr<C::Scalar> {
#[tracing::instrument(name = "LiteralExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
table_length: usize,
_builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
_accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let table_length = builder.table_length();
Column::from_literal_with_length(&self.value, table_length, alloc)
}

Expand Down
9 changes: 7 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/multiply_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ impl<C: Commitment> ProofExpr<C> for MultiplyExpr<C> {
)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_column: Column<'a, C::Scalar> =
self.lhs
.prover_evaluate(table_length, builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs
.prover_evaluate(table_length, builder, alloc, accessor);

// lhs_times_rhs
let lhs_times_rhs: &'a [C::Scalar] = multiply_columns(&lhs_column, &rhs_column, alloc);
Expand Down
4 changes: 3 additions & 1 deletion crates/proof-of-sql/src/sql/proof_exprs/not_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ impl<C: Commitment> ProofExpr<C> for NotExpr<C> {
#[tracing::instrument(name = "NotExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let expr_column: Column<'a, C::Scalar> =
self.expr.prover_evaluate(builder, alloc, accessor);
self.expr
.prover_evaluate(table_length, builder, alloc, accessor);
let expr = expr_column.as_boolean().expect("expr is not boolean");
Column::Boolean(alloc.alloc_slice_fill_with(expr.len(), |i| !expr[i]))
}
Expand Down
9 changes: 7 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/or_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@ impl<C: Commitment> ProofExpr<C> for OrExpr<C> {
#[tracing::instrument(name = "OrExpr::prover_evaluate", level = "debug", skip_all)]
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_column: Column<'a, C::Scalar> =
self.lhs
.prover_evaluate(table_length, builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs
.prover_evaluate(table_length, builder, alloc, accessor);
let lhs = lhs_column.as_boolean().expect("lhs is not boolean");
let rhs = rhs_column.as_boolean().expect("rhs is not boolean");
Column::Boolean(prover_evaluate_or(builder, alloc, lhs, rhs))
Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/sql/proof_exprs/proof_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pub trait ProofExpr<C: Commitment>: Debug + Send + Sync {
/// of values
fn prover_evaluate<'a>(
&self,
table_length: usize,
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
Expand Down
11 changes: 8 additions & 3 deletions crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for FilterExec<C> {
#[tracing::instrument(name = "FilterExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a>(
&self,
input_length: usize,
input_lengths: &[usize],
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Vec<Column<'a, C::Scalar>> {
assert!(input_lengths.len() == 1);
let input_length = input_lengths[0];
// 1. selection
let selection_column: Column<'a, C::Scalar> =
self.where_clause
Expand Down Expand Up @@ -184,13 +186,16 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for FilterExec<C> {
#[allow(unused_variables)]
fn final_round_evaluate<'a>(
&self,
input_lengths: &[usize],
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Vec<Column<'a, C::Scalar>> {
assert!(input_lengths.len() == 1);
let input_length = input_lengths[0];
// 1. selection
let selection_column: Column<'a, C::Scalar> =
self.where_clause.prover_evaluate(builder, alloc, accessor);
self.where_clause.prover_evaluate(input_length, builder, alloc, accessor);
let selection = selection_column
.as_boolean()
.expect("selection is not boolean");
Expand All @@ -199,7 +204,7 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for FilterExec<C> {
let columns: Vec<_> = self
.aliased_results
.iter()
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor))
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(input_length, builder, alloc, accessor))
.collect();
// Compute filtered_columns
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ impl ProverEvaluate<Curve25519Scalar> for DishonestFilterExec<RistrettoPoint> {
)]
fn result_evaluate<'a>(
&self,
input_length: usize,
input_lengths: &[usize],
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<Curve25519Scalar>,
) -> Vec<Column<'a, Curve25519Scalar>> {
assert!(input_lengths.len() == 1);
let input_length = input_lengths[0];
// 1. selection
let selection_column: Column<'a, Curve25519Scalar> =
self.where_clause
Expand Down Expand Up @@ -76,21 +78,24 @@ impl ProverEvaluate<Curve25519Scalar> for DishonestFilterExec<RistrettoPoint> {
#[allow(unused_variables)]
fn final_round_evaluate<'a>(
&self,
table_lengths: &[usize],
builder: &mut FinalRoundBuilder<'a, Curve25519Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<Curve25519Scalar>,
) -> Vec<Column<'a, Curve25519Scalar>> {
assert!(input_lengths.len() == 1);
let input_length = input_lengths[0];
// 1. selection
let selection_column: Column<'a, Curve25519Scalar> =
self.where_clause.prover_evaluate(builder, alloc, accessor);
self.where_clause.prover_evaluate(input_length, builder, alloc, accessor);
let selection = selection_column
.as_boolean()
.expect("selection is not boolean");
// 2. columns
let columns: Vec<_> = self
.aliased_results
.iter()
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor))
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(input_length, builder, alloc, accessor))
.collect();
// Compute filtered_columns
let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection);
Expand Down
13 changes: 9 additions & 4 deletions crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExec<C> {
#[tracing::instrument(name = "GroupByExec::result_evaluate", level = "debug", skip_all)]
fn result_evaluate<'a>(
&self,
input_length: usize,
input_lengths: &[usize],
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Vec<Column<'a, C::Scalar>> {
assert!(input_lengths.len() == 1);
let input_length = input_lengths[0];
// 1. selection
let selection_column: Column<'a, C::Scalar> =
self.where_clause
Expand Down Expand Up @@ -260,13 +262,16 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExec<C> {
#[allow(unused_variables)]
fn final_round_evaluate<'a>(
&self,
input_lengths: &[usize],
builder: &mut FinalRoundBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Vec<Column<'a, C::Scalar>> {
assert!(input_lengths.len() == 1);
let input_length = input_lengths[0];
// 1. selection
let selection_column: Column<'a, C::Scalar> =
self.where_clause.prover_evaluate(builder, alloc, accessor);
self.where_clause.prover_evaluate(input_length, builder, alloc, accessor);
let selection = selection_column
.as_boolean()
.expect("selection is not boolean");
Expand All @@ -275,12 +280,12 @@ impl<C: Commitment> ProverEvaluate<C::Scalar> for GroupByExec<C> {
let group_by_columns = self
.group_by_exprs
.iter()
.map(|expr| expr.prover_evaluate(builder, alloc, accessor))
.map(|expr| expr.prover_evaluate(input_length, builder, alloc, accessor))
.collect::<Vec<_>>();
let sum_columns = self
.sum_expr
.iter()
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor))
.map(|aliased_expr| aliased_expr.expr.prover_evaluate(input_length, builder, alloc, accessor))
.collect::<Vec<_>>();
// 3. Compute filtered_columns
let AggregatedColumns {
Expand Down
Loading

0 comments on commit 035e77e

Please sign in to comment.