Skip to content

Commit

Permalink
Group assignemnts by offset
Browse files Browse the repository at this point in the history
  • Loading branch information
Mason Liang committed Nov 3, 2023
1 parent eb57ded commit f1bce41
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 105 deletions.
113 changes: 44 additions & 69 deletions src/assignment_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,87 +6,62 @@ use halo2_proofs::{
circuit::{Region, Value},
plonk::Error,
};
use itertools::Itertools;
use rayon::prelude::*;
use std::collections::HashMap;
use std::collections::{BTreeMap, BTreeSet, HashMap};

#[derive(Clone, Default)]
pub struct AssignmentMap<F: FieldExt>(HashMap<(Column, usize), Value<F>>);
pub struct AssignmentMap<F: FieldExt>(BTreeMap<usize, Vec<(Column, Value<F>)>>);

impl<F: FieldExt> AssignmentMap<F> {
pub fn new(assignments: impl ParallelIterator<Item = ((Column, usize), Value<F>)>) -> Self {
Self(assignments.collect())
pub fn new(stream: impl ParallelIterator<Item = ((Column, usize), Value<F>)>) -> Self {
let mut sorted_by_offset: Vec<_> = stream
.map(|((column, offset), value)| (offset, column, value))
.collect();
sorted_by_offset.sort_by(|x, y| x.0.cmp(&y.0));
let grouped_by_offset = sorted_by_offset.iter().group_by(|(offset, _, _)| offset);
let y: BTreeMap<_, _> = grouped_by_offset
.into_iter()
.map(|(offset, group)| {
(
*offset,
group
.map(|(_offset, column, value)| (*column, *value))
.collect(),
)
})
.collect();
Self(y)
}

// pub fn enable_selector(&mut self, column: SelectorColumn, offset: usize) {
// self.add_assignment(column.into(), offset, Value::known(F::one()));
// }

// pub fn assign_fixed(&mut self, column: FixedColumn, offset: usize, assignment: F) {
// self.add_assignment(column.into(), offset, Value::known(assignment));
// }

// pub fn assign_advice(&mut self, column: AdviceColumn, offset: usize, assignment: F) {
// self.add_assignment(column.into(), offset, Value::known(assignment));
// }

// pub fn assign_second_phase_advice(
// &mut self,
// column: SecondPhaseAdviceColumn,
// offset: usize,
// assignment: Value<F>,
// ) {
// self.add_assignment(column.into(), offset, assignment);
// }

// fn add_assignment(&mut self, column: Column, offset: usize, assignment: Value<F>) {
// self.0
// .entry((column.into(), offset))
// .and_modify(|_| panic!("Did you mean to assign twice????"))
// .or_insert(assignment);
// }

pub fn assignments(self) -> Vec<impl FnMut(Region<'_, F>) -> Result<(), Error>> {
// self.0
// .into_iter()
// .map(|((column, offset), value)| {
// move |mut region: Region<'_, F>| {
// match column {
// Column::Selector(s) => {
// region.assign_fixed(|| "selector", s.0, offset, || value)
// }
// Column::Fixed(s) => region.assign_fixed(|| "fixed", s.0, offset, || value),
// Column::Advice(s) => {
// region.assign_advice(|| "advice", s.0, offset, || value)
// }
// Column::SecondPhaseAdvice(s) => {
// region.assign_advice(|| "second phase advice", s.0, offset, || value)
// }
// };
// Ok(())
// }
// })
// .collect()
vec![move |mut region: Region<'_, F>| {
let x = self.0.clone();
for ((column, offset), value) in x.into_iter() {
match column {
Column::Selector(s) => {
region.assign_fixed(|| "selector", s.0, offset, || value)
}
Column::Fixed(s) => region.assign_fixed(|| "fixed", s.0, offset, || value),
Column::Advice(s) => region.assign_advice(|| "advice", s.0, offset, || value),
Column::SecondPhaseAdvice(s) => {
region.assign_advice(|| "second phase advice", s.0, offset, || value)
pub fn to_vec(self) -> Vec<impl FnMut(Region<'_, F>) -> Result<(), Error>> {
self.0
.into_iter()
.map(|(_offset, column_assignments)| {
move |mut region: Region<'_, F>| {
for (column, value) in column_assignments.iter() {
match *column {
Column::Selector(s) => {
region.assign_fixed(|| "selector", s.0, 0, || *value)
}
Column::Fixed(s) => region.assign_fixed(|| "fixed", s.0, 0, || *value),
Column::Advice(s) => {
region.assign_advice(|| "advice", s.0, 0, || *value)
}
Column::SecondPhaseAdvice(s) => {
region.assign_advice(|| "second phase advice", s.0, 0, || *value)
}
}
.unwrap();
}
Ok(())
}
.unwrap();
}
Ok(())
}]
})
.collect()
}
}

#[derive(Clone, Copy, Hash, Eq, PartialEq)]
#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum Column {
Selector(SelectorColumn),
Fixed(FixedColumn),
Expand Down
8 changes: 4 additions & 4 deletions src/constraint_builder/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use halo2_proofs::{
};
use std::fmt::Debug;

#[derive(Clone, Copy, Hash, Eq, PartialEq)]
#[derive(Clone, Copy, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct SelectorColumn(pub Column<Fixed>);

impl SelectorColumn {
Expand Down Expand Up @@ -37,7 +37,7 @@ impl SelectorColumn {
}
}

#[derive(Clone, Copy, Hash, Eq, PartialEq)]
#[derive(Clone, Copy, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct FixedColumn(pub Column<Fixed>);

impl FixedColumn {
Expand Down Expand Up @@ -86,7 +86,7 @@ impl FixedColumn {
}
}

#[derive(Clone, Copy, Hash, Eq, PartialEq)]
#[derive(Clone, Copy, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct AdviceColumn(pub Column<Advice>);

impl AdviceColumn {
Expand Down Expand Up @@ -143,7 +143,7 @@ impl AdviceColumn {
}
}

#[derive(Clone, Copy, Hash, Eq, PartialEq)]
#[derive(Clone, Copy, Hash, Eq, PartialEq, PartialOrd, Ord)]
pub struct SecondPhaseAdviceColumn(pub Column<Advice>);

impl SecondPhaseAdviceColumn {
Expand Down
63 changes: 31 additions & 32 deletions src/mpt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
assignment_map::AssignmentMap,
assignment_map::{AssignmentMap, Column},
constraint_builder::{ConstraintBuilder, SelectorColumn},
gadgets::{
byte_bit::ByteBitGadget,
Expand All @@ -18,7 +18,7 @@ use crate::{
};
use halo2_proofs::{
arithmetic::FieldExt,
circuit::Layouter,
circuit::{Layouter, Value},
halo2curves::bn256::Fr,
plonk::{Challenge, ConstraintSystem, Error, Expression, VirtualCells},
};
Expand Down Expand Up @@ -126,23 +126,6 @@ impl MptCircuitConfig {
layouter.assign_region(
|| "mpt circuit",
|mut region| {
for offset in 1..n_rows {
self.selector.enable(&mut region, offset);
}

// pad canonical_representation to fixed count
// notice each input cost 32 rows in canonical_representation, and inside
// assign one extra input is added
let mut keys = mpt_update_keys(proofs);
keys.sort();
keys.dedup();
let total_rep_size = n_rows / 32 - 1;
assert!(
total_rep_size >= keys.len(),
"no enough space for canonical representation of all keys (need {})",
keys.len()
);

let n_assigned_rows = self.mpt_update.assign(&mut region, proofs, randomness);

assert!(
Expand All @@ -154,7 +137,6 @@ impl MptCircuitConfig {
for offset in 1 + n_assigned_rows..n_rows {
self.mpt_update.assign_padding_row(&mut region, offset);
}
self.is_final_row.enable(&mut region, n_rows - 1);

Ok(())
},
Expand All @@ -163,22 +145,27 @@ impl MptCircuitConfig {
let mut keys = mpt_update_keys(proofs);
keys.sort();
keys.dedup();

let selector_assignments = self.selector_assignments(n_rows);
let byte_bit_assignments = self.byte_bit.assignments();
let byte_representation_assignments = self
.byte_representation
.assignments(u32s, u64s, u128s, frs, randomness);
let canonical_representation_assignments = self
.canonical_representation
.assignments(keys, n_rows, randomness);
let key_bit_assignments = self.key_bit.assignments(key_bit_lookups(proofs));

layouter.assign_regions(
|| "mpt circuit parallel assignment",
AssignmentMap::new(
self.byte_bit
.assignments()
.chain(
self.byte_representation
.assignments(u32s, u64s, u128s, frs, randomness),
)
.chain(
self.canonical_representation
.assignments(keys, n_rows, randomness),
)
.chain(self.key_bit.assignments(key_bit_lookups(proofs))),
selector_assignments
.chain(byte_bit_assignments)
.chain(byte_representation_assignments)
.chain(canonical_representation_assignments)
.chain(key_bit_assignments),
)
.assignments(),
.to_vec(),
)?;

Ok(())
Expand All @@ -205,4 +192,16 @@ impl MptCircuitConfig {
.max()
.unwrap()
}

fn selector_assignments(
&self,
n_rows: usize,
) -> impl ParallelIterator<Item = ((Column, usize), Value<Fr>)> + '_ {
(0..n_rows).into_par_iter().flat_map_iter(move |offset| {
[
self.selector.assignment(offset, offset != 0),
self.is_final_row.assignment(offset, offset == n_rows - 1),
]
})
}
}

0 comments on commit f1bce41

Please sign in to comment.