Skip to content

Commit

Permalink
test: improve sumcheck test coverage (#204)
Browse files Browse the repository at this point in the history
# Rationale for this change

The tests we have for sumcheck do not cover as many edge cases as they
should. This PR adds many more tests cases.

# What changes are included in this PR?

See the individual commits.

* Added CompositePolynomial::rand and
`CompositePolynomial::hypercube_sum` - These two methods are useful for
generating test cases for sumcheck.
* Added `sumcheck::test_cases` module that enumerates a bunch of test
cases that we should be testing. These test cases are intended to be
re-used. In particular,
* Added a test for all of these test cases.


# Are these changes tested?

Yes
  • Loading branch information
JayWhite2357 authored Oct 2, 2024
2 parents f3c9dca + acc3d9f commit 22e06ae
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 1 deletion.
43 changes: 43 additions & 0 deletions crates/proof-of-sql/src/base/polynomial/composite_polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ use alloc::{rc::Rc, vec::Vec};
* See third_party/license/arkworks.LICENSE
*/
use core::cmp::max;
#[cfg(test)]
use core::iter;
#[cfg(test)]
use itertools::Itertools;

/// Stores a list of products of `DenseMultilinearExtension` that is meant to be added together.
///
Expand Down Expand Up @@ -84,6 +88,45 @@ impl<S: Scalar> CompositePolynomial<S> {
}
self.products.push((coefficient, indexed_product));
}
/// Generate random `CompositePolnomial`.
#[cfg(test)]
pub fn rand(
num_variables: usize,
max_multiplicands: usize,
multiplicands_length: impl IntoIterator<Item = usize>,
products: impl IntoIterator<Item = impl IntoIterator<Item = usize>>,
rng: &mut (impl ark_std::rand::Rng + ?Sized),
) -> Self {
let mut result = CompositePolynomial::new(num_variables);
result.max_multiplicands = max_multiplicands;
result.products = products
.into_iter()
.map(|p| (S::rand(rng), p.into_iter().collect()))
.collect();
result.flattened_ml_extensions = multiplicands_length
.into_iter()
.map(|length| Rc::new(iter::repeat_with(|| S::rand(rng)).take(length).collect()))
.collect();
result
}

#[cfg(test)]
/// Returns the product of the flattened_ml_extensions with referenced (as usize) by `terms` at the index `i`.
fn term_product(&self, terms: &[usize], i: usize) -> S {
terms
.iter()
.map(|&j| *self.flattened_ml_extensions[j].get(i).unwrap_or(&S::ZERO))
.product::<S>()
}
/// Returns the sum of the evaluations of the `CompositePolynomial` on the boolean hypercube.
#[cfg(test)]
pub fn hypercube_sum(&self, length: usize) -> S {
(0..length)
.cartesian_product(&self.products)
.map(|(i, (coeff, terms))| *coeff * self.term_product(terms, i))
.sum::<S>()
}

/// Evaluate the polynomial at point `point`
#[cfg(test)]
pub fn evaluate(&self, point: &[S]) -> S {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,36 @@ fn test_composite_polynomial_evaluation() {
assert_eq!(prod01, calc01);
assert_eq!(prod11, calc11);
}

#[allow(clippy::identity_op)]
#[test]
fn test_composite_polynomial_hypercube_sum() {
let a: Vec<Curve25519Scalar> = vec![
-Curve25519Scalar::from(7u32),
Curve25519Scalar::from(2u32),
-Curve25519Scalar::from(6u32),
Curve25519Scalar::from(17u32),
];
let b: Vec<Curve25519Scalar> = vec![
Curve25519Scalar::from(2u32),
-Curve25519Scalar::from(8u32),
Curve25519Scalar::from(4u32),
Curve25519Scalar::from(1u32),
];
let c: Vec<Curve25519Scalar> = vec![
Curve25519Scalar::from(1u32),
Curve25519Scalar::from(3u32),
-Curve25519Scalar::from(5u32),
-Curve25519Scalar::from(9u32),
];
let mut prod = CompositePolynomial::new(2);
prod.add_product([Rc::new(a), Rc::new(b)], Curve25519Scalar::from(3u32));
prod.add_product([Rc::new(c)], Curve25519Scalar::from(2u32));
let sum = prod.hypercube_sum(4);
assert_eq!(
sum,
Curve25519Scalar::from(
3 * ((-7) * 2 + 2 * (-8) + (-6) * 4 + 17 * 1) + 2 * (1 + 3 + (-5) + (-9))
)
);
}
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ pub use subclaim::Subclaim;

mod prover_round;
use prover_round::prove_round;

#[cfg(test)]
mod test_cases;
90 changes: 89 additions & 1 deletion crates/proof-of-sql/src/proof_primitive/sumcheck/proof_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::test_cases::sumcheck_test_cases;
use crate::base::{
polynomial::CompositePolynomial, proof::Transcript as _, scalar::Curve25519Scalar,
polynomial::{CompositePolynomial, CompositePolynomialInfo},
proof::Transcript as _,
scalar::{test_scalar::TestScalar, Curve25519Scalar, Scalar},
};
/**
* Adopted from arkworks
Expand Down Expand Up @@ -158,3 +161,88 @@ fn test_normal_polynomial() {

test_polynomial(nv, num_multiplicands_range, num_products);
}

#[test]
fn we_can_verify_many_random_test_cases() {
let mut rng = ark_std::test_rng();

for test_case in sumcheck_test_cases::<TestScalar>(&mut rng) {
let mut transcript = Transcript::new(b"sumchecktest");
let mut evaluation_point = vec![Default::default(); test_case.num_vars];
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
&test_case.polynomial,
);

let mut transcript = Transcript::new(b"sumchecktest");
let subclaim = proof
.verify_without_evaluation(
&mut transcript,
CompositePolynomialInfo {
max_multiplicands: test_case.max_multiplicands,
num_variables: test_case.num_vars,
},
&test_case.sum,
)
.expect("verification should succeed with the correct setup");
assert_eq!(
subclaim.evaluation_point, evaluation_point,
"the prover's evaluation point should match the verifier's"
);
assert_eq!(
test_case.polynomial.evaluate(&evaluation_point),
subclaim.expected_evaluation,
"the claimed evaluation should match the actual evaluation"
);

let mut transcript = Transcript::new(b"sumchecktest");
transcript.extend_serialize_as_le(&123u64);
let verify_result = proof.verify_without_evaluation(
&mut transcript,
CompositePolynomialInfo {
max_multiplicands: test_case.max_multiplicands,
num_variables: test_case.num_vars,
},
&test_case.sum,
);
if let Ok(subclaim) = verify_result {
assert_ne!(
subclaim.evaluation_point, evaluation_point,
"either verification should fail or we should have a different evaluation point with a different transcript"
)
}

let mut transcript = Transcript::new(b"sumchecktest");
assert!(
proof
.verify_without_evaluation(
&mut transcript,
CompositePolynomialInfo {
max_multiplicands: test_case.max_multiplicands,
num_variables: test_case.num_vars,
},
&(test_case.sum + TestScalar::ONE),
)
.is_err(),
"verification should fail when the sum is wrong"
);

let mut modified_proof = proof;
modified_proof.evaluations[0][0] += TestScalar::ONE;
let mut transcript = Transcript::new(b"sumchecktest");
assert!(
modified_proof
.verify_without_evaluation(
&mut transcript,
CompositePolynomialInfo {
max_multiplicands: test_case.max_multiplicands,
num_variables: test_case.num_vars,
},
&test_case.sum,
)
.is_err(),
"verification should fail when the proof is modified"
);
}
}
96 changes: 96 additions & 0 deletions crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use crate::base::{polynomial::CompositePolynomial, scalar::Scalar};
use core::iter;
use itertools::Itertools;

pub struct SumcheckTestCase<S: Scalar> {
pub polynomial: CompositePolynomial<S>,
pub num_vars: usize,
pub max_multiplicands: usize,
pub sum: S,
}

impl<S: Scalar> SumcheckTestCase<S> {
fn rand(
num_vars: usize,
max_multiplicands: usize,
products: impl IntoIterator<Item = impl IntoIterator<Item = usize>>,
rng: &mut (impl ark_std::rand::Rng + ?Sized),
) -> Self {
let length = 1 << num_vars;
let products_vec: Vec<Vec<usize>> = products
.into_iter()
.map(|p| p.into_iter().collect())
.collect();
let num_multiplicands = products_vec
.iter()
.map(|p| p.iter().max().copied().unwrap_or(0))
.max()
.unwrap_or(0)
+ 1;
let polynomial = CompositePolynomial::<S>::rand(
num_vars,
max_multiplicands,
iter::repeat(length).take(num_multiplicands),
products_vec,
rng,
);
let sum = polynomial.hypercube_sum(length);
Self {
polynomial,
num_vars,
max_multiplicands,
sum,
}
}
}

pub fn sumcheck_test_cases<S: Scalar>(
rng: &mut (impl ark_std::rand::Rng + ?Sized),
) -> impl Iterator<Item = SumcheckTestCase<S>> + '_ {
(1..=8)
.cartesian_product(1..=5)
.flat_map(|(num_vars, max_multiplicands)| {
[
Some(vec![]),
Some(vec![vec![]]),
(max_multiplicands >= 1).then_some(vec![vec![0]]),
(max_multiplicands >= 2).then_some(vec![vec![0, 1]]),
(max_multiplicands >= 3).then_some(vec![
vec![0, 1, 2],
vec![3, 4],
vec![0],
vec![],
]),
(max_multiplicands >= 5).then_some(vec![
vec![7, 0],
vec![2, 4, 8, 5],
vec![],
vec![3],
vec![1, 0, 8, 5, 0],
vec![3, 6, 9, 9],
vec![7, 8, 3],
vec![4, 3, 2],
vec![],
vec![9, 8, 2],
]),
(max_multiplicands >= 3).then_some(vec![
vec![],
vec![1, 0],
vec![3, 6, 1],
vec![],
vec![],
vec![1, 8],
vec![1],
vec![8],
vec![6, 6],
vec![4, 6, 7],
]),
]
.into_iter()
.flatten()
.map(move |products| (num_vars, max_multiplicands, products))
})
.map(|(num_vars, max_multiplicands, products)| {
SumcheckTestCase::rand(num_vars, max_multiplicands, products, rng)
})
}

0 comments on commit 22e06ae

Please sign in to comment.