diff --git a/crates/proof-of-sql/src/base/polynomial/composite_polynomial.rs b/crates/proof-of-sql/src/base/polynomial/composite_polynomial.rs index b7166635c..7547c7d19 100644 --- a/crates/proof-of-sql/src/base/polynomial/composite_polynomial.rs +++ b/crates/proof-of-sql/src/base/polynomial/composite_polynomial.rs @@ -6,6 +6,10 @@ use alloc::{rc::Rc, vec::Vec}; * See third_party/license/arkworks.LICENSE */ use core::cmp::max; +#[cfg(test)] +use core::iter; +#[cfg(test)] +use itertools::Itertools; /// Stores a list of products of `DenseMultilinearExtension` that is meant to be added together. /// @@ -84,6 +88,45 @@ impl CompositePolynomial { } self.products.push((coefficient, indexed_product)); } + /// Generate random `CompositePolnomial`. + #[cfg(test)] + pub fn rand( + num_variables: usize, + max_multiplicands: usize, + multiplicands_length: impl IntoIterator, + products: impl IntoIterator>, + rng: &mut (impl ark_std::rand::Rng + ?Sized), + ) -> Self { + let mut result = CompositePolynomial::new(num_variables); + result.max_multiplicands = max_multiplicands; + result.products = products + .into_iter() + .map(|p| (S::rand(rng), p.into_iter().collect())) + .collect(); + result.flattened_ml_extensions = multiplicands_length + .into_iter() + .map(|length| Rc::new(iter::repeat_with(|| S::rand(rng)).take(length).collect())) + .collect(); + result + } + + #[cfg(test)] + /// Returns the product of the flattened_ml_extensions with referenced (as usize) by `terms` at the index `i`. + fn term_product(&self, terms: &[usize], i: usize) -> S { + terms + .iter() + .map(|&j| *self.flattened_ml_extensions[j].get(i).unwrap_or(&S::ZERO)) + .product::() + } + /// Returns the sum of the evaluations of the `CompositePolynomial` on the boolean hypercube. + #[cfg(test)] + pub fn hypercube_sum(&self, length: usize) -> S { + (0..length) + .cartesian_product(&self.products) + .map(|(i, (coeff, terms))| *coeff * self.term_product(terms, i)) + .sum::() + } + /// Evaluate the polynomial at point `point` #[cfg(test)] pub fn evaluate(&self, point: &[S]) -> S { diff --git a/crates/proof-of-sql/src/base/polynomial/composite_polynomial_test.rs b/crates/proof-of-sql/src/base/polynomial/composite_polynomial_test.rs index db7aa8faa..6a3ac37ae 100644 --- a/crates/proof-of-sql/src/base/polynomial/composite_polynomial_test.rs +++ b/crates/proof-of-sql/src/base/polynomial/composite_polynomial_test.rs @@ -38,3 +38,36 @@ fn test_composite_polynomial_evaluation() { assert_eq!(prod01, calc01); assert_eq!(prod11, calc11); } + +#[allow(clippy::identity_op)] +#[test] +fn test_composite_polynomial_hypercube_sum() { + let a: Vec = vec![ + -Curve25519Scalar::from(7u32), + Curve25519Scalar::from(2u32), + -Curve25519Scalar::from(6u32), + Curve25519Scalar::from(17u32), + ]; + let b: Vec = vec![ + Curve25519Scalar::from(2u32), + -Curve25519Scalar::from(8u32), + Curve25519Scalar::from(4u32), + Curve25519Scalar::from(1u32), + ]; + let c: Vec = vec![ + Curve25519Scalar::from(1u32), + Curve25519Scalar::from(3u32), + -Curve25519Scalar::from(5u32), + -Curve25519Scalar::from(9u32), + ]; + let mut prod = CompositePolynomial::new(2); + prod.add_product([Rc::new(a), Rc::new(b)], Curve25519Scalar::from(3u32)); + prod.add_product([Rc::new(c)], Curve25519Scalar::from(2u32)); + let sum = prod.hypercube_sum(4); + assert_eq!( + sum, + Curve25519Scalar::from( + 3 * ((-7) * 2 + 2 * (-8) + (-6) * 4 + 17 * 1) + 2 * (1 + 3 + (-5) + (-9)) + ) + ); +} diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs index baa7c579c..8f7cefed2 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs @@ -11,3 +11,6 @@ pub use subclaim::Subclaim; mod prover_round; use prover_round::prove_round; + +#[cfg(test)] +mod test_cases; diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs index 6e7f51d97..69666d868 100644 --- a/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs @@ -1,5 +1,8 @@ +use super::test_cases::sumcheck_test_cases; use crate::base::{ - polynomial::CompositePolynomial, proof::Transcript as _, scalar::Curve25519Scalar, + polynomial::{CompositePolynomial, CompositePolynomialInfo}, + proof::Transcript as _, + scalar::{test_scalar::TestScalar, Curve25519Scalar, Scalar}, }; /** * Adopted from arkworks @@ -158,3 +161,88 @@ fn test_normal_polynomial() { test_polynomial(nv, num_multiplicands_range, num_products); } + +#[test] +fn we_can_verify_many_random_test_cases() { + let mut rng = ark_std::test_rng(); + + for test_case in sumcheck_test_cases::(&mut rng) { + let mut transcript = Transcript::new(b"sumchecktest"); + let mut evaluation_point = vec![Default::default(); test_case.num_vars]; + let proof = SumcheckProof::create( + &mut transcript, + &mut evaluation_point, + &test_case.polynomial, + ); + + let mut transcript = Transcript::new(b"sumchecktest"); + let subclaim = proof + .verify_without_evaluation( + &mut transcript, + CompositePolynomialInfo { + max_multiplicands: test_case.max_multiplicands, + num_variables: test_case.num_vars, + }, + &test_case.sum, + ) + .expect("verification should succeed with the correct setup"); + assert_eq!( + subclaim.evaluation_point, evaluation_point, + "the prover's evaluation point should match the verifier's" + ); + assert_eq!( + test_case.polynomial.evaluate(&evaluation_point), + subclaim.expected_evaluation, + "the claimed evaluation should match the actual evaluation" + ); + + let mut transcript = Transcript::new(b"sumchecktest"); + transcript.extend_serialize_as_le(&123u64); + let verify_result = proof.verify_without_evaluation( + &mut transcript, + CompositePolynomialInfo { + max_multiplicands: test_case.max_multiplicands, + num_variables: test_case.num_vars, + }, + &test_case.sum, + ); + if let Ok(subclaim) = verify_result { + assert_ne!( + subclaim.evaluation_point, evaluation_point, + "either verification should fail or we should have a different evaluation point with a different transcript" + ) + } + + let mut transcript = Transcript::new(b"sumchecktest"); + assert!( + proof + .verify_without_evaluation( + &mut transcript, + CompositePolynomialInfo { + max_multiplicands: test_case.max_multiplicands, + num_variables: test_case.num_vars, + }, + &(test_case.sum + TestScalar::ONE), + ) + .is_err(), + "verification should fail when the sum is wrong" + ); + + let mut modified_proof = proof; + modified_proof.evaluations[0][0] += TestScalar::ONE; + let mut transcript = Transcript::new(b"sumchecktest"); + assert!( + modified_proof + .verify_without_evaluation( + &mut transcript, + CompositePolynomialInfo { + max_multiplicands: test_case.max_multiplicands, + num_variables: test_case.num_vars, + }, + &test_case.sum, + ) + .is_err(), + "verification should fail when the proof is modified" + ); + } +} diff --git a/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs b/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs new file mode 100644 index 000000000..863c7b9ec --- /dev/null +++ b/crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs @@ -0,0 +1,96 @@ +use crate::base::{polynomial::CompositePolynomial, scalar::Scalar}; +use core::iter; +use itertools::Itertools; + +pub struct SumcheckTestCase { + pub polynomial: CompositePolynomial, + pub num_vars: usize, + pub max_multiplicands: usize, + pub sum: S, +} + +impl SumcheckTestCase { + fn rand( + num_vars: usize, + max_multiplicands: usize, + products: impl IntoIterator>, + rng: &mut (impl ark_std::rand::Rng + ?Sized), + ) -> Self { + let length = 1 << num_vars; + let products_vec: Vec> = products + .into_iter() + .map(|p| p.into_iter().collect()) + .collect(); + let num_multiplicands = products_vec + .iter() + .map(|p| p.iter().max().copied().unwrap_or(0)) + .max() + .unwrap_or(0) + + 1; + let polynomial = CompositePolynomial::::rand( + num_vars, + max_multiplicands, + iter::repeat(length).take(num_multiplicands), + products_vec, + rng, + ); + let sum = polynomial.hypercube_sum(length); + Self { + polynomial, + num_vars, + max_multiplicands, + sum, + } + } +} + +pub fn sumcheck_test_cases( + rng: &mut (impl ark_std::rand::Rng + ?Sized), +) -> impl Iterator> + '_ { + (1..=8) + .cartesian_product(1..=5) + .flat_map(|(num_vars, max_multiplicands)| { + [ + Some(vec![]), + Some(vec![vec![]]), + (max_multiplicands >= 1).then_some(vec![vec![0]]), + (max_multiplicands >= 2).then_some(vec![vec![0, 1]]), + (max_multiplicands >= 3).then_some(vec![ + vec![0, 1, 2], + vec![3, 4], + vec![0], + vec![], + ]), + (max_multiplicands >= 5).then_some(vec![ + vec![7, 0], + vec![2, 4, 8, 5], + vec![], + vec![3], + vec![1, 0, 8, 5, 0], + vec![3, 6, 9, 9], + vec![7, 8, 3], + vec![4, 3, 2], + vec![], + vec![9, 8, 2], + ]), + (max_multiplicands >= 3).then_some(vec![ + vec![], + vec![1, 0], + vec![3, 6, 1], + vec![], + vec![], + vec![1, 8], + vec![1], + vec![8], + vec![6, 6], + vec![4, 6, 7], + ]), + ] + .into_iter() + .flatten() + .map(move |products| (num_vars, max_multiplicands, products)) + }) + .map(|(num_vars, max_multiplicands, products)| { + SumcheckTestCase::rand(num_vars, max_multiplicands, products, rng) + }) +}