Skip to content

Commit

Permalink
Implement Hyrax commitment in a streaming manner
Browse files Browse the repository at this point in the history
  • Loading branch information
jprider63 committed Sep 27, 2024
1 parent fdcc587 commit daca462
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 27 deletions.
12 changes: 6 additions & 6 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::field::JoltField;
use crate::host;
use crate::jolt::vm::rv32i_vm::{RV32IJoltVM, C, M};
use crate::jolt::vm::Jolt;
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme;
use crate::poly::commitment::hyperkzg::HyperKZG;
use crate::poly::commitment::hyrax::HyraxScheme;
use crate::poly::commitment::zeromorph::Zeromorph;
Expand Down Expand Up @@ -61,23 +61,23 @@ pub fn benchmarks(
fn fibonacci<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
prove_example::<u32, PCS, F>("fibonacci-guest", &9u32)
}

fn sha2<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
prove_example::<Vec<u8>, PCS, F>("sha2-guest", &vec![5u8; 2048])
}

fn sha3<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
prove_example::<Vec<u8>, PCS, F>("sha3-guest", &vec![5u8; 2048])
}
Expand All @@ -99,7 +99,7 @@ fn prove_example<T: Serialize, PCS, F>(
) -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
let mut tasks = Vec::new();
let mut program = host::Program::new(example_name);
Expand Down Expand Up @@ -152,7 +152,7 @@ where
fn sha2chain<F, PCS>() -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
where
F: JoltField,
PCS: CommitmentScheme<Field = F>,
PCS: StreamingCommitmentScheme<Field = F>,
{
let mut tasks = Vec::new();
let mut program = host::Program::new("sha2-chain-guest");
Expand Down
28 changes: 22 additions & 6 deletions jolt-core/src/jolt/vm/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ pub struct BytecodeRowStep<F: JoltField, C: CommitmentScheme<Field = F>> {
_group: PhantomData<C>,

/// Memory address as read from the ELF.
address: F,
pub(super) address: F,
/// Packed instruction/circuit flags, used for r1cs
bitflags: F,
pub(super) bitflags: F,
/// Index of the destination register for this instruction (0 if register is unused).
rd: F,
pub(super) rd: F,
/// Index of the first source register for this instruction (0 if register is unused).
rs1: F,
pub(super) rs1: F,
/// Index of the second source register for this instruction (0 if register is unused).
rs2: F,
pub(super) rs2: F,
/// "Immediate" value for this instruction (0 if unused).
imm: F,
pub(super) imm: F,
// /// If this instruction is part of a "virtual sequence" (see Section 6.2 of the
// /// Jolt paper), then this contains the number of virtual instructions after this
// /// one in the sequence. I.e. if this is the last instruction in the sequence,
Expand Down Expand Up @@ -195,6 +195,8 @@ pub struct BytecodePolynomials<F: JoltField, C: CommitmentScheme<Field = F>> {
}

pub struct StreamingBytecodePolynomials<'a, F: JoltField, C: CommitmentScheme<Field = F>> {
/// Length of the polynomial.
length: usize,
/// Stream that builds the bytecode polynomial.
polynomial_stream: Box<dyn Iterator<Item = BytecodePolynomialStep<F, C>> + 'a>, // MapState<Vec<usize>, I, FN>,
}
Expand Down Expand Up @@ -299,6 +301,7 @@ impl<'a, F: JoltField, C: CommitmentScheme<Field = F>> StreamingBytecodePolynomi
trace: &'a mut [JoltTraceStep<InstructionSet>],
) -> Self {
let final_cts: Vec<usize> = vec![0; preprocessing.code_size];
let length = trace.len();

let polynomial_stream = map_state(final_cts, trace.iter_mut(), |final_cts, step| {
if !step.bytecode_row.address.is_zero() {
Expand Down Expand Up @@ -347,9 +350,22 @@ impl<'a, F: JoltField, C: CommitmentScheme<Field = F>> StreamingBytecodePolynomi
});

StreamingBytecodePolynomials {
length,
polynomial_stream: Box::new(polynomial_stream),
}
}

pub fn fold<T, FN>(self, init: T, f: FN) -> T
where
FN: FnMut(T, BytecodePolynomialStep<F, C>) -> T,
{
self.polynomial_stream.fold(init, f)
}

/// Returns the number of evaluations of the polynomial.
pub fn length(&self) -> usize {
self.length
}
}

impl<F: JoltField, C: CommitmentScheme<Field = F>> BytecodePolynomials<F, C> {
Expand Down
26 changes: 21 additions & 5 deletions jolt-core/src/jolt/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,23 @@ use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use strum::EnumCount;

use crate::jolt::vm::timestamp_range_check::RangeCheckPolynomials;
use crate::jolt::{
instruction::{
div::DIVInstruction, divu::DIVUInstruction, mulh::MULHInstruction,
mulhsu::MULHSUInstruction, rem::REMInstruction, remu::REMUInstruction, JoltInstruction,
VirtualInstructionSequence,
},
subtable::JoltSubtableSet,
vm::timestamp_range_check::TimestampValidityProof,
vm::{
bytecode::StreamingBytecodePolynomials,
timestamp_range_check::{
RangeCheckPolynomials,
TimestampValidityProof,
},
},
};
use crate::lasso::memory_checking::{MemoryCheckingProver, MemoryCheckingVerifier};
use crate::poly::commitment::commitment_scheme::{BatchType, CommitmentScheme};
use crate::poly::commitment::commitment_scheme::{BatchType, CommitmentScheme, StreamingCommitmentScheme};
use crate::poly::dense_mlpoly::DensePolynomial;
use crate::poly::structured_poly::StructuredCommitment;
use crate::r1cs::inputs::{R1CSCommitment, R1CSInputs, R1CSProof};
Expand Down Expand Up @@ -246,7 +251,7 @@ where
}
}

pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, const M: usize> {
pub trait Jolt<F: JoltField, PCS: StreamingCommitmentScheme<Field = F>, const C: usize, const M: usize> {
type InstructionSet: JoltInstructionSet;
type Subtables: JoltSubtableSet<F>;

Expand Down Expand Up @@ -329,10 +334,11 @@ pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, c
) {
let trace_length = trace.len();
let padded_trace_length = trace_length.next_power_of_two();
println!("Trace length: {}", trace_length);

JoltTraceStep::pad(&mut trace);

let mut trace2 = trace.clone();

let mut transcript = ProofTranscript::new(b"Jolt transcript");
Self::fiat_shamir_preamble(&mut transcript, &program_io, trace_length);

Expand Down Expand Up @@ -360,6 +366,15 @@ pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, c
|| RangeCheckPolynomials::<F, PCS>::new(read_timestamps),
);

let streaming_bytecode_polynomials = StreamingBytecodePolynomials::<F, PCS>::new(&preprocessing.bytecode, &mut trace2);
let initialized_commitment = PCS::initialize(streaming_bytecode_polynomials.length(), &preprocessing.generators, &BatchType::Big);
// JP: `fold` likely isn't sufficient since we need to extract the internal state.
let streaming_trace_commitments =
streaming_bytecode_polynomials.fold(initialized_commitment, |state, step| {
PCS::process(state, step.a_read_write)
});
let a_read_write_commitment = PCS::finalize(streaming_trace_commitments);

let jolt_polynomials = JoltPolynomials {
bytecode: bytecode_polynomials,
read_write_memory: memory_polynomials,
Expand All @@ -368,6 +383,7 @@ pub trait Jolt<F: JoltField, PCS: CommitmentScheme<Field = F>, const C: usize, c
};

let mut jolt_commitments = jolt_polynomials.commit(&preprocessing.generators);
assert_eq!(a_read_write_commitment, jolt_commitments.bytecode.trace_commitments[0]);

let (witness_segments, r1cs_commitments, r1cs_builder) = Self::r1cs_setup(
padded_trace_length,
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/jolt/vm/rv32i_vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::jolt::subtable::{
truncate_overflow::TruncateOverflowSubtable, xor::XorSubtable, JoltSubtableSet, LassoSubtable,
SubtableId,
};
use crate::poly::commitment::commitment_scheme::CommitmentScheme;
use crate::poly::commitment::commitment_scheme::StreamingCommitmentScheme;

/// Generates an enum out of a list of JoltInstruction types. All JoltInstruction methods
/// are callable on the enum type via enum_dispatch.
Expand Down Expand Up @@ -166,7 +166,7 @@ pub const M: usize = 1 << 16;
impl<F, CS> Jolt<F, CS, C, M> for RV32IJoltVM
where
F: JoltField,
CS: CommitmentScheme<Field = F>,
CS: StreamingCommitmentScheme<Field = F>,
{
type InstructionSet = RV32I;
type Subtables = RV32ISubtables<F>;
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/poly/commitment/binius.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
#[derive(Clone)]
pub struct Binius128Scheme {}

#[derive(CanonicalSerialize, CanonicalDeserialize)]
#[derive(CanonicalSerialize, CanonicalDeserialize, Debug, Eq, PartialEq)]
pub struct BiniusCommitment {}

impl AppendToTranscript for BiniusCommitment {
Expand Down
10 changes: 9 additions & 1 deletion jolt-core/src/poly/commitment/commitment_scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub enum BatchType {
pub trait CommitmentScheme: Clone + Sync + Send + 'static {
type Field: JoltField + Sized;
type Setup: Clone + Sync + Send;
type Commitment: Sync + Send + CanonicalSerialize + CanonicalDeserialize + AppendToTranscript;
type Commitment: Sync + Send + CanonicalSerialize + CanonicalDeserialize + AppendToTranscript + std::fmt::Debug + Eq;
type Proof: Sync + Send + CanonicalSerialize + CanonicalDeserialize;
type BatchedProof: Sync + Send + CanonicalSerialize + CanonicalDeserialize;

Expand Down Expand Up @@ -98,3 +98,11 @@ pub trait CommitmentScheme: Clone + Sync + Send + 'static {

fn protocol_name() -> &'static [u8];
}

pub trait StreamingCommitmentScheme: CommitmentScheme {
type State;

fn initialize(size: usize, setup: &Self::Setup, batch_type: &BatchType) -> Self::State;
fn process(state: Self::State, eval: Self::Field) -> Self::State;
fn finalize(state: Self::State) -> Self::Commitment;
}
21 changes: 19 additions & 2 deletions jolt-core/src/poly/commitment/hyperkzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! (2) HyperKZG is specialized to use KZG as the univariate commitment scheme, so it includes several optimizations (both during the transformation of multilinear-to-univariate claims
//! and within the KZG commitment scheme implementation itself).
use super::{
commitment_scheme::{BatchType, CommitmentScheme},
commitment_scheme::{BatchType, CommitmentScheme, StreamingCommitmentScheme},
kzg::{KZGProverKey, KZGVerifierKey, UnivariateKZG},
};
use crate::field;
Expand Down Expand Up @@ -58,7 +58,7 @@ pub struct HyperKZGVerifierKey<P: Pairing> {
pub kzg_vk: KZGVerifierKey<P>,
}

#[derive(Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
#[derive(Debug, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct HyperKZGCommitment<P: Pairing>(pub P::G1Affine);

impl<P: Pairing> AppendToTranscript for HyperKZGCommitment<P> {
Expand Down Expand Up @@ -627,6 +627,23 @@ where
}
}

impl<P: Pairing> StreamingCommitmentScheme for HyperKZG<P>
where
<P as Pairing>::ScalarField: field::JoltField,
{
type State = ();

fn initialize(size: usize, setup: &Self::Setup, batch_type: &BatchType) -> Self::State {
todo!()
}
fn process(state: Self::State, eval: Self::Field) -> Self::State {
todo!()
}
fn finalize(state: Self::State) -> Self::Commitment {
todo!()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
58 changes: 56 additions & 2 deletions jolt-core/src/poly/commitment/hyrax.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::marker::PhantomData;

use super::commitment_scheme::{BatchType, CommitShape, CommitmentScheme};
use super::commitment_scheme::{BatchType, CommitShape, CommitmentScheme, StreamingCommitmentScheme};
use super::pedersen::{PedersenCommitment, PedersenGenerators};
use crate::field::JoltField;
use crate::poly::dense_mlpoly::DensePolynomial;
Expand Down Expand Up @@ -148,12 +148,66 @@ impl<F: JoltField, G: CurveGroup<ScalarField = F>> CommitmentScheme for HyraxSch
}
}

pub struct HyraxSchemeState<G: CurveGroup> {
row_commitments: Vec<G>,
generators: Vec<<G as CurveGroup>::Affine>,
current_row: Vec<G::ScalarField>,
L_size: usize,
R_size: usize,
}

impl<F: JoltField, G: CurveGroup<ScalarField = F>> StreamingCommitmentScheme for HyraxScheme<G> {
type State = HyraxSchemeState<G>;

fn initialize(n: usize, generators: &Self::Setup, batch_type: &BatchType) -> Self::State {
let ell = n.log_2();

let ratio = batch_type_to_ratio(batch_type);

let (L_size, R_size) = matrix_dimensions(ell, ratio);
assert_eq!(dbg!(L_size) * dbg!(R_size), dbg!(n));

let generators = CurveGroup::normalize_batch(&generators.generators[..R_size]);

let row_commitments = Vec::with_capacity(L_size);
let current_row = Vec::with_capacity(R_size);

HyraxSchemeState {
row_commitments,
generators,
current_row,
L_size,
R_size,
}
}

fn process(mut state: Self::State, eval: Self::Field) -> Self::State {
state.current_row.push(eval);

if state.current_row.len() == state.R_size {
let commitment = PedersenCommitment::commit_vector(&state.current_row, &state.generators);
state.row_commitments.push(commitment);

state.current_row.clear();
}

state
}
fn finalize(state: Self::State) -> Self::Commitment {
assert_eq!(state.current_row.len(), 0, "Incorrect number of elements processed.");
assert_eq!(state.row_commitments.len(), state.L_size, "Incorrect number of elements processed.");

HyraxCommitment {
row_commitments: state.row_commitments,
}
}
}
#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
pub struct HyraxGenerators<G: CurveGroup> {
pub gens: PedersenGenerators<G>,
}

#[derive(Clone, Debug, CanonicalSerialize, CanonicalDeserialize)]
#[derive(Clone, Debug, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct HyraxCommitment<G: CurveGroup> {
pub row_commitments: Vec<G>,
}
Expand Down
21 changes: 19 additions & 2 deletions jolt-core/src/poly/commitment/zeromorph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use tracing::trace_span;
use rayon::prelude::*;

use super::{
commitment_scheme::{BatchType, CommitShape, CommitmentScheme},
commitment_scheme::{BatchType, CommitShape, CommitmentScheme, StreamingCommitmentScheme},
kzg::{KZGProverKey, KZGVerifierKey, UnivariateKZG, SRS},
};

Expand Down Expand Up @@ -64,7 +64,7 @@ pub struct ZeromorphVerifierKey<P: Pairing> {
pub tau_N_max_sub_2_N: P::G2Affine,
}

#[derive(Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
#[derive(Debug, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct ZeromorphCommitment<P: Pairing>(P::G1Affine);

impl<P: Pairing> AppendToTranscript for ZeromorphCommitment<P> {
Expand Down Expand Up @@ -593,6 +593,23 @@ where
}
}

impl<P: Pairing> StreamingCommitmentScheme for Zeromorph<P>
where
<P as Pairing>::ScalarField: field::JoltField,
{
type State = ();

fn initialize(size: usize, setup: &Self::Setup, batch_type: &BatchType) -> Self::State {
todo!()
}
fn process(state: Self::State, eval: Self::Field) -> Self::State {
todo!()
}
fn finalize(state: Self::State) -> Self::Commitment {
todo!()
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down

0 comments on commit daca462

Please sign in to comment.