diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index c0d8319fa..065f8f604 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -5,9 +5,10 @@ use std::ops::Deref; use itertools::Itertools; use tracing::{span, Level}; +use super::constant_columns::{ConstantColumn, ConstantTableLocation, StaticTree}; use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::air::{Component, ComponentProver, Trace, CONST_INTERACTION}; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; @@ -17,7 +18,7 @@ use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; -use crate::core::pcs::{TreeSubspan, TreeVec}; +use crate::core::pcs::{TreeLocation, TreeSubspan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::{utils, ColumnVec}; @@ -28,6 +29,7 @@ use crate::core::{utils, ColumnVec}; pub struct TraceLocationAllocator { /// Mapping of tree index to next available column offset. next_tree_offsets: TreeVec, + static_table_offsets: ConstantTableLocation, } impl TraceLocationAllocator { @@ -52,6 +54,23 @@ impl TraceLocationAllocator { .collect(), ) } + + pub fn with_static_tree(tree: &StaticTree) -> Self { + Self { + next_tree_offsets: Default::default(), + static_table_offsets: tree.locations.clone(), + } + } + + fn static_column_mappings( + &self, + constant_columns: &[ConstantColumn], + ) -> ColumnVec { + constant_columns + .iter() + .map(|col| self.static_table_offsets.get_location(*col).unwrap()) + .collect() + } } /// A component defined solely in means of the constraints framework. @@ -68,16 +87,21 @@ pub trait FrameworkEval { pub struct FrameworkComponent { eval: C, - trace_locations: TreeVec, + mask_spans: TreeVec, + static_columns_locations: Vec, } impl FrameworkComponent { pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self { - let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets; - let trace_locations = provider.next_for_structure(&eval_tree_structure); + let info = eval.evaluate(InfoEvaluator::default()); + let eval_tree_structure = info.mask_offsets; + let mask_spans = provider.next_for_structure(&eval_tree_structure); + let static_columns_locations = provider.static_column_mappings(&info.external_cols); + Self { eval, - trace_locations, + mask_spans, + static_columns_locations, } } } @@ -116,6 +140,13 @@ impl Component for FrameworkComponent { }) } + fn constant_column_locations(&self) -> ColumnVec { + self.static_columns_locations + .iter() + .map(|loc| loc.col_index) + .collect() + } + fn evaluate_constraint_quotients_at_point( &self, point: CirclePoint, @@ -123,7 +154,7 @@ impl Component for FrameworkComponent { evaluation_accumulator: &mut PointEvaluationAccumulator, ) { self.eval.evaluate(PointEvaluator::new( - mask.sub_tree(&self.trace_locations), + mask.sub_tree(&self.mask_spans), evaluation_accumulator, coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(), )); @@ -139,8 +170,26 @@ impl ComponentProver for FrameworkComponent { let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); let trace_domain = CanonicCoset::new(self.eval.log_size()); - let component_polys = trace.polys.sub_tree(&self.trace_locations); - let component_evals = trace.evals.sub_tree(&self.trace_locations); + // Retrieve necessary columns. + let mut mask_spans = self.mask_spans.clone(); + + // Constant columns locations are static, cannot be derived from mask spans. + mask_spans[CONST_INTERACTION] = TreeSubspan::empty(); + + let component_polys = TreeVec::concat_cols( + [ + trace.polys.sub_tree(&mask_spans), + trace.polys.sub_tree_sparse(&self.static_columns_locations), + ] + .into_iter(), + ); + let component_evals = TreeVec::concat_cols( + [ + trace.evals.sub_tree(&mask_spans), + trace.evals.sub_tree_sparse(&self.static_columns_locations), + ] + .into_iter(), + ); // Extend trace if necessary. // TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good diff --git a/crates/prover/src/constraint_framework/constant_columns.rs b/crates/prover/src/constraint_framework/constant_columns.rs index e57df28ab..dfa893877 100644 --- a/crates/prover/src/constraint_framework/constant_columns.rs +++ b/crates/prover/src/constraint_framework/constant_columns.rs @@ -1,10 +1,15 @@ +use std::collections::HashMap; + use num_traits::One; +use crate::core::air::CONST_INTERACTION; use crate::core::backend::{Backend, Col, Column}; use crate::core::fields::m31::BaseField; +use crate::core::pcs::TreeLocation; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; +use crate::core::vcs::blake2_hash::Blake2sHash; /// Generates a column with a single one at the first position, and zeros elsewhere. pub fn gen_is_first(log_size: u32) -> CircleEvaluation { @@ -35,3 +40,87 @@ pub fn gen_is_step_with_offset( CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ConstantColumn { + XorTable(u32, u32, usize), + One(u32), +} + +#[derive(Debug, Default, Clone)] +pub struct ConstantTableLocation { + locations: HashMap, +} + +impl ConstantTableLocation { + pub fn add(&mut self, column: ConstantColumn, location: usize) { + if self.locations.contains_key(&column) { + panic!("Type already exists."); + } + self.locations.insert(column, location); + } + + pub fn get_location(&self, column: ConstantColumn) -> Option { + self.locations.get(&column).map(|&col_index| TreeLocation { + tree_index: CONST_INTERACTION, + col_index, + }) + } +} + +#[derive(Debug, Default, Clone)] +pub struct StaticTree { + pub root: Blake2sHash, + pub locations: ConstantTableLocation, +} + +impl StaticTree { + pub fn blake_tree() -> Self { + let root = Blake2sHash::default(); + let mut locations = ConstantTableLocation::default(); + locations.add(ConstantColumn::XorTable(12, 4, 0), 0); + locations.add(ConstantColumn::XorTable(12, 4, 1), 1); + locations.add(ConstantColumn::XorTable(12, 4, 2), 2); + locations.add(ConstantColumn::XorTable(9, 2, 0), 3); + locations.add(ConstantColumn::XorTable(9, 2, 1), 4); + locations.add(ConstantColumn::XorTable(9, 2, 2), 5); + locations.add(ConstantColumn::XorTable(8, 2, 0), 6); + locations.add(ConstantColumn::XorTable(8, 2, 1), 7); + locations.add(ConstantColumn::XorTable(8, 2, 2), 8); + + locations.add(ConstantColumn::XorTable(7, 2, 0), 12); + locations.add(ConstantColumn::XorTable(7, 2, 1), 13); + locations.add(ConstantColumn::XorTable(7, 2, 2), 14); + + locations.add(ConstantColumn::XorTable(4, 0, 0), 9); + locations.add(ConstantColumn::XorTable(4, 0, 1), 10); + locations.add(ConstantColumn::XorTable(4, 0, 2), 11); + + StaticTree { root, locations } + } + + pub fn add1(log_size: u32) -> Self { + let root = Blake2sHash::default(); + let mut locations = ConstantTableLocation::default(); + locations.add(ConstantColumn::One(log_size), 0); + StaticTree { root, locations } + } + + pub fn get_location(&self, column: ConstantColumn) -> TreeLocation { + self.locations + .get_location(column) + .unwrap_or_else(|| panic!("{:?} column does not exist in the chosen tree!", column)) + } + + pub fn n_columns(&self) -> usize { + self.locations.locations.len() + } + + pub fn log_sizes(&self) -> Vec { + self.locations + .locations + .iter() + .map(|(_, &log_size)| log_size as u32) + .collect() + } +} diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index 05da93f6f..927251e35 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -2,7 +2,9 @@ use std::ops::Mul; use num_traits::One; +use super::constant_columns::ConstantColumn; use super::EvalAtRow; +use crate::core::air::CONST_INTERACTION; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::pcs::TreeVec; @@ -12,6 +14,7 @@ use crate::core::pcs::TreeVec; #[derive(Default)] pub struct InfoEvaluator { pub mask_offsets: TreeVec>>, + pub external_cols: Vec, pub n_constraints: usize, } impl InfoEvaluator { @@ -45,4 +48,13 @@ impl EvalAtRow for InfoEvaluator { fn combine_ef(_values: [Self::F; 4]) -> Self::EF { SecureField::one() } + + fn constant_interaction_mask( + &mut self, + col: ConstantColumn, + offsets: [isize; N], + ) -> [Self::F; N] { + self.external_cols.push(col); + self.next_interaction_mask(CONST_INTERACTION, offsets) + } } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 87069d344..53ccf42c5 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -13,11 +13,13 @@ use std::ops::{Add, AddAssign, Mul, Neg, Sub}; pub use assert::{assert_constraints, AssertEvaluator}; pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator}; +use constant_columns::ConstantColumn; pub use info::InfoEvaluator; use num_traits::{One, Zero}; pub use point::PointEvaluator; pub use simd_domain::SimdDomainEvaluator; +use crate::core::air::CONST_INTERACTION; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; @@ -65,17 +67,25 @@ pub trait EvalAtRow { /// Returns the next mask value for the first interaction at offset 0. fn next_trace_mask(&mut self) -> Self::F { - let [mask_item] = self.next_interaction_mask(0, [0]); + let [mask_item] = self.next_interaction_mask(1, [0]); mask_item } - /// Returns the mask values of the given offsets for the next column in the interaction. + /// Returns the mask values of the given offsets for the next owned column in the interaction. fn next_interaction_mask( &mut self, interaction: usize, offsets: [isize; N], ) -> [Self::F; N]; + fn constant_interaction_mask( + &mut self, + _col: ConstantColumn, + offsets: [isize; N], + ) -> [Self::F; N] { + self.next_interaction_mask(CONST_INTERACTION, offsets) + } + /// Returns the extension mask values of the given offsets for the next extension degree many /// columns in the interaction. fn next_extension_interaction_mask( diff --git a/crates/prover/src/core/air/components.rs b/crates/prover/src/core/air/components.rs index 320ec35cb..ceb635a42 100644 --- a/crates/prover/src/core/air/components.rs +++ b/crates/prover/src/core/air/components.rs @@ -1,7 +1,9 @@ +use std::collections::BTreeSet; + use itertools::Itertools; use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use super::{Component, ComponentProver, Trace}; +use super::{Component, ComponentProver, Trace, CONST_INTERACTION}; use crate::core::backend::Backend; use crate::core::circle::CirclePoint; use crate::core::fields::qm31::SecureField; @@ -27,6 +29,81 @@ impl<'a> Components<'a> { TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point))) } + // Returns the unique mask points for each column. + pub fn mask_points_by_column( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let mut components_masks = + TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point))); + components_masks[CONST_INTERACTION] = self.const_mask_points_by_column(point); + components_masks + } + + fn const_mask_points_by_column( + &self, + point: CirclePoint, + ) -> ColumnVec>> { + let mut static_column_masks: Vec>> = vec![]; + for component in &self.0 { + let component_static_masks = &component.mask_points(point)[CONST_INTERACTION]; + component_static_masks + .iter() + .zip(component.constant_column_locations()) + .for_each(|(points, index)| { + if index >= static_column_masks.len() { + static_column_masks.resize_with(index + 1, Default::default); + } + static_column_masks[index].extend(points); + }); + } + static_column_masks + .into_iter() + .map(|set| set.into_iter().collect()) + .collect() + } + + // Reorganizes the mask evaluations in the constant interaction according to the original mask + // points of each component. + pub fn reorganize_const_values_by_component( + &self, + point: CirclePoint, + mut mask_values: TreeVec>>, + ) -> TreeVec>> { + mask_values[CONST_INTERACTION] = + self.const_mask_values_by_component(point, &mask_values[CONST_INTERACTION]); + mask_values + } + + fn const_mask_values_by_component( + &self, + point: CirclePoint, + mask_values: &[Vec], + ) -> ColumnVec> { + let mask_by_column = &self.mask_points_by_column(point)[CONST_INTERACTION]; + + let mut masks_values_by_component = vec![]; + for component in &self.0 { + let component_static_masks = &component.mask_points(point)[CONST_INTERACTION]; + component_static_masks + .iter() + .zip(component.constant_column_locations()) + .for_each(|(points, column_idx)| { + let column_masks = &mask_by_column[column_idx]; + masks_values_by_component.push( + points + .iter() + .map(|&point| { + mask_values[column_idx] + [column_masks.iter().position(|&p| p == point).unwrap()] + }) + .collect_vec(), + ); + }); + } + masks_values_by_component + } + pub fn eval_composition_polynomial_at_point( &self, point: CirclePoint, diff --git a/crates/prover/src/core/air/mod.rs b/crates/prover/src/core/air/mod.rs index fcdd4d5f8..f634be8ea 100644 --- a/crates/prover/src/core/air/mod.rs +++ b/crates/prover/src/core/air/mod.rs @@ -14,6 +14,8 @@ pub mod accumulation; mod components; pub mod mask; +pub const CONST_INTERACTION: usize = 0; + /// Arithmetic Intermediate Representation (AIR). /// An Air instance is assumed to already contain all the information needed to /// evaluate the constraints. @@ -46,6 +48,11 @@ pub trait Component { point: CirclePoint, ) -> TreeVec>>>; + // TODO(Ohad): remove default implementation. + fn constant_column_locations(&self) -> ColumnVec { + vec![] + } + /// Evaluates the constraint quotients combination of the component at a point. fn evaluate_constraint_quotients_at_point( &self, diff --git a/crates/prover/src/core/pcs/mod.rs b/crates/prover/src/core/pcs/mod.rs index d9acf524b..30c2d1b30 100644 --- a/crates/prover/src/core/pcs/mod.rs +++ b/crates/prover/src/core/pcs/mod.rs @@ -25,6 +25,22 @@ pub struct TreeSubspan { pub col_end: usize, } +impl TreeSubspan { + pub fn empty() -> Self { + Self { + tree_index: 0, + col_start: 0, + col_end: 0, + } + } +} + +#[derive(Debug)] +pub struct TreeLocation { + pub tree_index: usize, + pub col_index: usize, +} + #[derive(Debug, Clone, Copy)] pub struct PcsConfig { pub pow_bits: u32, diff --git a/crates/prover/src/core/pcs/prover.rs b/crates/prover/src/core/pcs/prover.rs index e5ae1b266..2a2d703f6 100644 --- a/crates/prover/src/core/pcs/prover.rs +++ b/crates/prover/src/core/pcs/prover.rs @@ -83,14 +83,14 @@ impl<'a, B: BackendForChannel, MC: MerkleChannel> CommitmentSchemeProver<'a, pub fn prove_values( &self, - sampled_points: TreeVec>>>, + sampled_points: &TreeVec>>>, channel: &mut MC::C, ) -> CommitmentSchemeProof { // Evaluate polynomials on open points. let span = span!(Level::INFO, "Evaluate columns out of domain").entered(); let samples = self .polynomials() - .zip_cols(&sampled_points) + .zip_cols(sampled_points) .map_cols(|(poly, points)| { points .iter() diff --git a/crates/prover/src/core/pcs/utils.rs b/crates/prover/src/core/pcs/utils.rs index bfdbdb5d9..4fc8a02db 100644 --- a/crates/prover/src/core/pcs/utils.rs +++ b/crates/prover/src/core/pcs/utils.rs @@ -4,7 +4,7 @@ use std::ops::{Deref, DerefMut}; use itertools::zip_eq; use serde::{Deserialize, Serialize}; -use super::TreeSubspan; +use super::{TreeLocation, TreeSubspan}; use crate::core::ColumnVec; /// A container that holds an element for each commitment tree. @@ -127,7 +127,7 @@ impl TreeVec> { let max_tree_index = tree_indicies.iter().max().unwrap_or(&0); let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]); - for &location in locations { + for &location in locations.iter().filter(|l| l.col_start != l.col_end) { // TODO(andrew): Throwing error here might be better instead. let chunk = self.get_chunk(location).unwrap(); res[location.tree_index] = chunk; @@ -136,6 +136,23 @@ impl TreeVec> { res } + pub fn sub_tree_sparse(&self, locations: &[TreeLocation]) -> TreeVec> { + let tree_indicies: BTreeSet = locations.iter().map(|l| l.tree_index).collect(); + let max_tree_index = tree_indicies.iter().max().unwrap_or(&0); + let mut res = TreeVec(vec![Vec::new(); max_tree_index + 1]); + + for location in locations { + let column = self.get_single_column(location).expect("Invalid location"); + res[location.tree_index].push(column); + } + + res + } + + fn get_single_column(&self, location: &TreeLocation) -> Option<&T> { + self.0.get(location.tree_index)?.get(location.col_index) + } + fn get_chunk(&self, location: TreeSubspan) -> Option> { let tree = self.0.get(location.tree_index)?; let chunk = tree.get(location.col_start..location.col_end)?; diff --git a/crates/prover/src/core/prover/mod.rs b/crates/prover/src/core/prover/mod.rs index acc7f6c01..ece34253f 100644 --- a/crates/prover/src/core/prover/mod.rs +++ b/crates/prover/src/core/prover/mod.rs @@ -53,22 +53,34 @@ pub fn prove, MC: MerkleChannel>( let oods_point = CirclePoint::::get_random_point(channel); // Get mask sample points relative to oods point. - let mut sample_points = component_provers.components().mask_points(oods_point); + let mut sample_points = component_provers + .components() + .mask_points_by_column(oods_point); // Add the composition polynomial mask points. sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); // Prove the trace and composition OODS values, and retrieve them. - let commitment_scheme_proof = commitment_scheme.prove_values(sample_points, channel); + let commitment_scheme_proof = commitment_scheme.prove_values(&sample_points, channel); let sampled_oods_values = &commitment_scheme_proof.sampled_values; let composition_oods_eval = extract_composition_eval(sampled_oods_values).unwrap(); + // Evaluations from "prove_values" are ordered by commitment order. Reorg according to component + // usage. + let reorganized_sample_values = component_provers + .components() + .reorganize_const_values_by_component(oods_point, sampled_oods_values.clone()); + // Evaluate composition polynomial at OODS point and check that it matches the trace OODS // values. This is a sanity check. if composition_oods_eval != component_provers .components() - .eval_composition_polynomial_at_point(oods_point, sampled_oods_values, random_coeff) + .eval_composition_polynomial_at_point( + oods_point, + &reorganized_sample_values, + random_coeff, + ) { return Err(ProvingError::ConstraintsNotSatisfied); } @@ -101,7 +113,7 @@ pub fn verify( let oods_point = CirclePoint::::get_random_point(channel); // Get mask sample points relative to oods point. - let mut sample_points = components.mask_points(oods_point); + let mut sample_points = components.mask_points_by_column(oods_point); // Add the composition polynomial mask points. sample_points.push(vec![vec![oods_point]; SECURE_EXTENSION_DEGREE]); @@ -110,10 +122,13 @@ pub fn verify( VerificationError::InvalidStructure("Unexpected sampled_values structure".to_string()) })?; + let reogranized_sample_values = + components.reorganize_const_values_by_component(oods_point, sampled_oods_values.clone()); + if composition_oods_eval != components.eval_composition_polynomial_at_point( oods_point, - sampled_oods_values, + &reogranized_sample_values, random_coeff, ) { diff --git a/crates/prover/src/core/vcs/prover.rs b/crates/prover/src/core/vcs/prover.rs index 6312de114..84b847fd8 100644 --- a/crates/prover/src/core/vcs/prover.rs +++ b/crates/prover/src/core/vcs/prover.rs @@ -38,7 +38,7 @@ impl, H: MerkleHasher> MerkleProver { /// /// A new instance of `MerkleProver` with the committed layers. pub fn commit(columns: Vec<&Col>) -> Self { - assert!(!columns.is_empty()); + // assert!(!columns.is_empty()); let columns = &mut columns .into_iter() @@ -46,7 +46,11 @@ impl, H: MerkleHasher> MerkleProver { .peekable(); let mut layers: Vec> = Vec::new(); - let max_log_size = columns.peek().unwrap().len().ilog2(); + let max_log_size = if let Some(c) = columns.peek() { + c.len().ilog2() + } else { + 0 + }; for log_size in (0..=max_log_size).rev() { // Take columns of the current log_size. let layer_columns = columns diff --git a/crates/prover/src/core/vcs/verifier.rs b/crates/prover/src/core/vcs/verifier.rs index 53346bb93..a379acaa2 100644 --- a/crates/prover/src/core/vcs/verifier.rs +++ b/crates/prover/src/core/vcs/verifier.rs @@ -170,7 +170,8 @@ impl MerkleVerifier { return Err(MerkleVerificationError::WitnessTooLong); } - let [(_, computed_root)] = last_layer_hashes.unwrap().try_into().unwrap(); + // TODO(Ohad/Team): 'empty' trees will have a '000...00' root. + let [(_, computed_root)] = last_layer_hashes.unwrap().try_into().unwrap_or_default(); if computed_root != self.root { return Err(MerkleVerificationError::RootMismatch); } diff --git a/crates/prover/src/examples/blake/air.rs b/crates/prover/src/examples/blake/air.rs index ca583abe3..7cb5b7a86 100644 --- a/crates/prover/src/examples/blake/air.rs +++ b/crates/prover/src/examples/blake/air.rs @@ -8,6 +8,7 @@ use tracing::{span, Level}; use super::round::{blake_round_info, BlakeRoundComponent, BlakeRoundEval}; use super::scheduler::{BlakeSchedulerComponent, BlakeSchedulerEval}; use super::xor_table::{XorTableComponent, XorTableEval}; +use crate::constraint_framework::constant_columns::StaticTree; use crate::constraint_framework::TraceLocationAllocator; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::m31::LOG_N_LANES; @@ -52,7 +53,10 @@ impl BlakeStatement0 { sizes.push(xor_table::trace_sizes::<7, 2>()); sizes.push(xor_table::trace_sizes::<4, 0>()); - TreeVec::concat_cols(sizes.into_iter()) + // Constant columns fix. + let mut res = TreeVec::concat_cols(sizes.into_iter()); + res[0] = vec![16, 16, 16, 14, 14, 14, 12, 12, 12, 8, 8, 8, 10, 10, 10]; + res } fn mix_into(&self, channel: &mut impl Channel) { channel.mix_u64(self.log_size as u64); @@ -118,8 +122,13 @@ pub struct BlakeComponents { xor4: XorTableComponent<4, 0>, } impl BlakeComponents { - fn new(stmt0: &BlakeStatement0, all_elements: &AllElements, stmt1: &BlakeStatement1) -> Self { - let tree_span_provider = &mut TraceLocationAllocator::default(); + fn new( + stmt0: &BlakeStatement0, + all_elements: &AllElements, + stmt1: &BlakeStatement1, + static_tree: StaticTree, + ) -> Self { + let tree_span_provider = &mut TraceLocationAllocator::with_static_tree(&static_tree); Self { scheduler_component: BlakeSchedulerComponent::new( tree_span_provider, @@ -251,7 +260,24 @@ where let channel = &mut MC::C::default(); let commitment_scheme = &mut CommitmentSchemeProver::new(config, &twiddles); - let span = span!(Level::INFO, "Trace").entered(); + // Statement0. + let stmt0 = BlakeStatement0 { log_size }; + stmt0.mix_into(channel); + // Constant trace. + let span = span!(Level::INFO, "Constant Trace").entered(); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals( + chain![ + xor_table::generate_constant_trace::<12, 4>(), + xor_table::generate_constant_trace::<9, 2>(), + xor_table::generate_constant_trace::<8, 2>(), + xor_table::generate_constant_trace::<4, 0>(), + xor_table::generate_constant_trace::<7, 2>(), + ] + .collect_vec(), + ); + tree_builder.commit(channel); + span.exit(); // Scheduler. let (scheduler_trace, scheduler_lookup_data, round_inputs) = @@ -275,11 +301,8 @@ where let (xor_trace7, xor_lookup_data7) = xor_table::generate_trace(xor_accums.xor7); let (xor_trace4, xor_lookup_data4) = xor_table::generate_trace(xor_accums.xor4); - // Statement0. - let stmt0 = BlakeStatement0 { log_size }; - stmt0.mix_into(channel); - // Trace commitment. + let span = span!(Level::INFO, "Trace").entered(); let mut tree_builder = commitment_scheme.tree_builder(); tree_builder.extend_evals( chain![ @@ -294,7 +317,6 @@ where .collect_vec(), ); tree_builder.commit(channel); - span.exit(); // Draw lookup element. let all_elements = AllElements::draw(channel); @@ -346,6 +368,7 @@ where ] .collect_vec(), ); + tree_builder.commit(channel); // Statement1. let stmt1 = BlakeStatement1 { @@ -358,23 +381,6 @@ where xor4_claimed_sum, }; stmt1.mix_into(channel); - tree_builder.commit(channel); - span.exit(); - - // Constant trace. - let span = span!(Level::INFO, "Constant Trace").entered(); - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals( - chain![ - xor_table::generate_constant_trace::<12, 4>(), - xor_table::generate_constant_trace::<9, 2>(), - xor_table::generate_constant_trace::<8, 2>(), - xor_table::generate_constant_trace::<7, 2>(), - xor_table::generate_constant_trace::<4, 0>(), - ] - .collect_vec(), - ); - tree_builder.commit(channel); span.exit(); assert_eq!( @@ -386,8 +392,8 @@ where stmt0.log_sizes().0 ); - // Prove constraints. - let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + // Prove constraints + let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1, StaticTree::blake_tree()); let stark_proof = prove(&components.component_provers(), channel, commitment_scheme).unwrap(); BlakeProof { @@ -411,21 +417,22 @@ pub fn verify_blake( let log_sizes = stmt0.log_sizes(); - // Trace. + // Constant trace. stmt0.mix_into(channel); commitment_scheme.commit(stark_proof.commitments[0], &log_sizes[0], channel); + // Trace. + commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); + // Draw interaction elements. let all_elements = AllElements::draw(channel); // Interaction trace. - stmt1.mix_into(channel); - commitment_scheme.commit(stark_proof.commitments[1], &log_sizes[1], channel); - - // Constant trace. commitment_scheme.commit(stark_proof.commitments[2], &log_sizes[2], channel); + stmt1.mix_into(channel); + let static_tree = StaticTree::blake_tree(); - let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1); + let components = BlakeComponents::new(&stmt0, &all_elements, &stmt1, static_tree); // Check that all sums are correct. let total_sum = stmt1.scheduler_claimed_sum diff --git a/crates/prover/src/examples/blake/round/gen.rs b/crates/prover/src/examples/blake/round/gen.rs index 7adbe6fd5..3c5752ad6 100644 --- a/crates/prover/src/examples/blake/round/gen.rs +++ b/crates/prover/src/examples/blake/round/gen.rs @@ -37,7 +37,7 @@ pub struct TraceGenerator { impl TraceGenerator { fn new(log_size: u32) -> Self { assert!(log_size >= LOG_N_LANES); - let trace = (0..blake_round_info().mask_offsets[0].len()) + let trace = (0..blake_round_info().mask_offsets[1].len()) .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) .collect_vec(); Self { diff --git a/crates/prover/src/examples/blake/round/mod.rs b/crates/prover/src/examples/blake/round/mod.rs index cf8311339..f4f6e1dd4 100644 --- a/crates/prover/src/examples/blake/round/mod.rs +++ b/crates/prover/src/examples/blake/round/mod.rs @@ -32,7 +32,7 @@ impl FrameworkEval for BlakeRoundEval { eval, xor_lookup_elements: &self.xor_lookup_elements, round_lookup_elements: &self.round_lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size), + logup: LogupAtRow::new(2, self.claimed_sum, self.log_size), }; blake_eval.eval() } @@ -90,7 +90,7 @@ mod tests { &round_lookup_elements, ); - let trace = TreeVec::new(vec![trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); + let trace = TreeVec::new(vec![vec![], trace, interaction_trace, vec![gen_is_first(LOG_SIZE)]]); let trace_polys = trace.map_cols(|c| c.interpolate()); let component = BlakeRoundEval { diff --git a/crates/prover/src/examples/blake/scheduler/gen.rs b/crates/prover/src/examples/blake/scheduler/gen.rs index cd6a99b2f..077abdc82 100644 --- a/crates/prover/src/examples/blake/scheduler/gen.rs +++ b/crates/prover/src/examples/blake/scheduler/gen.rs @@ -54,7 +54,7 @@ pub fn gen_trace( let mut lookup_data = BlakeSchedulerLookupData::new(log_size); let mut round_inputs = Vec::with_capacity(inputs.len() * N_ROUNDS); - let mut trace = (0..blake_scheduler_info().mask_offsets[0].len()) + let mut trace = (0..blake_scheduler_info().mask_offsets[1].len()) .map(|_| unsafe { BaseColumn::uninitialized(1 << log_size) }) .collect_vec(); diff --git a/crates/prover/src/examples/blake/scheduler/mod.rs b/crates/prover/src/examples/blake/scheduler/mod.rs index e8a8c32f3..bed987a0d 100644 --- a/crates/prover/src/examples/blake/scheduler/mod.rs +++ b/crates/prover/src/examples/blake/scheduler/mod.rs @@ -33,7 +33,7 @@ impl FrameworkEval for BlakeSchedulerEval { &mut eval, &self.blake_lookup_elements, &self.round_lookup_elements, - LogupAtRow::new(1, self.claimed_sum, self.log_size), + LogupAtRow::new(2, self.claimed_sum, self.log_size), ); eval } @@ -86,7 +86,7 @@ mod tests { &blake_lookup_elements, ); - let trace = TreeVec::new(vec![trace, interaction_trace]); + let trace = TreeVec::new(vec![vec![], trace, interaction_trace]); let trace_polys = trace.map_cols(|c| c.interpolate()); let component = BlakeSchedulerEval { diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index f43d0088b..17c1b7245 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -1,6 +1,7 @@ use itertools::Itertools; use super::{limb_bits, XorElements}; +use crate::constraint_framework::constant_columns::ConstantColumn; use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; use crate::constraint_framework::EvalAtRow; use crate::core::fields::m31::BaseField; @@ -19,9 +20,15 @@ impl<'a, E: EvalAtRow, const ELEM_BITS: u32, const EXPAND_BITS: u32> // al, bl are the constant columns for the inputs: All pairs of elements in [0, // 2^LIMB_BITS). // cl is the constant column for the xor: al ^ bl. - let [al] = self.eval.next_interaction_mask(2, [0]); - let [bl] = self.eval.next_interaction_mask(2, [0]); - let [cl] = self.eval.next_interaction_mask(2, [0]); + let [al] = self + .eval + .constant_interaction_mask(ConstantColumn::XorTable(ELEM_BITS, EXPAND_BITS, 0), [0]); + let [bl] = self + .eval + .constant_interaction_mask(ConstantColumn::XorTable(ELEM_BITS, EXPAND_BITS, 1), [0]); + let [cl] = self + .eval + .constant_interaction_mask(ConstantColumn::XorTable(ELEM_BITS, EXPAND_BITS, 2), [0]); let frac_chunks = (0..(1 << (2 * EXPAND_BITS))) .map(|i| { diff --git a/crates/prover/src/examples/blake/xor_table/mod.rs b/crates/prover/src/examples/blake/xor_table/mod.rs index 877a65114..f1c8f3b4a 100644 --- a/crates/prover/src/examples/blake/xor_table/mod.rs +++ b/crates/prover/src/examples/blake/xor_table/mod.rs @@ -106,7 +106,7 @@ impl FrameworkEval let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> { eval, lookup_elements: &self.lookup_elements, - logup: LogupAtRow::new(1, self.claimed_sum, self.log_size()), + logup: LogupAtRow::new(2, self.claimed_sum, self.log_size()), }; xor_eval.eval() } diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index 330662de9..60509595b 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,5 +1,6 @@ pub mod blake; pub mod plonk; pub mod poseidon; +pub mod toy_const; pub mod wide_fibonacci; pub mod xor; diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index f2340e681..1f6d4a5a9 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -245,7 +245,7 @@ pub fn prove_fibonacci_plonk( #[cfg(test)] mod tests { use std::env; - + use crate::constraint_framework::logup::LookupElements; use crate::core::air::Component; use crate::core::channel::Blake2sChannel; @@ -254,7 +254,8 @@ mod tests { use crate::core::prover::verify; use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; use crate::examples::plonk::prove_fibonacci_plonk; - + + #[ignore = "Rewrite with constant columns consideration"] #[test_log::test] fn test_simd_plonk_prove() { // Get from environment variable: diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index d25cc1865..f81f0a1a2 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -59,7 +59,7 @@ impl FrameworkEval for PoseidonEval { self.log_n_rows + LOG_EXPAND } fn evaluate(&self, mut eval: E) -> E { - let logup = LogupAtRow::new(1, self.claimed_sum, self.log_n_rows); + let logup = LogupAtRow::new(2, self.claimed_sum, self.log_n_rows); eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements); eval } @@ -347,6 +347,11 @@ pub fn prove_poseidon( let commitment_scheme = &mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + // Constant Trace. + let span = span!(Level::INFO, "Constant Trace").entered(); + commitment_scheme.tree_builder().commit(channel); + span.exit(); + // Trace. let span = span!(Level::INFO, "Trace").entered(); let (trace, lookup_data) = gen_trace(log_n_rows); @@ -463,7 +468,7 @@ mod tests { let (trace1, claimed_sum) = gen_interaction_trace(LOG_N_ROWS, interaction_data, &lookup_elements); - let traces = TreeVec::new(vec![trace0, trace1]); + let traces = TreeVec::new(vec![vec![], trace0, trace1]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); assert_constraints(&trace_polys, CanonicCoset::new(LOG_N_ROWS), |mut eval| { @@ -505,12 +510,14 @@ mod tests { let sizes = component.trace_log_degree_bounds(); // Trace columns. commitment_scheme.commit(proof.commitments[0], &sizes[0], channel); + commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + // Draw lookup element. let lookup_elements = PoseidonElements::draw(channel); assert_eq!(lookup_elements, component.lookup_elements); // TODO(spapini): Check claimed sum against first and last instances. // Interaction columns. - commitment_scheme.commit(proof.commitments[1], &sizes[1], channel); + commitment_scheme.commit(proof.commitments[2], &sizes[2], channel); verify(&[&component], channel, commitment_scheme, proof).unwrap(); } diff --git a/crates/prover/src/examples/toy_const/constraints.rs b/crates/prover/src/examples/toy_const/constraints.rs new file mode 100644 index 000000000..a42017184 --- /dev/null +++ b/crates/prover/src/examples/toy_const/constraints.rs @@ -0,0 +1,42 @@ +use crate::constraint_framework::constant_columns::ConstantColumn; +use crate::constraint_framework::{EvalAtRow, FrameworkEval}; + +pub struct Add1Eval { + pub log_size: u32, +} + +impl FrameworkEval for Add1Eval { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let a = eval.next_trace_mask(); + let b = eval.next_trace_mask(); + let [one] = eval.constant_interaction_mask(ConstantColumn::One(self.log_size), [0]); + eval.add_constraint(a + one - b); + eval + } +} + +pub struct Add2Eval { + pub log_size: u32, +} + +impl FrameworkEval for Add2Eval { + fn log_size(&self) -> u32 { + self.log_size + } + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size + 1 + } + fn evaluate(&self, mut eval: E) -> E { + let a = eval.next_trace_mask(); + let b = eval.next_trace_mask(); + let [one, _] = eval.constant_interaction_mask(ConstantColumn::One(self.log_size), [0, 1]); + eval.add_constraint(a + one + one - b); + eval + } +} diff --git a/crates/prover/src/examples/toy_const/gen.rs b/crates/prover/src/examples/toy_const/gen.rs new file mode 100644 index 000000000..c05a73bb3 --- /dev/null +++ b/crates/prover/src/examples/toy_const/gen.rs @@ -0,0 +1,61 @@ +#![allow(unused)] +use itertools::Itertools; +use num_traits::One; +use rand::rngs::SmallRng; +use rand::{RngCore, SeedableRng}; + +use crate::core::backend::simd::column::BaseColumn; +use crate::core::backend::simd::SimdBackend; +use crate::core::fields::m31::BaseField; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; + +pub fn gen_add_1_trace( + log_size: u32, +) -> ColumnVec> { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::from_seed([0; 32]); + let a_col = BaseColumn::from_iter((0..1 << log_size).map(|_| BaseField::from(rng.next_u32()))); + let b_col = BaseColumn::from_iter( + a_col + .clone() + .into_cpu_vec() + .iter() + .map(|&x| x + BaseField::one()), + ); + [a_col, b_col] + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} + +pub fn gen_add_2_trace( + log_size: u32, +) -> ColumnVec> { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::from_seed([1; 32]); + let a_col = BaseColumn::from_iter((0..1 << log_size).map(|_| BaseField::from(rng.next_u32()))); + let b_col = BaseColumn::from_iter( + a_col + .clone() + .into_cpu_vec() + .iter() + .map(|&x| x + BaseField::one() + BaseField::one()), + ); + [a_col, b_col] + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} + +pub fn gen_const_1_trace( + log_size: u32, +) -> ColumnVec> { + let domain = CanonicCoset::new(log_size).circle_domain(); + let col = BaseColumn::from_iter((0..1 << log_size).map(|_| BaseField::one())); + [col] + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec() +} diff --git a/crates/prover/src/examples/toy_const/mod.rs b/crates/prover/src/examples/toy_const/mod.rs new file mode 100644 index 000000000..0ef044485 --- /dev/null +++ b/crates/prover/src/examples/toy_const/mod.rs @@ -0,0 +1,88 @@ +mod constraints; +mod gen; + +#[cfg(test)] +mod tests { + use itertools::chain; + + use crate::constraint_framework::constant_columns::StaticTree; + use crate::constraint_framework::{FrameworkComponent, TraceLocationAllocator}; + use crate::core::backend::simd::SimdBackend; + use crate::core::channel::Blake2sChannel; + use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; + use crate::core::poly::circle::{CanonicCoset, PolyOps}; + use crate::core::prover::{prove, verify}; + use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel; + use crate::examples::toy_const::constraints::{Add1Eval, Add2Eval}; + use crate::examples::toy_const::gen::{gen_add_1_trace, gen_add_2_trace, gen_const_1_trace}; + + #[test] + fn test_toy_const() { + const LOG_N_INSTANCES: u32 = 6; + let config = PcsConfig::default(); + // Precompute twiddles. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(LOG_N_INSTANCES + 1 + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let prover_channel = &mut Blake2sChannel::default(); + let commitment_scheme = + &mut CommitmentSchemeProver::::new( + config, &twiddles, + ); + let tree_span_provider = + &mut TraceLocationAllocator::with_static_tree(&StaticTree::add1(LOG_N_INSTANCES)); + + // Constant Trace. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(gen_const_1_trace(LOG_N_INSTANCES)); + tree_builder.commit(prover_channel); + + // Trace. + let add_1_trace = gen_add_1_trace(LOG_N_INSTANCES); + let add_2_trace = gen_add_2_trace(LOG_N_INSTANCES); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(chain![add_1_trace, add_2_trace]); + tree_builder.commit(prover_channel); + + // Prove constraints. + let add_1_component = FrameworkComponent::::new( + tree_span_provider, + Add1Eval { + log_size: LOG_N_INSTANCES, + }, + ); + let add_2_component = FrameworkComponent::::new( + tree_span_provider, + Add2Eval { + log_size: LOG_N_INSTANCES, + }, + ); + + let proof = prove::( + &[&add_1_component, &add_2_component], + prover_channel, + commitment_scheme, + ) + .unwrap(); + + // Verify. + let verifier_channel = &mut Blake2sChannel::default(); + let commitment_scheme = &mut CommitmentSchemeVerifier::::new(config); + + // Retrieve the expected column sizes in each commitment interaction, from the AIR. + let sizes = vec![6; 4]; + commitment_scheme.commit(proof.commitments[0], &[6], verifier_channel); + commitment_scheme.commit(proof.commitments[1], &sizes, verifier_channel); + verify( + &[&add_1_component, &add_2_component], + verifier_channel, + commitment_scheme, + proof, + ) + .unwrap(); + } +} diff --git a/crates/prover/src/examples/wide_fibonacci/mod.rs b/crates/prover/src/examples/wide_fibonacci/mod.rs index 7b0e9b766..ec08544c5 100644 --- a/crates/prover/src/examples/wide_fibonacci/mod.rs +++ b/crates/prover/src/examples/wide_fibonacci/mod.rs @@ -119,7 +119,7 @@ mod tests { #[test] fn test_wide_fibonacci_constraints() { const LOG_N_INSTANCES: u32 = 6; - let traces = TreeVec::new(vec![generate_test_trace(LOG_N_INSTANCES)]); + let traces = TreeVec::new(vec![vec![], generate_test_trace(LOG_N_INSTANCES)]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); @@ -138,7 +138,7 @@ mod tests { let mut trace = generate_test_trace(LOG_N_INSTANCES); // Modify the trace such that a constraint fail. trace[17].values.set(2, BaseField::one()); - let traces = TreeVec::new(vec![trace]); + let traces = TreeVec::new(vec![vec![], trace]); let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect_vec()); @@ -166,6 +166,9 @@ mod tests { &mut CommitmentSchemeProver::::new( config, &twiddles, ); + // Constant Trace. + let tree_builder = commitment_scheme.tree_builder(); + tree_builder.commit(prover_channel); // Trace. let trace = generate_test_trace(LOG_N_INSTANCES); @@ -195,6 +198,7 @@ mod tests { // Retrieve the expected column sizes in each commitment interaction, from the AIR. let sizes = component.trace_log_degree_bounds(); commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); } @@ -218,6 +222,9 @@ mod tests { config, &twiddles, ); + // Constant Trace. + commitment_scheme.tree_builder().commit(prover_channel); + // Trace. let trace = generate_test_trace(LOG_N_INSTANCES); let mut tree_builder = commitment_scheme.tree_builder(); @@ -246,6 +253,8 @@ mod tests { // Retrieve the expected column sizes in each commitment interaction, from the AIR. let sizes = component.trace_log_degree_bounds(); commitment_scheme.commit(proof.commitments[0], &sizes[0], verifier_channel); + println!("commitments: {:?}", proof.commitments); + commitment_scheme.commit(proof.commitments[1], &sizes[1], verifier_channel); verify(&[&component], verifier_channel, commitment_scheme, proof).unwrap(); } }