Skip to content

Commit

Permalink
Removed batching from LogupAtRow. (starkware-libs#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti authored and jarnesino committed Sep 17, 2024
1 parent dc89e3b commit 5d68b31
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 130 deletions.
82 changes: 20 additions & 62 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,76 +22,44 @@ use crate::core::ColumnVec;

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
/// BATCH_SIZE is the number of fractions to batch together. The degree of the resulting constraints
/// will be BATCH_SIZE + 1.
pub struct LogupAtRow<const BATCH_SIZE: usize, E: EvalAtRow> {
pub struct LogupAtRow<E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// Queue of fractions waiting to be batched together.
pub queue: [(E::EF, E::EF); BATCH_SIZE],
/// Number of fractions in the queue.
pub queue_size: usize,
/// A constant to subtract from each row, to make the totall sum of the last column zero.
/// In other words, claimed_sum / 2^log_size.
/// This is used to make the constraint uniform.
pub cumsum_shift: SecureField,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
cur_frac: Option<Fraction<E::EF, E::EF>>,
is_finalized: bool,
}
impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
impl<E: EvalAtRow> LogupAtRow<E> {
pub fn new(interaction: usize, claimed_sum: SecureField, log_size: u32) -> Self {
Self {
interaction,
queue: [(E::EF::zero(), E::EF::zero()); BATCH_SIZE],
queue_size: 0,
cumsum_shift: claimed_sum / BaseField::from_u32_unchecked(1 << log_size),
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: false,
}
}
pub fn push_lookup<const N: usize>(
&mut self,
eval: &mut E,
numerator: E::EF,
values: &[E::F],
lookup_elements: &LookupElements<N>,
) {
let shifted_value = lookup_elements.combine(values);
self.push_frac(eval, numerator, shifted_value);
}

pub fn push_frac(&mut self, eval: &mut E, numerator: E::EF, denominator: E::EF) {
if self.queue_size < BATCH_SIZE {
self.queue[self.queue_size] = (numerator, denominator);
self.queue_size += 1;
return;
}

// Compute sum_i pi/qi over batch, as a fraction, num/denom.
let (num, denom) = self.fold_queue();

self.queue[0] = (numerator, denominator);
self.queue_size = 1;

// Add a constraint that num / denom = diff.
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * denom - num);
}

pub fn add_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
pub fn write_frac(&mut self, eval: &mut E, fraction: Fraction<E::EF, E::EF>) {
// Add a constraint that num / denom = diff.
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * fraction.denominator - fraction.numerator);
if let Some(cur_frac) = self.cur_frac {
let cur_cumsum = eval.next_extension_interaction_mask(self.interaction, [0])[0];
let diff = cur_cumsum - self.prev_col_cumsum;
self.prev_col_cumsum = cur_cumsum;
eval.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}
self.cur_frac = Some(fraction);
}

pub fn finalize(mut self, eval: &mut E) {
assert!(!self.is_finalized, "LogupAtRow was already finalized");
let (num, denom) = self.fold_queue();

let frac = self.cur_frac.unwrap();

let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
Expand All @@ -101,25 +69,15 @@ impl<const BATCH_SIZE: usize, E: EvalAtRow> LogupAtRow<BATCH_SIZE, E> {
// This makes (num / denom - cumsum_shift) have sum zero, which makes the constraint
// uniform - apply on all rows.
let fixed_diff = diff + self.cumsum_shift;

eval.add_constraint(fixed_diff * denom - num);
eval.add_constraint(fixed_diff * frac.denominator - frac.numerator);

self.is_finalized = true;
}

fn fold_queue(&self) -> (E::EF, E::EF) {
self.queue[0..self.queue_size]
.iter()
.copied()
.fold((E::EF::zero(), E::EF::one()), |(p0, q0), (pi, qi)| {
(p0 * qi + pi * q0, qi * q0)
})
}
}

/// Ensures that the LogupAtRow is finalized.
/// LogupAtRow should be finalized exactly once.
impl<const BATCH_SIZE: usize, E: EvalAtRow> Drop for LogupAtRow<BATCH_SIZE, E> {
impl<E: EvalAtRow> Drop for LogupAtRow<E> {
fn drop(&mut self) {
assert!(self.is_finalized, "LogupAtRow was not finalized");
}
Expand Down Expand Up @@ -298,15 +256,15 @@ mod tests {
use super::LogupAtRow;
use crate::constraint_framework::InfoEvaluator;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;

#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<2, InfoEvaluator>::new(1, SecureField::one(), 7);
logup.push_frac(
let mut logup = LogupAtRow::<InfoEvaluator>::new(1, SecureField::one(), 7);
logup.write_frac(
&mut InfoEvaluator::default(),
SecureField::one(),
SecureField::one(),
Fraction::new(SecureField::one(), SecureField::one()),
);
}
}
25 changes: 14 additions & 11 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct BlakeRoundEval<'a, E: EvalAtRow> {
pub eval: E,
pub xor_lookup_elements: &'a BlakeXorElements,
pub round_lookup_elements: &'a RoundElements,
pub logup: LogupAtRow<2, E>,
pub logup: LogupAtRow<E>,
}
impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
pub fn eval(mut self) -> E {
Expand All @@ -33,16 +33,19 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
self.g(v.get_many_mut([3, 4, 9, 14]).unwrap(), m[14], m[15]);

// Yield `Round(input_v, output_v, message)`.
self.logup.push_lookup(
self.logup.write_frac(
&mut self.eval,
-E::EF::one(),
&chain![
input_v.iter().copied().flat_map(Fu32::to_felts),
v.iter().copied().flat_map(Fu32::to_felts),
m.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
self.round_lookup_elements,
Fraction::new(
-E::EF::one(),
self.round_lookup_elements.combine(
&chain![
input_v.iter().copied().flat_map(Fu32::to_felts),
v.iter().copied().flat_map(Fu32::to_felts),
m.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
),
);

self.logup.finalize(&mut self.eval);
Expand Down Expand Up @@ -158,7 +161,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
denominator: comb0 * comb1,
};

self.logup.add_frac(&mut self.eval, frac);
self.logup.write_frac(&mut self.eval, frac);
c
}
}
56 changes: 30 additions & 26 deletions crates/prover/src/examples/blake/scheduler/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use itertools::{chain, Itertools};
use num_traits::{One, Zero};
use num_traits::Zero;

use super::BlakeElements;
use crate::constraint_framework::logup::LogupAtRow;
use crate::constraint_framework::EvalAtRow;
use crate::core::lookups::utils::Fraction;
use crate::core::vcs::blake2s_ref::SIGMA;
use crate::examples::blake::round::RoundElements;
use crate::examples::blake::{Fu32, N_ROUNDS, STATE_SIZE};
Expand All @@ -12,46 +13,49 @@ pub fn eval_blake_scheduler_constraints<E: EvalAtRow>(
eval: &mut E,
blake_lookup_elements: &BlakeElements,
round_lookup_elements: &RoundElements,
mut logup: LogupAtRow<2, E>,
mut logup: LogupAtRow<E>,
) {
let messages: [Fu32<E::F>; STATE_SIZE] = std::array::from_fn(|_| eval_next_u32(eval));
let states: [[Fu32<E::F>; STATE_SIZE]; N_ROUNDS + 1] =
std::array::from_fn(|_| std::array::from_fn(|_| eval_next_u32(eval)));

// Schedule.
for i in 0..N_ROUNDS {
let input_state = &states[i];
let output_state = &states[i + 1];
let round_messages = SIGMA[i].map(|j| messages[j as usize]);
for [i, j] in (0..N_ROUNDS).array_chunks::<2>() {
// Use triplet in round lookup.
logup.push_lookup(
eval,
E::EF::one(),
&chain![
input_state.iter().copied().flat_map(Fu32::to_felts),
output_state.iter().copied().flat_map(Fu32::to_felts),
round_messages.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
round_lookup_elements,
)
let [denom_i, denom_j] = [i, j].map(|idx| {
let input_state = &states[idx];
let output_state = &states[idx + 1];
let round_messages = SIGMA[idx].map(|k| messages[k as usize]);
round_lookup_elements.combine::<E::F, E::EF>(
&chain![
input_state.iter().copied().flat_map(Fu32::to_felts),
output_state.iter().copied().flat_map(Fu32::to_felts),
round_messages.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
)
});
logup.write_frac(eval, Fraction::new(denom_i + denom_j, denom_i * denom_j));
}

let input_state = &states[0];
let output_state = &states[N_ROUNDS];

// TODO(spapini): Support multiplicities.
// TODO(spapini): Change to -1.
logup.push_lookup(
logup.write_frac(
eval,
E::EF::zero(),
&chain![
input_state.iter().copied().flat_map(Fu32::to_felts),
output_state.iter().copied().flat_map(Fu32::to_felts),
messages.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
blake_lookup_elements,
Fraction::new(
E::EF::zero(),
blake_lookup_elements.combine(
&chain![
input_state.iter().copied().flat_map(Fu32::to_felts),
output_state.iter().copied().flat_map(Fu32::to_felts),
messages.iter().copied().flat_map(Fu32::to_felts)
]
.collect_vec(),
),
),
);

logup.finalize(eval);
Expand Down
28 changes: 17 additions & 11 deletions crates/prover/src/examples/blake/xor_table/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use itertools::Itertools;

use super::{limb_bits, XorElements};
use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::lookups::utils::Fraction;

/// Constraints for the xor table.
pub struct XorTableEval<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> {
pub eval: E,
pub lookup_elements: &'a XorElements,
pub logup: LogupAtRow<2, E>,
pub logup: LogupAtRow<E>,
}
impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
XorTableEval<'a, E, ELEM_BITS, EXPAND_BITS>
Expand All @@ -19,8 +22,10 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
let [al] = self.eval.next_interaction_mask(2, [0]);
let [bl] = self.eval.next_interaction_mask(2, [0]);
let [cl] = self.eval.next_interaction_mask(2, [0]);
for i in 0..1 << EXPAND_BITS {
for j in 0..1 << EXPAND_BITS {

let frac_chunks = (0..(1 << (2 * EXPAND_BITS)))
.map(|i| {
let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32);
let multiplicity = self.eval.next_trace_mask();

let a = al
Expand All @@ -36,15 +41,16 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32>
(i ^ j) << limb_bits::<ELEM_BITS, EXPAND_BITS>(),
));

// Add with negative multiplicity. Consumers should lookup with positive
// multiplicity.
self.logup.push_lookup(
&mut self.eval,
Fraction::<E::EF, E::EF>::new(
(-multiplicity).into(),
&[a, b, c],
self.lookup_elements,
);
}
self.lookup_elements.combine(&[a, b, c]),
)
})
.collect_vec();

for frac_chunk in frac_chunks.chunks(2) {
let sum_frac: Fraction<E::EF, E::EF> = frac_chunk.iter().copied().sum();
self.logup.write_frac(&mut self.eval, sum_frac);
}
self.logup.finalize(&mut self.eval);
self.eval
Expand Down
27 changes: 12 additions & 15 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::core::backend::Column;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::{CommitmentSchemeProver, PcsConfig, TreeSubspan};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -43,7 +44,7 @@ impl FrameworkEval for PlonkEval {
}

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let mut logup = LogupAtRow::<2, _>::new(1, self.claimed_sum, self.log_n_rows);
let mut logup = LogupAtRow::<_>::new(1, self.claimed_sum, self.log_n_rows);

let [a_wire] = eval.next_interaction_mask(2, [0]);
let [b_wire] = eval.next_interaction_mask(2, [0]);
Expand All @@ -59,23 +60,19 @@ impl FrameworkEval for PlonkEval {

eval.add_constraint(c_val - op * (a_val + b_val) + (E::F::one() - op) * a_val * b_val);

logup.push_lookup(
&mut eval,
E::EF::one(),
&[a_wire, a_val],
&self.lookup_elements,
);
logup.push_lookup(
let denom_a: E::EF = self.lookup_elements.combine(&[a_wire, a_val]);
let denom_b: E::EF = self.lookup_elements.combine(&[b_wire, b_val]);

logup.write_frac(
&mut eval,
E::EF::one(),
&[b_wire, b_val],
&self.lookup_elements,
Fraction::new(denom_a + denom_b, denom_a * denom_b),
);
logup.push_lookup(
logup.write_frac(
&mut eval,
E::EF::from(-mult),
&[c_wire, c_val],
&self.lookup_elements,
Fraction::new(
(-mult).into(),
self.lookup_elements.combine(&[c_wire, c_val]),
),
);

logup.finalize(&mut eval);
Expand Down
Loading

0 comments on commit 5d68b31

Please sign in to comment.