From 6c8f329b2404b1ec5992c8347db0db241cf6b642 Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Thu, 12 Sep 2024 16:36:16 -0400 Subject: [PATCH] refactor: add `is_top_level` to `prover_evaluate` and `verifier_evaluate` for `ProofPlan`s --- .../proof-of-sql/src/base/database/column.rs | 17 ++++++++ .../proof-of-sql/src/sql/proof/proof_plan.rs | 2 + .../proof-of-sql/src/sql/proof/query_proof.rs | 2 +- .../src/sql/proof/query_proof_test.rs | 11 +++-- .../src/sql/proof/verification_builder.rs | 12 ++++++ .../sql/proof/verification_builder_test.rs | 10 ++++- .../src/sql/proof_plans/dyn_proof_plan.rs | 26 +++++++++--- .../src/sql/proof_plans/filter_exec.rs | 20 +++++---- .../filter_exec_test_dishonest_prover.rs | 16 +++---- .../src/sql/proof_plans/group_by_exec.rs | 42 ++++++++++++------- .../src/sql/proof_plans/projection_exec.rs | 15 ++++--- 11 files changed, 124 insertions(+), 49 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index c79fcc784..cefd264fb 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -161,6 +161,23 @@ impl<'a, S: Scalar> Column<'a, S> { } } + /// Returns the column as a slice of scalars + pub(crate) fn as_scalar(&self, alloc: &'a Bump) -> &'a [S] { + match self { + Self::Boolean(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), + Self::SmallInt(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), + Self::Int(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), + Self::BigInt(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), + Self::Int128(col) => alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])), + Self::Scalar(col) => col, + Self::Decimal75(_, _, col) => col, + Self::VarChar((_, scals)) => scals, + Self::TimestampTZ(_, _, col) => { + alloc.alloc_slice_fill_with(col.len(), |i| S::from(col[i])) + } + } + } + /// Returns element at index as scalar /// /// Note that if index is out of bounds, this function will return None diff --git a/crates/proof-of-sql/src/sql/proof/proof_plan.rs b/crates/proof-of-sql/src/sql/proof/proof_plan.rs index 1becb16c5..53c29f35c 100644 --- a/crates/proof-of-sql/src/sql/proof/proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof/proof_plan.rs @@ -38,6 +38,7 @@ pub trait ProofPlan: Debug + Send + Sync + ProverEvaluate, accessor: &dyn CommitmentAccessor, result: Option<&OwnedTable>, + is_top_level: bool, ) -> Result, ProofError>; /// Return all the result column fields @@ -67,6 +68,7 @@ pub trait ProverEvaluate { builder: &mut ProofBuilder<'a, S>, alloc: &'a Bump, accessor: &'a dyn DataAccessor, + is_top_level: bool, ) -> Vec>; } diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 04be6d2e5..9bdcece5c 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -76,7 +76,7 @@ impl QueryProof { let mut builder = ProofBuilder::new(table_length, num_sumcheck_variables, post_result_challenges); - expr.prover_evaluate(&mut builder, &alloc, accessor); + expr.prover_evaluate(&mut builder, &alloc, accessor, true); let num_sumcheck_variables = builder.num_sumcheck_variables(); let table_length = builder.table_length(); diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index 75fb2b4cd..7119b5edf 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -91,7 +91,10 @@ impl ProofPlan for TrivialTestProofPlan { _accessor: &dyn CommitmentAccessor, _result: Option<&OwnedTable>, ) -> Result, ProofError> { - assert_eq!(builder.consume_result_mle(), C::Scalar::ZERO); + assert_eq!( + builder.consume_result_or_intermediate_mle(self.is_top_level), + C::Scalar::ZERO + ); builder.produce_sumcheck_subpolynomial_evaluation(&C::Scalar::from(self.evaluation)); Ok(vec![C::Scalar::ZERO]) } @@ -265,7 +268,7 @@ impl ProofPlan for SquareTestProofPlan { accessor: &dyn CommitmentAccessor, _result: Option<&OwnedTable>, ) -> Result, ProofError> { - let res_eval = builder.consume_result_mle(); + let res_eval = builder.consume_result_or_intermediate_mle(self.is_top_level); let x_commit = C::Scalar::from(self.anchored_commit_multiplier) * accessor.get_commitment(ColumnRef::new( "sxt.test".parse().unwrap(), @@ -462,7 +465,7 @@ impl ProofPlan for DoubleSquareTestProofPlan { "x".parse().unwrap(), ColumnType::BigInt, )); - let res_eval = builder.consume_result_mle(); + let res_eval = builder.consume_result_or_intermediate_mle(self.is_top_level); let x_eval = builder.consume_anchored_mle(x_commit); let z_eval = builder.consume_intermediate_mle(); @@ -655,7 +658,7 @@ impl ProofPlan for ChallengeTestProofPlan { ) -> Result, ProofError> { let alpha = builder.consume_post_result_challenge(); let _beta = builder.consume_post_result_challenge(); - let res_eval = builder.consume_result_mle(); + let res_eval = builder.consume_result_or_intermediate_mle(self.is_top_level); let x_commit = accessor.get_commitment(ColumnRef::new( "sxt.test".parse().unwrap(), "x".parse().unwrap(), diff --git a/crates/proof-of-sql/src/sql/proof/verification_builder.rs b/crates/proof-of-sql/src/sql/proof/verification_builder.rs index 1ee6a2e6a..616986c36 100644 --- a/crates/proof-of-sql/src/sql/proof/verification_builder.rs +++ b/crates/proof-of-sql/src/sql/proof/verification_builder.rs @@ -104,6 +104,18 @@ impl<'a, C: Commitment> VerificationBuilder<'a, C> { self.mle_evaluations.result_evaluations[index] } + /// Consume the evaluation of an intermediate MLE if NOT at the top level + /// and consume the evaluation of a result MLE if at the top level + /// + /// This is used to consume the evaluation of an MLE used in sumcheck + pub fn consume_result_or_intermediate_mle(&mut self, is_top_level: bool) -> C::Scalar { + if is_top_level { + self.consume_result_mle() + } else { + self.consume_intermediate_mle() + } + } + /// Produce the evaluation of a subpolynomial used in sumcheck pub fn produce_sumcheck_subpolynomial_evaluation(&mut self, eval: &C::Scalar) { self.sumcheck_evaluation += diff --git a/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs b/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs index 5c69c76c0..0373fa85c 100644 --- a/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs +++ b/crates/proof-of-sql/src/sql/proof/verification_builder_test.rs @@ -120,8 +120,14 @@ fn we_can_consume_result_evaluations() { &[][..], Vec::new(), ); - assert_eq!(builder.consume_result_mle(), Curve25519Scalar::from(123u64)); - assert_eq!(builder.consume_result_mle(), Curve25519Scalar::from(456u64)); + assert_eq!( + builder.consume_result_or_intermediate_mle(true), + Curve25519Scalar::from(123u64) + ); + assert_eq!( + builder.consume_result_or_intermediate_mle(true), + Curve25519Scalar::from(456u64) + ); } #[test] diff --git a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs index fc905f962..5b0fc04d3 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/dyn_proof_plan.rs @@ -65,11 +65,18 @@ impl ProofPlan for DynProofPlan { builder: &mut crate::sql::proof::VerificationBuilder, accessor: &dyn crate::base::database::CommitmentAccessor, result: Option<&crate::base::database::OwnedTable>, + is_top_level: bool, ) -> Result, crate::base::proof::ProofError> { match self { - DynProofPlan::Projection(expr) => expr.verifier_evaluate(builder, accessor, result), - DynProofPlan::GroupBy(expr) => expr.verifier_evaluate(builder, accessor, result), - DynProofPlan::Filter(expr) => expr.verifier_evaluate(builder, accessor, result), + DynProofPlan::Projection(expr) => { + expr.verifier_evaluate(builder, accessor, result, is_top_level) + } + DynProofPlan::GroupBy(expr) => { + expr.verifier_evaluate(builder, accessor, result, is_top_level) + } + DynProofPlan::Filter(expr) => { + expr.verifier_evaluate(builder, accessor, result, is_top_level) + } } } @@ -111,11 +118,18 @@ impl ProverEvaluate for DynProofPlan { builder: &mut crate::sql::proof::ProofBuilder<'a, C::Scalar>, alloc: &'a bumpalo::Bump, accessor: &'a dyn crate::base::database::DataAccessor, + is_top_level: bool, ) -> Vec> { match self { - DynProofPlan::Projection(expr) => expr.prover_evaluate(builder, alloc, accessor), - DynProofPlan::GroupBy(expr) => expr.prover_evaluate(builder, alloc, accessor), - DynProofPlan::Filter(expr) => expr.prover_evaluate(builder, alloc, accessor), + DynProofPlan::Projection(expr) => { + expr.prover_evaluate(builder, alloc, accessor, is_top_level) + } + DynProofPlan::GroupBy(expr) => { + expr.prover_evaluate(builder, alloc, accessor, is_top_level) + } + DynProofPlan::Filter(expr) => { + expr.prover_evaluate(builder, alloc, accessor, is_top_level) + } } } } diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index e78cd2548..51bc7042e 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -89,6 +89,7 @@ where builder: &mut VerificationBuilder, accessor: &dyn CommitmentAccessor, _result: Option<&OwnedTable>, + is_top_level: bool, ) -> Result, ProofError> { // 1. selection let selection_eval = self.where_clause.verifier_evaluate(builder, accessor)?; @@ -106,7 +107,8 @@ where .ok_or(ProofError::VerificationError("invalid indexes"))?; // 4. filtered_columns let filtered_columns_evals = Vec::from_iter( - repeat_with(|| builder.consume_result_mle()).take(self.aliased_results.len()), + repeat_with(|| builder.consume_result_or_intermediate_mle(is_top_level)) + .take(self.aliased_results.len()), ); let alpha = builder.consume_post_result_challenge(); @@ -182,20 +184,22 @@ impl ProverEvaluate for FilterExec { builder: &mut ProofBuilder<'a, C::Scalar>, alloc: &'a Bump, accessor: &'a dyn DataAccessor, + is_top_level: bool, ) -> Vec> { // 1. selection - let selection_column: Column<'a, C::Scalar> = - self.where_clause.prover_evaluate(builder, alloc, accessor); + let selection_column: Column<'a, C::Scalar> = self + .where_clause + .prover_evaluate(builder, alloc, accessor, false); let selection = selection_column .as_boolean() .expect("selection is not boolean"); // 2. columns - let columns = Vec::from_iter( - self.aliased_results - .iter() - .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), - ); + let columns = Vec::from_iter(self.aliased_results.iter().map(|aliased_expr| { + aliased_expr + .expr + .prover_evaluate(builder, alloc, accessor, false) + })); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs index 14ef39996..5d188d652 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs @@ -74,19 +74,21 @@ impl ProverEvaluate for DishonestFilterExec { builder: &mut ProofBuilder<'a, Curve25519Scalar>, alloc: &'a Bump, accessor: &'a dyn DataAccessor, + is_top_level: bool, ) -> Vec> { // 1. selection - let selection_column: Column<'a, Curve25519Scalar> = - self.where_clause.prover_evaluate(builder, alloc, accessor); + let selection_column: Column<'a, Curve25519Scalar> = self + .where_clause + .prover_evaluate(builder, alloc, accessor, false); let selection = selection_column .as_boolean() .expect("selection is not boolean"); // 2. columns - let columns = Vec::from_iter( - self.aliased_results - .iter() - .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), - ); + let columns = Vec::from_iter(self.aliased_results.iter().map(|aliased_expr| { + aliased_expr + .expr + .prover_evaluate(builder, alloc, accessor, false) + })); // Compute filtered_columns and indexes let (filtered_columns, result_len) = filter_columns(alloc, &columns, selection); let filtered_columns = tamper_column(alloc, filtered_columns); diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index e7377e25f..2d9912830 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -104,19 +104,26 @@ impl ProofPlan for GroupByExec { builder: &mut VerificationBuilder, accessor: &dyn CommitmentAccessor, result: Option<&OwnedTable>, + is_top_level: bool, ) -> Result, ProofError> { // 1. selection - let where_eval = self.where_clause.verifier_evaluate(builder, accessor)?; + let where_eval = self + .where_clause + .verifier_evaluate(builder, accessor, false)?; // 2. columns let group_by_evals = self .group_by_exprs .iter() - .map(|expr| expr.verifier_evaluate(builder, accessor)) + .map(|expr| expr.verifier_evaluate(builder, accessor, false)) .collect::, _>>()?; let aggregate_evals = self .sum_expr .iter() - .map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor)) + .map(|aliased_expr| { + aliased_expr + .expr + .verifier_evaluate(builder, accessor, false) + }) .collect::, _>>()?; // 3. indexes let indexes_eval = builder @@ -126,11 +133,14 @@ impl ProofPlan for GroupByExec { // 4. filtered_columns let group_by_result_columns_evals = Vec::from_iter( - repeat_with(|| builder.consume_result_mle()).take(self.group_by_exprs.len()), + repeat_with(|| builder.consume_result_or_intermediate_mle(is_top_level)) + .take(self.group_by_exprs.len()), ); - let sum_result_columns_evals = - Vec::from_iter(repeat_with(|| builder.consume_result_mle()).take(self.sum_expr.len())); - let count_column_eval = builder.consume_result_mle(); + let sum_result_columns_evals = Vec::from_iter( + repeat_with(|| builder.consume_result_or_intermediate_mle(is_top_level)) + .take(self.sum_expr.len()), + ); + let count_column_eval = builder.consume_result_or_intermediate_mle(is_top_level); let alpha = builder.consume_post_result_challenge(); let beta = builder.consume_post_result_challenge(); @@ -260,10 +270,12 @@ impl ProverEvaluate for GroupByExec { builder: &mut ProofBuilder<'a, C::Scalar>, alloc: &'a Bump, accessor: &'a dyn DataAccessor, + is_top_level: bool, ) -> Vec> { // 1. selection - let selection_column: Column<'a, C::Scalar> = - self.where_clause.prover_evaluate(builder, alloc, accessor); + let selection_column: Column<'a, C::Scalar> = self + .where_clause + .prover_evaluate(builder, alloc, accessor, false); let selection = selection_column .as_boolean() .expect("selection is not boolean"); @@ -272,13 +284,13 @@ impl ProverEvaluate for GroupByExec { let group_by_columns = Vec::from_iter( self.group_by_exprs .iter() - .map(|expr| expr.prover_evaluate(builder, alloc, accessor)), - ); - let sum_columns = Vec::from_iter( - self.sum_expr - .iter() - .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), + .map(|expr| expr.prover_evaluate(builder, alloc, accessor, false)), ); + let sum_columns = Vec::from_iter(self.sum_expr.iter().map(|aliased_expr| { + aliased_expr + .expr + .prover_evaluate(builder, alloc, accessor, false) + })); // Compute filtered_columns and indexes let AggregatedColumns { group_by_columns: group_by_result_columns, diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs index 6e1020e4f..8d6e7a209 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs @@ -67,13 +67,15 @@ impl ProofPlan for ProjectionExec { builder: &mut VerificationBuilder, accessor: &dyn CommitmentAccessor, _result: Option<&OwnedTable>, + is_top_level: bool, ) -> Result, ProofError> { self.aliased_results .iter() .map(|aliased_expr| aliased_expr.expr.verifier_evaluate(builder, accessor)) .collect::, _>>()?; Ok(Vec::from_iter( - repeat_with(|| builder.consume_result_mle()).take(self.aliased_results.len()), + repeat_with(|| builder.consume_result_or_intermediate_mle(is_top_level)) + .take(self.aliased_results.len()), )) } @@ -117,11 +119,12 @@ impl ProverEvaluate for ProjectionExec { builder: &mut ProofBuilder<'a, C::Scalar>, alloc: &'a Bump, accessor: &'a dyn DataAccessor, + is_top_level: bool, ) -> Vec> { - Vec::from_iter( - self.aliased_results - .iter() - .map(|aliased_expr| aliased_expr.expr.prover_evaluate(builder, alloc, accessor)), - ) + Vec::from_iter(self.aliased_results.iter().map(|aliased_expr| { + aliased_expr + .expr + .prover_evaluate(builder, alloc, accessor, is_top_level) + })) } }