From 5e9f719ebef99ca384c887b77a6298c6b6b30a12 Mon Sep 17 00:00:00 2001 From: Philippe Laferriere Date: Thu, 19 Sep 2024 17:10:00 -0400 Subject: [PATCH] feat: move op group table to LogUp-GKR --- air/src/lib.rs | 259 +++++++++++++++++++++++++++-------- air/src/trace/decoder/mod.rs | 8 ++ 2 files changed, 213 insertions(+), 54 deletions(-) diff --git a/air/src/lib.rs b/air/src/lib.rs index e0627af1de..37dcb31300 100644 --- a/air/src/lib.rs +++ b/air/src/lib.rs @@ -9,7 +9,11 @@ extern crate std; use alloc::vec::Vec; use core::marker::PhantomData; -use decoder::{DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET}; +use decoder::{ + DECODER_ADDR_COL_IDX, DECODER_GROUP_COUNT_COL_IDX, DECODER_HASHER_STATE_OFFSET, + DECODER_IN_SPAN_COL_IDX, DECODER_OP_BATCH_FLAGS_OFFSET, DECODER_OP_BITS_EXTRA_COLS_OFFSET, + DECODER_OP_BITS_OFFSET, DECODER_USER_OP_HELPERS_OFFSET, +}; use vm_core::{ utils::{ByteReader, ByteWriter, Deserializable, Serializable}, ExtensionOf, ProgramInfo, StackInputs, StackOutputs, ONE, ZERO, @@ -29,6 +33,7 @@ pub use trace::rows::RowIndex; use trace::{ chiplets::{MEMORY_D0_COL_IDX, MEMORY_D1_COL_IDX}, range::{M_COL_IDX, V_COL_IDX}, + stack::STACK_TOP_OFFSET, *, }; @@ -316,6 +321,7 @@ impl Deserializable for PublicInputs { // LOGUP-GKR // ================================================================================================ +// TODO(plafer): move to submodule #[derive(Clone, Default)] pub struct MidenLogUpGkrEval { oracles: Vec, @@ -324,7 +330,13 @@ pub struct MidenLogUpGkrEval { impl MidenLogUpGkrEval { pub fn new() -> Self { - let oracles = (0..TRACE_WIDTH).map(LogUpGkrOracle::CurrentRow).collect(); + let oracles = { + let oracles_current_row = (0..TRACE_WIDTH).map(LogUpGkrOracle::CurrentRow); + let oracles_next_row = (0..TRACE_WIDTH).map(LogUpGkrOracle::NextRow); + + oracles_current_row.chain(oracles_next_row).collect() + }; + Self { oracles, _field: PhantomData } } } @@ -339,22 +351,30 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval { } fn get_num_rand_values(&self) -> usize { - 1 + // TODO(plafer): use constants + 1 // range checker + + 4 // Op group table } fn get_num_fractions(&self) -> usize { - 8 + // TODO(plafer): use constants + 7 // range checker + + 4 // op group table + + 5 // padding } fn max_degree(&self) -> usize { - 5 + // TODO(plafer): double check + 9 } fn build_query(&self, frame: &EvaluationFrame, query: &mut [E]) where E: FieldElement, { - query.iter_mut().zip(frame.current().iter()).for_each(|(q, f)| *q = *f) + let frame_current_then_next = frame.current().iter().chain(frame.next().iter()); + + query.iter_mut().zip(frame_current_then_next).for_each(|(q, f)| *q = *f); } fn evaluate_query( @@ -368,54 +388,23 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval { F: FieldElement, E: FieldElement + ExtensionOf, { - assert_eq!(numerator.len(), 8); - assert_eq!(denominator.len(), 8); - assert_eq!(query.len(), TRACE_WIDTH); - - // numerators - let multiplicity = query[M_COL_IDX]; - let f_m = { - let mem_selec0 = query[CHIPLETS_OFFSET]; - let mem_selec1 = query[CHIPLETS_OFFSET + 1]; - let mem_selec2 = query[CHIPLETS_OFFSET + 2]; - mem_selec0 * mem_selec1 * (F::ONE - mem_selec2) - }; - - let f_rc = { - let op_bit_4 = query[DECODER_OP_BITS_OFFSET + 4]; - let op_bit_5 = query[DECODER_OP_BITS_OFFSET + 5]; - let op_bit_6 = query[DECODER_OP_BITS_OFFSET + 6]; - - (F::ONE - op_bit_4) * (F::ONE - op_bit_5) * op_bit_6 - }; - numerator[0] = E::from(multiplicity); - numerator[1] = E::from(f_m); - numerator[2] = E::from(f_m); - numerator[3] = E::from(f_rc); - numerator[4] = E::from(f_rc); - numerator[5] = E::from(f_rc); - numerator[6] = E::from(f_rc); - numerator[7] = E::ZERO; - - // denominators - let alpha = rand_values[0]; - - let table_denom = alpha - E::from(query[V_COL_IDX]); - let memory_denom_0 = -(alpha - E::from(query[MEMORY_D0_COL_IDX])); - let memory_denom_1 = -(alpha - E::from(query[MEMORY_D1_COL_IDX])); - let stack_value_denom_0 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET])); - let stack_value_denom_1 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET + 1])); - let stack_value_denom_2 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET + 2])); - let stack_value_denom_3 = -(alpha - E::from(query[DECODER_USER_OP_HELPERS_OFFSET + 3])); - - denominator[0] = table_denom; - denominator[1] = memory_denom_0; - denominator[2] = memory_denom_1; - denominator[3] = stack_value_denom_0; - denominator[4] = stack_value_denom_1; - denominator[5] = stack_value_denom_2; - denominator[6] = stack_value_denom_3; - denominator[7] = E::ONE; + // TODO(plafer): use constants + assert_eq!(numerator.len(), 16); + assert_eq!(denominator.len(), 16); + assert_eq!(query.len(), TRACE_WIDTH * 2); + + let query_current = &query[0..TRACE_WIDTH]; + let query_next = &query[TRACE_WIDTH..]; + + range_checker(query_current, rand_values[0], &mut numerator[0..7], &mut denominator[0..7]); + op_group_table( + query_current, + query_next, + &rand_values[1..5], + &mut numerator[7..11], + &mut denominator[7..11], + ); + padding(&mut numerator[11..], &mut denominator[11..]); } fn compute_claim(&self, _inputs: &Self::PublicInputs, _rand_values: &[E]) -> E @@ -425,3 +414,165 @@ impl LogUpGkrEvaluator for MidenLogUpGkrEval { E::ZERO } } + +// HELPERS +// ----------------------------------------------------------------------------------------------- + +/// TODO(plafer): docs +fn range_checker(query_current: &[F], alpha: E, numerator: &mut [E], denominator: &mut [E]) +where + F: FieldElement, + E: FieldElement + ExtensionOf, +{ + // numerators + let multiplicity = query_current[M_COL_IDX]; + let f_m = { + let mem_selec0 = query_current[CHIPLETS_OFFSET]; + let mem_selec1 = query_current[CHIPLETS_OFFSET + 1]; + let mem_selec2 = query_current[CHIPLETS_OFFSET + 2]; + mem_selec0 * mem_selec1 * (F::ONE - mem_selec2) + }; + + let f_rc = { + let op_bit_4 = query_current[DECODER_OP_BITS_OFFSET + 4]; + let op_bit_5 = query_current[DECODER_OP_BITS_OFFSET + 5]; + let op_bit_6 = query_current[DECODER_OP_BITS_OFFSET + 6]; + + (F::ONE - op_bit_4) * (F::ONE - op_bit_5) * op_bit_6 + }; + numerator[0] = E::from(multiplicity); + numerator[1] = E::from(f_m); + numerator[2] = E::from(f_m); + numerator[3] = E::from(f_rc); + numerator[4] = E::from(f_rc); + numerator[5] = E::from(f_rc); + numerator[6] = E::from(f_rc); + + // denominators + let table_denom = alpha - E::from(query_current[V_COL_IDX]); + let memory_denom_0 = -(alpha - E::from(query_current[MEMORY_D0_COL_IDX])); + let memory_denom_1 = -(alpha - E::from(query_current[MEMORY_D1_COL_IDX])); + let stack_value_denom_0 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET])); + let stack_value_denom_1 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET + 1])); + let stack_value_denom_2 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET + 2])); + let stack_value_denom_3 = -(alpha - E::from(query_current[DECODER_USER_OP_HELPERS_OFFSET + 3])); + + denominator[0] = table_denom; + denominator[1] = memory_denom_0; + denominator[2] = memory_denom_1; + denominator[3] = stack_value_denom_0; + denominator[4] = stack_value_denom_1; + denominator[5] = stack_value_denom_2; + denominator[6] = stack_value_denom_3; +} + +/// TODO(plafer): docs +fn op_group_table( + query_current: &[F], + query_next: &[F], + alphas: &[E], + numerator: &mut [E], + denominator: &mut [E], +) where + F: FieldElement, + E: FieldElement + ExtensionOf, +{ + // numerators + // TODO(plafer): don't hardcode HALT's opcode + let f_not_halt = { + let op_bit_e1 = query_current[DECODER_OP_BITS_EXTRA_COLS_OFFSET + 1]; + let op_bit_4 = query_current[DECODER_OP_BITS_OFFSET + 4]; + let op_bit_3 = query_current[DECODER_OP_BITS_OFFSET + 3]; + let op_bit_2 = query_current[DECODER_OP_BITS_OFFSET + 2]; + let f_halt = op_bit_e1 * op_bit_4 * op_bit_3 * op_bit_2; + + E::from(F::ONE - f_halt) + }; + let f_delete_group = query_current[DECODER_IN_SPAN_COL_IDX] + * (query_current[DECODER_GROUP_COUNT_COL_IDX] - query_next[DECODER_GROUP_COUNT_COL_IDX]); + + let (f_g2, f_g4, f_g8) = { + let bc0 = query_current[DECODER_OP_BATCH_FLAGS_OFFSET]; + let bc1 = query_current[DECODER_OP_BATCH_FLAGS_OFFSET + 1]; + let bc2 = query_current[DECODER_OP_BATCH_FLAGS_OFFSET + 2]; + + ((F::ONE - bc0) * (F::ONE - bc1) * bc2, (F::ONE - bc0) * bc1 * bc2, bc0) + }; + + numerator[0] = f_not_halt.mul_base(f_delete_group); + numerator[1] = f_not_halt.mul_base(f_g2); + numerator[2] = f_not_halt.mul_base(f_g4); + numerator[3] = f_not_halt.mul_base(f_g8); + + // denominators + let addr = query_current[DECODER_ADDR_COL_IDX]; + let addr_next = query_next[DECODER_ADDR_COL_IDX]; + let group_count = query_current[DECODER_GROUP_COUNT_COL_IDX]; + let h0_next = query_next[DECODER_HASHER_STATE_OFFSET]; + let op_next = parse_op_value(query_next); + let (f_push, f_emit, f_imm) = { + // TODO(plafer): don't hardcode + let e0 = query_current[DECODER_OP_BITS_EXTRA_COLS_OFFSET]; + let b3 = query_current[DECODER_OP_BITS_OFFSET + 3]; + let b2 = query_current[DECODER_OP_BITS_OFFSET + 2]; + let b1 = query_current[DECODER_OP_BITS_OFFSET + 1]; + let b0 = query_current[DECODER_OP_BITS_OFFSET]; + + let f_push = e0 * b3 * (F::ONE - b2) * b1 * b0; + let f_emit = e0 * b3 * (F::ONE - b2) * b1 * (F::ONE - b0); + + (f_push, f_emit, f_push + f_emit) + }; + let h2 = query_current[DECODER_HASHER_STATE_OFFSET + 2]; + let s0_next = query_next[STACK_TRACE_OFFSET + STACK_TOP_OFFSET]; + let v = |idx: u8| { + alphas[0] + + alphas[1].mul_base(addr_next) + + alphas[2].mul_base(group_count - idx.into()) + + alphas[3].mul_base(query_current[DECODER_HASHER_STATE_OFFSET + idx as usize]) + }; + + denominator[0] = -(alphas[0] + + alphas[1].mul_base(addr) + + alphas[2].mul_base(group_count) + + alphas[3].mul_base( + (F::from(2_u32 << 7) * h0_next + op_next) * (F::ONE - f_imm) + + s0_next * f_push + + h2 * f_emit, + )); + denominator[1] = -v(1); + denominator[2] = -v(1) * v(2) * v(3); + denominator[3] = -v(1) * v(2) * v(3) * v(4) * v(5) * v(6) * v(7); +} + +/// TODO(plafer): docs +fn padding(numerator: &mut [E], denominator: &mut [E]) +where + E: FieldElement, +{ + numerator.fill(E::ZERO); + denominator.fill(E::ONE); +} + +/// TODO(plafer): docs +fn parse_op_value(query: &[E]) -> E +where + E: FieldElement, +{ + let b0 = query[DECODER_OP_BITS_OFFSET]; + let b1 = query[DECODER_OP_BITS_OFFSET + 1]; + let b2 = query[DECODER_OP_BITS_OFFSET + 2]; + let b3 = query[DECODER_OP_BITS_OFFSET + 3]; + let b4 = query[DECODER_OP_BITS_OFFSET + 4]; + let b5 = query[DECODER_OP_BITS_OFFSET + 5]; + let b6 = query[DECODER_OP_BITS_OFFSET + 6]; + let b7 = query[DECODER_OP_BITS_OFFSET + 7]; + + b0 + E::from(2_u32) * b1 + + E::from(2_u32 << 2) * b2 + + E::from(2_u32 << 3) * b3 + + E::from(2_u32 << 4) * b4 + + E::from(2_u32 << 5) * b5 + + E::from(2_u32 << 6) * b6 + + E::from(2_u32 << 7) * b7 +} diff --git a/air/src/trace/decoder/mod.rs b/air/src/trace/decoder/mod.rs index 77a7531d7f..10b1c7468e 100644 --- a/air/src/trace/decoder/mod.rs +++ b/air/src/trace/decoder/mod.rs @@ -101,6 +101,14 @@ pub const P2_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 1; pub const P3_COL_IDX: usize = DECODER_AUX_TRACE_OFFSET + 2; // --- GLOBALLY-INDEXED DECODER COLUMN ACCESSORS -------------------------------------------------- +pub const DECODER_ADDR_COL_IDX: usize = super::DECODER_TRACE_OFFSET + ADDR_COL_IDX; pub const DECODER_OP_BITS_OFFSET: usize = super::DECODER_TRACE_OFFSET + OP_BITS_OFFSET; +pub const DECODER_HASHER_STATE_OFFSET: usize = super::DECODER_TRACE_OFFSET + HASHER_STATE_OFFSET; pub const DECODER_USER_OP_HELPERS_OFFSET: usize = super::DECODER_TRACE_OFFSET + USER_OP_HELPERS_OFFSET; +pub const DECODER_IN_SPAN_COL_IDX: usize = super::DECODER_TRACE_OFFSET + IN_SPAN_COL_IDX; +pub const DECODER_GROUP_COUNT_COL_IDX: usize = super::DECODER_TRACE_OFFSET + GROUP_COUNT_COL_IDX; +pub const DECODER_OP_BATCH_FLAGS_OFFSET: usize = + super::DECODER_TRACE_OFFSET + OP_BATCH_FLAGS_OFFSET; +pub const DECODER_OP_BITS_EXTRA_COLS_OFFSET: usize = + super::DECODER_TRACE_OFFSET + OP_BITS_EXTRA_COLS_OFFSET;