Skip to content

Commit

Permalink
feat: initial implementation of LogUp-GKR bus
Browse files Browse the repository at this point in the history
Co-authored-by: Al-Kindi-0 <[email protected]>
Co-authored-by: Philippe Laferrière <[email protected]>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent ddf536c commit 085a41b
Show file tree
Hide file tree
Showing 19 changed files with 2,500 additions and 6 deletions.
12 changes: 8 additions & 4 deletions air/src/trace/main_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ use super::{
use core::ops::{Deref, Range};
use vm_core::{utils::range, Felt, Word, ONE, ZERO};

#[cfg(any(test, feature = "internals"))]
use alloc::vec::Vec;

// CONSTANTS
// ================================================================================================

Expand All @@ -43,6 +40,13 @@ impl Deref for MainTrace {
}
}

#[cfg(any(test, feature = "internals"))]
impl core::ops::DerefMut for MainTrace {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.columns
}
}

impl MainTrace {
pub fn new(main_trace: ColMatrix<Felt>) -> Self {
Self {
Expand All @@ -55,7 +59,7 @@ impl MainTrace {
}

#[cfg(any(test, feature = "internals"))]
pub fn get_column_range(&self, range: Range<usize>) -> Vec<Vec<Felt>> {
pub fn get_column_range(&self, range: Range<usize>) -> alloc::vec::Vec<alloc::vec::Vec<Felt>> {
range.fold(vec![], |mut acc, col_idx| {
acc.push(self.get_column(col_idx).to_vec());
acc
Expand Down
4 changes: 3 additions & 1 deletion processor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ std = ["vm-core/std", "winter-prover/std"]
tracing = { version = "0.1", default-features = false, features = [
"attributes",
] }
vm-core = { package = "miden-core", path = "../core", version = "0.9", default-features = false }
miden-air = { package = "miden-air", path = "../air", version = "0.9", default-features = false }
static_assertions = "1.1.0"
thiserror = { version = "1.0", default-features = false }
vm-core = { package = "miden-core", path = "../core", version = "0.9", default-features = false }
winter-prover = { package = "winter-prover", version = "0.9", default-features = false }

[dev-dependencies]
Expand Down
5 changes: 4 additions & 1 deletion processor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@ use chiplets::Chiplets;

mod trace;
use trace::TraceFragment;
pub use trace::{ChipletsLengths, ExecutionTrace, TraceLenSummary, NUM_RAND_ROWS};
pub use trace::{
prove_virtual_bus, verify_virtual_bus, ChipletsLengths, ExecutionTrace, TraceLenSummary,
NUM_RAND_ROWS,
};

mod errors;
pub use errors::{ExecutionError, Ext2InttError};
Expand Down
3 changes: 3 additions & 0 deletions processor/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ use winter_prover::{crypto::RandomCoin, EvaluationFrame, Trace, TraceInfo};
mod utils;
pub use utils::{AuxColumnBuilder, ChipletsLengths, TraceFragment, TraceLenSummary};

mod virtual_bus;
pub use virtual_bus::{prove as prove_virtual_bus, verify as verify_virtual_bus};

#[cfg(test)]
mod tests;
#[cfg(test)]
Expand Down
24 changes: 24 additions & 0 deletions processor/src/trace/virtual_bus/circuit/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use crate::trace::virtual_bus::sum_check::SumCheckProverError;
use crate::trace::virtual_bus::sum_check::SumCheckVerifierError;

#[derive(Debug, thiserror::Error)]
pub enum ProverError {
#[error("failed to generate multi-linear from the given evaluations")]
FailedToGenerateML,
#[error("failed to generate the sum-check proof")]
FailedToProveSumCheck(#[from] SumCheckProverError),
#[error("failed to generate the random challenge")]
FailedToGenerateChallenge,
}

#[derive(Debug, thiserror::Error)]
pub enum VerifierError {
#[error("one of the claimed circuit denominators is zero")]
ZeroOutputDenominator,
#[error("the output of the fraction circuit is not equal to the expected value")]
MismatchingCircuitOutput,
#[error("failed to generate the random challenge")]
FailedToGenerateChallenge,
#[error("failed to verify the sum-check proof")]
FailedToVerifySumCheck(#[from] SumCheckVerifierError),
}
290 changes: 290 additions & 0 deletions processor/src/trace/virtual_bus/circuit/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
use core::ops::Add;

use crate::trace::virtual_bus::multilinear::EqFunction;
use crate::trace::virtual_bus::sum_check::{CompositionPolynomial, RoundProof};
use alloc::vec::Vec;
use miden_air::trace::chiplets::{MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX};
use miden_air::trace::decoder::{DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET};
use miden_air::trace::range::{M_COL_IDX, V_COL_IDX};
use miden_air::trace::{CHIPLETS_OFFSET, TRACE_WIDTH};
use prover::CircuitLayerPolys;
use static_assertions::const_assert;
use vm_core::{Felt, FieldElement};

mod error;
mod prover;
pub use prover::prove;

mod verifier;
pub use verifier::verify;

use super::multilinear::MultiLinearPoly;
use super::sum_check::{FinalOpeningClaim, Proof as SumCheckProof};

/// Defines the number of wires in the input layer that are generated from a single main trace row.
const NUM_WIRES_PER_TRACE_ROW: usize = 8;
const_assert!(NUM_WIRES_PER_TRACE_ROW.is_power_of_two());

// CIRCUIT WIRE
// ================================================================================================

/// Represents a fraction `numerator / denominator` as a pair `(numerator, denominator)`. This is
/// the type for the gates' inputs in [`prover::EvaluatedCircuit`].
///
/// Hence, addition is defined in the natural way fractions are added together: `a/b + c/d = (ad +
/// bc) / bd`.
#[derive(Debug, Clone, Copy)]
pub struct CircuitWire<E: FieldElement> {
numerator: E,
denominator: E,
}

impl<E> CircuitWire<E>
where
E: FieldElement,
{
/// Creates new projective coordinates from a numerator and a denominator.
pub fn new(numerator: E, denominator: E) -> Self {
assert_ne!(denominator, E::ZERO);

Self {
numerator,
denominator,
}
}
}

impl<E> Add for CircuitWire<E>
where
E: FieldElement,
{
type Output = Self;

fn add(self, other: Self) -> Self {
let numerator = self.numerator * other.denominator + other.numerator * self.denominator;
let denominator = self.denominator * other.denominator;

Self::new(numerator, denominator)
}
}

/// Converts a main trace row (or more generally "query") to numerators and denominators of the
/// input layer.
fn evaluate_fractions_at_main_trace_query<E>(
query: &[E],
log_up_randomness: &[E],
) -> [[E; NUM_WIRES_PER_TRACE_ROW]; 2]
where
E: FieldElement,
{
// numerators
let multiplicity = query[M_COL_IDX];
let f_m = {
let mem_selec0 = query[CHIPLETS_OFFSET];
let mem_selec1 = query[CHIPLETS_OFFSET + 1];
let mem_selec2 = query[CHIPLETS_OFFSET + 2];
mem_selec0 * mem_selec1 * (E::ONE - mem_selec2)
};

let f_rc = {
let op_bit_4 = query[DECODER_OP_BITS_OFFSET + 4];
let op_bit_5 = query[DECODER_OP_BITS_OFFSET + 5];
let op_bit_6 = query[DECODER_OP_BITS_OFFSET + 6];

(E::ONE - op_bit_4) * (E::ONE - op_bit_5) * op_bit_6
};

// denominators
let alphas = log_up_randomness;

let table_denom = alphas[0] - query[V_COL_IDX];
let memory_denom_0 = -(alphas[0] - query[MEMORY_D0_COL_IDX]);
let memory_denom_1 = -(alphas[0] - query[MEMORY_D1_COL_IDX]);
let stack_value_denom_0 = -(alphas[0] - query[DECODER_USER_OP_HELPERS_OFFSET]);
let stack_value_denom_1 = -(alphas[0] - query[DECODER_USER_OP_HELPERS_OFFSET + 1]);
let stack_value_denom_2 = -(alphas[0] - query[DECODER_USER_OP_HELPERS_OFFSET + 2]);
let stack_value_denom_3 = -(alphas[0] - query[DECODER_USER_OP_HELPERS_OFFSET + 3]);

[
[multiplicity, f_m, f_m, f_rc, f_rc, f_rc, f_rc, E::ZERO],
[
table_denom,
memory_denom_0,
memory_denom_1,
stack_value_denom_0,
stack_value_denom_1,
stack_value_denom_2,
stack_value_denom_3,
E::ONE,
],
]
}

/// Computes the wires added to the input layer that come from a given main trace row (or more
/// generally, "query").
fn compute_input_layer_wires_at_main_trace_query<E>(
query: &[E],
log_up_randomness: &[E],
) -> [CircuitWire<E>; NUM_WIRES_PER_TRACE_ROW]
where
E: FieldElement,
{
let [numerators, denominators] =
evaluate_fractions_at_main_trace_query(query, log_up_randomness);
let input_gates_values: Vec<CircuitWire<E>> = numerators
.into_iter()
.zip(denominators)
.map(|(numerator, denominator)| CircuitWire::new(numerator, denominator))
.collect();
input_gates_values.try_into().unwrap()
}

/// A GKR proof for the correct evaluation of the sum of fractions circuit.
#[derive(Debug)]
pub struct GkrCircuitProof<E: FieldElement> {
circuit_outputs: CircuitLayerPolys<E>,
before_final_layer_proofs: BeforeFinalLayerProof<E>,
final_layer_proof: FinalLayerProof<E>,
}

impl<E: FieldElement> GkrCircuitProof<E> {
pub fn get_final_opening_claim(&self) -> FinalOpeningClaim<E> {
self.final_layer_proof.after_merge_proof.openings_claim.clone()
}
}

/// A set of sum-check proofs for all GKR layers but for the input circuit layer.
#[derive(Debug)]
pub struct BeforeFinalLayerProof<E: FieldElement> {
pub proof: Vec<SumCheckProof<E>>,
}

/// A proof for the input circuit layer i.e., the final layer in the GKR protocol.
#[derive(Debug)]
pub struct FinalLayerProof<E: FieldElement> {
before_merge_proof: Vec<RoundProof<E>>,
after_merge_proof: SumCheckProof<E>,
}

/// Represents a claim to be proven by a subsequent call to the sum-check protocol.
#[derive(Debug)]
pub struct GkrClaim<E: FieldElement> {
pub evaluation_point: Vec<E>,
pub claimed_evaluation: (E, E),
}

/// A composition polynomial used in the GKR protocol for all of its sum-checks except the final
/// one.
#[derive(Clone)]
pub struct GkrComposition<E>
where
E: FieldElement<BaseField = Felt>,
{
pub combining_randomness: E,
}

impl<E> GkrComposition<E>
where
E: FieldElement<BaseField = Felt>,
{
pub fn new(combining_randomness: E) -> Self {
Self {
combining_randomness,
}
}
}

impl<E> CompositionPolynomial<E> for GkrComposition<E>
where
E: FieldElement<BaseField = Felt>,
{
fn max_degree(&self) -> u32 {
3
}

fn evaluate(&self, query: &[E]) -> E {
let eval_left_numerator = query[0];
let eval_right_numerator = query[1];
let eval_left_denominator = query[2];
let eval_right_denominator = query[3];
let eq_eval = query[4];
eq_eval
* ((eval_left_numerator * eval_right_denominator
+ eval_right_numerator * eval_left_denominator)
+ eval_left_denominator * eval_right_denominator * self.combining_randomness)
}
}

/// A composition polynomial used in the GKR protocol for its final sum-check.
#[derive(Clone)]
pub struct GkrCompositionMerge<E>
where
E: FieldElement<BaseField = Felt>,
{
pub sum_check_combining_randomness: E,
pub tensored_merge_randomness: Vec<E>,
pub log_up_randomness: Vec<E>,
}

impl<E> GkrCompositionMerge<E>
where
E: FieldElement<BaseField = Felt>,
{
pub fn new(
combining_randomness: E,
merge_randomness: Vec<E>,
log_up_randomness: Vec<E>,
) -> Self {
let tensored_merge_randomness =
EqFunction::ml_at(merge_randomness.clone()).evaluations().to_vec();

Self {
sum_check_combining_randomness: combining_randomness,
tensored_merge_randomness,
log_up_randomness,
}
}
}

impl<E> CompositionPolynomial<E> for GkrCompositionMerge<E>
where
E: FieldElement<BaseField = Felt>,
{
fn max_degree(&self) -> u32 {
// Computed as:
// 1 + max(left_numerator_degree + right_denom_degree, right_numerator_degree +
// left_denom_degree)
5
}

fn evaluate(&self, query: &[E]) -> E {
let [numerators, denominators] =
evaluate_fractions_at_main_trace_query(query, &self.log_up_randomness);

let numerators = MultiLinearPoly::from_evaluations(numerators.to_vec()).unwrap();
let denominators = MultiLinearPoly::from_evaluations(denominators.to_vec()).unwrap();

let (left_numerators, right_numerators) = numerators.project_least_significant_variable();
let (left_denominators, right_denominators) =
denominators.project_least_significant_variable();

let eval_left_numerators =
left_numerators.evaluate_with_lagrange_kernel(&self.tensored_merge_randomness);
let eval_right_numerators =
right_numerators.evaluate_with_lagrange_kernel(&self.tensored_merge_randomness);

let eval_left_denominators =
left_denominators.evaluate_with_lagrange_kernel(&self.tensored_merge_randomness);
let eval_right_denominators =
right_denominators.evaluate_with_lagrange_kernel(&self.tensored_merge_randomness);

let eq_eval = query[TRACE_WIDTH];

eq_eval
* ((eval_left_numerators * eval_right_denominators
+ eval_right_numerators * eval_left_denominators)
+ eval_left_denominators
* eval_right_denominators
* self.sum_check_combining_randomness)
}
}
Loading

0 comments on commit 085a41b

Please sign in to comment.