Skip to content

Commit

Permalink
constant attempt 1
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Sep 29, 2024
1 parent a51f630 commit 573641f
Show file tree
Hide file tree
Showing 25 changed files with 589 additions and 70 deletions.
67 changes: 58 additions & 9 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use std::ops::Deref;
use itertools::Itertools;
use tracing::{span, Level};

use super::constant_columns::{ConstantColumn, ConstantTableLocation, StaticTree};
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::air::{Component, ComponentProver, Trace, CONST_INTERACTION};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
Expand All @@ -17,7 +18,7 @@ use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::pcs::{TreeSubspan, TreeVec};
use crate::core::pcs::{TreeLocation, TreeSubspan, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{utils, ColumnVec};
Expand All @@ -28,6 +29,7 @@ use crate::core::{utils, ColumnVec};
pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
static_table_offsets: ConstantTableLocation,
}

impl TraceLocationAllocator {
Expand All @@ -52,6 +54,23 @@ impl TraceLocationAllocator {
.collect(),
)
}

pub fn with_static_tree(tree: &StaticTree) -> Self {
Self {
next_tree_offsets: Default::default(),
static_table_offsets: tree.locations.clone(),
}
}

fn static_column_mappings(
&self,
constant_columns: &[ConstantColumn],
) -> ColumnVec<TreeLocation> {
constant_columns
.iter()
.map(|col| self.static_table_offsets.get_location(*col).unwrap())
.collect()
}
}

/// A component defined solely in means of the constraints framework.
Expand All @@ -68,16 +87,21 @@ pub trait FrameworkEval {

pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: TreeVec<TreeSubspan>,
mask_spans: TreeVec<TreeSubspan>,
static_columns_locations: Vec<TreeLocation>,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(provider: &mut TraceLocationAllocator, eval: E) -> Self {
let eval_tree_structure = eval.evaluate(InfoEvaluator::default()).mask_offsets;
let trace_locations = provider.next_for_structure(&eval_tree_structure);
let info = eval.evaluate(InfoEvaluator::default());
let eval_tree_structure = info.mask_offsets;
let mask_spans = provider.next_for_structure(&eval_tree_structure);
let static_columns_locations = provider.static_column_mappings(&info.external_cols);

Self {
eval,
trace_locations,
mask_spans,
static_columns_locations,
}
}
}
Expand Down Expand Up @@ -116,14 +140,21 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
})
}

fn constant_column_locations(&self) -> ColumnVec<usize> {
self.static_columns_locations
.iter()
.map(|loc| loc.col_index)
.collect()
}

fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
self.eval.evaluate(PointEvaluator::new(
mask.sub_tree(&self.trace_locations),
mask.sub_tree(&self.mask_spans),
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
));
Expand All @@ -139,8 +170,26 @@ impl<E: FrameworkEval> ComponentProver<SimdBackend> for FrameworkComponent<E> {
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);
// Retrieve necessary columns.
let mut mask_spans = self.mask_spans.clone();

// Constant columns locations are static, cannot be derived from mask spans.
mask_spans[CONST_INTERACTION] = TreeSubspan::empty();

let component_polys = TreeVec::concat_cols(
[
trace.polys.sub_tree(&mask_spans),
trace.polys.sub_tree_sparse(&self.static_columns_locations),
]
.into_iter(),
);
let component_evals = TreeVec::concat_cols(
[
trace.evals.sub_tree(&mask_spans),
trace.evals.sub_tree_sparse(&self.static_columns_locations),
]
.into_iter(),
);

// Extend trace if necessary.
// TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good
Expand Down
89 changes: 89 additions & 0 deletions crates/prover/src/constraint_framework/constant_columns.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::collections::HashMap;

use num_traits::One;

use crate::core::air::CONST_INTERACTION;
use crate::core::backend::{Backend, Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::pcs::TreeLocation;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::vcs::blake2_hash::Blake2sHash;

/// Generates a column with a single one at the first position, and zeros elsewhere.
pub fn gen_is_first<B: Backend>(log_size: u32) -> CircleEvaluation<B, BaseField, BitReversedOrder> {
Expand Down Expand Up @@ -35,3 +40,87 @@ pub fn gen_is_step_with_offset<B: Backend>(

CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col)
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConstantColumn {
XorTable(u32, u32, usize),
One(u32),
}

#[derive(Debug, Default, Clone)]
pub struct ConstantTableLocation {
locations: HashMap<ConstantColumn, usize>,
}

impl ConstantTableLocation {
pub fn add(&mut self, column: ConstantColumn, location: usize) {
if self.locations.contains_key(&column) {
panic!("Type already exists.");
}
self.locations.insert(column, location);
}

pub fn get_location(&self, column: ConstantColumn) -> Option<TreeLocation> {
self.locations.get(&column).map(|&col_index| TreeLocation {
tree_index: CONST_INTERACTION,
col_index,
})
}
}

#[derive(Debug, Default, Clone)]
pub struct StaticTree {
pub root: Blake2sHash,
pub locations: ConstantTableLocation,
}

impl StaticTree {
pub fn blake_tree() -> Self {
let root = Blake2sHash::default();
let mut locations = ConstantTableLocation::default();
locations.add(ConstantColumn::XorTable(12, 4, 0), 0);
locations.add(ConstantColumn::XorTable(12, 4, 1), 1);
locations.add(ConstantColumn::XorTable(12, 4, 2), 2);
locations.add(ConstantColumn::XorTable(9, 2, 0), 3);
locations.add(ConstantColumn::XorTable(9, 2, 1), 4);
locations.add(ConstantColumn::XorTable(9, 2, 2), 5);
locations.add(ConstantColumn::XorTable(8, 2, 0), 6);
locations.add(ConstantColumn::XorTable(8, 2, 1), 7);
locations.add(ConstantColumn::XorTable(8, 2, 2), 8);

locations.add(ConstantColumn::XorTable(7, 2, 0), 12);
locations.add(ConstantColumn::XorTable(7, 2, 1), 13);
locations.add(ConstantColumn::XorTable(7, 2, 2), 14);

locations.add(ConstantColumn::XorTable(4, 0, 0), 9);
locations.add(ConstantColumn::XorTable(4, 0, 1), 10);
locations.add(ConstantColumn::XorTable(4, 0, 2), 11);

StaticTree { root, locations }
}

pub fn add1(log_size: u32) -> Self {
let root = Blake2sHash::default();
let mut locations = ConstantTableLocation::default();
locations.add(ConstantColumn::One(log_size), 0);
StaticTree { root, locations }
}

pub fn get_location(&self, column: ConstantColumn) -> TreeLocation {
self.locations
.get_location(column)
.unwrap_or_else(|| panic!("{:?} column does not exist in the chosen tree!", column))
}

pub fn n_columns(&self) -> usize {
self.locations.locations.len()
}

pub fn log_sizes(&self) -> Vec<u32> {
self.locations
.locations
.iter()
.map(|(_, &log_size)| log_size as u32)
.collect()
}
}
11 changes: 11 additions & 0 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::ops::Mul;

use num_traits::One;

use super::constant_columns::ConstantColumn;
use super::EvalAtRow;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
Expand All @@ -12,6 +13,7 @@ use crate::core::pcs::TreeVec;
#[derive(Default)]
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub external_cols: Vec<ConstantColumn>,
pub n_constraints: usize,
}
impl InfoEvaluator {
Expand Down Expand Up @@ -45,4 +47,13 @@ impl EvalAtRow for InfoEvaluator {
fn combine_ef(_values: [Self::F; 4]) -> Self::EF {
SecureField::one()
}

fn constant_interaction_mask<const N: usize>(
&mut self,
col: ConstantColumn,
offsets: [isize; N],
) -> [Self::F; N] {
self.external_cols.push(col);
self.next_interaction_mask(0, offsets)
}
}
14 changes: 12 additions & 2 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ use std::ops::{Add, AddAssign, Mul, Neg, Sub};

pub use assert::{assert_constraints, AssertEvaluator};
pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
use constant_columns::ConstantColumn;
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
pub use simd_domain::SimdDomainEvaluator;

use crate::core::air::CONST_INTERACTION;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
Expand Down Expand Up @@ -65,17 +67,25 @@ pub trait EvalAtRow {

/// Returns the next mask value for the first interaction at offset 0.
fn next_trace_mask(&mut self) -> Self::F {
let [mask_item] = self.next_interaction_mask(0, [0]);
let [mask_item] = self.next_interaction_mask(1, [0]);
mask_item
}

/// Returns the mask values of the given offsets for the next column in the interaction.
/// Returns the mask values of the given offsets for the next owned column in the interaction.
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N];

fn constant_interaction_mask<const N: usize>(
&mut self,
_col: ConstantColumn,
offsets: [isize; N],
) -> [Self::F; N] {
self.next_interaction_mask(CONST_INTERACTION, offsets)
}

/// Returns the extension mask values of the given offsets for the next extension degree many
/// columns in the interaction.
fn next_extension_interaction_mask<const N: usize>(
Expand Down
79 changes: 78 additions & 1 deletion crates/prover/src/core/air/components.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::collections::BTreeSet;

use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Component, ComponentProver, Trace};
use super::{Component, ComponentProver, Trace, CONST_INTERACTION};
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
Expand All @@ -27,6 +29,81 @@ impl<'a> Components<'a> {
TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point)))
}

// Returns the unique mask points for each column.
pub fn mask_points_by_column(
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
let mut components_masks =
TreeVec::concat_cols(self.0.iter().map(|component| component.mask_points(point)));
components_masks[CONST_INTERACTION] = self.const_mask_points_by_column(point);
components_masks
}

fn const_mask_points_by_column(
&self,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
let mut static_column_masks: Vec<BTreeSet<CirclePoint<SecureField>>> = vec![];
for component in &self.0 {
let component_static_masks = &component.mask_points(point)[CONST_INTERACTION];
component_static_masks
.iter()
.zip(component.constant_column_locations())
.for_each(|(points, index)| {
if index >= static_column_masks.len() {
static_column_masks.resize_with(index + 1, Default::default);
}
static_column_masks[index].extend(points);
});
}
static_column_masks
.into_iter()
.map(|set| set.into_iter().collect())
.collect()
}

// Reorganizes the mask evaluations in the constant interaction according to the original mask
// points of each component.
pub fn reorganize_const_values_by_component(
&self,
point: CirclePoint<SecureField>,
mut mask_values: TreeVec<ColumnVec<Vec<SecureField>>>,
) -> TreeVec<ColumnVec<Vec<SecureField>>> {
mask_values[CONST_INTERACTION] =
self.const_mask_values_by_component(point, &mask_values[CONST_INTERACTION]);
mask_values
}

fn const_mask_values_by_component(
&self,
point: CirclePoint<SecureField>,
mask_values: &[Vec<SecureField>],
) -> ColumnVec<Vec<SecureField>> {
let mask_by_column = &self.mask_points_by_column(point)[CONST_INTERACTION];

let mut masks_values_by_component = vec![];
for component in &self.0 {
let component_static_masks = &component.mask_points(point)[CONST_INTERACTION];
component_static_masks
.iter()
.zip(component.constant_column_locations())
.for_each(|(points, column_idx)| {
let column_masks = &mask_by_column[column_idx];
masks_values_by_component.push(
points
.iter()
.map(|&point| {
mask_values[column_idx]
[column_masks.iter().position(|&p| p == point).unwrap()]
})
.collect_vec(),
);
});
}
masks_values_by_component
}

pub fn eval_composition_polynomial_at_point(
&self,
point: CirclePoint<SecureField>,
Expand Down
Loading

0 comments on commit 573641f

Please sign in to comment.