From 5d68b31490e1908ed425a0a92b60ad917923ad24 Mon Sep 17 00:00:00 2001 From: Alon-Ti <54235977+Alon-Ti@users.noreply.github.com> Date: Sun, 15 Sep 2024 14:26:58 +0300 Subject: [PATCH] Removed batching from LogupAtRow. (#817) --- .../prover/src/constraint_framework/logup.rs | 82 +++++-------------- .../src/examples/blake/round/constraints.rs | 25 +++--- .../examples/blake/scheduler/constraints.rs | 56 +++++++------ .../examples/blake/xor_table/constraints.rs | 28 ++++--- crates/prover/src/examples/plonk/mod.rs | 27 +++--- crates/prover/src/examples/poseidon/mod.rs | 18 ++-- 6 files changed, 106 insertions(+), 130 deletions(-) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index c6023bee2..bd8521645 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -22,76 +22,44 @@ use crate::core::ColumnVec; /// Evaluates constraints for batched logups. /// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum. -/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints -/// will be BATCH_SIZE + 1. -pub struct LogupAtRow { +pub struct LogupAtRow { /// The index of the interaction used for the cumulative sum columns. pub interaction: usize, - /// Queue of fractions waiting to be batched together. - pub queue: [(E::EF, E::EF); BATCH_SIZE], - /// Number of fractions in the queue. - pub queue_size: usize, /// A constant to subtract from each row, to make the totall sum of the last column zero. /// In other words, claimed_sum / 2^log_size. /// This is used to make the constraint uniform. pub cumsum_shift: SecureField, /// The evaluation of the last cumulative sum column. pub prev_col_cumsum: E::EF, + cur_frac: Option>, is_finalized: bool, } -impl LogupAtRow { +impl LogupAtRow { pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self { Self { interaction, - queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE], - queue_size: 0, cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size), prev_col_cumsum: E::EF::zero(), + cur_frac: None, is_finalized: false, } } - pub fn push_lookup( - &mut self, - eval: &mut E, - numerator: E::EF, - values: &[E::F], - lookup_elements: &LookupElements, - ) { - let shifted_value = lookup_elements.combine(values); - self.push_frac(eval, numerator, shifted_value); - } - - pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) { - if self.queue_size < BATCH_SIZE { - self.queue[self.queue_size] = (numerator, denominator); - self.queue_size += 1; - return; - } - - // Compute sum_i pi/qi over batch, as a fraction, num/denom. - let (num, denom) = self.fold_queue(); - - self.queue[0] = (numerator, denominator); - self.queue_size = 1; - - // Add a constraint that num / denom = diff. - let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; - let diff = cur_cumsum - self.prev_col_cumsum; - self.prev_col_cumsum = cur_cumsum; - eval.add_constraint(diff * denom - num); - } - pub fn add_frac(&mut self, eval: &mut E, fraction: Fraction) { + pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction) { // Add a constraint that num / denom = diff. - let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; - let diff = cur_cumsum - self.prev_col_cumsum; - self.prev_col_cumsum = cur_cumsum; - eval.add_constraint(diff * fraction.denominator - fraction.numerator); + if let Some(cur_frac) = self.cur_frac { + let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0]; + let diff = cur_cumsum - self.prev_col_cumsum; + self.prev_col_cumsum = cur_cumsum; + eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + } + self.cur_frac = Some(fraction); } pub fn finalize(mut self, eval: &mut E) { assert!(!self.is_finalized, "LogupAtRow was already finalized"); - let (num, denom) = self.fold_queue(); + + let frac = self.cur_frac.unwrap(); let [cur_cumsum, prev_row_cumsum] = eval.next_extension_interaction_mask(self.interaction, [0, -1]); @@ -101,25 +69,15 @@ impl LogupAtRow { // This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint // uniform - apply on all rows. let fixed_diff = diff + self.cumsum_shift; - - eval.add_constraint(fixed_diff * denom - num); + eval.add_constraint(fixed_diff * frac.denominator - frac.numerator); self.is_finalized = true; } - - fn fold_queue(&self) -> (E::EF, E::EF) { - self.queue[0..self.queue_size] - .iter() - .copied() - .fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| { - (p0 * qi + pi * q0, qi * q0) - }) - } } /// Ensures that the LogupAtRow is finalized. /// LogupAtRow should be finalized exactly once. -impl Drop for LogupAtRow { +impl Drop for LogupAtRow { fn drop(&mut self) { assert!(self.is_finalized, "LogupAtRow was not finalized"); } @@ -298,15 +256,15 @@ mod tests { use super::LogupAtRow; use crate::constraint_framework::InfoEvaluator; use crate::core::fields::qm31::SecureField; + use crate::core::lookups::utils::Fraction; #[test] #[should_panic] fn test_logup_not_finalized_panic() { - let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7); - logup.push_frac( + let mut logup = LogupAtRow::::new(1, SecureField::one(), 7); + logup.write_frac( &mut InfoEvaluator::default(), - SecureField::one(), - SecureField::one(), + Fraction::new(SecureField::one(), SecureField::one()), ); } } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index 944094482..0a2732d12 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -15,7 +15,7 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> { pub eval: E, pub xor_lookup_elements: &'a BlakeXorElements, pub round_lookup_elements: &'a RoundElements, - pub logup: LogupAtRow<2, E>, + pub logup: LogupAtRow, } impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { pub fn eval(mut self) -> E { @@ -33,16 +33,19 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]); // Yield `Round(input_v, output_v, message)`. - self.logup.push_lookup( + self.logup.write_frac( &mut self.eval, - -E::EF::one(), - &chain![ - input_v.iter().copied().flat_map(Fu32::to_felts), - v.iter().copied().flat_map(Fu32::to_felts), - m.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - self.round_lookup_elements, + Fraction::new( + -E::EF::one(), + self.round_lookup_elements.combine( + &chain![ + input_v.iter().copied().flat_map(Fu32::to_felts), + v.iter().copied().flat_map(Fu32::to_felts), + m.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + ), + ), ); self.logup.finalize(&mut self.eval); @@ -158,7 +161,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { denominator: comb0 * comb1, }; - self.logup.add_frac(&mut self.eval, frac); + self.logup.write_frac(&mut self.eval, frac); c } } diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 63b3cf696..ee9a1c654 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -1,9 +1,10 @@ use itertools::{chain, Itertools}; -use num_traits::{One, Zero}; +use num_traits::Zero; use super::BlakeElements; use crate::constraint_framework::logup::LogupAtRow; use crate::constraint_framework::EvalAtRow; +use crate::core::lookups::utils::Fraction; use crate::core::vcs::blake2s_ref::SIGMA; use crate::examples::blake::round::RoundElements; use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE}; @@ -12,29 +13,29 @@ pub fn eval_blake_scheduler_constraints( eval: &mut E, blake_lookup_elements: &BlakeElements, round_lookup_elements: &RoundElements, - mut logup: LogupAtRow<2, E>, + mut logup: LogupAtRow, ) { let messages: [Fu32; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval)); let states: [[Fu32; STATE_SIZE]; N_ROUNDS + 1] = std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval))); // Schedule. - for i in 0..N_ROUNDS { - let input_state = &states[i]; - let output_state = &states[i + 1]; - let round_messages = SIGMA[i].map(|j| messages[j as usize]); + for [i, j] in (0..N_ROUNDS).array_chunks::<2>() { // Use triplet in round lookup. - logup.push_lookup( - eval, - E::EF::one(), - &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - round_messages.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - round_lookup_elements, - ) + let [denom_i, denom_j] = [i, j].map(|idx| { + let input_state = &states[idx]; + let output_state = &states[idx + 1]; + let round_messages = SIGMA[idx].map(|k| messages[k as usize]); + round_lookup_elements.combine::( + &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + round_messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + ) + }); + logup.write_frac(eval, Fraction::new(denom_i + denom_j, denom_i * denom_j)); } let input_state = &states[0]; @@ -42,16 +43,19 @@ pub fn eval_blake_scheduler_constraints( // TODO(spapini): Support multiplicities. // TODO(spapini): Change to -1. - logup.push_lookup( + logup.write_frac( eval, - E::EF::zero(), - &chain![ - input_state.iter().copied().flat_map(Fu32::to_felts), - output_state.iter().copied().flat_map(Fu32::to_felts), - messages.iter().copied().flat_map(Fu32::to_felts) - ] - .collect_vec(), - blake_lookup_elements, + Fraction::new( + E::EF::zero(), + blake_lookup_elements.combine( + &chain![ + input_state.iter().copied().flat_map(Fu32::to_felts), + output_state.iter().copied().flat_map(Fu32::to_felts), + messages.iter().copied().flat_map(Fu32::to_felts) + ] + .collect_vec(), + ), + ), ); logup.finalize(eval); diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 00a658311..f43d0088b 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -1,13 +1,16 @@ +use itertools::Itertools; + use super::{limb_bits, XorElements}; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::EvalAtRow; use crate::core::fields::m31::BaseField; +use crate::core::lookups::utils::Fraction; /// Constraints for the xor table. pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> { pub eval: E, pub lookup_elements: &'a XorElements, - pub logup: LogupAtRow<2, E>, + pub logup: LogupAtRow, } impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS> @@ -19,8 +22,10 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> let [al] = self.eval.next_interaction_mask(2, [0]); let [bl] = self.eval.next_interaction_mask(2, [0]); let [cl] = self.eval.next_interaction_mask(2, [0]); - for i in 0..1 << EXPAND_BITS { - for j in 0..1 << EXPAND_BITS { + + let frac_chunks = (0..(1 << (2 * EXPAND_BITS))) + .map(|i| { + let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); let multiplicity = self.eval.next_trace_mask(); let a = al @@ -36,15 +41,16 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> (i ^ j) << limb_bits::(), )); - // Add with negative multiplicity. Consumers should lookup with positive - // multiplicity. - self.logup.push_lookup( - &mut self.eval, + Fraction::::new( (-multiplicity).into(), - &[a, b, c], - self.lookup_elements, - ); - } + self.lookup_elements.combine(&[a, b, c]), + ) + }) + .collect_vec(); + + for frac_chunk in frac_chunks.chunks(2) { + let sum_frac: Fraction = frac_chunk.iter().copied().sum(); + self.logup.write_frac(&mut self.eval, sum_frac); } self.logup.finalize(&mut self.eval); self.eval diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index eabfdc5bd..f2340e681 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -14,6 +14,7 @@ use crate::core::backend::Column; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::lookups::utils::Fraction; use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; @@ -43,7 +44,7 @@ impl FrameworkEval for PlonkEval { } fn evaluate(&self, mut eval: E) -> E { - let mut logup = LogupAtRow::<2, _>::new(1, self.claimed_sum, self.log_n_rows); + let mut logup = LogupAtRow::<_>::new(1, self.claimed_sum, self.log_n_rows); let [a_wire] = eval.next_interaction_mask(2, [0]); let [b_wire] = eval.next_interaction_mask(2, [0]); @@ -59,23 +60,19 @@ impl FrameworkEval for PlonkEval { eval.add_constraint(c_val - op * (a_val + b_val) + (E::F::one() - op) * a_val * b_val); - logup.push_lookup( - &mut eval, - E::EF::one(), - &[a_wire, a_val], - &self.lookup_elements, - ); - logup.push_lookup( + let denom_a: E::EF = self.lookup_elements.combine(&[a_wire, a_val]); + let denom_b: E::EF = self.lookup_elements.combine(&[b_wire, b_val]); + + logup.write_frac( &mut eval, - E::EF::one(), - &[b_wire, b_val], - &self.lookup_elements, + Fraction::new(denom_a + denom_b, denom_a * denom_b), ); - logup.push_lookup( + logup.write_frac( &mut eval, - E::EF::from(-mult), - &[c_wire, c_val], - &self.lookup_elements, + Fraction::new( + (-mult).into(), + self.lookup_elements.combine(&[c_wire, c_val]), + ), ); logup.finalize(&mut eval); diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 9eb9b1c1a..5f26161a2 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -3,7 +3,6 @@ use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; -use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; @@ -19,6 +18,7 @@ use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; +use crate::core::lookups::utils::Fraction; use crate::core::pcs::{CommitmentSchemeProver, PcsConfig}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; @@ -136,14 +136,14 @@ fn pow5(x: F) -> F { pub fn eval_poseidon_constraints( eval: &mut E, - mut logup: LogupAtRow<2, E>, + mut logup: LogupAtRow, lookup_elements: &PoseidonElements, ) { for _ in 0..N_INSTANCES_PER_ROW { let mut state: [_; N_STATE] = std::array::from_fn(|_| eval.next_trace_mask()); // Require state lookup. - logup.push_lookup(eval, E::EF::one(), &state, lookup_elements); + let initial_state_denom: E::EF = lookup_elements.combine(&state); // 4 full rounds. (0..N_HALF_FULL_ROUNDS).for_each(|round| { @@ -183,8 +183,16 @@ pub fn eval_poseidon_constraints( }); }); - // Provide state lookup. - logup.push_lookup(eval, -E::EF::one(), &state, lookup_elements); + // Provide state lookups. + let final_state_denom: E::EF = lookup_elements.combine(&state); + // (1 / denom0) - (1 / denom1) = (denom1 - denom0) / (denom0 * denom1). + logup.write_frac( + eval, + Fraction::new( + final_state_denom - initial_state_denom, + initial_state_denom * final_state_denom, + ), + ); } logup.finalize(eval);