Skip to content

Commit

Permalink
Several refactor and implement backend UniHyperPlonk (#30)
Browse files Browse the repository at this point in the history
* feat: generalize `ClassicSumCheck` to also support lexical rotation, and impl `multilinear_eval` of [PH23] and `UniHyperPlonk` with it

* refactor: `poly` and `pcs`

* feat: improve `ph23`

* chore

* feat: add keccak256 circuit from `zkevm-circuits` in benchmark

* feat: add `pcs::univariate::{hyrax,ipa}`

* chore
  • Loading branch information
han0110 committed Oct 21, 2023
1 parent 4ce89e1 commit 303cf24
Show file tree
Hide file tree
Showing 62 changed files with 4,733 additions and 2,211 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]
members = ["benchmark", "plonkish_backend"]
resolver = "2"

[profile.flamegraph]
inherits = "release"
Expand Down
1 change: 1 addition & 0 deletions benchmark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ plonkish_backend = { path = "../plonkish_backend", features = ["benchmark"] }
halo2_proofs = { git = "https://github.com/han0110/halo2.git", branch = "feature/for-benchmark" }
halo2_gadgets = { git = "https://github.com/han0110/halo2.git", branch = "feature/for-benchmark", features = ["unstable"] }
snark-verifier = { git = "https://github.com/han0110/snark-verifier", branch = "feature/for-benchmark", default-features = false, features = ["loader_halo2", "system_halo2"] }
zkevm-circuits = { git = "https://github.com/han0110/zkevm-circuits", branch = "feature/for-benchmark" }

# espresso
ark-ff = { version = "0.4.0", default-features = false }
Expand Down
93 changes: 65 additions & 28 deletions benchmark/benches/proof_system.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
use benchmark::{
espresso,
halo2::{AggregationCircuit, Sha256Circuit},
halo2::{AggregationCircuit, Keccak256Circuit, Sha256Circuit},
};
use espresso_hyperplonk::{prelude::MockCircuit, HyperPlonkSNARK};
use espresso_subroutines::{MultilinearKzgPCS, PolyIOP, PolynomialCommitmentScheme};
use halo2_proofs::{
plonk::{create_proof, keygen_pk, keygen_vk, verify_proof},
poly::kzg::{
commitment::ParamsKZG,
multiopen::{ProverGWC, VerifierGWC},
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
},
transcript::{Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer},
};
use itertools::Itertools;
use plonkish_backend::{
backend::{self, PlonkishBackend, PlonkishCircuit},
backend::{self, PlonkishBackend, PlonkishCircuit, WitnessEncoding},
frontend::halo2::{circuit::VanillaPlonk, CircuitExt, Halo2Circuit},
halo2_curves::bn256::{Bn256, Fr},
pcs::multilinear,
pcs::{multilinear, univariate, CommitmentChunk},
util::{
end_timer, start_timer,
test::std_rng,
transcript::{InMemoryTranscript, Keccak256Transcript},
transcript::{InMemoryTranscript, Keccak256Transcript, TranscriptRead, TranscriptWrite},
},
};
use std::{
env::args,
fmt::Display,
fs::{create_dir, File, OpenOptions},
io::Write,
io::{Cursor, Write},
iter,
ops::Range,
path::Path,
Expand All @@ -44,38 +44,54 @@ fn main() {
k_range.for_each(|k| systems.iter().for_each(|system| system.bench(k, circuit)));
}

fn bench_hyperplonk<C: CircuitExt<Fr>>(k: usize) {
type MultilinearKzg = multilinear::MultilinearKzg<Bn256>;
type HyperPlonk = backend::hyperplonk::HyperPlonk<MultilinearKzg>;

fn bench_plonkish_backend<B, C>(system: System, k: usize)
where
B: PlonkishBackend<Fr> + WitnessEncoding,
C: CircuitExt<Fr>,
Keccak256Transcript<Cursor<Vec<u8>>>: TranscriptRead<CommitmentChunk<Fr, B::Pcs>, Fr>
+ TranscriptWrite<CommitmentChunk<Fr, B::Pcs>, Fr>
+ InMemoryTranscript,
{
let circuit = C::rand(k, std_rng());
let circuit = Halo2Circuit::new::<HyperPlonk>(k, circuit);
let circuit = Halo2Circuit::new::<B>(k, circuit);
let circuit_info = circuit.circuit_info().unwrap();
let instances = circuit.instances();

let timer = start_timer(|| format!("hyperplonk_setup-{k}"));
let param = HyperPlonk::setup(&circuit_info, std_rng()).unwrap();
let timer = start_timer(|| format!("{system}_setup-{k}"));
let param = B::setup(&circuit_info, std_rng()).unwrap();
end_timer(timer);

let timer = start_timer(|| format!("hyperplonk_preprocess-{k}"));
let (pp, vp) = HyperPlonk::preprocess(&param, &circuit_info).unwrap();
let timer = start_timer(|| format!("{system}_preprocess-{k}"));
let (pp, vp) = B::preprocess(&param, &circuit_info).unwrap();
end_timer(timer);

let proof = sample(System::HyperPlonk, k, || {
let _timer = start_timer(|| format!("hyperplonk_prove-{k}"));
let proof = sample(system, k, || {
let _timer = start_timer(|| format!("{system}_prove-{k}"));
let mut transcript = Keccak256Transcript::default();
HyperPlonk::prove(&pp, &circuit, &mut transcript, std_rng()).unwrap();
B::prove(&pp, &circuit, &mut transcript, std_rng()).unwrap();
transcript.into_proof()
});

let _timer = start_timer(|| format!("hyperplonk_verify-{k}"));
let _timer = start_timer(|| format!("{system}_verify-{k}"));
let accept = {
let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice());
HyperPlonk::verify(&vp, instances, &mut transcript, std_rng()).is_ok()
B::verify(&vp, instances, &mut transcript, std_rng()).is_ok()
};
assert!(accept);
}

fn bench_hyperplonk<C: CircuitExt<Fr>>(k: usize) {
type GeminiKzg = multilinear::Gemini<univariate::UnivariateKzg<Bn256>>;
type HyperPlonk = backend::hyperplonk::HyperPlonk<GeminiKzg>;
bench_plonkish_backend::<HyperPlonk, C>(System::HyperPlonk, k)
}

fn bench_unihyperplonk<C: CircuitExt<Fr>>(k: usize) {
type UnivariateKzg = univariate::UnivariateKzg<Bn256>;
type UniHyperPlonk = backend::unihyperplonk::UniHyperPlonk<UnivariateKzg, true>;
bench_plonkish_backend::<UniHyperPlonk, C>(System::UniHyperPlonk, k)
}

fn bench_halo2<C: CircuitExt<Fr>>(k: usize) {
let circuit = C::rand(k, std_rng());
let circuits = &[circuit];
Expand All @@ -93,11 +109,13 @@ fn bench_halo2<C: CircuitExt<Fr>>(k: usize) {
end_timer(timer);

let create_proof = |c, d, e, mut f: Blake2bWrite<_, _, _>| {
create_proof::<_, ProverGWC<_>, _, _, _, _, false>(&param, &pk, c, d, e, &mut f).unwrap();
create_proof::<_, ProverSHPLONK<_>, _, _, _, _, false>(&param, &pk, c, d, e, &mut f)
.unwrap();
f.finalize()
};
let verify_proof =
|c, d, e| verify_proof::<_, VerifierGWC<_>, _, _, _, false>(&param, pk.get_vk(), c, d, e);
let verify_proof = |c, d, e| {
verify_proof::<_, VerifierSHPLONK<_>, _, _, _, false>(&param, pk.get_vk(), c, d, e)
};

let proof = sample(System::Halo2, k, || {
let _timer = start_timer(|| format!("halo2_prove-{k}"));
Expand Down Expand Up @@ -150,6 +168,7 @@ fn bench_espresso_hyperplonk(circuit: MockCircuit<ark_bn254::Fr>) {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum System {
HyperPlonk,
UniHyperPlonk,
Halo2,
EspressoHyperPlonk,
}
Expand All @@ -158,6 +177,7 @@ impl System {
fn all() -> Vec<System> {
vec![
System::HyperPlonk,
System::UniHyperPlonk,
System::Halo2,
System::EspressoHyperPlonk,
]
Expand All @@ -176,12 +196,15 @@ impl System {

fn support(&self, circuit: Circuit) -> bool {
match self {
System::HyperPlonk | System::Halo2 => match circuit {
Circuit::VanillaPlonk | Circuit::Aggregation | Circuit::Sha256 => true,
System::HyperPlonk | System::UniHyperPlonk | System::Halo2 => match circuit {
Circuit::VanillaPlonk
| Circuit::Aggregation
| Circuit::Sha256
| Circuit::Keccak256 => true,
},
System::EspressoHyperPlonk => match circuit {
Circuit::VanillaPlonk => true,
Circuit::Aggregation | Circuit::Sha256 => false,
Circuit::Aggregation | Circuit::Sha256 | Circuit::Keccak256 => false,
},
}
}
Expand All @@ -199,15 +222,23 @@ impl System {
Circuit::VanillaPlonk => bench_hyperplonk::<VanillaPlonk<Fr>>(k),
Circuit::Aggregation => bench_hyperplonk::<AggregationCircuit<Bn256>>(k),
Circuit::Sha256 => bench_hyperplonk::<Sha256Circuit>(k),
Circuit::Keccak256 => bench_hyperplonk::<Keccak256Circuit>(k),
},
System::UniHyperPlonk => match circuit {
Circuit::VanillaPlonk => bench_unihyperplonk::<VanillaPlonk<Fr>>(k),
Circuit::Aggregation => bench_unihyperplonk::<AggregationCircuit<Bn256>>(k),
Circuit::Sha256 => bench_unihyperplonk::<Sha256Circuit>(k),
Circuit::Keccak256 => bench_unihyperplonk::<Keccak256Circuit>(k),
},
System::Halo2 => match circuit {
Circuit::VanillaPlonk => bench_halo2::<VanillaPlonk<Fr>>(k),
Circuit::Aggregation => bench_halo2::<AggregationCircuit<Bn256>>(k),
Circuit::Sha256 => bench_halo2::<Sha256Circuit>(k),
Circuit::Keccak256 => bench_halo2::<Keccak256Circuit>(k),
},
System::EspressoHyperPlonk => match circuit {
Circuit::VanillaPlonk => bench_espresso_hyperplonk(espresso::vanilla_plonk(k)),
Circuit::Aggregation | Circuit::Sha256 => unreachable!(),
Circuit::Aggregation | Circuit::Sha256 | Circuit::Keccak256 => unreachable!(),
},
}
}
Expand All @@ -217,6 +248,7 @@ impl Display for System {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
System::HyperPlonk => write!(f, "hyperplonk"),
System::UniHyperPlonk => write!(f, "unihyperplonk"),
System::Halo2 => write!(f, "halo2"),
System::EspressoHyperPlonk => write!(f, "espresso_hyperplonk"),
}
Expand All @@ -228,6 +260,7 @@ enum Circuit {
VanillaPlonk,
Aggregation,
Sha256,
Keccak256,
}

impl Circuit {
Expand All @@ -236,6 +269,7 @@ impl Circuit {
Circuit::VanillaPlonk => 4,
Circuit::Aggregation => 20,
Circuit::Sha256 => 17,
Circuit::Keccak256 => 10,
}
}
}
Expand All @@ -246,6 +280,7 @@ impl Display for Circuit {
Circuit::VanillaPlonk => write!(f, "vanilla_plonk"),
Circuit::Aggregation => write!(f, "aggregation"),
Circuit::Sha256 => write!(f, "sha256"),
Circuit::Keccak256 => write!(f, "keccak256"),
}
}
}
Expand All @@ -258,16 +293,18 @@ fn parse_args() -> (Vec<System>, Circuit, Range<usize>) {
"--system" => match value.as_str() {
"all" => systems = System::all(),
"hyperplonk" => systems.push(System::HyperPlonk),
"unihyperplonk" => systems.push(System::UniHyperPlonk),
"halo2" => systems.push(System::Halo2),
"espresso_hyperplonk" => systems.push(System::EspressoHyperPlonk),
_ => panic!(
"system should be one of {{all,hyperplonk,halo2,espresso_hyperplonk}}"
"system should be one of {{all,hyperplonk,unihyperplonk,halo2,espresso_hyperplonk}}"
),
},
"--circuit" => match value.as_str() {
"vanilla_plonk" => circuit = Circuit::VanillaPlonk,
"aggregation" => circuit = Circuit::Aggregation,
"sha256" => circuit = Circuit::Sha256,
"keccak256" => circuit = Circuit::Keccak256,
_ => panic!("circuit should be one of {{aggregation,vanilla_plonk}}"),
},
"--k" => {
Expand Down
79 changes: 62 additions & 17 deletions benchmark/src/bin/plotter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn main() {
}

fn parse_args() -> (bool, Vec<String>) {
let (verbose, logs) = args().chain(Some("".to_string())).tuple_windows().fold(
let (verbose, logs) = args().chain(["".to_string()]).tuple_windows().fold(
(false, None),
|(mut verbose, mut logs), (key, value)| {
match key.as_str() {
Expand Down Expand Up @@ -94,6 +94,7 @@ fn parse_args() -> (bool, Vec<String>) {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum System {
HyperPlonk,
UniHyperPlonk,
Halo2,
EspressoHyperPlonk,
}
Expand All @@ -102,6 +103,7 @@ impl System {
fn iter() -> impl Iterator<Item = System> {
[
System::HyperPlonk,
System::UniHyperPlonk,
System::Halo2,
System::EspressoHyperPlonk,
]
Expand All @@ -110,7 +112,7 @@ impl System {

fn key_fn(&self) -> impl Fn(&Log) -> (bool, &str) + '_ {
move |log| match self {
System::HyperPlonk | System::Halo2 => (
System::HyperPlonk | System::UniHyperPlonk | System::Halo2 => (
false,
log.name.split([' ', '-']).next().unwrap_or(&log.name),
),
Expand Down Expand Up @@ -167,6 +169,49 @@ impl System {
]),
),
],
System::UniHyperPlonk => vec![
(
"all",
vec![
vec!["variable_base_msm"],
vec!["sum_check_prove"],
vec!["prove_multilinear_eval"],
],
None,
),
("multiexp", vec![vec!["variable_base_msm"]], None),
("sum check", vec![vec!["sum_check_prove"]], None),
(
"mleval multiexp",
vec![
vec!["prove_multilinear_eval", "variable_base_msm"],
vec![
"prove_multilinear_eval",
"pcs_batch_open",
"variable_base_msm",
],
],
None,
),
(
"mleval fft",
vec![vec!["prove_multilinear_eval", "fft"]],
None,
),
(
"mleval rest",
vec![vec!["prove_multilinear_eval"]],
Some(vec![
vec!["prove_multilinear_eval", "variable_base_msm"],
vec![
"prove_multilinear_eval",
"pcs_batch_open",
"variable_base_msm",
],
vec!["prove_multilinear_eval", "fft"],
]),
),
],
System::Halo2 => vec![
(
"all",
Expand Down Expand Up @@ -320,6 +365,7 @@ impl Display for System {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
System::HyperPlonk => write!(f, "hyperplonk"),
System::UniHyperPlonk => write!(f, "unihyperplonk"),
System::Halo2 => write!(f, "halo2"),
System::EspressoHyperPlonk => write!(f, "espresso_hyperplonk"),
}
Expand Down Expand Up @@ -613,21 +659,20 @@ fn plot_comparison(cost_breakdowns_by_system: &[BTreeMap<usize, Vec<(&str, Durat
let lines = System::iter()
.zip(cost_breakdowns_by_system.iter())
.skip(1)
.filter_map(|(system, cost_breakdowns)| {
(!cost_breakdowns.is_empty()).then(|| {
let [numer, denom] =
[cost_breakdowns, hyperplonk_cost_breakdowns].map(|cost_breakdowns| {
x.iter()
.map(|k| cost_breakdowns[k][0].1.as_nanos() as f64)
.collect_vec()
});
let ratio = numer
.iter()
.zip(denom.iter())
.map(|(numer, denom)| numer / denom)
.collect_vec();
(format!("{system}/{}", System::HyperPlonk), ratio)
})
.filter(|(_, cost_breakdowns)| !cost_breakdowns.is_empty())
.map(|(system, cost_breakdowns)| {
let [numer, denom] =
[cost_breakdowns, hyperplonk_cost_breakdowns].map(|cost_breakdowns| {
x.iter()
.map(|k| cost_breakdowns[k][0].1.as_nanos() as f64)
.collect_vec()
});
let ratio = numer
.iter()
.zip(denom.iter())
.map(|(numer, denom)| numer / denom)
.collect_vec();
(format!("{system}/{}", System::HyperPlonk), ratio)
})
.collect_vec();

Expand Down
Loading

0 comments on commit 303cf24

Please sign in to comment.