-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
test: improve sumcheck test coverage (#204)
# Rationale for this change The tests we have for sumcheck do not cover as many edge cases as they should. This PR adds many more tests cases. # What changes are included in this PR? See the individual commits. * Added CompositePolynomial::rand and `CompositePolynomial::hypercube_sum` - These two methods are useful for generating test cases for sumcheck. * Added `sumcheck::test_cases` module that enumerates a bunch of test cases that we should be testing. These test cases are intended to be re-used. In particular, * Added a test for all of these test cases. # Are these changes tested? Yes
- Loading branch information
Showing
5 changed files
with
264 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,6 @@ pub use subclaim::Subclaim; | |
|
||
mod prover_round; | ||
use prover_round::prove_round; | ||
|
||
#[cfg(test)] | ||
mod test_cases; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
crates/proof-of-sql/src/proof_primitive/sumcheck/test_cases.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
use crate::base::{polynomial::CompositePolynomial, scalar::Scalar}; | ||
use core::iter; | ||
use itertools::Itertools; | ||
|
||
pub struct SumcheckTestCase<S: Scalar> { | ||
pub polynomial: CompositePolynomial<S>, | ||
pub num_vars: usize, | ||
pub max_multiplicands: usize, | ||
pub sum: S, | ||
} | ||
|
||
impl<S: Scalar> SumcheckTestCase<S> { | ||
fn rand( | ||
num_vars: usize, | ||
max_multiplicands: usize, | ||
products: impl IntoIterator<Item = impl IntoIterator<Item = usize>>, | ||
rng: &mut (impl ark_std::rand::Rng + ?Sized), | ||
) -> Self { | ||
let length = 1 << num_vars; | ||
let products_vec: Vec<Vec<usize>> = products | ||
.into_iter() | ||
.map(|p| p.into_iter().collect()) | ||
.collect(); | ||
let num_multiplicands = products_vec | ||
.iter() | ||
.map(|p| p.iter().max().copied().unwrap_or(0)) | ||
.max() | ||
.unwrap_or(0) | ||
+ 1; | ||
let polynomial = CompositePolynomial::<S>::rand( | ||
num_vars, | ||
max_multiplicands, | ||
iter::repeat(length).take(num_multiplicands), | ||
products_vec, | ||
rng, | ||
); | ||
let sum = polynomial.hypercube_sum(length); | ||
Self { | ||
polynomial, | ||
num_vars, | ||
max_multiplicands, | ||
sum, | ||
} | ||
} | ||
} | ||
|
||
pub fn sumcheck_test_cases<S: Scalar>( | ||
rng: &mut (impl ark_std::rand::Rng + ?Sized), | ||
) -> impl Iterator<Item = SumcheckTestCase<S>> + '_ { | ||
(1..=8) | ||
.cartesian_product(1..=5) | ||
.flat_map(|(num_vars, max_multiplicands)| { | ||
[ | ||
Some(vec![]), | ||
Some(vec![vec![]]), | ||
(max_multiplicands >= 1).then_some(vec![vec![0]]), | ||
(max_multiplicands >= 2).then_some(vec![vec![0, 1]]), | ||
(max_multiplicands >= 3).then_some(vec![ | ||
vec![0, 1, 2], | ||
vec![3, 4], | ||
vec![0], | ||
vec![], | ||
]), | ||
(max_multiplicands >= 5).then_some(vec![ | ||
vec![7, 0], | ||
vec![2, 4, 8, 5], | ||
vec![], | ||
vec![3], | ||
vec![1, 0, 8, 5, 0], | ||
vec![3, 6, 9, 9], | ||
vec![7, 8, 3], | ||
vec![4, 3, 2], | ||
vec![], | ||
vec![9, 8, 2], | ||
]), | ||
(max_multiplicands >= 3).then_some(vec![ | ||
vec![], | ||
vec![1, 0], | ||
vec![3, 6, 1], | ||
vec![], | ||
vec![], | ||
vec![1, 8], | ||
vec![1], | ||
vec![8], | ||
vec![6, 6], | ||
vec![4, 6, 7], | ||
]), | ||
] | ||
.into_iter() | ||
.flatten() | ||
.map(move |products| (num_vars, max_multiplicands, products)) | ||
}) | ||
.map(|(num_vars, max_multiplicands, products)| { | ||
SumcheckTestCase::rand(num_vars, max_multiplicands, products, rng) | ||
}) | ||
} |