Skip to content

Commit

Permalink
parallelize mpt updates (#99)
Browse files Browse the repository at this point in the history
* parallelize mpt updates

* fix

* remove unnecessary log

* add ci task for testing parallel assignment

* fix all_paddings test failure

* collect runtime statistics

* defer invert to be batched

* fix

* assign_par for key_bit gadget

* assign_par for canonical_repr gadget

* comment

* fix lock file merge errors

---------

Co-authored-by: z2trillion <[email protected]>
Co-authored-by: Mason Liang <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2023
1 parent 9325879 commit 6353bc8
Show file tree
Hide file tree
Showing 11 changed files with 606 additions and 261 deletions.
14 changes: 13 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,19 @@ jobs:
toolchain: nightly-2022-12-10
override: true
- run: make test
bench:

par-test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: nightly-2022-12-10
override: true
- run: make test_par

bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand Down
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ num-bigint = "0.4"
hex = "0.4"
thiserror = "1.0"
log = "0.4"
env_logger = "0.9.0"
mpt-zktrie = { git = "https://github.com/scroll-tech/zkevm-circuits.git", rev = "7d9bc181953cfc6e7baf82ff0ce651281fd70a8a" }
rand_chacha = "0.3.0"
criterion = { version = "0.4", optional = true}
Expand All @@ -33,7 +34,8 @@ ethers-core = { git = "https://github.com/scroll-tech/ethers-rs.git", branch = "
[features]
# printout the layout of circuits for demo and some unittests
print_layout = ["halo2_proofs/dev-graph"]
bench = [ "dep:criterion" ]
default = ["halo2_proofs/mock-batch-inv", "halo2_proofs/parallel_syn"]
bench = ["dep:criterion"]

[dev-dependencies]
# mpt-zktrie = { path = "../scroll-circuits/zktrie" }
Expand All @@ -52,4 +54,4 @@ debug-assertions = true
[[bench]]
name = "parallel_assignment"
harness = false
required-features = [ "bench" ]
required-features = ["bench"]
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
test:
@cargo test

test_par:
PARALLEL_SYN=true cargo test -- --nocapture

fmt:
@cargo fmt

Expand Down
12 changes: 12 additions & 0 deletions src/constraint_builder/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{BinaryQuery, Query};
use halo2_proofs::plonk::Assigned;
use halo2_proofs::{
arithmetic::FieldExt,
circuit::{Region, Value},
Expand Down Expand Up @@ -101,6 +102,17 @@ impl AdviceColumn {
)
.expect("failed assign_advice");
}

pub fn assign_rational<F: FieldExt>(
&self,
region: &mut Region<'_, F>,
offset: usize,
value: Assigned<F>,
) {
region
.assign_advice(|| "advice", self.0, offset, || Value::known(value))
.expect("failed assign_advice");
}
}

#[derive(Clone, Copy)]
Expand Down
93 changes: 93 additions & 0 deletions src/gadgets/canonical_representation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use super::super::constraint_builder::{
};
use super::{byte_bit::RangeCheck256Lookup, is_zero::IsZeroGadget, rlc_randomness::RlcRandomness};
use ethers_core::types::U256;
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::Error;
use halo2_proofs::{
arithmetic::{Field, FieldExt},
circuit::{Region, Value},
Expand Down Expand Up @@ -212,6 +214,97 @@ impl CanonicalRepresentationConfig {
}
}

pub fn assign_par(
&self,
layouter: &mut impl Layouter<Fr>,
randomness: Value<Fr>,
values: &[Fr],
n_rows: usize,
) {
let modulus = U256::from_str_radix(Fr::MODULUS, 16).unwrap();
let mut modulus_bytes = [0u8; 32];
modulus.to_big_endian(&mut modulus_bytes);

let num_threads = std::thread::available_parallelism().unwrap().get();
let num_values = n_rows / 32;
let zero = Fr::zero();
log::debug!("num_real_values: {}", values.len());
let values = values
.iter()
.chain(std::iter::repeat(&zero))
.take(num_values)
.collect_vec();
let chunk_size = (num_values + num_threads - 1) / num_threads;
let mut is_first_passes = vec![true; num_threads];
let assignments = values
.chunks(chunk_size)
.zip(is_first_passes.iter_mut())
.enumerate()
.map(|(i, (values, is_first_pass))| {
move |mut region: Region<'_, Fr>| -> Result<(), Error> {
let region = &mut region;
if *is_first_pass {
*is_first_pass = false;
let last_off = if i == 0 {
values.len() * 32
} else {
values.len() * 32 - 1
};
self.value.assign(region, last_off, Fr::zero());
return Ok(());
}
let mut offset = if i == 0 { 1 } else { 0 };
for value in values.iter() {
let mut bytes = value.to_bytes();
bytes.reverse();
let mut differences_are_zero_so_far = true;
let mut rlc = Value::known(Fr::zero());
for (index, (byte, modulus_byte)) in
bytes.iter().zip_eq(&modulus_bytes).enumerate()
{
self.byte.assign(region, offset, u64::from(*byte));
self.modulus_byte
.assign(region, offset, u64::from(*modulus_byte));

self.index
.assign(region, offset, u64::try_from(index).unwrap());
if index.is_zero() {
self.index_is_zero.enable(region, offset);
} else if index == 31 {
self.index_is_31.enable(region, offset);
}

let difference =
Fr::from(u64::from(*modulus_byte)) - Fr::from(u64::from(*byte));
self.difference.assign(region, offset, difference);
self.difference_is_zero.assign(region, offset, difference);

self.differences_are_zero_so_far.assign(
region,
offset,
differences_are_zero_so_far,
);
differences_are_zero_so_far &= difference.is_zero_vartime();

self.value.assign(region, offset, **value);

rlc = rlc * randomness + Value::known(Fr::from(u64::from(*byte)));
self.rlc.assign(region, offset, rlc);

offset += 1
}
}

Ok(())
}
})
.collect_vec();

layouter
.assign_regions(|| "canonical_repr", assignments)
.unwrap();
}

pub fn n_rows_required(values: &[Fr]) -> usize {
// +1 because assigment starts on offset = 1 instead of offset = 0.
values.len() * 32 + 1
Expand Down
6 changes: 4 additions & 2 deletions src/gadgets/is_zero.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::constraint_builder::{AdviceColumn, BinaryQuery, ConstraintBuilder, Query};
use halo2_proofs::plonk::Assigned;
use halo2_proofs::{arithmetic::FieldExt, circuit::Region, plonk::ConstraintSystem};
use std::fmt::Debug;

Expand All @@ -25,10 +26,11 @@ impl IsZeroGadget {
) where
<T as TryInto<F>>::Error: Debug,
{
self.inverse_or_zero.assign(
self.inverse_or_zero.assign_rational(
region,
offset,
value.try_into().unwrap().invert().unwrap_or(F::zero()),
// invert is deferred and then batched by the real/mock prover
Assigned::<F>::from(value.try_into().unwrap()).invert(),
);
}

Expand Down
53 changes: 52 additions & 1 deletion src/gadgets/key_bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use super::{
canonical_representation::CanonicalRepresentationLookup,
};
use crate::constraint_builder::{AdviceColumn, ConstraintBuilder, Query};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::{
arithmetic::FieldExt, circuit::Region, halo2curves::bn256::Fr, plonk::ConstraintSystem,
};
use itertools::Itertools;

pub trait KeyBitLookup {
fn lookup<F: FieldExt>(&self) -> [Query<F>; 3];
Expand Down Expand Up @@ -83,10 +85,22 @@ impl KeyBitConfig {
}

pub fn assign(&self, region: &mut Region<'_, Fr>, lookups: &[(Fr, usize, bool)]) {
self.assign_internal(region, lookups, false)
}
pub fn assign_internal(
&self,
region: &mut Region<'_, Fr>,
lookups: &[(Fr, usize, bool)],
use_par: bool,
) {
// TODO; dedup lookups
for (offset, (value, index, bit)) in lookups.iter().enumerate() {
// TODO: either move the disabled row to the end of the assigment or get rid of it entirely.
let offset = offset + 1; // Start assigning at offet = 1 because the first row is disabled.
let offset = if !use_par {
offset + 1 // Start assigning at offet = 1 because the first row is disabled.
} else {
offset
};
let bytes = value.to_bytes();

let index_div_8 = index / 8; // index = (31 - index/8) * 8
Expand All @@ -107,6 +121,43 @@ impl KeyBitConfig {
}
}

pub fn assign_par(&self, layouter: &mut impl Layouter<Fr>, lookups: &[(Fr, usize, bool)]) {
let num_threads = std::thread::available_parallelism()
.expect("get num threads")
.get();
let chunk_size = (lookups.len() + num_threads - 1) / num_threads;
let mut is_first_pass = vec![true; num_threads];
let assignments = lookups
.chunks(chunk_size)
.zip(is_first_pass.iter_mut())
.enumerate()
.map(|(i, (lookups, is_first_pass))| {
move |mut region: Region<'_, Fr>| {
if *is_first_pass {
*is_first_pass = false;

if !lookups.is_empty() {
// only meant to get region's shape.
let last_off = if i == 0 {
// 1st row is disabled.
lookups.len()
} else {
lookups.len() - 1
};
self.byte.assign(&mut region, last_off, 0_u64);
}
return Ok(());
}
self.assign_internal(&mut region, lookups, true);

Ok(())
}
})
.collect_vec();

layouter.assign_regions(|| "key_bit", assignments).unwrap();
}

pub fn n_rows_required(lookups: &[(Fr, usize, bool)]) -> usize {
// +1 because assigment starts on offset = 1 instead of offset = 0.
1 + lookups.len()
Expand Down
Loading

0 comments on commit 6353bc8

Please sign in to comment.