Skip to content

Commit

Permalink
feat: move op group table to LogUp-GKR
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer committed Sep 19, 2024
1 parent 3d2294e commit 5e9f719
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 54 deletions.
259 changes: 205 additions & 54 deletions air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ extern crate std;
use alloc::vec::Vec;
use core::marker::PhantomData;

use decoder::{DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET};
use decoder::{
DECODER_ADDR_COL_IDX, DECODER_GROUP_COUNT_COL_IDX, DECODER_HASHER_STATE_OFFSET,
DECODER_IN_SPAN_COL_IDX, DECODER_OP_BATCH_FLAGS_OFFSET, DECODER_OP_BITS_EXTRA_COLS_OFFSET,
DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET,
};
use vm_core::{
utils::{ByteReader, ByteWriter, Deserializable, Serializable},
ExtensionOf, ProgramInfo, StackInputs, StackOutputs, ONE, ZERO,
Expand All @@ -29,6 +33,7 @@ pub use trace::rows::RowIndex;
use trace::{
chiplets::{MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX},
range::{M_COL_IDX, V_COL_IDX},
stack::STACK_TOP_OFFSET,
*,
};

Expand Down Expand Up @@ -316,6 +321,7 @@ impl Deserializable for PublicInputs {
// LOGUP-GKR
// ================================================================================================

// TODO(plafer): move to submodule
#[derive(Clone, Default)]
pub struct MidenLogUpGkrEval<B: FieldElement + StarkField> {
oracles: Vec<LogUpGkrOracle>,
Expand All @@ -324,7 +330,13 @@ pub struct MidenLogUpGkrEval<B: FieldElement + StarkField> {

impl<B: FieldElement + StarkField> MidenLogUpGkrEval<B> {
pub fn new() -> Self {
let oracles = (0..TRACE_WIDTH).map(LogUpGkrOracle::CurrentRow).collect();
let oracles = {
let oracles_current_row = (0..TRACE_WIDTH).map(LogUpGkrOracle::CurrentRow);
let oracles_next_row = (0..TRACE_WIDTH).map(LogUpGkrOracle::NextRow);

oracles_current_row.chain(oracles_next_row).collect()
};

Self { oracles, _field: PhantomData }
}
}
Expand All @@ -339,22 +351,30 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval<Felt> {
}

fn get_num_rand_values(&self) -> usize {
1
// TODO(plafer): use constants
1 // range checker
+ 4 // Op group table
}

fn get_num_fractions(&self) -> usize {
8
// TODO(plafer): use constants
7 // range checker
+ 4 // op group table
+ 5 // padding
}

fn max_degree(&self) -> usize {
5
// TODO(plafer): double check
9
}

fn build_query<E>(&self, frame: &EvaluationFrame<E>, query: &mut [E])
where
E: FieldElement<BaseField = Self::BaseField>,
{
query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f)
let frame_current_then_next = frame.current().iter().chain(frame.next().iter());

query.iter_mut().zip(frame_current_then_next).for_each(|(q, f)| *q = *f);
}

fn evaluate_query<F, E>(
Expand All @@ -368,54 +388,23 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval<Felt> {
F: FieldElement<BaseField = Self::BaseField>,
E: FieldElement<BaseField = Self::BaseField> + ExtensionOf<F>,
{
assert_eq!(numerator.len(), 8);
assert_eq!(denominator.len(), 8);
assert_eq!(query.len(), TRACE_WIDTH);

// numerators
let multiplicity = query[M_COL_IDX];
let f_m = {
let mem_selec0 = query[CHIPLETS_OFFSET];
let mem_selec1 = query[CHIPLETS_OFFSET + 1];
let mem_selec2 = query[CHIPLETS_OFFSET + 2];
mem_selec0 * mem_selec1 * (F::ONE - mem_selec2)
};

let f_rc = {
let op_bit_4 = query[DECODER_OP_BITS_OFFSET + 4];
let op_bit_5 = query[DECODER_OP_BITS_OFFSET + 5];
let op_bit_6 = query[DECODER_OP_BITS_OFFSET + 6];

(F::ONE - op_bit_4) * (F::ONE - op_bit_5) * op_bit_6
};
numerator[0] = E::from(multiplicity);
numerator[1] = E::from(f_m);
numerator[2] = E::from(f_m);
numerator[3] = E::from(f_rc);
numerator[4] = E::from(f_rc);
numerator[5] = E::from(f_rc);
numerator[6] = E::from(f_rc);
numerator[7] = E::ZERO;

// denominators
let alpha = rand_values[0];

let table_denom = alpha - E::from(query[V_COL_IDX]);
let memory_denom_0 = -(alpha - E::from(query[MEMORY_D0_COL_IDX]));
let memory_denom_1 = -(alpha - E::from(query[MEMORY_D1_COL_IDX]));
let stack_value_denom_0 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET]));
let stack_value_denom_1 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET + 1]));
let stack_value_denom_2 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET + 2]));
let stack_value_denom_3 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET + 3]));

denominator[0] = table_denom;
denominator[1] = memory_denom_0;
denominator[2] = memory_denom_1;
denominator[3] = stack_value_denom_0;
denominator[4] = stack_value_denom_1;
denominator[5] = stack_value_denom_2;
denominator[6] = stack_value_denom_3;
denominator[7] = E::ONE;
// TODO(plafer): use constants
assert_eq!(numerator.len(), 16);
assert_eq!(denominator.len(), 16);
assert_eq!(query.len(), TRACE_WIDTH * 2);

let query_current = &query[0..TRACE_WIDTH];
let query_next = &query[TRACE_WIDTH..];

range_checker(query_current, rand_values[0], &mut numerator[0..7], &mut denominator[0..7]);
op_group_table(
query_current,
query_next,
&rand_values[1..5],
&mut numerator[7..11],
&mut denominator[7..11],
);
padding(&mut numerator[11..], &mut denominator[11..]);
}

fn compute_claim<E>(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E
Expand All @@ -425,3 +414,165 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval<Felt> {
E::ZERO
}
}

// HELPERS
// -----------------------------------------------------------------------------------------------

/// TODO(plafer): docs
fn range_checker<F, E>(query_current: &[F], alpha: E, numerator: &mut [E], denominator: &mut [E])
where
F: FieldElement,
E: FieldElement + ExtensionOf<F>,
{
// numerators
let multiplicity = query_current[M_COL_IDX];
let f_m = {
let mem_selec0 = query_current[CHIPLETS_OFFSET];
let mem_selec1 = query_current[CHIPLETS_OFFSET + 1];
let mem_selec2 = query_current[CHIPLETS_OFFSET + 2];
mem_selec0 * mem_selec1 * (F::ONE - mem_selec2)
};

let f_rc = {
let op_bit_4 = query_current[DECODER_OP_BITS_OFFSET + 4];
let op_bit_5 = query_current[DECODER_OP_BITS_OFFSET + 5];
let op_bit_6 = query_current[DECODER_OP_BITS_OFFSET + 6];

(F::ONE - op_bit_4) * (F::ONE - op_bit_5) * op_bit_6
};
numerator[0] = E::from(multiplicity);
numerator[1] = E::from(f_m);
numerator[2] = E::from(f_m);
numerator[3] = E::from(f_rc);
numerator[4] = E::from(f_rc);
numerator[5] = E::from(f_rc);
numerator[6] = E::from(f_rc);

// denominators
let table_denom = alpha - E::from(query_current[V_COL_IDX]);
let memory_denom_0 = -(alpha - E::from(query_current[MEMORY_D0_COL_IDX]));
let memory_denom_1 = -(alpha - E::from(query_current[MEMORY_D1_COL_IDX]));
let stack_value_denom_0 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET]));
let stack_value_denom_1 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET + 1]));
let stack_value_denom_2 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET + 2]));
let stack_value_denom_3 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET + 3]));

denominator[0] = table_denom;
denominator[1] = memory_denom_0;
denominator[2] = memory_denom_1;
denominator[3] = stack_value_denom_0;
denominator[4] = stack_value_denom_1;
denominator[5] = stack_value_denom_2;
denominator[6] = stack_value_denom_3;
}

/// TODO(plafer): docs
fn op_group_table<F, E>(
query_current: &[F],
query_next: &[F],
alphas: &[E],
numerator: &mut [E],
denominator: &mut [E],
) where
F: FieldElement,
E: FieldElement + ExtensionOf<F>,
{
// numerators
// TODO(plafer): don't hardcode HALT's opcode
let f_not_halt = {
let op_bit_e1 = query_current[DECODER_OP_BITS_EXTRA_COLS_OFFSET + 1];
let op_bit_4 = query_current[DECODER_OP_BITS_OFFSET + 4];
let op_bit_3 = query_current[DECODER_OP_BITS_OFFSET + 3];
let op_bit_2 = query_current[DECODER_OP_BITS_OFFSET + 2];
let f_halt = op_bit_e1 * op_bit_4 * op_bit_3 * op_bit_2;

E::from(F::ONE - f_halt)
};
let f_delete_group = query_current[DECODER_IN_SPAN_COL_IDX]
* (query_current[DECODER_GROUP_COUNT_COL_IDX] - query_next[DECODER_GROUP_COUNT_COL_IDX]);

let (f_g2, f_g4, f_g8) = {
let bc0 = query_current[DECODER_OP_BATCH_FLAGS_OFFSET];
let bc1 = query_current[DECODER_OP_BATCH_FLAGS_OFFSET + 1];
let bc2 = query_current[DECODER_OP_BATCH_FLAGS_OFFSET + 2];

((F::ONE - bc0) * (F::ONE - bc1) * bc2, (F::ONE - bc0) * bc1 * bc2, bc0)
};

numerator[0] = f_not_halt.mul_base(f_delete_group);
numerator[1] = f_not_halt.mul_base(f_g2);
numerator[2] = f_not_halt.mul_base(f_g4);
numerator[3] = f_not_halt.mul_base(f_g8);

// denominators
let addr = query_current[DECODER_ADDR_COL_IDX];
let addr_next = query_next[DECODER_ADDR_COL_IDX];
let group_count = query_current[DECODER_GROUP_COUNT_COL_IDX];
let h0_next = query_next[DECODER_HASHER_STATE_OFFSET];
let op_next = parse_op_value(query_next);
let (f_push, f_emit, f_imm) = {
// TODO(plafer): don't hardcode
let e0 = query_current[DECODER_OP_BITS_EXTRA_COLS_OFFSET];
let b3 = query_current[DECODER_OP_BITS_OFFSET + 3];
let b2 = query_current[DECODER_OP_BITS_OFFSET + 2];
let b1 = query_current[DECODER_OP_BITS_OFFSET + 1];
let b0 = query_current[DECODER_OP_BITS_OFFSET];

let f_push = e0 * b3 * (F::ONE - b2) * b1 * b0;
let f_emit = e0 * b3 * (F::ONE - b2) * b1 * (F::ONE - b0);

(f_push, f_emit, f_push + f_emit)
};
let h2 = query_current[DECODER_HASHER_STATE_OFFSET + 2];
let s0_next = query_next[STACK_TRACE_OFFSET + STACK_TOP_OFFSET];
let v = |idx: u8| {
alphas[0]
+ alphas[1].mul_base(addr_next)
+ alphas[2].mul_base(group_count - idx.into())
+ alphas[3].mul_base(query_current[DECODER_HASHER_STATE_OFFSET + idx as usize])
};

denominator[0] = -(alphas[0]
+ alphas[1].mul_base(addr)
+ alphas[2].mul_base(group_count)
+ alphas[3].mul_base(
(F::from(2_u32 << 7) * h0_next + op_next) * (F::ONE - f_imm)
+ s0_next * f_push
+ h2 * f_emit,
));
denominator[1] = -v(1);
denominator[2] = -v(1) * v(2) * v(3);
denominator[3] = -v(1) * v(2) * v(3) * v(4) * v(5) * v(6) * v(7);
}

/// TODO(plafer): docs
fn padding<E>(numerator: &mut [E], denominator: &mut [E])
where
E: FieldElement,
{
numerator.fill(E::ZERO);
denominator.fill(E::ONE);
}

/// TODO(plafer): docs
fn parse_op_value<E>(query: &[E]) -> E
where
E: FieldElement,
{
let b0 = query[DECODER_OP_BITS_OFFSET];
let b1 = query[DECODER_OP_BITS_OFFSET + 1];
let b2 = query[DECODER_OP_BITS_OFFSET + 2];
let b3 = query[DECODER_OP_BITS_OFFSET + 3];
let b4 = query[DECODER_OP_BITS_OFFSET + 4];
let b5 = query[DECODER_OP_BITS_OFFSET + 5];
let b6 = query[DECODER_OP_BITS_OFFSET + 6];
let b7 = query[DECODER_OP_BITS_OFFSET + 7];

b0 + E::from(2_u32) * b1
+ E::from(2_u32 << 2) * b2
+ E::from(2_u32 << 3) * b3
+ E::from(2_u32 << 4) * b4
+ E::from(2_u32 << 5) * b5
+ E::from(2_u32 << 6) * b6
+ E::from(2_u32 << 7) * b7
}
8 changes: 8 additions & 0 deletions air/src/trace/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ pub const P2_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 1;
pub const P3_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 2;

// --- GLOBALLY-INDEXED DECODER COLUMN ACCESSORS --------------------------------------------------
pub const DECODER_ADDR_COL_IDX: usize = super::DECODER_TRACE_OFFSET + ADDR_COL_IDX;
pub const DECODER_OP_BITS_OFFSET: usize = super::DECODER_TRACE_OFFSET + OP_BITS_OFFSET;
pub const DECODER_HASHER_STATE_OFFSET: usize = super::DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET;
pub const DECODER_USER_OP_HELPERS_OFFSET: usize =
super::DECODER_TRACE_OFFSET + USER_OP_HELPERS_OFFSET;
pub const DECODER_IN_SPAN_COL_IDX: usize = super::DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX;
pub const DECODER_GROUP_COUNT_COL_IDX: usize = super::DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX;
pub const DECODER_OP_BATCH_FLAGS_OFFSET: usize =
super::DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET;
pub const DECODER_OP_BITS_EXTRA_COLS_OFFSET: usize =
super::DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET;

0 comments on commit 5e9f719

Please sign in to comment.