Skip to content

Commit

Permalink
GKR fractional sum check (#28)
Browse files Browse the repository at this point in the history
* feat: implement GKR for fractional sumchecks described in https://eprint.iacr.org/2023/1284.pdf

* feat: make `fractional_sum` support batching

* chore: rename
  • Loading branch information
han0110 authored Sep 27, 2023
1 parent 7b60ccf commit 6466eb2
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 66 deletions.
1 change: 1 addition & 0 deletions plonkish_backend/src/piop.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod gkr;
pub mod sum_check;
3 changes: 3 additions & 0 deletions plonkish_backend/src/piop/gkr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod fractional_sum_check;

pub use fractional_sum_check::{prove_fractional_sum_check, verify_fractional_sum_check};
371 changes: 371 additions & 0 deletions plonkish_backend/src/piop/gkr/fractional_sum_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
//! Implementation of GKR for fractional sumchecks in [PH23].
//! Notations are same as in section 3.
//!
//! [PH23]: https://eprint.iacr.org/2023/1284.pdf

use crate::{
piop::sum_check::{
classic::{ClassicSumCheck, EvaluationsProver},
evaluate, SumCheck as _, VirtualPolynomial,
},
poly::{multilinear::MultilinearPolynomial, Polynomial},
util::{
arithmetic::{div_ceil, inner_product, powers, PrimeField},
chain,
expression::{Expression, Query, Rotation},
izip,
parallel::{num_threads, parallelize_iter},
transcript::{FieldTranscriptRead, FieldTranscriptWrite},
Itertools,
},
Error,
};
use std::{array, collections::HashMap, iter};

type SumCheck<F> = ClassicSumCheck<EvaluationsProver<F>>;

struct Layer<F> {
p_l: MultilinearPolynomial<F>,
p_r: MultilinearPolynomial<F>,
q_l: MultilinearPolynomial<F>,
q_r: MultilinearPolynomial<F>,
}

impl<F> From<[Vec<F>; 4]> for Layer<F> {
fn from(values: [Vec<F>; 4]) -> Self {
let [p_l, p_r, q_l, q_r] = values.map(MultilinearPolynomial::new);
Self { p_l, p_r, q_l, q_r }
}
}

impl<F: PrimeField> Layer<F> {
fn bottom((p, q): (&&MultilinearPolynomial<F>, &&MultilinearPolynomial<F>)) -> Self {
let mid = p.evals().len() >> 1;
[&p[..mid], &p[mid..], &q[..mid], &q[mid..]]
.map(ToOwned::to_owned)
.into()
}

fn num_vars(&self) -> usize {
self.p_l.num_vars()
}

fn polys(&self) -> [&MultilinearPolynomial<F>; 4] {
[&self.p_l, &self.p_r, &self.q_l, &self.q_r]
}

fn poly_chunks(&self, chunk_size: usize) -> impl Iterator<Item = (&[F], &[F], &[F], &[F])> {
let [p_l, p_r, q_l, q_r] = self.polys().map(|poly| poly.evals().chunks(chunk_size));
izip!(p_l, p_r, q_l, q_r)
}

fn up(&self) -> Self {
assert!(self.num_vars() != 0);

let len = 1 << self.num_vars();
let chunk_size = div_ceil(len, num_threads()).next_power_of_two();

let mut outputs: [_; 4] = array::from_fn(|_| vec![F::ZERO; len >> 1]);
let (p, q) = outputs.split_at_mut(2);
parallelize_iter(
izip!(
chain![p].flat_map(|p| p.chunks_mut(chunk_size)),
chain![q].flat_map(|q| q.chunks_mut(chunk_size)),
self.poly_chunks(chunk_size),
),
|(p, q, (p_l, p_r, q_l, q_r))| {
izip!(p, q, p_l, p_r, q_l, q_r).for_each(|(p, q, p_l, p_r, q_l, q_r)| {
*p = *p_l * q_r + *p_r * q_l;
*q = *q_l * q_r;
})
},
);

outputs.into()
}
}

#[allow(clippy::type_complexity)]
pub fn prove_fractional_sum_check<'a, F: PrimeField>(
claimed_p_0s: impl IntoIterator<Item = Option<F>>,
claimed_q_0s: impl IntoIterator<Item = Option<F>>,
ps: impl IntoIterator<Item = &'a MultilinearPolynomial<F>>,
qs: impl IntoIterator<Item = &'a MultilinearPolynomial<F>>,
transcript: &mut impl FieldTranscriptWrite<F>,
) -> Result<(Vec<F>, Vec<F>, Vec<F>), Error> {
let claimed_p_0s = claimed_p_0s.into_iter().collect_vec();
let claimed_q_0s = claimed_q_0s.into_iter().collect_vec();
let ps = ps.into_iter().collect_vec();
let qs = qs.into_iter().collect_vec();
let num_batching = claimed_p_0s.len();

assert!(num_batching != 0);
assert_eq!(num_batching, claimed_q_0s.len());
assert_eq!(num_batching, ps.len());
assert_eq!(num_batching, qs.len());
for poly in chain![&ps, &qs] {
assert_eq!(poly.num_vars(), ps[0].num_vars());
}

let bottom_layers = izip!(&ps, &qs).map(Layer::bottom).collect_vec();
let layers = iter::successors(bottom_layers.into(), |layers| {
(layers[0].num_vars() > 0).then(|| layers.iter().map(Layer::up).collect())
})
.collect_vec();

let [claimed_p_0s, claimed_q_0s]: [_; 2] = {
let (p_0s, q_0s) = chain![layers.last().unwrap()]
.map(|layer| {
let [p_l, p_r, q_l, q_r] = layer.polys().map(|poly| poly[0]);
(p_l * q_r + p_r * q_l, q_l * q_r)
})
.unzip::<_, _, Vec<_>, Vec<_>>();

let mut hash_to_transcript = |claimed: Vec<_>, computed: Vec<_>| {
izip!(claimed, computed)
.map(|(claimed, computed)| match claimed {
Some(claimed) => {
if cfg!(feature = "sanity-check") {
assert_eq!(claimed, computed)
}
transcript.common_field_element(&computed).map(|_| computed)
}
None => transcript.write_field_element(&computed).map(|_| computed),
})
.try_collect::<_, Vec<_>, _>()
};

[
hash_to_transcript(claimed_p_0s, p_0s)?,
hash_to_transcript(claimed_q_0s, q_0s)?,
]
};

let expression = sum_check_expression(num_batching);

let (p_xs, q_xs, x) = layers.iter().rev().fold(
Ok((claimed_p_0s, claimed_q_0s, Vec::new())),
|result, layers| {
let (claimed_p_ys, claimed_q_ys, y) = result?;

let num_vars = layers[0].num_vars();
let polys = layers.iter().flat_map(|layer| layer.polys());

let (mut x, evals) = if num_vars == 0 {
(vec![], polys.map(|poly| poly[0]).collect_vec())
} else {
let gamma = transcript.squeeze_challenge();

let (x, evals) = {
let claim = sum_check_claim(&claimed_p_ys, &claimed_q_ys, gamma);
SumCheck::prove(
&(),
num_vars,
VirtualPolynomial::new(&expression, polys, &[gamma], &[y]),
claim,
transcript,
)?
};

(x, evals)
};

transcript.write_field_elements(&evals)?;

let mu = transcript.squeeze_challenge();

let (p_xs, q_xs) = layer_down_claim(&evals, mu);
x.push(mu);

Ok((p_xs, q_xs, x))
},
)?;

if cfg!(feature = "sanity-check") {
izip!(chain![ps, qs], chain![&p_xs, &q_xs])
.for_each(|(poly, eval)| assert_eq!(poly.evaluate(&x), *eval));
}

Ok((p_xs, q_xs, x))
}

#[allow(clippy::type_complexity)]
pub fn verify_fractional_sum_check<F: PrimeField>(
num_vars: usize,
claimed_p_0s: impl IntoIterator<Item = Option<F>>,
claimed_q_0s: impl IntoIterator<Item = Option<F>>,
transcript: &mut impl FieldTranscriptRead<F>,
) -> Result<(Vec<F>, Vec<F>, Vec<F>), Error> {
let claimed_p_0s = claimed_p_0s.into_iter().collect_vec();
let claimed_q_0s = claimed_q_0s.into_iter().collect_vec();
let num_batching = claimed_p_0s.len();

assert!(num_batching != 0);
assert_eq!(num_batching, claimed_q_0s.len());

let [claimed_p_0s, claimed_q_0s]: [_; 2] = {
[claimed_p_0s, claimed_q_0s]
.into_iter()
.map(|claimed| {
claimed
.into_iter()
.map(|claimed| match claimed {
Some(claimed) => transcript.common_field_element(&claimed).map(|_| claimed),
None => transcript.read_field_element(),
})
.try_collect::<_, Vec<_>, _>()
})
.try_collect::<_, Vec<_>, _>()?
.try_into()
.unwrap()
};

let expression = sum_check_expression(num_batching);

let (p_xs, q_xs, x) = (0..num_vars).fold(
Ok((claimed_p_0s, claimed_q_0s, Vec::new())),
|result, num_vars| {
let (claimed_p_ys, claimed_q_ys, y) = result?;

let (mut x, evals) = if num_vars == 0 {
let evals = transcript.read_field_elements(4 * num_batching)?;

for (claimed_p, claimed_q, (&p_l, &p_r, &q_l, &q_r)) in
izip!(claimed_p_ys, claimed_q_ys, evals.iter().tuples())
{
if claimed_p != p_l * q_r + p_r * q_l || claimed_q != q_l * q_r {
return Err(err_unmatched_sum_check_output());
}
}

(Vec::new(), evals)
} else {
let gamma = transcript.squeeze_challenge();

let (x_eval, x) = {
let claim = sum_check_claim(&claimed_p_ys, &claimed_q_ys, gamma);
SumCheck::verify(&(), num_vars, expression.degree(), claim, transcript)?
};

let evals = transcript.read_field_elements(4 * num_batching)?;

let eval_by_query = eval_by_query(&evals);
if x_eval != evaluate(&expression, num_vars, &eval_by_query, &[gamma], &[&y], &x) {
return Err(err_unmatched_sum_check_output());
}

(x, evals)
};

let mu = transcript.squeeze_challenge();

let (p_xs, q_xs) = layer_down_claim(&evals, mu);
x.push(mu);

Ok((p_xs, q_xs, x))
},
)?;

Ok((p_xs, q_xs, x))
}

fn sum_check_expression<F: PrimeField>(num_batching: usize) -> Expression<F> {
let exprs = &(0..4 * num_batching)
.map(|idx| Expression::<F>::Polynomial(Query::new(idx, Rotation::cur())))
.tuples()
.flat_map(|(ref p_l, ref p_r, ref q_l, ref q_r)| [p_l * q_r + p_r * q_l, q_l * q_r])
.collect_vec();
let eq_xy = &Expression::eq_xy(0);
let gamma = &Expression::Challenge(0);
Expression::distribute_powers(exprs, gamma) * eq_xy
}

fn sum_check_claim<F: PrimeField>(claimed_p_ys: &[F], claimed_q_ys: &[F], gamma: F) -> F {
inner_product(
izip!(claimed_p_ys, claimed_q_ys).flat_map(|(p, q)| [p, q]),
&powers(gamma).take(claimed_p_ys.len() * 2).collect_vec(),
)
}

fn layer_down_claim<F: PrimeField>(evals: &[F], mu: F) -> (Vec<F>, Vec<F>) {
evals
.iter()
.tuples()
.map(|(&p_l, &p_r, &q_l, &q_r)| (p_l + mu * (p_r - p_l), q_l + mu * (q_r - q_l)))
.unzip()
}

fn eval_by_query<F: PrimeField>(evals: &[F]) -> HashMap<Query, F> {
izip!(
(0..).map(|idx| Query::new(idx, Rotation::cur())),
evals.iter().cloned()
)
.collect()
}

fn err_unmatched_sum_check_output() -> Error {
Error::InvalidSumcheck("Unmatched between sum_check output and query evaluation".to_string())
}

#[cfg(test)]
mod test {
use crate::{
piop::gkr::fractional_sum_check::{
prove_fractional_sum_check, verify_fractional_sum_check,
},
poly::multilinear::MultilinearPolynomial,
util::{
chain, izip_eq,
test::{rand_vec, seeded_std_rng},
transcript::{InMemoryTranscript, Keccak256Transcript},
Itertools,
},
};
use halo2_curves::bn256::Fr;
use std::iter;

#[test]
fn fractional_sum_check() {
let num_batching = 3;
for num_vars in 1..16 {
let mut rng = seeded_std_rng();

let polys = iter::repeat_with(|| rand_vec(1 << num_vars, &mut rng))
.map(MultilinearPolynomial::new)
.take(2 * num_batching)
.collect_vec();
let claims = vec![None; 2 * num_batching];
let (ps, qs) = polys.split_at(num_batching);
let (p_0s, q_0s) = claims.split_at(num_batching);

let proof = {
let mut transcript = Keccak256Transcript::new(());
prove_fractional_sum_check::<Fr>(
p_0s.to_vec(),
q_0s.to_vec(),
ps,
qs,
&mut transcript,
)
.unwrap();
transcript.into_proof()
};

let result = {
let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice());
verify_fractional_sum_check::<Fr>(
num_vars,
p_0s.to_vec(),
q_0s.to_vec(),
&mut transcript,
)
};
assert_eq!(result.as_ref().map(|_| ()), Ok(()));

let (p_xs, q_xs, x) = result.unwrap();
for (poly, eval) in izip_eq!(chain![ps, qs], chain![p_xs, q_xs]) {
assert_eq!(poly.evaluate(&x), eval);
}
}
}
}
Loading

0 comments on commit 6466eb2

Please sign in to comment.