Skip to content

Commit

Permalink
trace in bit reverse order
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 26, 2024
1 parent 7c5d309 commit de00dca
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 28 deletions.
7 changes: 5 additions & 2 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use num_traits::{One, Zero};

use crate::constraint_framework::logup::{LogupAtRow, LookupElements};
use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupAtRow, LookupElements};
use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
Expand Down Expand Up @@ -28,6 +28,7 @@ pub struct StateTransitionEval<const COORDINATE: usize> {
pub log_n_rows: u32,
pub lookup_elements: StateMachineElements,
pub total_sum: QM31,
pub claimed_sum: ClaimedPrefixSum,
}

impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE> {
Expand All @@ -39,7 +40,8 @@ impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE>
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup: LogupAtRow<E> = LogupAtRow::new(1, self.total_sum, None, is_first);
let mut logup: LogupAtRow<E> =
LogupAtRow::new(1, self.total_sum, Some(self.claimed_sum), is_first);

let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask());
let input_denom: E::EF = self.lookup_elements.combine(&input_state);
Expand Down Expand Up @@ -98,6 +100,7 @@ fn state_transition_info<const INDEX: usize>() -> InfoEvaluator {
log_n_rows: 1,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
};
component.evaluate(InfoEvaluator::default())
}
Expand Down
24 changes: 17 additions & 7 deletions crates/prover/src/examples/state_machine/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::ColumnVec;

// Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the
Expand All @@ -30,7 +31,9 @@ pub fn gen_trace(
// Add the states in bit reversed circle domain order.
for i in 0..1 << log_size {
for j in 0..STATE_SIZE {
trace[j][i] = curr_state[j];
let bit_rev_index =
bit_reverse_index(coset_index_to_circle_domain_index(i, log_size), log_size);
trace[j][bit_rev_index] = curr_state[j];
}
// Increment the state to the next state row.
curr_state[inc_index] += M31::one();
Expand All @@ -48,14 +51,17 @@ pub fn gen_trace(
}

pub fn gen_interaction_trace(
log_size: u32,
n_rows: usize,
trace: &ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
inc_index: usize,
lookup_elements: &LookupElements<STATE_SIZE>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, M31, BitReversedOrder>>,
QM31,
[QM31; 2],
) {
let log_size = trace[0].domain.log_size();
assert!(n_rows <= 1 << log_size, "n_rows exceeds the trace size");

let ones = PackedM31::broadcast(M31::one());
let mut logup_gen = LogupTraceGenerator::new(log_size);
let mut col_gen = logup_gen.new_col();
Expand All @@ -78,7 +84,7 @@ pub fn gen_interaction_trace(
}
col_gen.finalize_col();

logup_gen.finalize_last()
logup_gen.finalize_at([(1 << log_size) - 1, n_rows])
}

#[cfg(test)]
Expand All @@ -88,6 +94,7 @@ mod tests {
use crate::core::fields::qm31::QM31;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::fields::FieldExpOps;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::examples::state_machine::components::StateMachineElements;
use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace};

Expand All @@ -97,13 +104,15 @@ mod tests {
let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)];
let inc_index = 1;
let row = 123;
let bit_rev_row =
bit_reverse_index(coset_index_to_circle_domain_index(row, log_size), log_size);

let trace = gen_trace(log_size, initial_state, inc_index);

assert_eq!(trace.len(), 2);
assert_eq!(trace[0].at(row), initial_state[0]);
assert_eq!(
trace[1].at(row),
trace[1].at(bit_rev_row),
initial_state[1] + M31::from_u32_unchecked(row as u32)
);
}
Expand All @@ -122,10 +131,11 @@ mod tests {
let first_state_comb: QM31 = lookup_elements.combine(&first_state);
let last_state_comb: QM31 = lookup_elements.combine(&last_state);

let (interaction_trace, total_sum) =
gen_interaction_trace(log_size, &trace, inc_index, &lookup_elements);
let (interaction_trace, [total_sum, claimed_sum]) =
gen_interaction_trace((1 << log_size) - 1, &trace, inc_index, &lookup_elements);

assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column.
assert_eq!(claimed_sum, total_sum);
assert_eq!(
total_sum,
first_state_comb.inverse() - last_state_comb.inverse()
Expand Down
41 changes: 22 additions & 19 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ pub fn prove_state_machine(
StateMachineComponents,
StateMachineProof<Blake2sMerkleHasher>,
) {
assert!(log_n_rows >= LOG_N_LANES);
let x_axis_log_rows = log_n_rows;
let y_axis_log_rows = log_n_rows - 1;
let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1);
let (x_row, y_row) = (34, 56);
assert!(y_axis_log_rows >= LOG_N_LANES && x_axis_log_rows >= LOG_N_LANES);
assert!(x_row < 1 << x_axis_log_rows);
assert!(y_row < 1 << y_axis_log_rows);

let mut intermediate_state = initial_state;
intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows);
intermediate_state[0] += M31::from_u32_unchecked(x_row);
let mut final_state = intermediate_state;
final_state[1] += M31::from_u32_unchecked(1 << y_axis_log_rows);
final_state[1] += M31::from_u32_unchecked(y_row);

// Precompute twiddles.
let twiddles = SimdBackend::precompute_twiddles(
Expand Down Expand Up @@ -69,14 +71,14 @@ pub fn prove_state_machine(
let lookup_elements = StateMachineElements::draw(channel);

// Interaction trace.
let (interaction_trace_op0, total_sum_op0) =
gen_interaction_trace(x_axis_log_rows, &trace_op0, 0, &lookup_elements);
let (interaction_trace_op1, total_sum_op1) =
gen_interaction_trace(y_axis_log_rows, &trace_op1, 1, &lookup_elements);
let (interaction_trace_op0, [total_sum_op0, claimed_sum_op0]) =
gen_interaction_trace(x_row as usize - 1, &trace_op0, 0, &lookup_elements);
let (interaction_trace_op1, [total_sum_op1, claimed_sum_op1]) =
gen_interaction_trace(y_row as usize - 1, &trace_op1, 1, &lookup_elements);

let stmt1 = StateMachineStatement1 {
x_axis_claimed_sum: total_sum_op0,
y_axis_claimed_sum: total_sum_op1,
x_axis_claimed_sum: claimed_sum_op0,
y_axis_claimed_sum: claimed_sum_op1,
};
stmt1.mix_into(channel);

Expand All @@ -100,6 +102,7 @@ pub fn prove_state_machine(
log_n_rows: x_axis_log_rows,
lookup_elements: lookup_elements.clone(),
total_sum: total_sum_op0,
claimed_sum: (claimed_sum_op0, x_row as usize - 1),
},
);
let component1 = StateMachineOp1Component::new(
Expand All @@ -108,6 +111,7 @@ pub fn prove_state_machine(
log_n_rows: y_axis_log_rows,
lookup_elements,
total_sum: total_sum_op1,
claimed_sum: (claimed_sum_op1, y_row as usize - 1),
},
);
let components = StateMachineComponents {
Expand Down Expand Up @@ -190,15 +194,17 @@ mod tests {
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

// Interaction trace.
let (interaction_trace, total_sum) =
gen_interaction_trace(log_n_rows, &trace, 0, &lookup_elements);
let (interaction_trace, [total_sum, claimed_sum]) =
gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements);

assert_eq!(total_sum, claimed_sum);
let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum,
claimed_sum: (total_sum, (1 << log_n_rows) - 1),
},
);

Expand All @@ -214,16 +220,13 @@ mod tests {
}

#[test]
fn test_state_machine_total_sum() {
fn test_state_machine_claimed_sum() {
let log_n_rows = 8;
let config = PcsConfig::default();

// Initial and last state.
let initial_state = [M31::zero(); STATE_SIZE];
let last_state = [
M31::from_u32_unchecked(1 << log_n_rows),
M31::from_u32_unchecked(1 << (log_n_rows - 1)),
];
let last_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)];

// Setup protocol.
let channel = &mut Blake2sChannel::default();
Expand All @@ -234,7 +237,7 @@ mod tests {
let last_state_comb: QM31 = interaction_elements.combine(&last_state);

assert_eq!(
component.component0.total_sum + component.component1.total_sum,
component.component0.claimed_sum.0 + component.component1.claimed_sum.0,
initial_state_comb.inverse() - last_state_comb.inverse()
);
}
Expand Down

0 comments on commit de00dca

Please sign in to comment.