diff --git a/plonkish_backend/src/piop.rs b/plonkish_backend/src/piop.rs index 6ce0f90e..5739e324 100644 --- a/plonkish_backend/src/piop.rs +++ b/plonkish_backend/src/piop.rs @@ -1 +1,2 @@ +pub mod gkr; pub mod sum_check; diff --git a/plonkish_backend/src/piop/gkr.rs b/plonkish_backend/src/piop/gkr.rs new file mode 100644 index 00000000..b26907c7 --- /dev/null +++ b/plonkish_backend/src/piop/gkr.rs @@ -0,0 +1,3 @@ +mod fractional_sum_check; + +pub use fractional_sum_check::{prove_fractional_sum_check, verify_fractional_sum_check}; diff --git a/plonkish_backend/src/piop/gkr/fractional_sum_check.rs b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs new file mode 100644 index 00000000..5e16213e --- /dev/null +++ b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs @@ -0,0 +1,371 @@ +//! Implementation of GKR for fractional sumchecks in [PH23]. +//! Notations are same as in section 3. +//! +//! [PH23]: https://eprint.iacr.org/2023/1284.pdf + +use crate::{ + piop::sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + evaluate, SumCheck as _, VirtualPolynomial, + }, + poly::{multilinear::MultilinearPolynomial, Polynomial}, + util::{ + arithmetic::{div_ceil, inner_product, powers, PrimeField}, + chain, + expression::{Expression, Query, Rotation}, + izip, + parallel::{num_threads, parallelize_iter}, + transcript::{FieldTranscriptRead, FieldTranscriptWrite}, + Itertools, + }, + Error, +}; +use std::{array, collections::HashMap, iter}; + +type SumCheck = ClassicSumCheck>; + +struct Layer { + p_l: MultilinearPolynomial, + p_r: MultilinearPolynomial, + q_l: MultilinearPolynomial, + q_r: MultilinearPolynomial, +} + +impl From<[Vec; 4]> for Layer { + fn from(values: [Vec; 4]) -> Self { + let [p_l, p_r, q_l, q_r] = values.map(MultilinearPolynomial::new); + Self { p_l, p_r, q_l, q_r } + } +} + +impl Layer { + fn bottom((p, q): (&&MultilinearPolynomial, &&MultilinearPolynomial)) -> Self { + let mid = p.evals().len() >> 1; + [&p[..mid], &p[mid..], &q[..mid], &q[mid..]] + .map(ToOwned::to_owned) + .into() + } + + fn num_vars(&self) -> usize { + self.p_l.num_vars() + } + + fn polys(&self) -> [&MultilinearPolynomial; 4] { + [&self.p_l, &self.p_r, &self.q_l, &self.q_r] + } + + fn poly_chunks(&self, chunk_size: usize) -> impl Iterator { + let [p_l, p_r, q_l, q_r] = self.polys().map(|poly| poly.evals().chunks(chunk_size)); + izip!(p_l, p_r, q_l, q_r) + } + + fn up(&self) -> Self { + assert!(self.num_vars() != 0); + + let len = 1 << self.num_vars(); + let chunk_size = div_ceil(len, num_threads()).next_power_of_two(); + + let mut outputs: [_; 4] = array::from_fn(|_| vec![F::ZERO; len >> 1]); + let (p, q) = outputs.split_at_mut(2); + parallelize_iter( + izip!( + chain![p].flat_map(|p| p.chunks_mut(chunk_size)), + chain![q].flat_map(|q| q.chunks_mut(chunk_size)), + self.poly_chunks(chunk_size), + ), + |(p, q, (p_l, p_r, q_l, q_r))| { + izip!(p, q, p_l, p_r, q_l, q_r).for_each(|(p, q, p_l, p_r, q_l, q_r)| { + *p = *p_l * q_r + *p_r * q_l; + *q = *q_l * q_r; + }) + }, + ); + + outputs.into() + } +} + +#[allow(clippy::type_complexity)] +pub fn prove_fractional_sum_check<'a, F: PrimeField>( + claimed_p_0s: impl IntoIterator>, + claimed_q_0s: impl IntoIterator>, + ps: impl IntoIterator>, + qs: impl IntoIterator>, + transcript: &mut impl FieldTranscriptWrite, +) -> Result<(Vec, Vec, Vec), Error> { + let claimed_p_0s = claimed_p_0s.into_iter().collect_vec(); + let claimed_q_0s = claimed_q_0s.into_iter().collect_vec(); + let ps = ps.into_iter().collect_vec(); + let qs = qs.into_iter().collect_vec(); + let num_batching = claimed_p_0s.len(); + + assert!(num_batching != 0); + assert_eq!(num_batching, claimed_q_0s.len()); + assert_eq!(num_batching, ps.len()); + assert_eq!(num_batching, qs.len()); + for poly in chain![&ps, &qs] { + assert_eq!(poly.num_vars(), ps[0].num_vars()); + } + + let bottom_layers = izip!(&ps, &qs).map(Layer::bottom).collect_vec(); + let layers = iter::successors(bottom_layers.into(), |layers| { + (layers[0].num_vars() > 0).then(|| layers.iter().map(Layer::up).collect()) + }) + .collect_vec(); + + let [claimed_p_0s, claimed_q_0s]: [_; 2] = { + let (p_0s, q_0s) = chain![layers.last().unwrap()] + .map(|layer| { + let [p_l, p_r, q_l, q_r] = layer.polys().map(|poly| poly[0]); + (p_l * q_r + p_r * q_l, q_l * q_r) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let mut hash_to_transcript = |claimed: Vec<_>, computed: Vec<_>| { + izip!(claimed, computed) + .map(|(claimed, computed)| match claimed { + Some(claimed) => { + if cfg!(feature = "sanity-check") { + assert_eq!(claimed, computed) + } + transcript.common_field_element(&computed).map(|_| computed) + } + None => transcript.write_field_element(&computed).map(|_| computed), + }) + .try_collect::<_, Vec<_>, _>() + }; + + [ + hash_to_transcript(claimed_p_0s, p_0s)?, + hash_to_transcript(claimed_q_0s, q_0s)?, + ] + }; + + let expression = sum_check_expression(num_batching); + + let (p_xs, q_xs, x) = layers.iter().rev().fold( + Ok((claimed_p_0s, claimed_q_0s, Vec::new())), + |result, layers| { + let (claimed_p_ys, claimed_q_ys, y) = result?; + + let num_vars = layers[0].num_vars(); + let polys = layers.iter().flat_map(|layer| layer.polys()); + + let (mut x, evals) = if num_vars == 0 { + (vec![], polys.map(|poly| poly[0]).collect_vec()) + } else { + let gamma = transcript.squeeze_challenge(); + + let (x, evals) = { + let claim = sum_check_claim(&claimed_p_ys, &claimed_q_ys, gamma); + SumCheck::prove( + &(), + num_vars, + VirtualPolynomial::new(&expression, polys, &[gamma], &[y]), + claim, + transcript, + )? + }; + + (x, evals) + }; + + transcript.write_field_elements(&evals)?; + + let mu = transcript.squeeze_challenge(); + + let (p_xs, q_xs) = layer_down_claim(&evals, mu); + x.push(mu); + + Ok((p_xs, q_xs, x)) + }, + )?; + + if cfg!(feature = "sanity-check") { + izip!(chain![ps, qs], chain![&p_xs, &q_xs]) + .for_each(|(poly, eval)| assert_eq!(poly.evaluate(&x), *eval)); + } + + Ok((p_xs, q_xs, x)) +} + +#[allow(clippy::type_complexity)] +pub fn verify_fractional_sum_check( + num_vars: usize, + claimed_p_0s: impl IntoIterator>, + claimed_q_0s: impl IntoIterator>, + transcript: &mut impl FieldTranscriptRead, +) -> Result<(Vec, Vec, Vec), Error> { + let claimed_p_0s = claimed_p_0s.into_iter().collect_vec(); + let claimed_q_0s = claimed_q_0s.into_iter().collect_vec(); + let num_batching = claimed_p_0s.len(); + + assert!(num_batching != 0); + assert_eq!(num_batching, claimed_q_0s.len()); + + let [claimed_p_0s, claimed_q_0s]: [_; 2] = { + [claimed_p_0s, claimed_q_0s] + .into_iter() + .map(|claimed| { + claimed + .into_iter() + .map(|claimed| match claimed { + Some(claimed) => transcript.common_field_element(&claimed).map(|_| claimed), + None => transcript.read_field_element(), + }) + .try_collect::<_, Vec<_>, _>() + }) + .try_collect::<_, Vec<_>, _>()? + .try_into() + .unwrap() + }; + + let expression = sum_check_expression(num_batching); + + let (p_xs, q_xs, x) = (0..num_vars).fold( + Ok((claimed_p_0s, claimed_q_0s, Vec::new())), + |result, num_vars| { + let (claimed_p_ys, claimed_q_ys, y) = result?; + + let (mut x, evals) = if num_vars == 0 { + let evals = transcript.read_field_elements(4 * num_batching)?; + + for (claimed_p, claimed_q, (&p_l, &p_r, &q_l, &q_r)) in + izip!(claimed_p_ys, claimed_q_ys, evals.iter().tuples()) + { + if claimed_p != p_l * q_r + p_r * q_l || claimed_q != q_l * q_r { + return Err(err_unmatched_sum_check_output()); + } + } + + (Vec::new(), evals) + } else { + let gamma = transcript.squeeze_challenge(); + + let (x_eval, x) = { + let claim = sum_check_claim(&claimed_p_ys, &claimed_q_ys, gamma); + SumCheck::verify(&(), num_vars, expression.degree(), claim, transcript)? + }; + + let evals = transcript.read_field_elements(4 * num_batching)?; + + let eval_by_query = eval_by_query(&evals); + if x_eval != evaluate(&expression, num_vars, &eval_by_query, &[gamma], &[&y], &x) { + return Err(err_unmatched_sum_check_output()); + } + + (x, evals) + }; + + let mu = transcript.squeeze_challenge(); + + let (p_xs, q_xs) = layer_down_claim(&evals, mu); + x.push(mu); + + Ok((p_xs, q_xs, x)) + }, + )?; + + Ok((p_xs, q_xs, x)) +} + +fn sum_check_expression(num_batching: usize) -> Expression { + let exprs = &(0..4 * num_batching) + .map(|idx| Expression::::Polynomial(Query::new(idx, Rotation::cur()))) + .tuples() + .flat_map(|(ref p_l, ref p_r, ref q_l, ref q_r)| [p_l * q_r + p_r * q_l, q_l * q_r]) + .collect_vec(); + let eq_xy = &Expression::eq_xy(0); + let gamma = &Expression::Challenge(0); + Expression::distribute_powers(exprs, gamma) * eq_xy +} + +fn sum_check_claim(claimed_p_ys: &[F], claimed_q_ys: &[F], gamma: F) -> F { + inner_product( + izip!(claimed_p_ys, claimed_q_ys).flat_map(|(p, q)| [p, q]), + &powers(gamma).take(claimed_p_ys.len() * 2).collect_vec(), + ) +} + +fn layer_down_claim(evals: &[F], mu: F) -> (Vec, Vec) { + evals + .iter() + .tuples() + .map(|(&p_l, &p_r, &q_l, &q_r)| (p_l + mu * (p_r - p_l), q_l + mu * (q_r - q_l))) + .unzip() +} + +fn eval_by_query(evals: &[F]) -> HashMap { + izip!( + (0..).map(|idx| Query::new(idx, Rotation::cur())), + evals.iter().cloned() + ) + .collect() +} + +fn err_unmatched_sum_check_output() -> Error { + Error::InvalidSumcheck("Unmatched between sum_check output and query evaluation".to_string()) +} + +#[cfg(test)] +mod test { + use crate::{ + piop::gkr::fractional_sum_check::{ + prove_fractional_sum_check, verify_fractional_sum_check, + }, + poly::multilinear::MultilinearPolynomial, + util::{ + chain, izip_eq, + test::{rand_vec, seeded_std_rng}, + transcript::{InMemoryTranscript, Keccak256Transcript}, + Itertools, + }, + }; + use halo2_curves::bn256::Fr; + use std::iter; + + #[test] + fn fractional_sum_check() { + let num_batching = 3; + for num_vars in 1..16 { + let mut rng = seeded_std_rng(); + + let polys = iter::repeat_with(|| rand_vec(1 << num_vars, &mut rng)) + .map(MultilinearPolynomial::new) + .take(2 * num_batching) + .collect_vec(); + let claims = vec![None; 2 * num_batching]; + let (ps, qs) = polys.split_at(num_batching); + let (p_0s, q_0s) = claims.split_at(num_batching); + + let proof = { + let mut transcript = Keccak256Transcript::new(()); + prove_fractional_sum_check::( + p_0s.to_vec(), + q_0s.to_vec(), + ps, + qs, + &mut transcript, + ) + .unwrap(); + transcript.into_proof() + }; + + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + verify_fractional_sum_check::( + num_vars, + p_0s.to_vec(), + q_0s.to_vec(), + &mut transcript, + ) + }; + assert_eq!(result.as_ref().map(|_| ()), Ok(())); + + let (p_xs, q_xs, x) = result.unwrap(); + for (poly, eval) in izip_eq!(chain![ps, qs], chain![p_xs, q_xs]) { + assert_eq!(poly.evaluate(&x), eval); + } + } + } +} diff --git a/plonkish_backend/src/piop/sum_check/classic/coeff.rs b/plonkish_backend/src/piop/sum_check/classic/coeff.rs index 0635d2ea..78c39c76 100644 --- a/plonkish_backend/src/piop/sum_check/classic/coeff.rs +++ b/plonkish_backend/src/piop/sum_check/classic/coeff.rs @@ -1,17 +1,17 @@ use crate::{ piop::sum_check::classic::{ClassicSumCheckProver, ClassicSumCheckRoundMessage, ProverState}, - poly::multilinear::zip_self, + poly::multilinear::{zip_self, MultilinearPolynomial}, util::{ arithmetic::{div_ceil, horner, PrimeField}, expression::{CommonPolynomial, Expression, Rotation}, - impl_index, + impl_index, izip_eq, parallel::{num_threads, parallelize_iter}, transcript::{FieldTranscriptRead, FieldTranscriptWrite}, Itertools, }, Error, }; -use std::{fmt::Debug, iter, ops::AddAssign}; +use std::{array, fmt::Debug, iter, ops::AddAssign}; #[derive(Debug)] pub struct Coefficients(Vec); @@ -63,7 +63,10 @@ impl<'rhs, F: PrimeField> AddAssign<(&'rhs F, &'rhs Coefficients)> for Coeffi impl_index!(Coefficients, 0); #[derive(Clone, Debug)] -pub struct CoefficientsProver(F, Vec<(F, Vec>)>); +pub struct CoefficientsProver { + constant: F, + products: Vec<(F, Vec>)>, +} impl ClassicSumCheckProver for CoefficientsProver where @@ -72,7 +75,7 @@ where type RoundMessage = Coefficients; fn new(state: &ProverState) -> Self { - let (constant, flattened) = state.expression.evaluate( + let (constant, products) = state.expression.evaluate( &|constant| (constant, vec![]), &|poly| { ( @@ -127,21 +130,21 @@ where (constant * &rhs, products) }, ); - Self(constant, flattened) + Self { constant, products } } fn prove_round(&self, state: &ProverState) -> Self::RoundMessage { let mut coeffs = Coefficients(vec![F::ZERO; state.expression.degree() + 1]); - coeffs += &(F::from(state.size() as u64) * &self.0); - if self.1.iter().all(|(_, products)| products.len() == 2) { - for (scalar, products) in self.1.iter() { - let [lhs, rhs] = [0, 1].map(|idx| &products[idx]); - coeffs += (scalar, &self.karatsuba::(state, lhs, rhs)); + coeffs += &(F::from(state.size() as u64) * &self.constant); + + for (scalar, products) in self.products.iter() { + match products.len() { + 2 => coeffs += (scalar, &self.karatsuba::(state, products)), + _ => unimplemented!(), } - coeffs[1] = state.sum - coeffs[0].double() - coeffs[2]; - } else { - unimplemented!() } + + coeffs[1] = state.sum - coeffs.sum(); coeffs } } @@ -150,60 +153,64 @@ impl CoefficientsProver { fn karatsuba( &self, state: &ProverState, - lhs: &Expression, - rhs: &Expression, + items: &[Expression], ) -> Coefficients { - let mut coeffs = [F::ZERO; 3]; - match (lhs, rhs) { - ( - Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)), - Expression::Polynomial(query), + debug_assert_eq!(items.len(), 2); + + let [lhs, rhs] = array::from_fn(|idx| poly(state, &items[idx])); + let evaluate_serial = |coeffs: &mut [F; 3], start: usize, n: usize| { + izip_eq!( + zip_self!(lhs.iter(), 2, start), + zip_self!(rhs.iter(), 2, start) ) - | ( - Expression::Polynomial(query), - Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)), - ) if query.rotation() == Rotation::cur() => { - let lhs = &state.eq_xys[*idx]; - let rhs = &state.polys[query.poly()][state.num_vars]; - - let evaluate_serial = |coeffs: &mut [F; 3], start: usize, n: usize| { - zip_self!(lhs.iter(), 2, start) - .zip(zip_self!(rhs.iter(), 2, start)) - .take(n) - .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { - let coeff_0 = *lhs_0 * rhs_0; - let coeff_2 = (*lhs_1 - lhs_0) * &(*rhs_1 - rhs_0); - coeffs[0] += &coeff_0; - coeffs[2] += &coeff_2; - if !LAZY { - coeffs[1] += &(*lhs_1 * rhs_1 - &coeff_0 - &coeff_2); - } - }); - }; - - let num_threads = num_threads(); - if state.size() < num_threads { - evaluate_serial(&mut coeffs, 0, state.size()); - } else { - let chunk_size = div_ceil(state.size(), num_threads); - let mut partials = vec![[F::ZERO; 3]; num_threads]; - parallelize_iter( - partials.iter_mut().zip((0..).step_by(chunk_size << 1)), - |(partial, start)| { - evaluate_serial(partial, start, chunk_size); - }, - ); - partials.iter().for_each(|partial| { - coeffs[0] += partial[0]; - coeffs[2] += partial[2]; - if !LAZY { - coeffs[1] += partial[1]; - } - }) - }; - } - _ => unimplemented!(), - } + .take(n) + .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { + let eval_0 = *lhs_0 * rhs_0; + let eval_2 = (*lhs_1 - lhs_0) * &(*rhs_1 - rhs_0); + coeffs[0] += &eval_0; + coeffs[2] += &eval_2; + if !LAZY { + coeffs[1] += &(*lhs_1 * rhs_1 - &eval_0 - &eval_2); + } + }); + }; + + let mut coeffs = [F::ZERO; 3]; + + let num_threads = num_threads(); + if state.size() < 16 { + evaluate_serial(&mut coeffs, 0, state.size()); + } else { + let chunk_size = div_ceil(state.size(), num_threads); + let mut partials = vec![[F::ZERO; 3]; num_threads]; + parallelize_iter( + partials.iter_mut().zip((0..).step_by(chunk_size << 1)), + |(partial, start)| { + evaluate_serial(partial, start, chunk_size); + }, + ); + partials.iter().for_each(|partial| { + coeffs[0] += partial[0]; + coeffs[2] += partial[2]; + if !LAZY { + coeffs[1] += partial[1]; + } + }) + }; + Coefficients(coeffs.to_vec()) } } + +fn poly<'a, F: PrimeField>( + state: &'a ProverState, + expr: &Expression, +) -> &'a MultilinearPolynomial { + match expr { + Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)) => &state.eq_xys[*idx], + Expression::Polynomial(query) if query.rotation() == Rotation::cur() => { + &state.polys[query.poly()][state.num_vars] + } + _ => unimplemented!(), + } +}