Skip to content

Commit

Permalink
feat: toggled quark hybrid grand product (#497)
Browse files Browse the repository at this point in the history
* feat: toggled quark hybrid grand product

* cleanup

* refactor to share code between the existing quarks impl and sparse quarks

* fmt

* fix confused comments from merge
  • Loading branch information
sagar-a16z authored Nov 12, 2024
1 parent 1fb2811 commit 1861026
Show file tree
Hide file tree
Showing 5 changed files with 385 additions and 136 deletions.
2 changes: 1 addition & 1 deletion jolt-core/benches/grand_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ fn benchmark_verify<PCS, F, G, ProofTranscript>(
transcript = ProofTranscript::new(b"test_transcript");
let mut verifier_accumulator: VerifierOpeningAccumulator<F, PCS, ProofTranscript> =
VerifierOpeningAccumulator::new();
let (_, r_verifier) = QuarkGrandProduct::verify_grand_product(
let (_, r_verifier) = QuarkGrandProduct::verify_quark_grand_product(
&proof,
&known_products,
Some(&mut verifier_accumulator),
Expand Down
4 changes: 4 additions & 0 deletions jolt-core/src/subprotocols/grand_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ where

Self::verify_layers(&proof.gkr_layers, claim, transcript, r)
}

fn quark_poly(&self) -> Option<&[F]> {
None
}
}

pub trait BatchedGrandProductLayer<F, ProofTranscript>:
Expand Down
151 changes: 89 additions & 62 deletions jolt-core/src/subprotocols/grand_product_quarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::poly::dense_interleaved_poly::DenseInterleavedPolynomial;
use crate::poly::dense_mlpoly::DensePolynomial;
use crate::poly::eq_poly::EqPolynomial;
use crate::poly::opening_proof::{ProverOpeningAccumulator, VerifierOpeningAccumulator};
use crate::subprotocols::QuarkHybridLayerDepth;
use crate::utils::math::Math;
use crate::utils::transcript::{AppendToTranscript, Transcript};
use ark_serialize::*;
Expand All @@ -27,38 +28,16 @@ pub struct QuarkGrandProductProof<
g_r_sumcheck: PCS::Field,
g_r_prime: (PCS::Field, PCS::Field),
v_r_prime: (PCS::Field, PCS::Field),
num_vars: usize,
pub num_vars: usize,
}

pub struct QuarkGrandProduct<F: JoltField, ProofTranscript: Transcript> {
batch_size: usize,
quark_poly: Vec<F>,
quark_poly: Option<Vec<F>>,
base_layers: Vec<DenseInterleavedPolynomial<F>>,
_marker: PhantomData<ProofTranscript>,
}

#[derive(Clone, Copy, Debug, Default)]
pub enum QuarkHybridLayerDepth {
#[default]
Default,
Min,
Max,
Custom(usize),
}

impl QuarkHybridLayerDepth {
/// The depth in the binary tree of the GKR grand product at which the hybrid scheme
/// will switch to using Quarks Section 5 grand product argument.
pub fn get_crossover_depth(&self) -> usize {
match self {
QuarkHybridLayerDepth::Min => 0,
QuarkHybridLayerDepth::Default => 4,
QuarkHybridLayerDepth::Max => usize::MAX,
QuarkHybridLayerDepth::Custom(depth) => *depth,
}
}
}

#[derive(Clone, Copy, Debug, Default)]
pub struct QuarkGrandProductConfig {
pub hybrid_layer_depth: QuarkHybridLayerDepth,
Expand Down Expand Up @@ -114,7 +93,7 @@ where
if tree_depth <= num_layers {
return Self {
batch_size,
quark_poly: Vec::new(),
quark_poly: None,
base_layers: layers,
_marker: PhantomData,
};
Expand All @@ -125,61 +104,114 @@ where
let quark_poly = layers.pop().unwrap().coeffs;
Self {
batch_size,
quark_poly,
quark_poly: Some(quark_poly),
base_layers: layers,
_marker: PhantomData,
}
}

fn num_layers(&self) -> usize {
unimplemented!("Unused");
self.base_layers.len()
}

/// The claimed outputs of the grand products.
fn claimed_outputs(&self) -> Vec<F> {
let chunk_size = self.quark_poly.len() / self.batch_size;
self.quark_poly
.par_chunks(chunk_size)
.map(|chunk| chunk.iter().product())
.collect()
if let Some(quark_poly) = &self.quark_poly {
let chunk_size = quark_poly.len() / self.batch_size;
quark_poly
.par_chunks(chunk_size)
.map(|chunk| chunk.iter().product())
.collect()
} else {
let top_layer = &self.base_layers[self.base_layers.len() - 1];
top_layer
.par_chunks(2)
.map(|chunk| chunk[0] * chunk[1])
.collect()
}
}

/// Returns an iterator over the layers of this batched grand product circuit.
/// Each layer is mutable so that its polynomials can be bound over the course
/// of proving.
#[allow(unreachable_code)]
fn layers(
&'_ mut self,
) -> impl Iterator<Item = &'_ mut dyn BatchedGrandProductLayer<F, ProofTranscript>> {
unimplemented!("We don't use the default prover and so we don't need the generic iterator");
std::iter::empty()
self.base_layers
.iter_mut()
.map(|layer| layer as &mut dyn BatchedGrandProductLayer<F, ProofTranscript>)
.rev()
}

fn quark_poly(&self) -> Option<&[F]> {
self.quark_poly.as_deref()
}

/// Computes a batched grand product proof, layer by layer.
#[tracing::instrument(skip_all, name = "BatchedGrandProduct::prove_grand_product")]
fn prove_grand_product(
&mut self,
opening_accumulator: Option<&mut ProverOpeningAccumulator<F, ProofTranscript>>,
transcript: &mut ProofTranscript,
setup: Option<&PCS::Setup>,
) -> (BatchedGrandProductProof<PCS, ProofTranscript>, Vec<F>) {
let mut proof_layers = Vec::with_capacity(self.base_layers.len());
QuarkGrandProductBase::prove_quark_grand_product(
self,
opening_accumulator,
transcript,
setup,
)
}

let outputs: Vec<F> =
<Self as BatchedGrandProduct<F, PCS, ProofTranscript>>::claimed_outputs(self);
#[tracing::instrument(skip_all, name = "BatchedGrandProduct::verify_grand_product")]
fn verify_grand_product(
proof: &BatchedGrandProductProof<PCS, ProofTranscript>,
claimed_outputs: &[F],
opening_accumulator: Option<&mut VerifierOpeningAccumulator<F, PCS, ProofTranscript>>,
transcript: &mut ProofTranscript,
_setup: Option<&PCS::Setup>,
) -> (F, Vec<F>) {
QuarkGrandProductBase::verify_quark_grand_product::<Self, PCS>(
proof,
claimed_outputs,
opening_accumulator,
transcript,
)
}
}

pub struct QuarkGrandProductBase<F: JoltField, ProofTranscript: Transcript> {
_marker: PhantomData<(F, ProofTranscript)>,
}

impl<F, ProofTranscript> QuarkGrandProductBase<F, ProofTranscript>
where
F: JoltField,
ProofTranscript: Transcript,
{
/// Computes a batched grand product proof, layer by layer.
#[tracing::instrument(skip_all, name = "QuarkGrandProduct::prove_grand_product")]
pub fn prove_quark_grand_product<PCS: CommitmentScheme<ProofTranscript, Field = F>>(
grand_product: &mut impl BatchedGrandProduct<F, PCS, ProofTranscript>,
opening_accumulator: Option<&mut ProverOpeningAccumulator<F, ProofTranscript>>,
transcript: &mut ProofTranscript,
setup: Option<&PCS::Setup>,
) -> (BatchedGrandProductProof<PCS, ProofTranscript>, Vec<F>) {
let mut proof_layers = Vec::with_capacity(grand_product.num_layers());

let outputs: Vec<F> = grand_product.claimed_outputs();
transcript.append_scalars(&outputs);
let output_mle = DensePolynomial::new_padded(outputs);
let r_outputs: Vec<F> = transcript.challenge_vector(output_mle.get_num_vars());
let claim = output_mle.evaluate(&r_outputs);

// For polynomials of size less than 16 we just use the GKR grand product
let (quark_proof, mut random, mut claim) = if !self.quark_poly.is_empty() {
// When doing the quark hybrid proof, we first prove the grand product of a layer of a polynomial which is 4 layers deep in the tree
let (quark_proof, mut random, mut claim) = if grand_product.quark_poly().is_some() {
// When doing the quark hybrid proof, we first prove the grand product of a layer of a polynomial which is N layers deep in the tree
// of a standard layered sumcheck grand product, then we use the sumcheck layers to prove via GKR layers that the random point opened
// by the quark proof is in fact the folded result of the base layer.
let (quark, random, quark_claim) =
QuarkGrandProductProof::<PCS, ProofTranscript>::prove(
&self.quark_poly,
grand_product.quark_poly().unwrap(),
r_outputs,
claim,
opening_accumulator.unwrap(),
Expand All @@ -191,7 +223,7 @@ where
(None, r_outputs, claim)
};

for layer in self.base_layers.iter_mut().rev() {
for layer in grand_product.layers() {
proof_layers.push(layer.prove_layer(&mut claim, &mut random, transcript));
}

Expand All @@ -205,16 +237,17 @@ where
}

/// Verifies the given grand product proof.
#[tracing::instrument(skip_all, name = "BatchedGrandProduct::verify_grand_product")]
fn verify_grand_product(
#[tracing::instrument(skip_all, name = "QuarkGrandProduct::verify_grand_product")]
pub fn verify_quark_grand_product<G, PCS>(
proof: &BatchedGrandProductProof<PCS, ProofTranscript>,
claimed_outputs: &[F],
opening_accumulator: Option<&mut VerifierOpeningAccumulator<F, PCS, ProofTranscript>>,
transcript: &mut ProofTranscript,
_setup: Option<&PCS::Setup>,
) -> (F, Vec<F>) {
// Evaluate the MLE of the output layer at a random point to reduce the outputs to
// a single claim.
) -> (F, Vec<F>)
where
PCS: CommitmentScheme<ProofTranscript, Field = F>,
G: BatchedGrandProduct<F, PCS, ProofTranscript>,
{
transcript.append_scalars(claimed_outputs);
let r_outputs: Vec<F> =
transcript.challenge_vector(claimed_outputs.len().next_power_of_two().log_2());
Expand All @@ -225,7 +258,6 @@ where
Some(quark) => {
// In this case we verify the quark which fixes the first log(n)-4 vars in the random eval point.
let v_len = quark.num_vars;
// Todo (aleph_v) - bubble up errors
quark
.verify(
r_outputs,
Expand All @@ -234,21 +266,16 @@ where
transcript,
v_len,
)
.unwrap()
.unwrap_or_else(|e| panic!("quark verify error: {:?}", e))
}
None => {
// Otherwise we must check the actual claims and the preset random will be empty.
(claim, r_outputs)
}
};

let (grand_product_claim, grand_product_r) = <Self as BatchedGrandProduct<
F,
PCS,
ProofTranscript,
>>::verify_layers(
&proof.gkr_layers, claim, transcript, rand
);
let (grand_product_claim, grand_product_r) =
G::verify_layers(&proof.gkr_layers, claim, transcript, rand);

(grand_product_claim, grand_product_r)
}
Expand Down Expand Up @@ -277,7 +304,7 @@ where
/// Then - Constructs a g poly and preforms sumcheck proof that sum == 0
/// Finally - computes opening proofs for a random sampled during sumcheck proof and returns
/// Returns a random point and evaluation to be verified by the caller (which our hybrid prover does with GKR)
fn prove(
pub fn prove(
v: &[PCS::Field],
r_outputs: Vec<PCS::Field>,
claim: PCS::Field,
Expand All @@ -288,7 +315,7 @@ where
let v_length = v.len();
let v_variables = v_length.log_2();

let v_polynomial = DensePolynomial::<PCS::Field>::new(v.to_vec());
let v_polynomial = DensePolynomial::<PCS::Field>::new_padded(v.to_vec());
// Compute f(1, x), f(x, 0), and f(x, 1) from v(x)
let (f_1x, f_x0, f_x1) = v_into_f::<PCS::Field>(&v_polynomial);

Expand Down Expand Up @@ -443,7 +470,7 @@ where

/// Verifies the given grand product proof.
#[allow(clippy::type_complexity)]
fn verify(
pub fn verify(
&self,
r_outputs: Vec<PCS::Field>,
claim: PCS::Field,
Expand Down Expand Up @@ -681,7 +708,7 @@ mod quark_grand_product_tests {
&known_products,
Some(&mut verifier_accumulator),
&mut verifier_transcript,
Some(&setup),
None,
);
assert!(verifier_accumulator
.reduce_and_verify(&setup, &batched_proof, &mut verifier_transcript)
Expand Down
22 changes: 22 additions & 0 deletions jolt-core/src/subprotocols/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,25 @@ pub mod grand_product;
pub mod grand_product_quarks;
pub mod sparse_grand_product;
pub mod sumcheck;

#[derive(Clone, Copy, Debug, Default)]
pub enum QuarkHybridLayerDepth {
#[default]
Default,
Min,
Max,
Custom(usize),
}

impl QuarkHybridLayerDepth {
// The depth in the product tree of the grand product at which the
// hybrid implementation will switch to using quarks grand product proofs
pub fn get_crossover_depth(&self) -> usize {
match self {
QuarkHybridLayerDepth::Min => 0, // Always use quarks
QuarkHybridLayerDepth::Default => 4,
QuarkHybridLayerDepth::Max => usize::MAX, // Never use quarks
QuarkHybridLayerDepth::Custom(depth) => *depth,
}
}
}
Loading

0 comments on commit 1861026

Please sign in to comment.