diff --git a/Cargo.toml b/Cargo.toml index 0fe466bd..247d8a82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = ["benchmark", "plonkish_backend"] +resolver = "2" [profile.flamegraph] inherits = "release" diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml index 5b91eaca..f8b03b99 100644 --- a/benchmark/Cargo.toml +++ b/benchmark/Cargo.toml @@ -15,6 +15,7 @@ plonkish_backend = { path = "../plonkish_backend", features = ["benchmark"] } halo2_proofs = { git = "https://github.com/han0110/halo2.git", branch = "feature/for-benchmark" } halo2_gadgets = { git = "https://github.com/han0110/halo2.git", branch = "feature/for-benchmark", features = ["unstable"] } snark-verifier = { git = "https://github.com/han0110/snark-verifier", branch = "feature/for-benchmark", default-features = false, features = ["loader_halo2", "system_halo2"] } +zkevm-circuits = { git = "https://github.com/han0110/zkevm-circuits", branch = "feature/for-benchmark" } # espresso ark-ff = { version = "0.4.0", default-features = false } diff --git a/benchmark/benches/proof_system.rs b/benchmark/benches/proof_system.rs index 99000df6..17a879ec 100644 --- a/benchmark/benches/proof_system.rs +++ b/benchmark/benches/proof_system.rs @@ -1,6 +1,6 @@ use benchmark::{ espresso, - halo2::{AggregationCircuit, Sha256Circuit}, + halo2::{AggregationCircuit, Keccak256Circuit, Sha256Circuit}, }; use espresso_hyperplonk::{prelude::MockCircuit, HyperPlonkSNARK}; use espresso_subroutines::{MultilinearKzgPCS, PolyIOP, PolynomialCommitmentScheme}; @@ -8,28 +8,28 @@ use halo2_proofs::{ plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, poly::kzg::{ commitment::ParamsKZG, - multiopen::{ProverGWC, VerifierGWC}, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, strategy::SingleStrategy, }, transcript::{Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer}, }; use itertools::Itertools; use plonkish_backend::{ - backend::{self, PlonkishBackend, PlonkishCircuit}, + backend::{self, PlonkishBackend, PlonkishCircuit, WitnessEncoding}, frontend::halo2::{circuit::VanillaPlonk, CircuitExt, Halo2Circuit}, halo2_curves::bn256::{Bn256, Fr}, - pcs::multilinear, + pcs::{multilinear, univariate, CommitmentChunk}, util::{ end_timer, start_timer, test::std_rng, - transcript::{InMemoryTranscript, Keccak256Transcript}, + transcript::{InMemoryTranscript, Keccak256Transcript, TranscriptRead, TranscriptWrite}, }, }; use std::{ env::args, fmt::Display, fs::{create_dir, File, OpenOptions}, - io::Write, + io::{Cursor, Write}, iter, ops::Range, path::Path, @@ -44,38 +44,54 @@ fn main() { k_range.for_each(|k| systems.iter().for_each(|system| system.bench(k, circuit))); } -fn bench_hyperplonk>(k: usize) { - type MultilinearKzg = multilinear::MultilinearKzg; - type HyperPlonk = backend::hyperplonk::HyperPlonk; - +fn bench_plonkish_backend(system: System, k: usize) +where + B: PlonkishBackend + WitnessEncoding, + C: CircuitExt, + Keccak256Transcript>>: TranscriptRead, Fr> + + TranscriptWrite, Fr> + + InMemoryTranscript, +{ let circuit = C::rand(k, std_rng()); - let circuit = Halo2Circuit::new::(k, circuit); + let circuit = Halo2Circuit::new::(k, circuit); let circuit_info = circuit.circuit_info().unwrap(); let instances = circuit.instances(); - let timer = start_timer(|| format!("hyperplonk_setup-{k}")); - let param = HyperPlonk::setup(&circuit_info, std_rng()).unwrap(); + let timer = start_timer(|| format!("{system}_setup-{k}")); + let param = B::setup(&circuit_info, std_rng()).unwrap(); end_timer(timer); - let timer = start_timer(|| format!("hyperplonk_preprocess-{k}")); - let (pp, vp) = HyperPlonk::preprocess(¶m, &circuit_info).unwrap(); + let timer = start_timer(|| format!("{system}_preprocess-{k}")); + let (pp, vp) = B::preprocess(¶m, &circuit_info).unwrap(); end_timer(timer); - let proof = sample(System::HyperPlonk, k, || { - let _timer = start_timer(|| format!("hyperplonk_prove-{k}")); + let proof = sample(system, k, || { + let _timer = start_timer(|| format!("{system}_prove-{k}")); let mut transcript = Keccak256Transcript::default(); - HyperPlonk::prove(&pp, &circuit, &mut transcript, std_rng()).unwrap(); + B::prove(&pp, &circuit, &mut transcript, std_rng()).unwrap(); transcript.into_proof() }); - let _timer = start_timer(|| format!("hyperplonk_verify-{k}")); + let _timer = start_timer(|| format!("{system}_verify-{k}")); let accept = { let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); - HyperPlonk::verify(&vp, instances, &mut transcript, std_rng()).is_ok() + B::verify(&vp, instances, &mut transcript, std_rng()).is_ok() }; assert!(accept); } +fn bench_hyperplonk>(k: usize) { + type GeminiKzg = multilinear::Gemini>; + type HyperPlonk = backend::hyperplonk::HyperPlonk; + bench_plonkish_backend::(System::HyperPlonk, k) +} + +fn bench_unihyperplonk>(k: usize) { + type UnivariateKzg = univariate::UnivariateKzg; + type UniHyperPlonk = backend::unihyperplonk::UniHyperPlonk; + bench_plonkish_backend::(System::UniHyperPlonk, k) +} + fn bench_halo2>(k: usize) { let circuit = C::rand(k, std_rng()); let circuits = &[circuit]; @@ -93,11 +109,13 @@ fn bench_halo2>(k: usize) { end_timer(timer); let create_proof = |c, d, e, mut f: Blake2bWrite<_, _, _>| { - create_proof::<_, ProverGWC<_>, _, _, _, _, false>(¶m, &pk, c, d, e, &mut f).unwrap(); + create_proof::<_, ProverSHPLONK<_>, _, _, _, _, false>(¶m, &pk, c, d, e, &mut f) + .unwrap(); f.finalize() }; - let verify_proof = - |c, d, e| verify_proof::<_, VerifierGWC<_>, _, _, _, false>(¶m, pk.get_vk(), c, d, e); + let verify_proof = |c, d, e| { + verify_proof::<_, VerifierSHPLONK<_>, _, _, _, false>(¶m, pk.get_vk(), c, d, e) + }; let proof = sample(System::Halo2, k, || { let _timer = start_timer(|| format!("halo2_prove-{k}")); @@ -150,6 +168,7 @@ fn bench_espresso_hyperplonk(circuit: MockCircuit) { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] enum System { HyperPlonk, + UniHyperPlonk, Halo2, EspressoHyperPlonk, } @@ -158,6 +177,7 @@ impl System { fn all() -> Vec { vec![ System::HyperPlonk, + System::UniHyperPlonk, System::Halo2, System::EspressoHyperPlonk, ] @@ -176,12 +196,15 @@ impl System { fn support(&self, circuit: Circuit) -> bool { match self { - System::HyperPlonk | System::Halo2 => match circuit { - Circuit::VanillaPlonk | Circuit::Aggregation | Circuit::Sha256 => true, + System::HyperPlonk | System::UniHyperPlonk | System::Halo2 => match circuit { + Circuit::VanillaPlonk + | Circuit::Aggregation + | Circuit::Sha256 + | Circuit::Keccak256 => true, }, System::EspressoHyperPlonk => match circuit { Circuit::VanillaPlonk => true, - Circuit::Aggregation | Circuit::Sha256 => false, + Circuit::Aggregation | Circuit::Sha256 | Circuit::Keccak256 => false, }, } } @@ -199,15 +222,23 @@ impl System { Circuit::VanillaPlonk => bench_hyperplonk::>(k), Circuit::Aggregation => bench_hyperplonk::>(k), Circuit::Sha256 => bench_hyperplonk::(k), + Circuit::Keccak256 => bench_hyperplonk::(k), + }, + System::UniHyperPlonk => match circuit { + Circuit::VanillaPlonk => bench_unihyperplonk::>(k), + Circuit::Aggregation => bench_unihyperplonk::>(k), + Circuit::Sha256 => bench_unihyperplonk::(k), + Circuit::Keccak256 => bench_unihyperplonk::(k), }, System::Halo2 => match circuit { Circuit::VanillaPlonk => bench_halo2::>(k), Circuit::Aggregation => bench_halo2::>(k), Circuit::Sha256 => bench_halo2::(k), + Circuit::Keccak256 => bench_halo2::(k), }, System::EspressoHyperPlonk => match circuit { Circuit::VanillaPlonk => bench_espresso_hyperplonk(espresso::vanilla_plonk(k)), - Circuit::Aggregation | Circuit::Sha256 => unreachable!(), + Circuit::Aggregation | Circuit::Sha256 | Circuit::Keccak256 => unreachable!(), }, } } @@ -217,6 +248,7 @@ impl Display for System { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { System::HyperPlonk => write!(f, "hyperplonk"), + System::UniHyperPlonk => write!(f, "unihyperplonk"), System::Halo2 => write!(f, "halo2"), System::EspressoHyperPlonk => write!(f, "espresso_hyperplonk"), } @@ -228,6 +260,7 @@ enum Circuit { VanillaPlonk, Aggregation, Sha256, + Keccak256, } impl Circuit { @@ -236,6 +269,7 @@ impl Circuit { Circuit::VanillaPlonk => 4, Circuit::Aggregation => 20, Circuit::Sha256 => 17, + Circuit::Keccak256 => 10, } } } @@ -246,6 +280,7 @@ impl Display for Circuit { Circuit::VanillaPlonk => write!(f, "vanilla_plonk"), Circuit::Aggregation => write!(f, "aggregation"), Circuit::Sha256 => write!(f, "sha256"), + Circuit::Keccak256 => write!(f, "keccak256"), } } } @@ -258,16 +293,18 @@ fn parse_args() -> (Vec, Circuit, Range) { "--system" => match value.as_str() { "all" => systems = System::all(), "hyperplonk" => systems.push(System::HyperPlonk), + "unihyperplonk" => systems.push(System::UniHyperPlonk), "halo2" => systems.push(System::Halo2), "espresso_hyperplonk" => systems.push(System::EspressoHyperPlonk), _ => panic!( - "system should be one of {{all,hyperplonk,halo2,espresso_hyperplonk}}" + "system should be one of {{all,hyperplonk,unihyperplonk,halo2,espresso_hyperplonk}}" ), }, "--circuit" => match value.as_str() { "vanilla_plonk" => circuit = Circuit::VanillaPlonk, "aggregation" => circuit = Circuit::Aggregation, "sha256" => circuit = Circuit::Sha256, + "keccak256" => circuit = Circuit::Keccak256, _ => panic!("circuit should be one of {{aggregation,vanilla_plonk}}"), }, "--k" => { diff --git a/benchmark/src/bin/plotter.rs b/benchmark/src/bin/plotter.rs index 84483a52..d4f890bc 100644 --- a/benchmark/src/bin/plotter.rs +++ b/benchmark/src/bin/plotter.rs @@ -57,7 +57,7 @@ fn main() { } fn parse_args() -> (bool, Vec) { - let (verbose, logs) = args().chain(Some("".to_string())).tuple_windows().fold( + let (verbose, logs) = args().chain(["".to_string()]).tuple_windows().fold( (false, None), |(mut verbose, mut logs), (key, value)| { match key.as_str() { @@ -94,6 +94,7 @@ fn parse_args() -> (bool, Vec) { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] enum System { HyperPlonk, + UniHyperPlonk, Halo2, EspressoHyperPlonk, } @@ -102,6 +103,7 @@ impl System { fn iter() -> impl Iterator { [ System::HyperPlonk, + System::UniHyperPlonk, System::Halo2, System::EspressoHyperPlonk, ] @@ -110,7 +112,7 @@ impl System { fn key_fn(&self) -> impl Fn(&Log) -> (bool, &str) + '_ { move |log| match self { - System::HyperPlonk | System::Halo2 => ( + System::HyperPlonk | System::UniHyperPlonk | System::Halo2 => ( false, log.name.split([' ', '-']).next().unwrap_or(&log.name), ), @@ -167,6 +169,49 @@ impl System { ]), ), ], + System::UniHyperPlonk => vec![ + ( + "all", + vec![ + vec!["variable_base_msm"], + vec!["sum_check_prove"], + vec!["prove_multilinear_eval"], + ], + None, + ), + ("multiexp", vec![vec!["variable_base_msm"]], None), + ("sum check", vec![vec!["sum_check_prove"]], None), + ( + "mleval multiexp", + vec![ + vec!["prove_multilinear_eval", "variable_base_msm"], + vec![ + "prove_multilinear_eval", + "pcs_batch_open", + "variable_base_msm", + ], + ], + None, + ), + ( + "mleval fft", + vec![vec!["prove_multilinear_eval", "fft"]], + None, + ), + ( + "mleval rest", + vec![vec!["prove_multilinear_eval"]], + Some(vec![ + vec!["prove_multilinear_eval", "variable_base_msm"], + vec![ + "prove_multilinear_eval", + "pcs_batch_open", + "variable_base_msm", + ], + vec!["prove_multilinear_eval", "fft"], + ]), + ), + ], System::Halo2 => vec![ ( "all", @@ -320,6 +365,7 @@ impl Display for System { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { System::HyperPlonk => write!(f, "hyperplonk"), + System::UniHyperPlonk => write!(f, "unihyperplonk"), System::Halo2 => write!(f, "halo2"), System::EspressoHyperPlonk => write!(f, "espresso_hyperplonk"), } @@ -613,21 +659,20 @@ fn plot_comparison(cost_breakdowns_by_system: &[BTreeMap); + + impl Circuit for Keccak256Circuit { + type Config = (KeccakCircuitConfig, Challenges); + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + KeccakCircuit::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + layouter: impl Layouter, + ) -> Result<(), Error> { + self.0.synthesize(config, layouter) + } + } + + impl CircuitExt for Keccak256Circuit { + fn rand(k: usize, mut rng: impl RngCore) -> Self { + let capacity = KeccakCircuit::::new(1 << k, Vec::new()).capacity(); + let mut input = vec![0; (capacity.unwrap() - 1) * 136]; + rng.fill_bytes(&mut input); + Keccak256Circuit(KeccakCircuit::new(1 << k, vec![input])) + } + + fn instances(&self) -> Vec> { + Vec::new() + } + } +} diff --git a/plonkish_backend/benches/zero_check.rs b/plonkish_backend/benches/zero_check.rs index de085d85..f4a7e9ac 100644 --- a/plonkish_backend/benches/zero_check.rs +++ b/plonkish_backend/benches/zero_check.rs @@ -7,24 +7,29 @@ use plonkish_backend::{ SumCheck, VirtualPolynomial, }, util::{ + arithmetic::Field, + expression::rotate::BinaryField, test::{rand_vec, seeded_std_rng}, transcript::Keccak256Transcript, }, }; use pprof::criterion::{Output, PProfProfiler}; -type ZeroCheck = ClassicSumCheck>; +type ZeroCheck = ClassicSumCheck, BinaryField>; fn run(num_vars: usize, virtual_poly: VirtualPolynomial) { let mut transcript = Keccak256Transcript::>::default(); - ZeroCheck::prove(&(), num_vars, virtual_poly, Fr::zero(), &mut transcript).unwrap(); + ZeroCheck::prove(&(), num_vars, virtual_poly, Fr::ZERO, &mut transcript).unwrap(); } fn zero_check(c: &mut Criterion) { let setup = |num_vars: usize| { let expression = vanilla_plonk_expression(num_vars); - let (polys, challenges) = - rand_vanilla_plonk_assignment::(num_vars, seeded_std_rng(), seeded_std_rng()); + let (polys, challenges) = rand_vanilla_plonk_assignment::( + num_vars, + seeded_std_rng(), + seeded_std_rng(), + ); let ys = [rand_vec(num_vars, seeded_std_rng())]; (expression, polys, challenges, ys) }; diff --git a/plonkish_backend/src/accumulation.rs b/plonkish_backend/src/accumulation.rs index 12ca8ca6..bfb342ac 100644 --- a/plonkish_backend/src/accumulation.rs +++ b/plonkish_backend/src/accumulation.rs @@ -255,7 +255,7 @@ pub(crate) mod test { seeded_std_rng(), ) }; - assert!(matches!(result, Ok(_))); + assert_eq!(result, Ok(())); end_timer(timer); } } diff --git a/plonkish_backend/src/accumulation/protostar.rs b/plonkish_backend/src/accumulation/protostar.rs index b0f6bfff..3537f64f 100644 --- a/plonkish_backend/src/accumulation/protostar.rs +++ b/plonkish_backend/src/accumulation/protostar.rs @@ -4,7 +4,7 @@ use crate::{ PlonkishNark, PlonkishNarkInstance, }, backend::PlonkishBackend, - pcs::{AdditiveCommitment, PolynomialCommitmentScheme}, + pcs::{Additive, PolynomialCommitmentScheme}, util::{ arithmetic::{inner_product, powers, Field}, chain, @@ -136,7 +136,7 @@ where cross_term_comms: &[Pcs::Commitment], r: &F, ) where - Pcs::Commitment: AdditiveCommitment, + Pcs::Commitment: Additive, { self.instance .fold_uncompressed(&rhs.instance, cross_term_comms, r); @@ -154,7 +154,7 @@ where compressed_cross_term_sums: &[F], r: &F, ) where - Pcs::Commitment: AdditiveCommitment, + Pcs::Commitment: Additive, { self.instance.fold_compressed( &rhs.instance, @@ -257,19 +257,19 @@ where fn fold_uncompressed(&mut self, rhs: &Self, cross_term_comms: &[C], r: &F) where - C: AdditiveCommitment, + C: Additive, { let one = F::ONE; let powers_of_r = powers(*r).take(cross_term_comms.len() + 2).collect_vec(); izip_eq!(&mut self.instances, &rhs.instances) .for_each(|(lhs, rhs)| izip_eq!(lhs, rhs).for_each(|(lhs, rhs)| *lhs += &(*rhs * r))); izip_eq!(&mut self.witness_comms, &rhs.witness_comms) - .for_each(|(lhs, rhs)| *lhs = C::sum_with_scalar([&one, r], [lhs, rhs])); + .for_each(|(lhs, rhs)| *lhs = C::msm([&one, r], [lhs, rhs])); izip_eq!(&mut self.challenges, &rhs.challenges).for_each(|(lhs, rhs)| *lhs += &(*rhs * r)); self.u += &(rhs.u * r); self.e_comm = { let comms = chain![[&self.e_comm], cross_term_comms, [&rhs.e_comm]]; - C::sum_with_scalar(&powers_of_r, comms) + C::msm(&powers_of_r, comms) }; } @@ -280,7 +280,7 @@ where compressed_cross_term_sums: &[F], r: &F, ) where - C: AdditiveCommitment, + C: Additive, { let one = F::ONE; let powers_of_r = powers(*r) @@ -289,12 +289,12 @@ where izip_eq!(&mut self.instances, &rhs.instances) .for_each(|(lhs, rhs)| izip_eq!(lhs, rhs).for_each(|(lhs, rhs)| *lhs += &(*rhs * r))); izip_eq!(&mut self.witness_comms, &rhs.witness_comms) - .for_each(|(lhs, rhs)| *lhs = C::sum_with_scalar([&one, r], [lhs, rhs])); + .for_each(|(lhs, rhs)| *lhs = C::msm([&one, r], [lhs, rhs])); izip_eq!(&mut self.challenges, &rhs.challenges).for_each(|(lhs, rhs)| *lhs += &(*rhs * r)); self.u += &(rhs.u * r); self.e_comm = { let comms = [&self.e_comm, zeta_cross_term_comm, &rhs.e_comm]; - C::sum_with_scalar(&powers_of_r[..3], comms) + C::msm(&powers_of_r[..3], comms) }; *self.compressed_e_sum.as_mut().unwrap() += &inner_product( &powers_of_r[1..], diff --git a/plonkish_backend/src/accumulation/protostar/hyperplonk.rs b/plonkish_backend/src/accumulation/protostar/hyperplonk.rs index bda73b78..221cab44 100644 --- a/plonkish_backend/src/accumulation/protostar/hyperplonk.rs +++ b/plonkish_backend/src/accumulation/protostar/hyperplonk.rs @@ -23,13 +23,15 @@ use crate::{ verifier::verify_sum_check, HyperPlonk, }, - PlonkishCircuit, PlonkishCircuitInfo, + PlonkishCircuit, PlonkishCircuitInfo, WitnessEncoding, }, - pcs::{AdditiveCommitment, CommitmentChunk, PolynomialCommitmentScheme}, + pcs::{Additive, CommitmentChunk, PolynomialCommitmentScheme}, poly::multilinear::MultilinearPolynomial, util::{ arithmetic::{powers, PrimeField}, - end_timer, start_timer, + chain, end_timer, + expression::rotate::BinaryField, + start_timer, transcript::{TranscriptRead, TranscriptWrite}, DeserializeOwned, Itertools, Serialize, }, @@ -45,8 +47,7 @@ impl AccumulationScheme for Protostar>, - Pcs::Commitment: AdditiveCommitment, - Pcs::CommitmentChunk: AdditiveCommitment, + Pcs::Commitment: Additive, { type Pcs = Pcs; type ProverParam = ProtostarProverParam>; @@ -147,17 +148,10 @@ where let timer = start_timer(|| format!("lookup_compressed_polys-{}", pp.lookups.len())); let lookup_compressed_polys = { - let instance_polys = instance_polys(pp.num_vars, instances); - let polys = iter::empty() - .chain(instance_polys.iter()) - .chain(pp.preprocess_polys.iter()) - .chain(witness_polys.iter()) - .collect_vec(); - let thetas = iter::empty() - .chain(Some(F::ONE)) - .chain(theta_primes.iter().cloned()) - .collect_vec(); - lookup_compressed_polys(&pp.lookups, &polys, &challenges, &thetas) + let instance_polys = instance_polys::<_, BinaryField>(pp.num_vars, instances); + let polys = chain![&instance_polys, &pp.preprocess_polys, &witness_polys].collect_vec(); + let thetas = chain![[F::ONE], theta_primes.iter().cloned()].collect_vec(); + lookup_compressed_polys::<_, BinaryField>(&pp.lookups, &polys, &challenges, &thetas) }; end_timer(timer); @@ -211,25 +205,21 @@ where Ok(PlonkishNark::new( instances.to_vec(), - iter::empty() - .chain(challenges) - .chain(theta_primes) - .chain(Some(beta_prime)) - .chain(zeta) - .chain(alpha_primes) - .collect(), - iter::empty() - .chain(witness_comms) - .chain(lookup_m_comms) - .chain(lookup_h_comms) - .chain(powers_of_zeta_comm) - .collect(), - iter::empty() - .chain(witness_polys) - .chain(lookup_m_polys) - .chain(lookup_h_polys.into_iter().flatten()) - .chain(powers_of_zeta_poly) - .collect(), + chain![challenges, theta_primes, [beta_prime], zeta, alpha_primes,].collect(), + chain![ + witness_comms, + lookup_m_comms, + lookup_h_comms, + powers_of_zeta_comm, + ] + .collect(), + chain![ + witness_polys, + lookup_m_polys, + lookup_h_polys.into_iter().flatten(), + powers_of_zeta_poly, + ] + .collect(), )) } @@ -397,19 +387,14 @@ where let nark = PlonkishNarkInstance::new( instances.to_vec(), - iter::empty() - .chain(challenges) - .chain(theta_primes) - .chain(Some(beta_prime)) - .chain(zeta) - .chain(alpha_primes) - .collect(), - iter::empty() - .chain(witness_comms) - .chain(lookup_m_comms) - .chain(lookup_h_comms) - .chain(powers_of_zeta_comm) - .collect(), + chain![challenges, theta_primes, [beta_prime], zeta, alpha_primes,].collect(), + chain![ + witness_comms, + lookup_m_comms, + lookup_h_comms, + powers_of_zeta_comm, + ] + .collect(), ); let incoming = ProtostarAccumulatorInstance::from_nark(*strategy, nark); accumulator.absorb_into(transcript)?; @@ -463,14 +448,16 @@ where let timer = start_timer(|| format!("permutation_z_polys-{}", pp.permutation_polys.len())); let builtin_witness_poly_offset = pp.num_witness_polys.iter().sum::(); - let instance_polys = instance_polys(pp.num_vars, &accumulator.instance.instances); - let polys = iter::empty() - .chain(&instance_polys) - .chain(&pp.preprocess_polys) - .chain(&accumulator.witness_polys[..builtin_witness_poly_offset]) - .chain(pp.permutation_polys.iter().map(|(_, poly)| poly)) - .collect_vec(); - let permutation_z_polys = permutation_z_polys( + let instance_polys = + instance_polys::<_, BinaryField>(pp.num_vars, &accumulator.instance.instances); + let polys = chain![ + &instance_polys, + &pp.preprocess_polys, + &accumulator.witness_polys[..builtin_witness_poly_offset], + pp.permutation_polys.iter().map(|(_, poly)| poly), + ] + .collect_vec(); + let permutation_z_polys = permutation_z_polys::<_, BinaryField>( pp.num_permutation_z_polys, &pp.permutation_polys, &polys, @@ -487,17 +474,19 @@ where let alpha = transcript.squeeze_challenge(); let y = transcript.squeeze_challenges(pp.num_vars); - let polys = iter::empty() - .chain(polys) - .chain(&accumulator.witness_polys[builtin_witness_poly_offset..]) - .chain(permutation_z_polys.iter()) - .chain(Some(&accumulator.e_poly)) - .collect_vec(); - let challenges = iter::empty() - .chain(accumulator.instance.challenges.iter().copied()) - .chain([accumulator.instance.u]) - .chain([beta, gamma, alpha]) - .collect(); + let polys = chain![ + polys, + &accumulator.witness_polys[builtin_witness_poly_offset..], + &permutation_z_polys, + [&accumulator.e_poly], + ] + .collect_vec(); + let challenges = chain![ + accumulator.instance.challenges.iter().copied(), + [accumulator.instance.u], + [beta, gamma, alpha], + ] + .collect(); let (points, evals) = { prove_sum_check( pp.num_instances.len(), @@ -513,15 +502,16 @@ where // PCS open let dummy_comm = Pcs::Commitment::default(); - let comms = iter::empty() - .chain(iter::repeat(&dummy_comm).take(pp.num_instances.len())) - .chain(&pp.preprocess_comms) - .chain(&accumulator.instance.witness_comms[..builtin_witness_poly_offset]) - .chain(&pp.permutation_comms) - .chain(&accumulator.instance.witness_comms[builtin_witness_poly_offset..]) - .chain(&permutation_z_comms) - .chain(Some(&accumulator.instance.e_comm)) - .collect_vec(); + let comms = chain![ + iter::repeat(&dummy_comm).take(pp.num_instances.len()), + &pp.preprocess_comms, + &accumulator.instance.witness_comms[..builtin_witness_poly_offset], + &pp.permutation_comms, + &accumulator.instance.witness_comms[builtin_witness_poly_offset..], + &permutation_z_comms, + [&accumulator.instance.e_comm], + ] + .collect_vec(); let timer = start_timer(|| format!("pcs_batch_open-{}", evals.len())); Pcs::batch_open(&pp.pcs, polys, comms, &points, &evals, transcript)?; end_timer(timer); @@ -552,11 +542,12 @@ where let alpha = transcript.squeeze_challenge(); let y = transcript.squeeze_challenges(vp.num_vars); - let challenges = iter::empty() - .chain(accumulator.challenges.iter().copied()) - .chain([accumulator.u]) - .chain([beta, gamma, alpha]) - .collect_vec(); + let challenges = chain![ + accumulator.challenges.iter().copied(), + [accumulator.u], + [beta, gamma, alpha], + ] + .collect_vec(); let (points, evals) = { verify_sum_check( vp.num_vars, @@ -573,27 +564,34 @@ where let builtin_witness_poly_offset = vp.num_witness_polys.iter().sum::(); let dummy_comm = Pcs::Commitment::default(); - let comms = iter::empty() - .chain(iter::repeat(&dummy_comm).take(vp.num_instances.len())) - .chain(&vp.preprocess_comms) - .chain(&accumulator.witness_comms[..builtin_witness_poly_offset]) - .chain(vp.permutation_comms.iter().map(|(_, comm)| comm)) - .chain(&accumulator.witness_comms[builtin_witness_poly_offset..]) - .chain(&permutation_z_comms) - .chain(Some(&accumulator.e_comm)) - .collect_vec(); + let comms = chain![ + iter::repeat(&dummy_comm).take(vp.num_instances.len()), + &vp.preprocess_comms, + &accumulator.witness_comms[..builtin_witness_poly_offset], + vp.permutation_comms.iter().map(|(_, comm)| comm), + &accumulator.witness_comms[builtin_witness_poly_offset..], + &permutation_z_comms, + [&accumulator.e_comm], + ] + .collect_vec(); Pcs::batch_verify(&vp.pcs, comms, &points, &evals, transcript)?; Ok(()) } } +impl WitnessEncoding for Protostar, STRATEGY> { + fn row_mapping(k: usize) -> Vec { + HyperPlonk::::row_mapping(k) + } +} + #[cfg(test)] pub(crate) mod test { use crate::{ accumulation::{protostar::Protostar, test::run_accumulation_scheme}, backend::hyperplonk::{ - util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_with_lookup_circuit}, + util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_w_lookup_circuit}, HyperPlonk, }, pcs::{ @@ -601,6 +599,7 @@ pub(crate) mod test { univariate::UnivariateKzg, }, util::{ + expression::rotate::BinaryField, test::{seeded_std_rng, std_rng}, transcript::Keccak256Transcript, Itertools, @@ -610,14 +609,14 @@ pub(crate) mod test { use std::iter; macro_rules! tests { - ($name:ident, $pcs:ty, $num_vars_range:expr) => { + ($suffix:ident, $pcs:ty, $num_vars_range:expr) => { paste::paste! { #[test] - fn [<$name _protostar_hyperplonk_vanilla_plonk>]() { + fn []() { run_accumulation_scheme::<_, Protostar>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { - let (circuit_info, _) = rand_vanilla_plonk_circuit(num_vars, std_rng(), seeded_std_rng()); + let (circuit_info, _) = rand_vanilla_plonk_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); let circuits = iter::repeat_with(|| { - let (_, circuit) = rand_vanilla_plonk_circuit(num_vars, std_rng(), seeded_std_rng()); + let (_, circuit) = rand_vanilla_plonk_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); circuit }).take(3).collect_vec(); (circuit_info, circuits) @@ -625,11 +624,11 @@ pub(crate) mod test { } #[test] - fn [<$name _protostar_hyperplonk_vanilla_plonk_with_lookup>]() { + fn []() { run_accumulation_scheme::<_, Protostar>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { - let (circuit_info, _) = rand_vanilla_plonk_with_lookup_circuit(num_vars, std_rng(), seeded_std_rng()); + let (circuit_info, _) = rand_vanilla_plonk_w_lookup_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); let circuits = iter::repeat_with(|| { - let (_, circuit) = rand_vanilla_plonk_with_lookup_circuit(num_vars, std_rng(), seeded_std_rng()); + let (_, circuit) = rand_vanilla_plonk_w_lookup_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); circuit }).take(3).collect_vec(); (circuit_info, circuits) @@ -637,8 +636,8 @@ pub(crate) mod test { } } }; - ($name:ident, $pcs:ty) => { - tests!($name, $pcs, 2..16); + ($suffix:ident, $pcs:ty) => { + tests!($suffix, $pcs, 2..16); }; } diff --git a/plonkish_backend/src/accumulation/protostar/hyperplonk/preprocessor.rs b/plonkish_backend/src/accumulation/protostar/hyperplonk/preprocessor.rs index 1ca692ca..83d212d4 100644 --- a/plonkish_backend/src/accumulation/protostar/hyperplonk/preprocessor.rs +++ b/plonkish_backend/src/accumulation/protostar/hyperplonk/preprocessor.rs @@ -21,7 +21,7 @@ use crate::{ }, Error, }; -use std::{array, borrow::Cow, collections::BTreeSet, hash::Hash, iter}; +use std::{array, borrow::Cow, collections::BTreeSet, hash::Hash}; pub(crate) fn batch_size( circuit_info: &PlonkishCircuitInfo, @@ -99,24 +99,22 @@ where witness_poly_offset + num_witness_polys + circuit_info.permutation_polys().len(); let poly_set = PolynomialSet { - preprocess: iter::empty() - .chain( - (circuit_info.num_instances.len()..) - .take(circuit_info.preprocess_polys.len()), - ) - .collect(), - folding: iter::empty() - .chain(0..circuit_info.num_instances.len()) - .chain((witness_poly_offset..).take(num_witness_polys)) - .chain((builtin_witness_poly_offset..).take(num_builtin_witness_polys)) - .collect(), + preprocess: chain![ + (circuit_info.num_instances.len()..).take(circuit_info.preprocess_polys.len()), + ] + .collect(), + folding: chain![ + 0..circuit_info.num_instances.len(), + (witness_poly_offset..).take(num_witness_polys), + (builtin_witness_poly_offset..).take(num_builtin_witness_polys), + ] + .collect(), }; let products = { - let mut constraints = iter::empty() - .chain(circuit_info.constraints.iter()) - .chain(lookup_constraints.iter()) - .collect_vec(); + let mut constraints = + chain![circuit_info.constraints.iter(), lookup_constraints.iter()] + .collect_vec(); let folding_degrees = constraints .iter() .map(|constraint| folding_degree(&poly_set.preprocess, constraint)) @@ -128,16 +126,15 @@ where constraints.swap(0, a.0); } } - let compressed_constraint = iter::empty() - .chain(constraints.first().cloned().cloned()) - .chain( - constraints - .into_iter() - .skip(1) - .zip((alpha_prime_offset..).map(Expression::Challenge)) - .map(|(constraint, challenge)| constraint * challenge), - ) - .sum::>(); + let compressed_constraint = chain![ + constraints.first().cloned().cloned(), + constraints + .into_iter() + .skip(1) + .zip((alpha_prime_offset..).map(Expression::Challenge)) + .map(|(constraint, challenge)| constraint * challenge), + ] + .sum::>(); products(&poly_set.preprocess, &compressed_constraint) }; @@ -170,25 +167,22 @@ where witness_poly_offset + num_witness_polys + circuit_info.permutation_polys().len(); let poly_set = PolynomialSet { - preprocess: iter::empty() - .chain( - (circuit_info.num_instances.len()..) - .take(circuit_info.preprocess_polys.len()), - ) - .collect(), - folding: iter::empty() - .chain(0..circuit_info.num_instances.len()) - .chain((witness_poly_offset..).take(num_witness_polys)) - .chain((builtin_witness_poly_offset..).take(num_builtin_witness_polys)) + preprocess: (circuit_info.num_instances.len()..) + .take(circuit_info.preprocess_polys.len()) .collect(), + folding: chain![ + 0..circuit_info.num_instances.len(), + (witness_poly_offset..).take(num_witness_polys), + (builtin_witness_poly_offset..).take(num_builtin_witness_polys), + ] + .collect(), }; let powers_of_zeta = builtin_witness_poly_offset + circuit_info.lookups.len() * 3; let compressed_products = { - let mut constraints = iter::empty() - .chain(circuit_info.constraints.iter()) - .chain(lookup_constraints.iter()) - .collect_vec(); + let mut constraints = + chain![circuit_info.constraints.iter(), lookup_constraints.iter()] + .collect_vec(); let folding_degrees = constraints .iter() .map(|constraint| folding_degree(&poly_set.preprocess, constraint)) @@ -202,16 +196,15 @@ where } let powers_of_zeta = Expression::::Polynomial(Query::new(powers_of_zeta, Rotation::cur())); - let compressed_constraint = iter::empty() - .chain(constraints.first().cloned().cloned()) - .chain( - constraints - .into_iter() - .skip(1) - .zip((alpha_prime_offset..).map(Expression::Challenge)) - .map(|(constraint, challenge)| constraint * challenge), - ) - .sum::>() + let compressed_constraint = chain![ + constraints.first().cloned().cloned(), + constraints + .into_iter() + .skip(1) + .zip((alpha_prime_offset..).map(Expression::Challenge)) + .map(|(constraint, challenge)| constraint * challenge), + ] + .sum::>() * powers_of_zeta; products(&poly_set.preprocess, &compressed_constraint) }; @@ -255,16 +248,15 @@ where let expression = { let zero_check_on_every_row = Expression::distribute_powers( - iter::empty() - .chain(Some(&zero_check_on_every_row)) - .chain(&permutation_constraints), + chain![[&zero_check_on_every_row], &permutation_constraints], alpha, ) * Expression::eq_xy(0); Expression::distribute_powers( - iter::empty() - .chain(&sum_check) - .chain(lookup_zero_checks.iter()) - .chain(Some(&zero_check_on_every_row)), + chain![ + &sum_check, + lookup_zero_checks.iter(), + [&zero_check_on_every_row], + ], alpha, ) }; @@ -317,13 +309,14 @@ pub(crate) fn max_degree( self::lookup_constraints(circuit_info, &dummy_challenges, &dummy_challenges[0]).0, ) }); - iter::empty() - .chain(circuit_info.constraints.iter().map(Expression::degree)) - .chain(lookup_constraints.iter().map(Expression::degree)) - .chain(circuit_info.max_degree) - .chain(Some(2)) - .max() - .unwrap() + chain![ + circuit_info.constraints.iter().map(Expression::degree), + lookup_constraints.iter().map(Expression::degree), + circuit_info.max_degree, + [2], + ] + .max() + .unwrap() } pub(crate) fn folding_degree( @@ -364,16 +357,15 @@ pub(crate) fn lookup_constraints( .map(|(input, table)| (input, table)) .unzip::<_, _, Vec<_>, Vec<_>>(); let [input, table] = &[inputs, tables].map(|exprs| { - iter::empty() - .chain(exprs.first().cloned().cloned()) - .chain( - exprs - .into_iter() - .skip(1) - .zip(theta_primes) - .map(|(expr, theta_prime)| expr * theta_prime), - ) - .sum::>() + chain![ + exprs.first().cloned().cloned(), + exprs + .into_iter() + .skip(1) + .zip(theta_primes) + .map(|(expr, theta_prime)| expr * theta_prime), + ] + .sum::>() }); [ h_input * (input + beta_prime) - one, @@ -395,12 +387,11 @@ pub(crate) fn lookup_constraints( } fn powers_of_zeta_constraint(zeta: usize, powers_of_zeta: usize) -> Expression { - let l_0 = &Expression::::lagrange(0); let l_last = &Expression::::lagrange(-1); let one = &Expression::one(); let zeta = &Expression::Challenge(zeta); let [powers_of_zeta, powers_of_zeta_next] = &[Rotation::cur(), Rotation::next()] .map(|rotation| Expression::Polynomial(Query::new(powers_of_zeta, rotation))); - powers_of_zeta_next - (l_0 + l_last * zeta + (one - (l_0 + l_last)) * powers_of_zeta * zeta) + powers_of_zeta_next - (l_last + (one - l_last) * powers_of_zeta * zeta) } diff --git a/plonkish_backend/src/accumulation/protostar/hyperplonk/prover.rs b/plonkish_backend/src/accumulation/protostar/hyperplonk/prover.rs index 54c45a5b..4fcb5e7c 100644 --- a/plonkish_backend/src/accumulation/protostar/hyperplonk/prover.rs +++ b/plonkish_backend/src/accumulation/protostar/hyperplonk/prover.rs @@ -4,14 +4,19 @@ use crate::{ pcs::PolynomialCommitmentScheme, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{div_ceil, powers, sum, BatchInvert, BooleanHypercube, PrimeField}, - expression::{evaluator::ExpressionRegistry, Expression, Rotation}, + arithmetic::{div_ceil, powers, sum, BatchInvert, PrimeField}, + chain, + expression::{ + evaluator::hadamard::HadamardEvaluator, + rotate::{BinaryField, Rotatable}, + Expression, Rotation, + }, izip, izip_eq, parallel::{num_threads, par_map_collect, parallelize, parallelize_iter}, Itertools, }, }; -use std::{borrow::Cow, hash::Hash, iter}; +use std::{borrow::Cow, hash::Hash}; pub(crate) fn lookup_h_polys( compressed_polys: &[[MultilinearPolynomial; 2]], @@ -47,11 +52,12 @@ fn lookup_h_poly( let chunk_size = div_ceil(2 * h_input.len(), num_threads()); parallelize_iter( - iter::empty() - .chain(h_input.chunks_mut(chunk_size)) - .chain(h_table.chunks_mut(chunk_size)), + chain![ + h_input.chunks_mut(chunk_size), + h_table.chunks_mut(chunk_size) + ], |h| { - h.iter_mut().batch_invert(); + h.batch_invert(); }, ); @@ -75,8 +81,10 @@ pub(super) fn powers_of_zeta_poly( num_vars: usize, zeta: F, ) -> MultilinearPolynomial { - let powers_of_zeta = powers(zeta).take(1 << num_vars).collect_vec(); - let nth_map = BooleanHypercube::new(num_vars).nth_map(); + let powers_of_zeta = chain![[F::ZERO], powers(zeta)] + .take(1 << num_vars) + .collect_vec(); + let nth_map = BinaryField::new(num_vars).nth_map(); MultilinearPolynomial::new(par_map_collect(&nth_map, |b| powers_of_zeta[*b])) } @@ -211,8 +219,8 @@ where let size = 1 << num_vars; let mut cross_term = vec![F::ZERO; size]; - let bh = BooleanHypercube::new(num_vars); - let next_map = bh.rotation_map(Rotation::next()); + let bf = BinaryField::new(num_vars); + let next_map = bf.rotation_map(Rotation::next()); parallelize(&mut cross_term, |(cross_term, start)| { cross_term .iter_mut() @@ -222,12 +230,9 @@ where - (acc_pow[b] * incoming_zeta + incoming_pow[b] * acc_zeta); }) }); - let b_0 = 0; - let b_last = bh.rotate(1, Rotation::prev()); - cross_term[b_0] += acc_pow[b_0] * incoming_zeta + incoming_pow[b_0] * acc_zeta - acc_u.double(); - cross_term[b_last] += acc_pow[b_last] * incoming_zeta + incoming_pow[b_last] * acc_zeta - - acc_u * incoming_zeta - - acc_zeta; + let b_last = bf.rotate(1, Rotation::prev()); + cross_term[b_last] += + acc_pow[b_last] * incoming_zeta + incoming_pow[b_last] * acc_zeta - acc_u.double(); MultilinearPolynomial::new(cross_term) } @@ -238,28 +243,32 @@ fn init_hadamard_evaluator<'a, F, Pcs>( preprocess_polys: &'a [MultilinearPolynomial], accumulator: &'a ProtostarAccumulator, incoming: &'a ProtostarAccumulator, -) -> HadamardEvaluator<'a, F> +) -> HadamardEvaluator<'a, F, BinaryField> where F: PrimeField, Pcs: PolynomialCommitmentScheme>, { assert!(!expressions.is_empty()); - let acc_instance_polys = instance_polys(num_vars, &accumulator.instance.instances); - let incoming_instance_polys = instance_polys(num_vars, &incoming.instance.instances); - let polys = iter::empty() - .chain(preprocess_polys.iter().map(Cow::Borrowed)) - .chain(acc_instance_polys.into_iter().map(Cow::Owned)) - .chain(accumulator.witness_polys.iter().map(Cow::Borrowed)) - .chain(incoming_instance_polys.into_iter().map(Cow::Owned)) - .chain(incoming.witness_polys.iter().map(Cow::Borrowed)) - .collect_vec(); - let challenges = iter::empty() - .chain(accumulator.instance.challenges.iter().cloned()) - .chain(Some(accumulator.instance.u)) - .chain(incoming.instance.challenges.iter().cloned()) - .chain(Some(incoming.instance.u)) - .collect_vec(); + let accumulator_instance_polys = + instance_polys::<_, BinaryField>(num_vars, &accumulator.instance.instances); + let incoming_instance_polys = + instance_polys::<_, BinaryField>(num_vars, &incoming.instance.instances); + let polys = chain![ + chain![preprocess_polys].map(|poly| Cow::Borrowed(poly.evals())), + chain![accumulator_instance_polys].map(|poly| Cow::Owned(poly.into_evals())), + chain![&accumulator.witness_polys].map(|poly| Cow::Borrowed(poly.evals())), + chain![incoming_instance_polys].map(|poly| Cow::Owned(poly.into_evals())), + chain![&incoming.witness_polys].map(|poly| Cow::Borrowed(poly.evals())), + ] + .collect_vec(); + let challenges = chain![ + accumulator.instance.challenges.iter().cloned(), + [accumulator.instance.u], + incoming.instance.challenges.iter().cloned(), + [incoming.instance.u], + ] + .collect_vec(); let expressions = expressions .iter() @@ -272,75 +281,3 @@ where HadamardEvaluator::new(num_vars, &expressions, polys) } - -#[derive(Clone, Debug)] -pub(crate) struct HadamardEvaluator<'a, F: PrimeField> { - pub(crate) num_vars: usize, - pub(crate) reg: ExpressionRegistry, - lagranges: Vec, - polys: Vec>>, -} - -impl<'a, F: PrimeField> HadamardEvaluator<'a, F> { - pub(crate) fn new( - num_vars: usize, - expressions: &[Expression], - polys: Vec>>, - ) -> Self { - let mut reg = ExpressionRegistry::new(); - for expression in expressions.iter() { - reg.register(expression); - } - assert!(reg.eq_xys().is_empty()); - - let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); - let lagranges = reg - .lagranges() - .iter() - .map(|i| bh[i.rem_euclid(1 << num_vars) as usize]) - .collect_vec(); - - Self { - num_vars, - reg, - lagranges, - polys, - } - } - - pub(crate) fn cache(&self) -> Vec { - self.reg.cache() - } - - pub(crate) fn evaluate(&self, evals: &mut [F], cache: &mut [F], b: usize) { - self.evaluate_calculations(cache, b); - izip_eq!(evals, self.reg.indexed_outputs()).for_each(|(eval, idx)| *eval = cache[*idx]) - } - - pub(crate) fn evaluate_and_sum(&self, sums: &mut [F], cache: &mut [F], b: usize) { - self.evaluate_calculations(cache, b); - izip_eq!(sums, self.reg.indexed_outputs()).for_each(|(sum, idx)| *sum += cache[*idx]) - } - - fn evaluate_calculations(&self, cache: &mut [F], b: usize) { - let bh = BooleanHypercube::new(self.num_vars); - if self.reg.has_identity() { - cache[self.reg.offsets().identity()] = F::from(b as u64); - } - cache[self.reg.offsets().lagranges()..] - .iter_mut() - .zip(&self.lagranges) - .for_each(|(value, i)| *value = if &b == i { F::ONE } else { F::ZERO }); - cache[self.reg.offsets().polys()..] - .iter_mut() - .zip(self.reg.polys()) - .for_each(|(value, (query, _))| { - *value = self.polys[query.poly()][bh.rotate(b, query.rotation())] - }); - self.reg - .indexed_calculations() - .iter() - .zip(self.reg.offsets().calculations()..) - .for_each(|(calculation, idx)| calculation.calculate(cache, idx)); - } -} diff --git a/plonkish_backend/src/accumulation/sangria/hyperplonk.rs b/plonkish_backend/src/accumulation/sangria/hyperplonk.rs index 314b2f71..80a44714 100644 --- a/plonkish_backend/src/accumulation/sangria/hyperplonk.rs +++ b/plonkish_backend/src/accumulation/sangria/hyperplonk.rs @@ -3,7 +3,7 @@ pub(crate) mod test { use crate::{ accumulation::{sangria::Sangria, test::run_accumulation_scheme}, backend::hyperplonk::{ - util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_with_lookup_circuit}, + util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_w_lookup_circuit}, HyperPlonk, }, pcs::{ @@ -11,6 +11,7 @@ pub(crate) mod test { univariate::UnivariateKzg, }, util::{ + expression::rotate::BinaryField, test::{seeded_std_rng, std_rng}, transcript::Keccak256Transcript, Itertools, @@ -20,14 +21,14 @@ pub(crate) mod test { use std::iter; macro_rules! tests { - ($name:ident, $pcs:ty, $num_vars_range:expr) => { + ($suffix:ident, $pcs:ty, $num_vars_range:expr) => { paste::paste! { #[test] - fn [<$name _sangria_hyperplonk_vanilla_plonk>]() { + fn []() { run_accumulation_scheme::<_, Sangria>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { - let (circuit_info, _) = rand_vanilla_plonk_circuit(num_vars, std_rng(), seeded_std_rng()); + let (circuit_info, _) = rand_vanilla_plonk_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); let circuits = iter::repeat_with(|| { - let (_, circuit) = rand_vanilla_plonk_circuit(num_vars, std_rng(), seeded_std_rng()); + let (_, circuit) = rand_vanilla_plonk_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); circuit }).take(3).collect_vec(); (circuit_info, circuits) @@ -35,11 +36,11 @@ pub(crate) mod test { } #[test] - fn [<$name _sangria_hyperplonk_vanilla_plonk_with_lookup>]() { + fn []() { run_accumulation_scheme::<_, Sangria>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { - let (circuit_info, _) = rand_vanilla_plonk_with_lookup_circuit(num_vars, std_rng(), seeded_std_rng()); + let (circuit_info, _) = rand_vanilla_plonk_w_lookup_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); let circuits = iter::repeat_with(|| { - let (_, circuit) = rand_vanilla_plonk_with_lookup_circuit(num_vars, std_rng(), seeded_std_rng()); + let (_, circuit) = rand_vanilla_plonk_w_lookup_circuit::<_, BinaryField>(num_vars, std_rng(), seeded_std_rng()); circuit }).take(3).collect_vec(); (circuit_info, circuits) @@ -47,8 +48,8 @@ pub(crate) mod test { } } }; - ($name:ident, $pcs:ty) => { - tests!($name, $pcs, 2..16); + ($suffix:ident, $pcs:ty) => { + tests!($suffix, $pcs, 2..16); }; } diff --git a/plonkish_backend/src/backend.rs b/plonkish_backend/src/backend.rs index db879eee..462e0483 100644 --- a/plonkish_backend/src/backend.rs +++ b/plonkish_backend/src/backend.rs @@ -2,6 +2,7 @@ use crate::{ pcs::{CommitmentChunk, PolynomialCommitmentScheme}, util::{ arithmetic::Field, + chain, expression::Expression, transcript::{TranscriptRead, TranscriptWrite}, Deserialize, DeserializeOwned, Itertools, Serialize, @@ -9,9 +10,10 @@ use crate::{ Error, }; use rand::RngCore; -use std::{collections::BTreeSet, fmt::Debug, iter}; +use std::{collections::BTreeSet, fmt::Debug}; pub mod hyperplonk; +pub mod unihyperplonk; pub trait PlonkishBackend: Clone + Debug { type Pcs: PolynomialCommitmentScheme; @@ -76,12 +78,12 @@ impl PlonkishCircuitInfo { pub fn is_well_formed(&self) -> bool { let num_poly = self.num_poly(); let num_challenges = self.num_challenges.iter().sum::(); - let polys = iter::empty() - .chain(self.expressions().flat_map(Expression::used_poly)) - .chain(self.permutation_polys()) - .collect::>(); - let challenges = iter::empty() - .chain(self.expressions().flat_map(Expression::used_challenge)) + let polys = chain![ + self.expressions().flat_map(Expression::used_poly), + self.permutation_polys(), + ] + .collect::>(); + let challenges = chain![self.expressions().flat_map(Expression::used_challenge)] .collect::>(); // Same amount of phases self.num_witness_polys.len() == self.num_challenges.len() @@ -121,11 +123,11 @@ impl PlonkishCircuitInfo { } pub fn expressions(&self) -> impl Iterator> { - iter::empty().chain(self.constraints.iter()).chain( - self.lookups - .iter() + chain![ + &self.constraints, + chain![&self.lookups] .flat_map(|lookup| lookup.iter().flat_map(|(input, table)| [input, table])), - ) + ] } } diff --git a/plonkish_backend/src/backend/hyperplonk.rs b/plonkish_backend/src/backend/hyperplonk.rs index e94d29fc..363ea533 100644 --- a/plonkish_backend/src/backend/hyperplonk.rs +++ b/plonkish_backend/src/backend/hyperplonk.rs @@ -1,7 +1,7 @@ use crate::{ backend::{ hyperplonk::{ - preprocessor::{batch_size, compose, permutation_polys}, + preprocessor::{batch_size, preprocess}, prover::{ instance_polys, lookup_compressed_polys, lookup_h_polys, lookup_m_polys, permutation_z_polys, prove_zero_check, @@ -13,9 +13,12 @@ use crate::{ pcs::PolynomialCommitmentScheme, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{powers, BooleanHypercube, PrimeField}, - end_timer, - expression::Expression, + arithmetic::{powers, PrimeField}, + chain, end_timer, + expression::{ + rotate::{BinaryField, Rotatable}, + Expression, + }, start_timer, transcript::{TranscriptRead, TranscriptWrite}, Deserialize, DeserializeOwned, Itertools, Serialize, @@ -98,67 +101,10 @@ where param: &Pcs::Param, circuit_info: &PlonkishCircuitInfo, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { - assert!(circuit_info.is_well_formed()); - - let num_vars = circuit_info.k; - let poly_size = 1 << num_vars; - let batch_size = batch_size(circuit_info); - let (pcs_pp, pcs_vp) = Pcs::trim(param, poly_size, batch_size)?; - - // Compute preprocesses comms - let preprocess_polys = circuit_info - .preprocess_polys - .iter() - .cloned() - .map(MultilinearPolynomial::new) - .collect_vec(); - let preprocess_comms = Pcs::batch_commit(&pcs_pp, &preprocess_polys)?; - - // Compute permutation polys and comms - let permutation_polys = permutation_polys( - num_vars, - &circuit_info.permutation_polys(), - &circuit_info.permutations, - ); - let permutation_comms = Pcs::batch_commit(&pcs_pp, &permutation_polys)?; - - // Compose `VirtualPolynomialInfo` - let (num_permutation_z_polys, expression) = compose(circuit_info); - let vp = HyperPlonkVerifierParam { - pcs: pcs_vp, - num_instances: circuit_info.num_instances.clone(), - num_witness_polys: circuit_info.num_witness_polys.clone(), - num_challenges: circuit_info.num_challenges.clone(), - num_lookups: circuit_info.lookups.len(), - num_permutation_z_polys, - num_vars, - expression: expression.clone(), - preprocess_comms: preprocess_comms.clone(), - permutation_comms: circuit_info - .permutation_polys() - .into_iter() - .zip(permutation_comms.clone()) - .collect(), - }; - let pp = HyperPlonkProverParam { - pcs: pcs_pp, - num_instances: circuit_info.num_instances.clone(), - num_witness_polys: circuit_info.num_witness_polys.clone(), - num_challenges: circuit_info.num_challenges.clone(), - lookups: circuit_info.lookups.clone(), - num_permutation_z_polys, - num_vars, - expression, - preprocess_polys, - preprocess_comms, - permutation_polys: circuit_info - .permutation_polys() - .into_iter() - .zip(permutation_polys) - .collect(), - permutation_comms, - }; - Ok((pp, vp)) + preprocess(param, circuit_info, |pp, polys| { + let comms = Pcs::batch_commit(pp, &polys)?; + Ok((polys, comms)) + }) } fn prove( @@ -175,7 +121,7 @@ where transcript.common_field_element(instance)?; } } - instance_polys(pp.num_vars, instances) + instance_polys::<_, BinaryField>(pp.num_vars, instances) }; // Round 0..n @@ -202,11 +148,7 @@ where witness_polys.extend(polys); challenges.extend(transcript.squeeze_challenges(*num_challenges)); } - let polys = iter::empty() - .chain(instance_polys.iter()) - .chain(pp.preprocess_polys.iter()) - .chain(witness_polys.iter()) - .collect_vec(); + let polys = chain![&instance_polys, &pp.preprocess_polys, &witness_polys].collect_vec(); // Round n @@ -216,7 +158,7 @@ where let lookup_compressed_polys = { let max_lookup_width = pp.lookups.iter().map(Vec::len).max().unwrap_or_default(); let betas = powers(beta).take(max_lookup_width).collect_vec(); - lookup_compressed_polys(&pp.lookups, &polys, &challenges, &betas) + lookup_compressed_polys::<_, BinaryField>(&pp.lookups, &polys, &challenges, &betas) }; end_timer(timer); @@ -235,7 +177,7 @@ where end_timer(timer); let timer = start_timer(|| format!("permutation_z_polys-{}", pp.permutation_polys.len())); - let permutation_z_polys = permutation_z_polys( + let permutation_z_polys = permutation_z_polys::<_, BinaryField>( pp.num_permutation_z_polys, &pp.permutation_polys, &polys, @@ -244,10 +186,8 @@ where ); end_timer(timer); - let lookup_h_permutation_z_polys = iter::empty() - .chain(lookup_h_polys.iter()) - .chain(permutation_z_polys.iter()) - .collect_vec(); + let lookup_h_permutation_z_polys = + chain![lookup_h_polys.iter(), permutation_z_polys.iter()].collect_vec(); let lookup_h_permutation_z_comms = Pcs::batch_commit_and_write(&pp.pcs, lookup_h_permutation_z_polys.clone(), transcript)?; @@ -256,12 +196,13 @@ where let alpha = transcript.squeeze_challenge(); let y = transcript.squeeze_challenges(pp.num_vars); - let polys = iter::empty() - .chain(polys) - .chain(pp.permutation_polys.iter().map(|(_, poly)| poly)) - .chain(lookup_m_polys.iter()) - .chain(lookup_h_permutation_z_polys) - .collect_vec(); + let polys = chain![ + polys, + pp.permutation_polys.iter().map(|(_, poly)| poly), + lookup_m_polys.iter(), + lookup_h_permutation_z_polys, + ] + .collect_vec(); challenges.extend([beta, gamma, alpha]); let (points, evals) = prove_zero_check( pp.num_instances.len(), @@ -275,14 +216,15 @@ where // PCS open let dummy_comm = Pcs::Commitment::default(); - let comms = iter::empty() - .chain(iter::repeat(&dummy_comm).take(pp.num_instances.len())) - .chain(&pp.preprocess_comms) - .chain(&witness_comms) - .chain(&pp.permutation_comms) - .chain(&lookup_m_comms) - .chain(&lookup_h_permutation_z_comms) - .collect_vec(); + let comms = chain![ + iter::repeat(&dummy_comm).take(pp.num_instances.len()), + &pp.preprocess_comms, + &witness_comms, + &pp.permutation_comms, + &lookup_m_comms, + &lookup_h_permutation_z_comms, + ] + .collect_vec(); let timer = start_timer(|| format!("pcs_batch_open-{}", evals.len())); Pcs::batch_open(&pp.pcs, polys, comms, &points, &evals, transcript)?; end_timer(timer); @@ -348,14 +290,15 @@ where // PCS verify let dummy_comm = Pcs::Commitment::default(); - let comms = iter::empty() - .chain(iter::repeat(&dummy_comm).take(vp.num_instances.len())) - .chain(&vp.preprocess_comms) - .chain(&witness_comms) - .chain(vp.permutation_comms.iter().map(|(_, comm)| comm)) - .chain(&lookup_m_comms) - .chain(&lookup_h_permutation_z_comms) - .collect_vec(); + let comms = chain![ + iter::repeat(&dummy_comm).take(vp.num_instances.len()), + &vp.preprocess_comms, + &witness_comms, + vp.permutation_comms.iter().map(|(_, comm)| comm), + &lookup_m_comms, + &lookup_h_permutation_z_comms, + ] + .collect_vec(); Pcs::batch_verify(&vp.pcs, comms, &points, &evals, transcript)?; Ok(()) @@ -364,7 +307,7 @@ where impl WitnessEncoding for HyperPlonk { fn row_mapping(k: usize) -> Vec { - BooleanHypercube::new(k).iter().skip(1).chain([0]).collect() + BinaryField::new(k).usable_indices() } } @@ -373,7 +316,7 @@ mod test { use crate::{ backend::{ hyperplonk::{ - util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_with_lookup_circuit}, + util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_w_lookup_circuit}, HyperPlonk, }, test::run_plonkish_backend, @@ -386,8 +329,8 @@ mod test { univariate::UnivariateKzg, }, util::{ - code::BrakedownSpec6, hash::Keccak256, test::seeded_std_rng, - transcript::Keccak256Transcript, + code::BrakedownSpec6, expression::rotate::BinaryField, hash::Keccak256, + test::seeded_std_rng, transcript::Keccak256Transcript, }, }; use halo2_curves::{ @@ -396,25 +339,25 @@ mod test { }; macro_rules! tests { - ($name:ident, $pcs:ty, $num_vars_range:expr) => { + ($suffix:ident, $pcs:ty, $num_vars_range:expr) => { paste::paste! { #[test] - fn [<$name _hyperplonk_vanilla_plonk>]() { + fn []() { run_plonkish_backend::<_, HyperPlonk<$pcs>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { - rand_vanilla_plonk_circuit(num_vars, seeded_std_rng(), seeded_std_rng()) + rand_vanilla_plonk_circuit::<_, BinaryField>(num_vars, seeded_std_rng(), seeded_std_rng()) }); } #[test] - fn [<$name _hyperplonk_vanilla_plonk_with_lookup>]() { + fn []() { run_plonkish_backend::<_, HyperPlonk<$pcs>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { - rand_vanilla_plonk_with_lookup_circuit(num_vars, seeded_std_rng(), seeded_std_rng()) + rand_vanilla_plonk_w_lookup_circuit::<_, BinaryField>(num_vars, seeded_std_rng(), seeded_std_rng()) }); } } }; - ($name:ident, $pcs:ty) => { - tests!($name, $pcs, 2..16); + ($suffix:ident, $pcs:ty) => { + tests!($suffix, $pcs, 2..16); }; } diff --git a/plonkish_backend/src/backend/hyperplonk/preprocessor.rs b/plonkish_backend/src/backend/hyperplonk/preprocessor.rs index c23c9b07..5ce57289 100644 --- a/plonkish_backend/src/backend/hyperplonk/preprocessor.rs +++ b/plonkish_backend/src/backend/hyperplonk/preprocessor.rs @@ -1,5 +1,9 @@ use crate::{ - backend::PlonkishCircuitInfo, + backend::{ + hyperplonk::{HyperPlonkProverParam, HyperPlonkVerifierParam}, + PlonkishCircuitInfo, + }, + pcs::PolynomialCommitmentScheme, poly::multilinear::MultilinearPolynomial, util::{ arithmetic::{div_ceil, steps, PrimeField}, @@ -7,10 +11,11 @@ use crate::{ expression::{Expression, Query, Rotation}, Itertools, }, + Error, }; -use std::{array, borrow::Cow, iter, mem}; +use std::{array, borrow::Cow, mem}; -pub(super) fn batch_size(circuit_info: &PlonkishCircuitInfo) -> usize { +pub(crate) fn batch_size(circuit_info: &PlonkishCircuitInfo) -> usize { let num_lookups = circuit_info.lookups.len(); let num_permutation_polys = circuit_info.permutation_polys().len(); chain![ @@ -22,7 +27,85 @@ pub(super) fn batch_size(circuit_info: &PlonkishCircuitInfo) - .sum() } -pub(super) fn compose( +#[allow(clippy::type_complexity)] +pub(crate) fn preprocess>( + param: &Pcs::Param, + circuit_info: &PlonkishCircuitInfo, + batch_commit: impl Fn( + &Pcs::ProverParam, + Vec>, + ) -> Result<(Vec>, Vec), Error>, +) -> Result< + ( + HyperPlonkProverParam, + HyperPlonkVerifierParam, + ), + Error, +> { + assert!(circuit_info.is_well_formed()); + + let num_vars = circuit_info.k; + let poly_size = 1 << num_vars; + let batch_size = batch_size(circuit_info); + let (pcs_pp, pcs_vp) = Pcs::trim(param, poly_size, batch_size)?; + + // Compute preprocesses comms + let preprocess_polys = circuit_info + .preprocess_polys + .iter() + .cloned() + .map(MultilinearPolynomial::new) + .collect_vec(); + let (preprocess_polys, preprocess_comms) = batch_commit(&pcs_pp, preprocess_polys)?; + + // Compute permutation polys and comms + let permutation_polys = permutation_polys( + num_vars, + &circuit_info.permutation_polys(), + &circuit_info.permutations, + ); + let (permutation_polys, permutation_comms) = batch_commit(&pcs_pp, permutation_polys)?; + + // Compose expression + let (num_permutation_z_polys, expression) = compose(circuit_info); + let vp = HyperPlonkVerifierParam { + pcs: pcs_vp, + num_instances: circuit_info.num_instances.clone(), + num_witness_polys: circuit_info.num_witness_polys.clone(), + num_challenges: circuit_info.num_challenges.clone(), + num_lookups: circuit_info.lookups.len(), + num_permutation_z_polys, + num_vars, + expression: expression.clone(), + preprocess_comms: preprocess_comms.clone(), + permutation_comms: circuit_info + .permutation_polys() + .into_iter() + .zip(permutation_comms.clone()) + .collect(), + }; + let pp = HyperPlonkProverParam { + pcs: pcs_pp, + num_instances: circuit_info.num_instances.clone(), + num_witness_polys: circuit_info.num_witness_polys.clone(), + num_challenges: circuit_info.num_challenges.clone(), + lookups: circuit_info.lookups.clone(), + num_permutation_z_polys, + num_vars, + expression, + preprocess_polys, + preprocess_comms, + permutation_polys: circuit_info + .permutation_polys() + .into_iter() + .zip(permutation_polys) + .collect(), + permutation_comms, + }; + Ok((pp, vp)) +} + +pub(crate) fn compose( circuit_info: &PlonkishCircuitInfo, ) -> (usize, Expression) { let challenge_offset = circuit_info.num_challenges.iter().sum::(); @@ -41,17 +124,16 @@ pub(super) fn compose( ); let expression = { - let constraints = iter::empty() - .chain(circuit_info.constraints.iter()) - .chain(lookup_constraints.iter()) - .chain(permutation_constraints.iter()) - .collect_vec(); + let constraints = chain![ + circuit_info.constraints.iter(), + lookup_constraints.iter(), + permutation_constraints.iter(), + ] + .collect_vec(); let eq = Expression::eq_xy(0); let zero_check_on_every_row = Expression::distribute_powers(constraints, alpha) * eq; Expression::distribute_powers( - iter::empty() - .chain(lookup_zero_checks.iter()) - .chain(Some(&zero_check_on_every_row)), + chain![lookup_zero_checks.iter(), [&zero_check_on_every_row]], alpha, ) }; @@ -67,13 +149,14 @@ pub(super) fn max_degree( let dummy_challenge = Expression::zero(); Cow::Owned(self::lookup_constraints(circuit_info, &dummy_challenge, &dummy_challenge).0) }); - iter::empty() - .chain(circuit_info.constraints.iter().map(Expression::degree)) - .chain(lookup_constraints.iter().map(Expression::degree)) - .chain(circuit_info.max_degree) - .chain(Some(2)) - .max() - .unwrap() + chain![ + circuit_info.constraints.iter().map(Expression::degree), + lookup_constraints.iter().map(Expression::degree), + circuit_info.max_degree, + [2], + ] + .max() + .unwrap() } pub(super) fn lookup_constraints( @@ -139,37 +222,36 @@ pub(crate) fn permutation_constraints( .take(num_chunks) .collect_vec(); let z_0_next = Expression::::Polynomial(Query::new(z_offset, Rotation::next())); - let l_1 = &Expression::::lagrange(1); + let l_0 = &Expression::::lagrange(0); let one = &Expression::one(); - let constraints = iter::empty() - .chain(zs.first().map(|z_0| l_1 * (z_0 - one))) - .chain( - polys - .chunks(chunk_size) - .zip(ids.chunks(chunk_size)) - .zip(permutations.chunks(chunk_size)) - .zip(zs.iter()) - .zip(zs.iter().skip(1).chain(Some(&z_0_next))) - .map(|((((polys, ids), permutations), z_lhs), z_rhs)| { - z_lhs + let constraints = chain![ + zs.first().map(|z_0| l_0 * (z_0 - one)), + polys + .chunks(chunk_size) + .zip(ids.chunks(chunk_size)) + .zip(permutations.chunks(chunk_size)) + .zip(zs.iter()) + .zip(zs.iter().skip(1).chain([&z_0_next])) + .map(|((((polys, ids), permutations), z_lhs), z_rhs)| { + z_lhs + * polys + .iter() + .zip(ids) + .map(|(poly, id)| poly + beta * id + gamma) + .product::>() + - z_rhs * polys .iter() - .zip(ids) - .map(|(poly, id)| poly + beta * id + gamma) + .zip(permutations) + .map(|(poly, permutation)| poly + beta * permutation + gamma) .product::>() - - z_rhs - * polys - .iter() - .zip(permutations) - .map(|(poly, permutation)| poly + beta * permutation + gamma) - .product::>() - }), - ) - .collect(); + }), + ] + .collect(); (num_chunks, constraints) } -pub(super) fn permutation_polys( +pub(crate) fn permutation_polys( num_vars: usize, permutation_polys: &[usize], cycles: &[Vec<(usize, usize)>], @@ -192,7 +274,6 @@ pub(super) fn permutation_polys( let (i0, j0) = cycle[0]; let mut last = permutations[poly_index[i0]][j0]; for &(i, j) in cycle.iter().cycle().skip(1).take(cycle.len()) { - assert_ne!(j, 0); mem::swap(&mut permutations[poly_index[i]][j], &mut last); } } @@ -205,9 +286,7 @@ pub(super) fn permutation_polys( #[cfg(test)] pub(crate) mod test { use crate::{ - backend::hyperplonk::util::{ - vanilla_plonk_expression, vanilla_plonk_with_lookup_expression, - }, + backend::hyperplonk::util::{vanilla_plonk_expression, vanilla_plonk_w_lookup_expression}, util::expression::{Expression, Query, Rotation}, }; use halo2_curves::bn256::Fr; @@ -230,12 +309,12 @@ pub(crate) mod test { let [id_1, id_2, id_3] = array::from_fn(|idx| { Expression::Constant(Fr::from((idx << num_vars) as u64)) + Expression::identity() }); - let l_1 = Expression::::lagrange(1); + let l_0 = Expression::::lagrange(0); let one = Expression::one(); let constraints = { vec![ q_l * w_l + q_r * w_r + q_m * w_l * w_r + q_o * w_o + q_c + pi, - l_1 * (z - one), + l_0 * (z - one), (z * ((w_l + beta * id_1 + gamma) * (w_r + beta * id_2 + gamma) * (w_o + beta * id_3 + gamma))) @@ -251,9 +330,9 @@ pub(crate) mod test { } #[test] - fn compose_vanilla_plonk_with_lookup() { + fn compose_vanilla_plonk_w_lookup() { let num_vars = 3; - let expression = vanilla_plonk_with_lookup_expression(num_vars); + let expression = vanilla_plonk_w_lookup_expression(num_vars); assert_eq!(expression, { let [pi, q_l, q_r, q_m, q_o, q_c, q_lookup, t_l, t_r, t_o, w_l, w_r, w_o, s_1, s_2, s_3] = &array::from_fn(|poly| Query::new(poly, Rotation::cur())) @@ -272,7 +351,7 @@ pub(crate) mod test { let [id_1, id_2, id_3] = array::from_fn(|idx| { Expression::Constant(Fr::from((idx << num_vars) as u64)) + Expression::identity() }); - let l_1 = &Expression::::lagrange(1); + let l_0 = &Expression::::lagrange(0); let one = &Expression::one(); let lookup_input = &Expression::distribute_powers(&[w_l, w_r, w_o].map(|w| q_lookup * w), beta); @@ -283,7 +362,7 @@ pub(crate) mod test { lookup_h * (lookup_input + gamma) * (lookup_table + gamma) - (lookup_table + gamma) + lookup_m * (lookup_input + gamma), - l_1 * (perm_z - one), + l_0 * (perm_z - one), (perm_z * ((w_l + beta * id_1 + gamma) * (w_r + beta * id_2 + gamma) diff --git a/plonkish_backend/src/backend/hyperplonk/prover.rs b/plonkish_backend/src/backend/hyperplonk/prover.rs index 81e8fcec..72ffeb56 100644 --- a/plonkish_backend/src/backend/hyperplonk/prover.rs +++ b/plonkish_backend/src/backend/hyperplonk/prover.rs @@ -1,11 +1,5 @@ use crate::{ - backend::{ - hyperplonk::{ - verifier::{pcs_query, point_offset, points}, - HyperPlonk, - }, - WitnessEncoding, - }, + backend::hyperplonk::verifier::{pcs_query, point_offset, points}, pcs::Evaluation, piop::sum_check::{ classic::{ClassicSumCheck, EvaluationsProver}, @@ -13,9 +7,12 @@ use crate::{ }, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{div_ceil, steps_by, sum, BatchInvert, BooleanHypercube, PrimeField}, - end_timer, - expression::{CommonPolynomial, Expression, Rotation}, + arithmetic::{div_ceil, steps_by, sum, BatchInvert, PrimeField}, + chain, end_timer, + expression::{ + rotate::{BinaryField, Rotatable}, + CommonPolynomial, Expression, Rotation, + }, parallel::{num_threads, par_map_collect, parallelize, parallelize_iter}, start_timer, transcript::FieldTranscriptWrite, @@ -24,21 +21,21 @@ use crate::{ Error, }; use std::{ + borrow::Borrow, collections::{HashMap, HashSet}, hash::Hash, - iter, }; -pub(crate) fn instance_polys<'a, F: PrimeField>( +pub(crate) fn instance_polys<'a, F: PrimeField, R: Rotatable + From>( num_vars: usize, instances: impl IntoIterator>, ) -> Vec> { - let row_mapping = HyperPlonk::<()>::row_mapping(num_vars); + let usable_indices = R::from(num_vars).usable_indices(); instances .into_iter() .map(|instances| { let mut poly = vec![F::ZERO; 1 << num_vars]; - for (b, instance) in row_mapping.iter().zip(instances.into_iter()) { + for (b, instance) in usable_indices.iter().zip(instances.into_iter()) { poly[*b] = *instance; } poly @@ -47,9 +44,9 @@ pub(crate) fn instance_polys<'a, F: PrimeField>( .collect() } -pub(crate) fn lookup_compressed_polys( +pub(crate) fn lookup_compressed_polys>( lookups: &[Vec<(Expression, Expression)>], - polys: &[&MultilinearPolynomial], + polys: &[impl Borrow>], challenges: &[F], betas: &[F], ) -> Vec<[MultilinearPolynomial; 2]> { @@ -57,26 +54,27 @@ pub(crate) fn lookup_compressed_polys( return Default::default(); } + let polys = polys.iter().map(Borrow::borrow).collect_vec(); let num_vars = polys[0].num_vars(); let expression = lookups .iter() .flat_map(|lookup| lookup.iter().map(|(input, table)| (input + table))) .sum::>(); let lagranges = { - let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); + let rotatable = R::from(num_vars); expression .used_langrange() .into_iter() - .map(|i| (i, bh[i.rem_euclid(1 << num_vars) as usize])) + .map(|i| (i, rotatable.nth(i))) .collect::>() }; lookups .iter() - .map(|lookup| lookup_compressed_poly(lookup, &lagranges, polys, challenges, betas)) + .map(|lookup| lookup_compressed_poly::<_, R>(lookup, &lagranges, &polys, challenges, betas)) .collect() } -pub(super) fn lookup_compressed_poly( +pub(super) fn lookup_compressed_poly>( lookup: &[(Expression, Expression)], lagranges: &HashSet<(i32, usize)>, polys: &[&MultilinearPolynomial], @@ -84,7 +82,7 @@ pub(super) fn lookup_compressed_poly( betas: &[F], ) -> [MultilinearPolynomial; 2] { let num_vars = polys[0].num_vars(); - let bh = BooleanHypercube::new(num_vars); + let rotatable = R::from(num_vars); let compress = |expressions: &[&Expression]| { betas .iter() @@ -106,7 +104,7 @@ pub(super) fn lookup_compressed_poly( } CommonPolynomial::EqXY(_) => unreachable!(), }, - &|query| polys[query.poly()][bh.rotate(b, query.rotation())], + &|query| polys[query.poly()][rotatable.rotate(b, query.rotation())], &|challenge| challenges[challenge], &|value| -value, &|lhs, rhs| lhs + &rhs, @@ -191,7 +189,7 @@ pub(super) fn lookup_m_poly( Ok(MultilinearPolynomial::new(m)) } -pub(super) fn lookup_h_polys( +pub(crate) fn lookup_h_polys( compressed_polys: &[[MultilinearPolynomial; 2]], m_polys: &[MultilinearPolynomial], gamma: &F, @@ -225,11 +223,12 @@ pub(super) fn lookup_h_poly( let chunk_size = div_ceil(2 * h_input.len(), num_threads()); parallelize_iter( - iter::empty() - .chain(h_input.chunks_mut(chunk_size)) - .chain(h_table.chunks_mut(chunk_size)), + chain![ + h_input.chunks_mut(chunk_size), + h_table.chunks_mut(chunk_size) + ], |h| { - h.iter_mut().batch_invert(); + h.batch_invert(); }, ); @@ -249,10 +248,10 @@ pub(super) fn lookup_h_poly( MultilinearPolynomial::new(h_input) } -pub(crate) fn permutation_z_polys( +pub(crate) fn permutation_z_polys>( num_chunks: usize, permutation_polys: &[(usize, MultilinearPolynomial)], - polys: &[&MultilinearPolynomial], + polys: &[impl Borrow>], beta: &F, gamma: &F, ) -> Vec> { @@ -261,6 +260,7 @@ pub(crate) fn permutation_z_polys( } let chunk_size = div_ceil(permutation_polys.len(), num_chunks); + let polys = polys.iter().map(Borrow::borrow).collect_vec(); let num_vars = polys[0].num_vars(); let timer = start_timer(|| "products"); @@ -283,7 +283,7 @@ pub(crate) fn permutation_z_polys( } parallelize(&mut product, |(product, _)| { - product.iter_mut().batch_invert(); + product.batch_invert(); }); for ((poly, _), idx) in permutation_polys.iter().zip(chunk_idx * chunk_size..) { @@ -304,44 +304,31 @@ pub(crate) fn permutation_z_polys( .collect_vec(); end_timer(timer); - let timer = start_timer(|| "z_polys"); - let z = iter::empty() - .chain(iter::repeat(F::ZERO).take(num_chunks)) - .chain(Some(F::ONE)) - .chain( - BooleanHypercube::new(num_vars) - .iter() - .skip(1) - .flat_map(|b| iter::repeat(b).take(num_chunks)) - .zip(products.iter().cycle()) - .scan(F::ONE, |state, (b, product)| { - *state *= &product[b]; - Some(*state) - }), - ) - .take(num_chunks << num_vars) - .collect_vec(); + let _timer = start_timer(|| "z_polys"); + let mut z = vec![vec![F::ZERO; 1 << num_vars]; num_chunks]; + + let usable_indices = R::from(num_vars).usable_indices(); + let first_idx = usable_indices[0]; + z[0][first_idx] = F::ONE; + for chunk_idx in 1..num_chunks { + z[chunk_idx][first_idx] = z[chunk_idx - 1][first_idx] * products[chunk_idx - 1][first_idx]; + } + for (last_idx, idx) in usable_indices.iter().copied().tuple_windows() { + z[0][idx] = z[num_chunks - 1][last_idx] * products[num_chunks - 1][last_idx]; + for chunk_idx in 1..num_chunks { + z[chunk_idx][idx] = z[chunk_idx - 1][idx] * products[chunk_idx - 1][idx]; + } + } if cfg!(feature = "sanity-check") { - let b_last = BooleanHypercube::new(num_vars).iter().last().unwrap(); + let last_idx = *usable_indices.last().unwrap(); assert_eq!( - *z.last().unwrap() * products.last().unwrap()[b_last], + z.last().unwrap()[last_idx] * products.last().unwrap()[last_idx], F::ONE ); } - drop(products); - end_timer(timer); - - let _timer = start_timer(|| "into_bh_order"); - let nth_map = BooleanHypercube::new(num_vars) - .nth_map() - .into_iter() - .map(|b| num_chunks * b) - .collect_vec(); - (0..num_chunks) - .map(|offset| MultilinearPolynomial::new(par_map_collect(&nth_map, |b| z[offset + b]))) - .collect() + z.into_iter().map(MultilinearPolynomial::new).collect() } #[allow(clippy::type_complexity)] @@ -377,7 +364,7 @@ pub(crate) fn prove_sum_check( let num_vars = polys[0].num_vars(); let ys = [y]; let virtual_poly = VirtualPolynomial::new(expression, polys.to_vec(), &challenges, &ys); - let (_, x, evals) = ClassicSumCheck::>::prove( + let (_, x, evals) = ClassicSumCheck::, BinaryField>::prove( &(), num_vars, virtual_poly, @@ -394,7 +381,7 @@ pub(crate) fn prove_sum_check( .flat_map(|query| { (point_offset[&query.rotation()]..) .zip(if query.rotation() == Rotation::cur() { - vec![evals[query.poly()]] + vec![evals[query]] } else { polys[query.poly()].evaluate_for_rotation(&x, query.rotation()) }) diff --git a/plonkish_backend/src/backend/hyperplonk/util.rs b/plonkish_backend/src/backend/hyperplonk/util.rs index faa80780..96d1d995 100644 --- a/plonkish_backend/src/backend/hyperplonk/util.rs +++ b/plonkish_backend/src/backend/hyperplonk/util.rs @@ -12,8 +12,9 @@ use crate::{ }, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{powers, BooleanHypercube, PrimeField}, - expression::{Expression, Query, Rotation}, + arithmetic::{powers, PrimeField}, + chain, + expression::{rotate::Rotatable, Expression, Query, Rotation}, test::{rand_array, rand_idx, rand_vec}, Itertools, }, @@ -24,7 +25,7 @@ use std::{ array, collections::{HashMap, HashSet}, hash::Hash, - iter, + iter, mem, }; pub fn vanilla_plonk_circuit_info( @@ -60,7 +61,7 @@ pub fn vanilla_plonk_expression(num_vars: usize) -> Expression expression } -pub fn vanilla_plonk_with_lookup_circuit_info( +pub fn vanilla_plonk_w_lookup_circuit_info( num_vars: usize, num_instances: usize, preprocess_polys: [Vec; 9], @@ -85,8 +86,8 @@ pub fn vanilla_plonk_with_lookup_circuit_info( } } -pub fn vanilla_plonk_with_lookup_expression(num_vars: usize) -> Expression { - let circuit_info = vanilla_plonk_with_lookup_circuit_info( +pub fn vanilla_plonk_w_lookup_expression(num_vars: usize) -> Expression { + let circuit_info = vanilla_plonk_w_lookup_circuit_info( num_vars, 0, Default::default(), @@ -97,7 +98,7 @@ pub fn vanilla_plonk_with_lookup_expression(num_vars: usize) -> E expression } -pub fn rand_vanilla_plonk_circuit( +pub fn rand_vanilla_plonk_circuit>( num_vars: usize, mut preprocess_rng: impl RngCore, mut witness_rng: impl RngCore, @@ -106,7 +107,7 @@ pub fn rand_vanilla_plonk_circuit( let mut polys = [(); 9].map(|_| vec![F::ZERO; size]); let instances = rand_vec(num_vars, &mut witness_rng); - polys[0] = instance_polys(num_vars, [&instances])[0].evals().to_vec(); + polys[0] = mem::take(&mut instance_polys::<_, R>(num_vars, [&instances])[0]).into_evals(); let mut permutation = Permutation::default(); for poly in [6, 7, 8] { @@ -168,31 +169,27 @@ pub fn rand_vanilla_plonk_circuit( ) } -pub fn rand_vanilla_plonk_assignment( +pub fn rand_vanilla_plonk_assignment>( num_vars: usize, mut preprocess_rng: impl RngCore, mut witness_rng: impl RngCore, ) -> (Vec>, Vec) { let (polys, permutations) = { let (circuit_info, circuit) = - rand_vanilla_plonk_circuit(num_vars, &mut preprocess_rng, &mut witness_rng); + rand_vanilla_plonk_circuit::<_, R>(num_vars, &mut preprocess_rng, &mut witness_rng); let witness = circuit.synthesize(0, &[]).unwrap(); - let polys = iter::empty() - .chain(instance_polys(num_vars, circuit.instances())) - .chain( - iter::empty() - .chain(circuit_info.preprocess_polys) - .chain(witness) - .map(MultilinearPolynomial::new), - ) - .collect_vec(); + let polys = chain![ + instance_polys::<_, R>(num_vars, circuit.instances()), + chain![circuit_info.preprocess_polys, witness].map(MultilinearPolynomial::new), + ] + .collect_vec(); (polys, circuit_info.permutations) }; let challenges: [_; 3] = rand_array(&mut witness_rng); let [beta, gamma, _] = challenges; let permutation_polys = permutation_polys(num_vars, &[6, 7, 8], &permutations); - let permutation_z_polys = permutation_z_polys( + let permutation_z_polys = permutation_z_polys::<_, R>( 1, &[6, 7, 8] .into_iter() @@ -204,16 +201,12 @@ pub fn rand_vanilla_plonk_assignment( ); ( - iter::empty() - .chain(polys) - .chain(permutation_polys) - .chain(permutation_z_polys) - .collect_vec(), + chain![polys, permutation_polys, permutation_z_polys].collect_vec(), challenges.to_vec(), ) } -pub fn rand_vanilla_plonk_with_lookup_circuit( +pub fn rand_vanilla_plonk_w_lookup_circuit>( num_vars: usize, mut preprocess_rng: impl RngCore, mut witness_rng: impl RngCore, @@ -222,20 +215,24 @@ pub fn rand_vanilla_plonk_with_lookup_circuit( let mut polys = [(); 13].map(|_| vec![F::ZERO; size]); let [t_l, t_r, t_o] = [(); 3].map(|_| { - iter::empty() - .chain([F::ZERO, F::ZERO]) - .chain(iter::repeat_with(|| F::random(&mut preprocess_rng))) - .take(size) - .collect_vec() + chain![ + [F::ZERO; 2], + iter::repeat_with(|| F::random(&mut preprocess_rng)), + ] + .take(size) + .collect_vec() }); polys[7] = t_l; polys[8] = t_r; polys[9] = t_o; let instances = rand_vec(num_vars, &mut witness_rng); - polys[0] = instance_polys(num_vars, [&instances])[0].evals().to_vec(); - let instance_rows = BooleanHypercube::new(num_vars) - .iter() + polys[0] = instance_polys::<_, R>(num_vars, [&instances])[0] + .evals() + .to_vec(); + let instance_rows = R::from(num_vars) + .usable_indices() + .into_iter() .take(num_vars + 1) .collect::>(); @@ -303,7 +300,7 @@ pub fn rand_vanilla_plonk_with_lookup_circuit( } let [_, q_l, q_r, q_m, q_o, q_c, q_lookup, t_l, t_r, t_o, w_l, w_r, w_o] = polys; - let circuit_info = vanilla_plonk_with_lookup_circuit_info( + let circuit_info = vanilla_plonk_w_lookup_circuit_info( num_vars, instances.len(), [q_l, q_r, q_m, q_o, q_c, q_lookup, t_l, t_r, t_o], @@ -315,24 +312,23 @@ pub fn rand_vanilla_plonk_with_lookup_circuit( ) } -pub fn rand_vanilla_plonk_with_lookup_assignment( +pub fn rand_vanilla_plonk_w_lookup_assignment>( num_vars: usize, mut preprocess_rng: impl RngCore, mut witness_rng: impl RngCore, ) -> (Vec>, Vec) { let (polys, permutations) = { - let (circuit_info, circuit) = - rand_vanilla_plonk_with_lookup_circuit(num_vars, &mut preprocess_rng, &mut witness_rng); + let (circuit_info, circuit) = rand_vanilla_plonk_w_lookup_circuit::<_, R>( + num_vars, + &mut preprocess_rng, + &mut witness_rng, + ); let witness = circuit.synthesize(0, &[]).unwrap(); - let polys = iter::empty() - .chain(instance_polys(num_vars, circuit.instances())) - .chain( - iter::empty() - .chain(circuit_info.preprocess_polys) - .chain(witness) - .map(MultilinearPolynomial::new), - ) - .collect_vec(); + let polys = chain![ + instance_polys::<_, R>(num_vars, circuit.instances()), + chain![circuit_info.preprocess_polys, witness].map(MultilinearPolynomial::new), + ] + .collect_vec(); (polys, circuit_info.permutations) }; let challenges: [_; 3] = rand_array(&mut witness_rng); @@ -340,17 +336,18 @@ pub fn rand_vanilla_plonk_with_lookup_assignment( let (lookup_compressed_polys, lookup_m_polys) = { let PlonkishCircuitInfo { lookups, .. } = - vanilla_plonk_with_lookup_circuit_info(0, 0, Default::default(), Vec::new()); + vanilla_plonk_w_lookup_circuit_info(0, 0, Default::default(), Vec::new()); + let polys = polys.iter().collect_vec(); let betas = powers(beta).take(3).collect_vec(); let lookup_compressed_polys = - lookup_compressed_polys(&lookups, &polys.iter().collect_vec(), &[], &betas); + lookup_compressed_polys::<_, R>(&lookups, &polys, &[], &betas); let lookup_m_polys = lookup_m_polys(&lookup_compressed_polys).unwrap(); (lookup_compressed_polys, lookup_m_polys) }; let lookup_h_polys = lookup_h_polys(&lookup_compressed_polys, &lookup_m_polys, &gamma); let permutation_polys = permutation_polys(num_vars, &[10, 11, 12], &permutations); - let permutation_z_polys = permutation_z_polys( + let permutation_z_polys = permutation_z_polys::<_, R>( 1, &[10, 11, 12] .into_iter() @@ -362,13 +359,14 @@ pub fn rand_vanilla_plonk_with_lookup_assignment( ); ( - iter::empty() - .chain(polys) - .chain(permutation_polys) - .chain(lookup_m_polys) - .chain(lookup_h_polys) - .chain(permutation_z_polys) - .collect_vec(), + chain![ + polys, + permutation_polys, + lookup_m_polys, + lookup_h_polys, + permutation_z_polys, + ] + .collect_vec(), challenges.to_vec(), ) } diff --git a/plonkish_backend/src/backend/hyperplonk/verifier.rs b/plonkish_backend/src/backend/hyperplonk/verifier.rs index dcc602c7..4a20ef67 100644 --- a/plonkish_backend/src/backend/hyperplonk/verifier.rs +++ b/plonkish_backend/src/backend/hyperplonk/verifier.rs @@ -6,8 +6,11 @@ use crate::{ }, poly::multilinear::{rotation_eval, rotation_eval_points}, util::{ - arithmetic::{inner_product, BooleanHypercube, PrimeField}, - expression::{Expression, Query, Rotation}, + arithmetic::{inner_product, PrimeField}, + expression::{ + rotate::{BinaryField, Rotatable}, + Expression, Query, Rotation, + }, transcript::FieldTranscriptRead, Itertools, }, @@ -45,7 +48,7 @@ pub(crate) fn verify_sum_check( y: &[F], transcript: &mut impl FieldTranscriptRead, ) -> Result<(Vec>, Vec>), Error> { - let (x_eval, x) = ClassicSumCheck::>::verify( + let (x_eval, x) = ClassicSumCheck::, BinaryField>::verify( &(), num_vars, expression.degree(), @@ -66,11 +69,11 @@ pub(crate) fn verify_sum_check( .into_iter() .unzip::<_, _, Vec<_>, Vec<_>>(); - let evals = instance_evals(num_vars, expression, instances, &x) + let evals = instance_evals::<_, BinaryField>(num_vars, expression, instances, &x) .into_iter() .chain(evals) .collect(); - if evaluate(expression, num_vars, &evals, challenges, &[y], &x) != x_eval { + if evaluate::<_, BinaryField>(expression, num_vars, &evals, challenges, &[y], &x) != x_eval { return Err(Error::InvalidSnark( "Unmatched between sum_check output and query evaluation".to_string(), )); @@ -89,7 +92,7 @@ pub(crate) fn verify_sum_check( Ok((points(&pcs_query, &x), evals)) } -fn instance_evals( +pub(crate) fn instance_evals>( num_vars: usize, expression: &Expression, instances: &[Vec], @@ -98,53 +101,31 @@ fn instance_evals( let mut instance_query = expression.used_query(); instance_query.retain(|query| query.poly() < instances.len()); - let lagranges = { - let mut lagranges = instance_query.iter().fold(0..0, |range, query| { - let i = -query.rotation().0; - range.start.min(i)..range.end.max(i + instances[query.poly()].len() as i32) - }); - if lagranges.start < 0 { - lagranges.start -= 1; - } - if lagranges.end > 0 { - lagranges.end += 1; - } - lagranges + let (min_rotation, max_rotation) = instance_query.iter().fold((0, 0), |(min, max), query| { + (min.min(query.rotation().0), max.max(query.rotation().0)) + }); + let lagrange_evals = { + let rotatable = R::from(num_vars); + let max_instance_len = instances.iter().map(Vec::len).max().unwrap_or_default(); + (-max_rotation..max_instance_len as i32 + min_rotation.abs()) + .map(|i| lagrange_eval(x, rotatable.nth(i))) + .collect_vec() }; - let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); - let lagrange_evals = lagranges - .filter_map(|i| { - (i != 0).then(|| { - let b = bh[i.rem_euclid(1 << num_vars as i32) as usize]; - (i, lagrange_eval(x, b)) - }) - }) - .collect::>(); - instance_query .into_iter() .map(|query| { - let is = if query.rotation() > Rotation::cur() { - (-query.rotation().0..0) - .chain(1..) - .take(instances[query.poly()].len()) - .collect_vec() - } else { - (1 - query.rotation().0..) - .take(instances[query.poly()].len()) - .collect_vec() - }; + let offset = (max_rotation - query.rotation().0) as usize; let eval = inner_product( &instances[query.poly()], - is.iter().map(|i| lagrange_evals.get(i).unwrap()), + &lagrange_evals[offset..offset + instances[query.poly()].len()], ); (query, eval) }) .collect() } -pub(super) fn pcs_query( +pub(crate) fn pcs_query( expression: &Expression, num_instance_poly: usize, ) -> BTreeSet { diff --git a/plonkish_backend/src/backend/unihyperplonk.rs b/plonkish_backend/src/backend/unihyperplonk.rs new file mode 100644 index 00000000..d8737c45 --- /dev/null +++ b/plonkish_backend/src/backend/unihyperplonk.rs @@ -0,0 +1,436 @@ +use crate::{ + backend::{ + hyperplonk::{HyperPlonkProverParam, HyperPlonkVerifierParam}, + unihyperplonk::{ + preprocessor::{batch_size, preprocess}, + prover::{ + instance_polys, lookup_compressed_polys, lookup_h_polys, lookup_m_polys, + permutation_z_polys, prove_zero_check, + }, + verifier::verify_zero_check, + }, + PlonkishBackend, PlonkishCircuit, PlonkishCircuitInfo, WitnessEncoding, + }, + pcs::{Additive, PolynomialCommitmentScheme}, + piop::multilinear_eval::ph23::{self, s_polys}, + poly::{multilinear::MultilinearPolynomial, univariate::UnivariatePolynomial}, + util::{ + arithmetic::{powers, WithSmallOrderMulGroup}, + chain, end_timer, + expression::rotate::{Lexical, Rotatable}, + start_timer, + transcript::{TranscriptRead, TranscriptWrite}, + Deserialize, DeserializeOwned, Itertools, Serialize, + }, + Error, +}; +use rand::RngCore; +use std::{borrow::Cow, fmt::Debug, hash::Hash, iter, marker::PhantomData, ops::Deref}; + +pub(crate) mod preprocessor; +pub(crate) mod prover; +pub(crate) mod verifier; + +#[derive(Clone, Debug)] +pub struct UniHyperPlonk(PhantomData); + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound(serialize = "F: Serialize", deserialize = "F: DeserializeOwned"))] +pub struct UniHyperPlonkProverParam +where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme, +{ + pub(crate) pp: HyperPlonkProverParam, + pub(crate) s_polys: Vec>, +} + +impl Deref for UniHyperPlonkProverParam +where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme, +{ + type Target = HyperPlonkProverParam; + + fn deref(&self) -> &Self::Target { + &self.pp + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound(serialize = "F: Serialize", deserialize = "F: DeserializeOwned"))] +pub struct UniHyperPlonkVerifierParam +where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme, +{ + pub(crate) vp: HyperPlonkVerifierParam, +} + +impl Deref for UniHyperPlonkVerifierParam +where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme, +{ + type Target = HyperPlonkVerifierParam; + + fn deref(&self) -> &Self::Target { + &self.vp + } +} + +impl PlonkishBackend for UniHyperPlonk +where + F: WithSmallOrderMulGroup<3> + Hash + Serialize + DeserializeOwned, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: Additive, +{ + type Pcs = Pcs; + type ProverParam = UniHyperPlonkProverParam; + type VerifierParam = UniHyperPlonkVerifierParam; + + fn setup( + circuit_info: &PlonkishCircuitInfo, + rng: impl RngCore, + ) -> Result { + assert!(circuit_info.is_well_formed()); + + let num_vars = circuit_info.k; + let poly_size = 1 << num_vars; + let batch_size = batch_size(circuit_info); + Pcs::setup(poly_size, batch_size, rng) + } + + fn preprocess( + param: &Pcs::Param, + circuit_info: &PlonkishCircuitInfo, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + let (pp, vp) = preprocess(param, circuit_info, |pp, polys| { + batch_commit::<_, Pcs>(pp, polys) + })?; + let s_polys = s_polys(circuit_info.k); + Ok(( + UniHyperPlonkProverParam { pp, s_polys }, + UniHyperPlonkVerifierParam { vp }, + )) + } + + fn prove( + pp: &Self::ProverParam, + circuit: &impl PlonkishCircuit, + transcript: &mut impl TranscriptWrite, + _: impl RngCore, + ) -> Result<(), Error> { + let instance_polys = { + let instances = circuit.instances(); + for (num_instances, instances) in pp.num_instances.iter().zip_eq(instances) { + assert_eq!(instances.len(), *num_instances); + for instance in instances.iter() { + transcript.common_field_element(instance)?; + } + } + instance_polys::<_, Lexical>(pp.num_vars, instances) + }; + + // Round 0..n + + let mut witness_polys = Vec::with_capacity(pp.num_witness_polys.iter().sum()); + let mut witness_comms = Vec::with_capacity(witness_polys.len()); + let mut challenges = Vec::with_capacity(pp.num_challenges.iter().sum::() + 4); + for (round, (num_witness_polys, num_challenges)) in pp + .num_witness_polys + .iter() + .zip_eq(pp.num_challenges.iter()) + .enumerate() + { + let timer = start_timer(|| format!("witness_collector-{round}")); + let polys = circuit + .synthesize(round, &challenges)? + .into_iter() + .map(MultilinearPolynomial::new) + .collect_vec(); + assert_eq!(polys.len(), *num_witness_polys); + end_timer(timer); + + let (polys, comms) = batch_commit_and_write::<_, Pcs>(&pp.pcs, polys, transcript)?; + witness_comms.extend(comms); + witness_polys.extend(polys); + challenges.extend(transcript.squeeze_challenges(*num_challenges)); + } + let polys = chain![ + instance_polys.into_iter().map(Cow::Owned), + pp.preprocess_polys.iter().map(Cow::Borrowed), + witness_polys.into_iter().map(Cow::Owned) + ] + .collect_vec(); + + // Round n + + let beta = transcript.squeeze_challenge(); + + let timer = start_timer(|| format!("lookup_compressed_polys-{}", pp.lookups.len())); + let lookup_compressed_polys = { + let max_lookup_width = pp.lookups.iter().map(Vec::len).max().unwrap_or_default(); + let betas = powers(beta).take(max_lookup_width).collect_vec(); + lookup_compressed_polys::<_, Lexical>(&pp.lookups, &polys, &challenges, &betas) + }; + end_timer(timer); + + let timer = start_timer(|| format!("lookup_m_polys-{}", pp.lookups.len())); + let lookup_m_polys = lookup_m_polys(&lookup_compressed_polys)?; + end_timer(timer); + + let (lookup_m_polys, lookup_m_comms) = + batch_commit_and_write::<_, Pcs>(&pp.pcs, lookup_m_polys, transcript)?; + + // Round n+1 + + let gamma = transcript.squeeze_challenge(); + + let timer = start_timer(|| format!("lookup_h_polys-{}", pp.lookups.len())); + let lookup_h_polys = lookup_h_polys(&lookup_compressed_polys, &lookup_m_polys, &gamma); + end_timer(timer); + + let timer = start_timer(|| format!("permutation_z_polys-{}", pp.permutation_polys.len())); + let permutation_z_polys = permutation_z_polys::<_, Lexical>( + pp.num_permutation_z_polys, + &pp.permutation_polys, + &polys, + &beta, + &gamma, + ); + end_timer(timer); + + let lookup_h_permutation_z_polys = + chain![lookup_h_polys, permutation_z_polys].collect_vec(); + let (lookup_h_permutation_z_polys, lookup_h_permutation_z_comms) = + batch_commit_and_write::<_, Pcs>(&pp.pcs, lookup_h_permutation_z_polys, transcript)?; + + // Round n+2 + + let alpha = transcript.squeeze_challenge(); + let y = transcript.squeeze_challenges(pp.num_vars); + + let polys = chain![ + polys, + chain![&pp.permutation_polys].map(|(_, poly)| Cow::Borrowed(poly)), + lookup_m_polys.into_iter().map(Cow::Owned), + lookup_h_permutation_z_polys.into_iter().map(Cow::Owned), + ] + .collect_vec(); + challenges.extend([beta, gamma, alpha]); + let (point, evals) = prove_zero_check( + pp.num_instances.len(), + &pp.expression, + &polys, + challenges, + y, + transcript, + )?; + + // Prove PH23 multilinear evaluation + + let polys = polys + .into_iter() + .map(|poly| poly.into_owned().into_evals()) + .map(UnivariatePolynomial::lagrange) + .collect_vec(); + let dummy_comm = Pcs::Commitment::default(); + let comms = chain![ + iter::repeat(&dummy_comm).take(pp.num_instances.len()), + &pp.preprocess_comms, + &witness_comms, + &pp.permutation_comms, + &lookup_m_comms, + &lookup_h_permutation_z_comms, + ] + .collect_vec(); + let timer = start_timer(|| format!("prove_multilinear_eval-{}", evals.len())); + ph23::additive::prove_multilinear_eval::<_, Pcs>( + &pp.pcs, + pp.num_vars, + &pp.s_polys, + &polys, + comms, + &point, + &evals, + transcript, + )?; + end_timer(timer); + + Ok(()) + } + + fn verify( + vp: &Self::VerifierParam, + instances: &[Vec], + transcript: &mut impl TranscriptRead, + _: impl RngCore, + ) -> Result<(), Error> { + for (num_instances, instances) in vp.num_instances.iter().zip_eq(instances) { + assert_eq!(instances.len(), *num_instances); + for instance in instances.iter() { + transcript.common_field_element(instance)?; + } + } + + // Round 0..n + + let mut witness_comms = Vec::with_capacity(vp.num_witness_polys.iter().sum()); + let mut challenges = Vec::with_capacity(vp.num_challenges.iter().sum::() + 4); + for (num_polys, num_challenges) in + vp.num_witness_polys.iter().zip_eq(vp.num_challenges.iter()) + { + witness_comms.extend(Pcs::read_commitments(&vp.pcs, *num_polys, transcript)?); + challenges.extend(transcript.squeeze_challenges(*num_challenges)); + } + + // Round n + + let beta = transcript.squeeze_challenge(); + + let lookup_m_comms = Pcs::read_commitments(&vp.pcs, vp.num_lookups, transcript)?; + + // Round n+1 + + let gamma = transcript.squeeze_challenge(); + + let lookup_h_permutation_z_comms = Pcs::read_commitments( + &vp.pcs, + vp.num_lookups + vp.num_permutation_z_polys, + transcript, + )?; + + // Round n+2 + + let alpha = transcript.squeeze_challenge(); + let y = transcript.squeeze_challenges(vp.num_vars); + + challenges.extend([beta, gamma, alpha]); + let (point, evals) = verify_zero_check( + vp.num_vars, + &vp.expression, + instances, + &challenges, + &y, + transcript, + )?; + + // Verify PH23 multilinear evaluation + + let dummy_comm = Pcs::Commitment::default(); + let comms = chain![ + iter::repeat(&dummy_comm).take(vp.num_instances.len()), + &vp.preprocess_comms, + &witness_comms, + vp.permutation_comms.iter().map(|(_, comm)| comm), + &lookup_m_comms, + &lookup_h_permutation_z_comms, + ] + .collect_vec(); + ph23::additive::verify_multilinear_eval::<_, Pcs>( + &vp.pcs, + vp.num_vars, + comms, + &point, + &evals, + transcript, + )?; + + Ok(()) + } +} + +impl WitnessEncoding for UniHyperPlonk { + fn row_mapping(k: usize) -> Vec { + Lexical::new(k).usable_indices() + } +} + +#[allow(clippy::type_complexity)] +fn batch_commit( + pp: &Pcs::ProverParam, + polys: impl IntoIterator>, +) -> Result<(Vec>, Vec), Error> +where + F: WithSmallOrderMulGroup<3> + Hash + Serialize + DeserializeOwned, + Pcs: PolynomialCommitmentScheme>, +{ + let polys = polys + .into_iter() + .map(MultilinearPolynomial::into_evals) + .map(UnivariatePolynomial::lagrange) + .collect_vec(); + let comms = Pcs::batch_commit(pp, &polys)?; + let polys = polys + .into_iter() + .map(UnivariatePolynomial::into_coeffs) + .map(MultilinearPolynomial::new) + .collect_vec(); + Ok((polys, comms)) +} + +#[allow(clippy::type_complexity)] +fn batch_commit_and_write( + pp: &Pcs::ProverParam, + polys: impl IntoIterator>, + transcript: &mut impl TranscriptWrite, +) -> Result<(Vec>, Vec), Error> +where + F: WithSmallOrderMulGroup<3> + Hash + Serialize + DeserializeOwned, + Pcs: PolynomialCommitmentScheme>, +{ + let polys = polys + .into_iter() + .map(MultilinearPolynomial::into_evals) + .map(UnivariatePolynomial::lagrange) + .collect_vec(); + let comms = Pcs::batch_commit_and_write(pp, &polys, transcript)?; + let polys = polys + .into_iter() + .map(UnivariatePolynomial::into_coeffs) + .map(MultilinearPolynomial::new) + .collect_vec(); + Ok((polys, comms)) +} + +#[cfg(test)] +mod test { + use crate::{ + backend::{ + hyperplonk::util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_w_lookup_circuit}, + test::run_plonkish_backend, + unihyperplonk::UniHyperPlonk, + }, + pcs::univariate::UnivariateKzg, + util::{ + expression::rotate::Lexical, test::seeded_std_rng, transcript::Keccak256Transcript, + }, + }; + use halo2_curves::bn256::Bn256; + + macro_rules! tests { + ($suffix:ident, $pcs:ty, $additive:literal, $num_vars_range:expr) => { + paste::paste! { + #[test] + fn []() { + run_plonkish_backend::<_, UniHyperPlonk<$pcs, $additive>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { + rand_vanilla_plonk_circuit::<_, Lexical>(num_vars, seeded_std_rng(), seeded_std_rng()) + }); + } + + #[test] + fn []() { + run_plonkish_backend::<_, UniHyperPlonk<$pcs, $additive>, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { + rand_vanilla_plonk_w_lookup_circuit::<_, Lexical>(num_vars, seeded_std_rng(), seeded_std_rng()) + }); + } + } + }; + ($suffix:ident, $pcs:ty, $additive:literal) => { + tests!($suffix, $pcs, $additive, 2..16); + }; + } + + tests!(kzg, UnivariateKzg, true); +} diff --git a/plonkish_backend/src/backend/unihyperplonk/preprocessor.rs b/plonkish_backend/src/backend/unihyperplonk/preprocessor.rs new file mode 100644 index 00000000..4011712d --- /dev/null +++ b/plonkish_backend/src/backend/unihyperplonk/preprocessor.rs @@ -0,0 +1 @@ +pub(super) use crate::backend::hyperplonk::preprocessor::{batch_size, preprocess}; diff --git a/plonkish_backend/src/backend/unihyperplonk/prover.rs b/plonkish_backend/src/backend/unihyperplonk/prover.rs new file mode 100644 index 00000000..44429767 --- /dev/null +++ b/plonkish_backend/src/backend/unihyperplonk/prover.rs @@ -0,0 +1,69 @@ +use crate::{ + backend::hyperplonk::verifier::pcs_query, + piop::sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + SumCheck, VirtualPolynomial, + }, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::PrimeField, + expression::{rotate::Lexical, Expression, Query}, + transcript::FieldTranscriptWrite, + Itertools, + }, + Error, +}; +use std::borrow::Borrow; + +pub(super) use crate::backend::hyperplonk::prover::{ + instance_polys, lookup_compressed_polys, lookup_h_polys, lookup_m_polys, permutation_z_polys, +}; + +#[allow(clippy::type_complexity)] +pub(super) fn prove_zero_check( + num_instance_poly: usize, + expression: &Expression, + polys: &[impl Borrow>], + challenges: Vec, + y: Vec, + transcript: &mut impl FieldTranscriptWrite, +) -> Result<(Vec, Vec<(Query, F)>), Error> { + prove_sum_check( + num_instance_poly, + expression, + F::ZERO, + polys, + challenges, + y, + transcript, + ) +} + +#[allow(clippy::type_complexity)] +pub(super) fn prove_sum_check( + num_instance_poly: usize, + expression: &Expression, + sum: F, + polys: &[impl Borrow>], + challenges: Vec, + y: Vec, + transcript: &mut impl FieldTranscriptWrite, +) -> Result<(Vec, Vec<(Query, F)>), Error> { + let polys = polys.iter().map(Borrow::borrow).collect_vec(); + let num_vars = polys[0].num_vars(); + let ys = [y]; + let virtual_poly = VirtualPolynomial::new(expression, polys, &challenges, &ys); + let (_, x, mut evals) = ClassicSumCheck::, Lexical>::prove( + &(), + num_vars, + virtual_poly, + sum, + transcript, + )?; + + let pcs_query = pcs_query(expression, num_instance_poly); + evals.retain(|query, _| pcs_query.contains(query)); + transcript.write_field_elements(evals.values())?; + + Ok((x, evals.into_iter().collect())) +} diff --git a/plonkish_backend/src/backend/unihyperplonk/verifier.rs b/plonkish_backend/src/backend/unihyperplonk/verifier.rs new file mode 100644 index 00000000..0bd073ce --- /dev/null +++ b/plonkish_backend/src/backend/unihyperplonk/verifier.rs @@ -0,0 +1,74 @@ +use crate::{ + backend::hyperplonk::verifier::{instance_evals, pcs_query}, + piop::sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + evaluate, SumCheck, + }, + util::{ + arithmetic::PrimeField, + chain, + expression::{rotate::Lexical, Expression, Query}, + izip, + transcript::FieldTranscriptRead, + Itertools, + }, + Error, +}; + +#[allow(clippy::type_complexity)] +pub(super) fn verify_zero_check( + num_vars: usize, + expression: &Expression, + instances: &[Vec], + challenges: &[F], + y: &[F], + transcript: &mut impl FieldTranscriptRead, +) -> Result<(Vec, Vec<(Query, F)>), Error> { + verify_sum_check( + num_vars, + expression, + F::ZERO, + instances, + challenges, + y, + transcript, + ) +} + +#[allow(clippy::type_complexity)] +pub(super) fn verify_sum_check( + num_vars: usize, + expression: &Expression, + sum: F, + instances: &[Vec], + challenges: &[F], + y: &[F], + transcript: &mut impl FieldTranscriptRead, +) -> Result<(Vec, Vec<(Query, F)>), Error> { + let (x_eval, x) = ClassicSumCheck::, Lexical>::verify( + &(), + num_vars, + expression.degree(), + sum, + transcript, + )?; + + let evals = { + let pcs_query = pcs_query(expression, instances.len()); + let evals = transcript.read_field_elements(pcs_query.len())?; + izip!(pcs_query, evals).collect_vec() + }; + + let query_eval = { + let instance_evals = instance_evals::<_, Lexical>(num_vars, expression, instances, &x); + let evals = chain![evals.iter().copied(), instance_evals].collect(); + evaluate::<_, Lexical>(expression, num_vars, &evals, challenges, &[y], &x) + }; + if query_eval != x_eval { + return Err(Error::InvalidSnark( + "Unmatched between sum_check output and query evaluation".to_string(), + )); + } + + Ok((x, evals.into_iter().collect())) +} diff --git a/plonkish_backend/src/frontend/halo2.rs b/plonkish_backend/src/frontend/halo2.rs index 9bae2ad8..55520587 100644 --- a/plonkish_backend/src/frontend/halo2.rs +++ b/plonkish_backend/src/frontend/halo2.rs @@ -2,8 +2,9 @@ use crate::{ backend::{PlonkishCircuit, PlonkishCircuitInfo, WitnessEncoding}, util::{ arithmetic::{BatchInvert, Field}, + chain, expression::{Expression, Query, Rotation}, - Itertools, + izip, Itertools, }, }; use halo2_proofs::{ @@ -16,7 +17,7 @@ use halo2_proofs::{ use rand::RngCore; use std::{ collections::{HashMap, HashSet}, - iter, mem, + mem, }; #[cfg(any(test, feature = "benchmark"))] @@ -201,15 +202,16 @@ impl> PlonkishCircuit for Halo2Circuit { ) .map_err(|err| crate::Error::InvalidSnark(format!("Synthesize failure: {err:?}")))?; - circuit_info.preprocess_polys = iter::empty() - .chain(batch_invert_assigned(preprocess_collector.fixeds)) - .chain(preprocess_collector.selectors.into_iter().map(|selectors| { + circuit_info.preprocess_polys = chain![ + batch_invert_assigned(preprocess_collector.fixeds), + preprocess_collector.selectors.into_iter().map(|selectors| { selectors .into_iter() .map(|selector| if selector { F::ONE } else { F::ZERO }) .collect() - })) - .collect(); + }), + ] + .collect(); circuit_info.permutations = preprocess_collector.permutation.into_cycles(); Ok(circuit_info) @@ -589,13 +591,14 @@ fn advice_idx(cs: &ConstraintSystem) -> Vec { fn column_idx(cs: &ConstraintSystem) -> HashMap<(Any, usize), usize> { let advice_idx = advice_idx(cs); - iter::empty() - .chain((0..cs.num_instance_columns()).map(|idx| (Any::Instance, idx))) - .chain((0..cs.num_fixed_columns() + cs.num_selectors()).map(|idx| (Any::Fixed, idx))) - .enumerate() - .map(|(idx, column)| (column, idx)) - .chain((0..advice_idx.len()).map(|idx| ((Any::advice(), idx), advice_idx[idx]))) - .collect() + chain![ + (0..cs.num_instance_columns()).map(|idx| (Any::Instance, idx)), + (0..cs.num_fixed_columns() + cs.num_selectors()).map(|idx| (Any::Fixed, idx)), + ] + .enumerate() + .map(|(idx, column)| (column, idx)) + .chain((0..advice_idx.len()).map(|idx| ((Any::advice(), idx), advice_idx[idx]))) + .collect() } fn num_phases(phases: &[u8]) -> usize { @@ -691,9 +694,7 @@ fn batch_invert_assigned(assigneds: Vec>>) -> Vec Self { let mut rand_row = || [(); 8].map(|_| Assigned::Rational(F::random(&mut rng), F::random(&mut rng))); - let values = iter::empty() - .chain(Some(rand_row())) - .chain( - iter::repeat_with(|| { - let mut values = rand_row(); - let [q_l, q_r, q_m, q_o, _, w_l, w_r, w_o] = values; - values[4] = -(q_l * w_l + q_r * w_r + q_m * w_l * w_r + q_o * w_o); - values - }) - .take((1 << k) - 7) - .collect_vec(), - ) - .collect(); + let values = chain![ + [rand_row()], + iter::repeat_with(|| { + let mut values = rand_row(); + let [q_l, q_r, q_m, q_o, _, w_l, w_r, w_o] = values; + values[4] = -(q_l * w_l + q_r * w_r + q_m * w_l * w_r + q_o * w_o); + values + }) + .take((1 << k) - 7) + .collect_vec(), + ] + .collect(); Self(k, values) } diff --git a/plonkish_backend/src/pcs.rs b/plonkish_backend/src/pcs.rs index 6982e5ac..746252b7 100644 --- a/plonkish_backend/src/pcs.rs +++ b/plonkish_backend/src/pcs.rs @@ -1,9 +1,9 @@ use crate::{ poly::Polynomial, util::{ - arithmetic::{variable_base_msm, Curve, CurveAffine, Field}, + arithmetic::Field, transcript::{TranscriptRead, TranscriptWrite}, - DeserializeOwned, Itertools, Serialize, + DeserializeOwned, Serialize, }, Error, }; @@ -154,24 +154,147 @@ impl Evaluation { } } -pub trait AdditiveCommitment: Debug + Default + PartialEq + Eq { - fn sum_with_scalar<'a>( - scalars: impl IntoIterator + 'a, - bases: impl IntoIterator + 'a, +pub trait Additive: Clone + Debug + Default + PartialEq + Eq { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, ) -> Self where - Self: 'a; + Self: 'b; } -impl AdditiveCommitment for C { - fn sum_with_scalar<'a>( - scalars: impl IntoIterator + 'a, - bases: impl IntoIterator + 'a, - ) -> Self { - let scalars = scalars.into_iter().collect_vec(); - let bases = bases.into_iter().collect_vec(); - assert_eq!(scalars.len(), bases.len()); +#[cfg(test)] +mod test { + use crate::{ + pcs::{Evaluation, PolynomialCommitmentScheme}, + poly::Polynomial, + util::{ + arithmetic::PrimeField, + chain, + transcript::{InMemoryTranscript, TranscriptRead, TranscriptWrite}, + Itertools, + }, + }; + use rand::{rngs::OsRng, Rng}; + use std::iter; - variable_base_msm(scalars, bases).to_affine() + pub(super) fn run_commit_open_verify() + where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, + T: TranscriptRead + + TranscriptWrite + + InMemoryTranscript, + { + for k in 3..16 { + // Setup + let (pp, vp) = { + let mut rng = OsRng; + let poly_size = 1 << k; + let param = Pcs::setup(poly_size, 1, &mut rng).unwrap(); + Pcs::trim(¶m, poly_size, 1).unwrap() + }; + // Commit and open + let proof = { + let mut transcript = T::new(()); + let poly = >::rand(1 << k, OsRng); + let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); + let point = >::squeeze_point(k, &mut transcript); + let eval = poly.evaluate(&point); + transcript.write_field_element(&eval).unwrap(); + Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); + transcript.into_proof() + }; + // Verify + let result = { + let mut transcript = T::from_proof((), proof.as_slice()); + Pcs::verify( + &vp, + &Pcs::read_commitment(&vp, &mut transcript).unwrap(), + &>::squeeze_point(k, &mut transcript), + &transcript.read_field_element().unwrap(), + &mut transcript, + ) + }; + assert_eq!(result, Ok(())); + } + } + + pub(super) fn run_batch_commit_open_verify() + where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, + T: TranscriptRead + + TranscriptWrite + + InMemoryTranscript, + { + for k in 3..16 { + let batch_size = 8; + let num_points = batch_size >> 1; + let mut rng = OsRng; + // Setup + let (pp, vp) = { + let poly_size = 1 << k; + let param = Pcs::setup(poly_size, batch_size, &mut rng).unwrap(); + Pcs::trim(¶m, poly_size, batch_size).unwrap() + }; + // Batch commit and open + let evals = chain![ + (0..num_points).map(|point| (0, point)), + (0..batch_size).map(|poly| (poly, 0)), + iter::repeat_with(|| (rng.gen_range(0..batch_size), rng.gen_range(0..num_points))) + .take(batch_size) + ] + .unique() + .collect_vec(); + let proof = { + let mut transcript = T::new(()); + let polys = + iter::repeat_with(|| >::rand(1 << k, OsRng)) + .take(batch_size) + .collect_vec(); + let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); + let points = iter::repeat_with(|| { + >::squeeze_point(k, &mut transcript) + }) + .take(num_points) + .collect_vec(); + let evals = evals + .iter() + .copied() + .map(|(poly, point)| Evaluation { + poly, + point, + value: polys[poly].evaluate(&points[point]), + }) + .collect_vec(); + transcript + .write_field_elements(evals.iter().map(Evaluation::value)) + .unwrap(); + Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); + transcript.into_proof() + }; + // Batch verify + let result = { + let mut transcript = T::from_proof((), proof.as_slice()); + Pcs::batch_verify( + &vp, + &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(), + &iter::repeat_with(|| { + >::squeeze_point(k, &mut transcript) + }) + .take(num_points) + .collect_vec(), + &evals + .iter() + .copied() + .zip(transcript.read_field_elements(evals.len()).unwrap()) + .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) + .collect_vec(), + &mut transcript, + ) + }; + assert_eq!(result, Ok(())); + } } } diff --git a/plonkish_backend/src/pcs/multilinear.rs b/plonkish_backend/src/pcs/multilinear.rs index d0192237..54574879 100644 --- a/plonkish_backend/src/pcs/multilinear.rs +++ b/plonkish_backend/src/pcs/multilinear.rs @@ -12,14 +12,14 @@ mod kzg; mod zeromorph; pub use brakedown::{ - MultilinearBrakedown, MultilinearBrakedownCommitment, MultilinearBrakedownParams, + MultilinearBrakedown, MultilinearBrakedownCommitment, MultilinearBrakedownParam, }; pub use gemini::Gemini; -pub use hyrax::{MultilinearHyrax, MultilinearHyraxCommitment, MultilinearHyraxParams}; -pub use ipa::{MultilinearIpa, MultilinearIpaCommitment, MultilinearIpaParams}; +pub use hyrax::{MultilinearHyrax, MultilinearHyraxCommitment, MultilinearHyraxParam}; +pub use ipa::{MultilinearIpa, MultilinearIpaCommitment, MultilinearIpaParam}; pub use kzg::{ - MultilinearKzg, MultilinearKzgCommitment, MultilinearKzgParams, MultilinearKzgProverParams, - MultilinearKzgVerifierParams, + MultilinearKzg, MultilinearKzgCommitment, MultilinearKzgParam, MultilinearKzgProverParam, + MultilinearKzgVerifierParam, }; pub use zeromorph::{Zeromorph, ZeromorphKzgProverParam, ZeromorphKzgVerifierParam}; @@ -109,8 +109,7 @@ fn quotients( mod additive { use crate::{ pcs::{ - multilinear::validate_input, AdditiveCommitment, Evaluation, Point, - PolynomialCommitmentScheme, + multilinear::validate_input, Additive, Evaluation, Point, PolynomialCommitmentScheme, }, piop::sum_check::{ classic::{ClassicSumCheck, CoefficientsProver}, @@ -118,7 +117,7 @@ mod additive { }, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{inner_product, PrimeField}, + arithmetic::{fe_to_bytes, inner_product, PrimeField}, end_timer, expression::{Expression, Query, Rotation}, start_timer, @@ -143,10 +142,25 @@ mod additive { where F: PrimeField, Pcs: PolynomialCommitmentScheme>, - Pcs::Commitment: AdditiveCommitment, + Pcs::Commitment: Additive, { validate_input("batch open", num_vars, polys.clone(), points)?; + if cfg!(feature = "sanity-check") { + assert_eq!( + points + .iter() + .map(|point| point.iter().map(fe_to_bytes::).collect_vec()) + .unique() + .count(), + points.len() + ); + for eval in evals { + let (poly, point) = (&polys[eval.poly()], &points[eval.point()]); + assert_eq!(poly.evaluate(point), *eval.value()); + } + } + let ell = evals.len().next_power_of_two().ilog2() as usize; let t = transcript.squeeze_challenges(ell); @@ -219,7 +233,7 @@ mod additive { .map(|(eval, eq_xt_i)| eq_xy_evals[eval.point()] * eq_xt_i) .collect_vec(); let bases = evals.iter().map(|eval| comms[eval.poly()]); - Pcs::Commitment::sum_with_scalar(&scalars, bases) + Pcs::Commitment::msm(&scalars, bases) } else { Pcs::Commitment::default() }; @@ -244,7 +258,7 @@ mod additive { where F: PrimeField, Pcs: PolynomialCommitmentScheme>, - Pcs::Commitment: AdditiveCommitment, + Pcs::Commitment: Additive, { validate_input("batch verify", num_vars, [], points)?; @@ -268,139 +282,8 @@ mod additive { .map(|(eval, eq_xt_i)| eq_xy_evals[eval.point()] * eq_xt_i) .collect_vec(); let bases = evals.iter().map(|eval| comms[eval.poly()]); - Pcs::Commitment::sum_with_scalar(&scalars, bases) + Pcs::Commitment::msm(&scalars, bases) }; Pcs::verify(vp, &g_prime_comm, &challenges, &g_prime_eval, transcript) } } - -#[cfg(test)] -mod test { - use crate::{ - pcs::{Evaluation, PolynomialCommitmentScheme}, - poly::multilinear::MultilinearPolynomial, - util::{ - arithmetic::PrimeField, - chain, - transcript::{InMemoryTranscript, TranscriptRead, TranscriptWrite}, - Itertools, - }, - }; - use rand::{rngs::OsRng, Rng}; - use std::iter; - - pub(super) fn run_commit_open_verify() - where - F: PrimeField, - Pcs: PolynomialCommitmentScheme>, - T: TranscriptRead - + TranscriptWrite - + InMemoryTranscript, - { - for num_vars in 3..16 { - // Setup - let (pp, vp) = { - let mut rng = OsRng; - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size, 1, &mut rng).unwrap(); - Pcs::trim(¶m, poly_size, 1).unwrap() - }; - // Commit and open - let proof = { - let mut transcript = T::new(()); - let poly = MultilinearPolynomial::rand(num_vars, OsRng); - let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - let point = transcript.squeeze_challenges(num_vars); - let eval = poly.evaluate(point.as_slice()); - transcript.write_field_element(&eval).unwrap(); - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - transcript.into_proof() - }; - // Verify - let result = { - let mut transcript = T::from_proof((), proof.as_slice()); - Pcs::verify( - &vp, - &Pcs::read_commitment(&vp, &mut transcript).unwrap(), - &transcript.squeeze_challenges(num_vars), - &transcript.read_field_element().unwrap(), - &mut transcript, - ) - }; - assert_eq!(result, Ok(())); - } - } - - pub(super) fn run_batch_commit_open_verify() - where - F: PrimeField, - Pcs: PolynomialCommitmentScheme>, - T: TranscriptRead - + TranscriptWrite - + InMemoryTranscript, - { - for num_vars in 3..16 { - let batch_size = 8; - let num_points = batch_size >> 1; - let mut rng = OsRng; - // Setup - let (pp, vp) = { - let poly_size = 1 << num_vars; - let param = Pcs::setup(poly_size, batch_size, &mut rng).unwrap(); - Pcs::trim(¶m, poly_size, batch_size).unwrap() - }; - // Batch commit and open - let evals = chain![ - (0..num_points).map(|point| (0, point)), - (0..batch_size).map(|poly| (poly, 0)), - iter::repeat_with(|| (rng.gen_range(0..batch_size), rng.gen_range(0..num_points))) - .take(batch_size) - ] - .unique() - .collect_vec(); - let proof = { - let mut transcript = T::new(()); - let polys = iter::repeat_with(|| MultilinearPolynomial::rand(num_vars, OsRng)) - .take(batch_size) - .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - let points = iter::repeat_with(|| transcript.squeeze_challenges(num_vars)) - .take(num_points) - .collect_vec(); - let evals = evals - .iter() - .copied() - .map(|(poly, point)| Evaluation { - poly, - point, - value: polys[poly].evaluate(&points[point]), - }) - .collect_vec(); - transcript - .write_field_elements(evals.iter().map(Evaluation::value)) - .unwrap(); - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - transcript.into_proof() - }; - // Batch verify - let result = { - let mut transcript = T::from_proof((), proof.as_slice()); - Pcs::batch_verify( - &vp, - &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(), - &iter::repeat_with(|| transcript.squeeze_challenges(num_vars)) - .take(num_points) - .collect_vec(), - &evals - .iter() - .copied() - .zip(transcript.read_field_elements(evals.len()).unwrap()) - .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) - .collect_vec(), - &mut transcript, - ) - }; - assert_eq!(result, Ok(())); - } - } -} diff --git a/plonkish_backend/src/pcs/multilinear/brakedown.rs b/plonkish_backend/src/pcs/multilinear/brakedown.rs index 5975acd7..39f961c1 100644 --- a/plonkish_backend/src/pcs/multilinear/brakedown.rs +++ b/plonkish_backend/src/pcs/multilinear/brakedown.rs @@ -31,13 +31,13 @@ impl Clone for MultilinearBrakedown { +pub struct MultilinearBrakedownParam { num_vars: usize, num_rows: usize, brakedown: Brakedown, } -impl MultilinearBrakedownParams { +impl MultilinearBrakedownParam { pub fn num_vars(&self) -> usize { self.num_vars } @@ -92,9 +92,9 @@ where H: Hash, S: BrakedownSpec, { - type Param = MultilinearBrakedownParams; - type ProverParam = MultilinearBrakedownParams; - type VerifierParam = MultilinearBrakedownParams; + type Param = MultilinearBrakedownParam; + type ProverParam = MultilinearBrakedownParam; + type VerifierParam = MultilinearBrakedownParam; type Polynomial = MultilinearPolynomial; type Commitment = MultilinearBrakedownCommitment; type CommitmentChunk = Output; @@ -103,7 +103,7 @@ where assert!(poly_size.is_power_of_two()); let num_vars = poly_size.ilog2() as usize; let brakedown = Brakedown::new_multilinear::(num_vars, 20.min((1 << num_vars) - 1), rng); - Ok(MultilinearBrakedownParams { + Ok(MultilinearBrakedownParam { num_vars, num_rows: (1 << num_vars) / brakedown.row_len(), brakedown, @@ -120,7 +120,7 @@ where Ok((param.clone(), param.clone())) } else { Err(Error::InvalidPcsParam( - "Can't trim MultilinearBrakedownParams into different poly_size".to_string(), + "Can't trim MultilinearBrakedownParam into different poly_size".to_string(), )) } } @@ -437,8 +437,8 @@ fn squeeze_challenge_idx( #[cfg(test)] mod test { use crate::{ - pcs::multilinear::{ - brakedown::MultilinearBrakedown, + pcs::{ + multilinear::brakedown::MultilinearBrakedown, test::{run_batch_commit_open_verify, run_commit_open_verify}, }, util::{code::BrakedownSpec6, hash::Keccak256, transcript::Keccak256Transcript}, diff --git a/plonkish_backend/src/pcs/multilinear/gemini.rs b/plonkish_backend/src/pcs/multilinear/gemini.rs index 7b47b3e1..49af3e1a 100644 --- a/plonkish_backend/src/pcs/multilinear/gemini.rs +++ b/plonkish_backend/src/pcs/multilinear/gemini.rs @@ -4,13 +4,12 @@ use crate::{ pcs::{ multilinear::additive, - univariate::{UnivariateKzg, UnivariateKzgCommitment}, + univariate::{err_too_large_deree, UnivariateKzg, UnivariateKzgCommitment}, Evaluation, Point, PolynomialCommitmentScheme, }, poly::{ multilinear::{merge_into, MultilinearPolynomial}, - univariate::{UnivariateBasis::Monomial, UnivariatePolynomial}, - Polynomial, + univariate::UnivariatePolynomial, }, util::{ arithmetic::{squares, Field, MultiMillerLoop}, @@ -55,11 +54,8 @@ where fn commit(pp: &Self::ProverParam, poly: &Self::Polynomial) -> Result { if pp.degree() + 1 < poly.evals().len() { - return Err(Error::InvalidPcsParam(format!( - "Too large degree of poly to commit (param supports degree up to {} but got {})", - pp.degree(), - poly.evals().len() - ))); + let got = poly.evals().len() - 1; + return Err(err_too_large_deree("commit", pp.degree(), got)); } Ok(UnivariateKzg::commit_monomial(pp, poly.evals())) @@ -85,11 +81,8 @@ where ) -> Result<(), Error> { let num_vars = point.len(); if pp.degree() + 1 < poly.evals().len() { - return Err(Error::InvalidPcsParam(format!( - "Too large degree of poly to open (param supports degree up to {} but got {})", - pp.degree(), - poly.evals().len() - ))); + let got = poly.evals().len() - 1; + return Err(err_too_large_deree("open", pp.degree(), got)); } if cfg!(feature = "sanity-check") { @@ -99,12 +92,12 @@ where let fs = { let mut fs = Vec::with_capacity(num_vars); - fs.push(UnivariatePolynomial::new(Monomial, poly.evals().to_vec())); + fs.push(UnivariatePolynomial::monomial(poly.evals().to_vec())); for x_i in &point[..num_vars - 1] { let f_i_minus_one = fs.last().unwrap().coeffs(); let mut f_i = Vec::with_capacity(f_i_minus_one.len() >> 1); merge_into(&mut f_i, f_i_minus_one, x_i, 1, 0); - fs.push(UnivariatePolynomial::new(Monomial, f_i)); + fs.push(UnivariatePolynomial::monomial(f_i)); } if cfg!(feature = "sanity-check") { @@ -214,10 +207,8 @@ where mod test { use crate::{ pcs::{ - multilinear::{ - gemini::Gemini, - test::{run_batch_commit_open_verify, run_commit_open_verify}, - }, + multilinear::gemini::Gemini, + test::{run_batch_commit_open_verify, run_commit_open_verify}, univariate::UnivariateKzg, }, util::transcript::Keccak256Transcript, diff --git a/plonkish_backend/src/pcs/multilinear/hyrax.rs b/plonkish_backend/src/pcs/multilinear/hyrax.rs index b4b7e4a5..45e3ecb0 100644 --- a/plonkish_backend/src/pcs/multilinear/hyrax.rs +++ b/plonkish_backend/src/pcs/multilinear/hyrax.rs @@ -2,10 +2,10 @@ use crate::{ pcs::{ multilinear::{ additive, err_too_many_variates, - ipa::{MultilinearIpa, MultilinearIpaCommitment, MultilinearIpaParams}, + ipa::{MultilinearIpa, MultilinearIpaCommitment, MultilinearIpaParam}, validate_input, }, - AdditiveCommitment, Evaluation, Point, PolynomialCommitmentScheme, + Additive, Evaluation, Point, PolynomialCommitmentScheme, }, poly::multilinear::MultilinearPolynomial, util::{ @@ -16,7 +16,6 @@ use crate::{ }, Error, }; - use rand::RngCore; use std::{borrow::Cow, iter, marker::PhantomData}; @@ -24,14 +23,14 @@ use std::{borrow::Cow, iter, marker::PhantomData}; pub struct MultilinearHyrax(PhantomData); #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct MultilinearHyraxParams { +pub struct MultilinearHyraxParam { num_vars: usize, batch_num_vars: usize, row_num_vars: usize, - ipa: MultilinearIpaParams, + ipa: MultilinearIpaParam, } -impl MultilinearHyraxParams { +impl MultilinearHyraxParam { pub fn num_vars(&self) -> usize { self.num_vars } @@ -77,10 +76,10 @@ impl AsRef<[C]> for MultilinearHyraxCommitment { } // TODO: Batch all MSMs into one -impl AdditiveCommitment for MultilinearHyraxCommitment { - fn sum_with_scalar<'a>( - scalars: impl IntoIterator + 'a, - bases: impl IntoIterator + 'a, +impl Additive for MultilinearHyraxCommitment { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, ) -> Self { let (scalars, bases) = scalars .into_iter() @@ -108,9 +107,9 @@ where C: CurveAffine + Serialize + DeserializeOwned, C::ScalarExt: Serialize + DeserializeOwned, { - type Param = MultilinearHyraxParams; - type ProverParam = MultilinearHyraxParams; - type VerifierParam = MultilinearHyraxParams; + type Param = MultilinearHyraxParam; + type ProverParam = MultilinearHyraxParam; + type VerifierParam = MultilinearHyraxParam; type Polynomial = MultilinearPolynomial; type Commitment = MultilinearHyraxCommitment; type CommitmentChunk = C; @@ -171,8 +170,8 @@ where let comm = { let mut comm = vec![C::CurveExt::identity(); pp.num_chunks()]; parallelize(&mut comm, |(comm, start)| { - for (comm, start) in comm.iter_mut().zip((start * row_len..).step_by(row_len)) { - *comm = variable_base_msm(&scalars[start..start + row_len], pp.g()); + for (comm, offset) in comm.iter_mut().zip((start * row_len..).step_by(row_len)) { + *comm = variable_base_msm(&scalars[offset..offset + row_len], pp.g()); } }); batch_projective_to_affine(&comm) @@ -198,8 +197,8 @@ where let comms = { let mut comms = vec![C::CurveExt::identity(); scalars.len()]; parallelize(&mut comms, |(comms, start)| { - for (comm, scalars) in comms.iter_mut().zip(&scalars[start..]) { - *comm = variable_base_msm(*scalars, pp.g()); + for (comm, row) in comms.iter_mut().zip(&scalars[start..]) { + *comm = variable_base_msm(*row, pp.g()); } }); batch_projective_to_affine(&comms) @@ -268,14 +267,13 @@ where num_polys: usize, transcript: &mut impl TranscriptRead, ) -> Result, Error> { - let comms = iter::repeat_with(|| { + iter::repeat_with(|| { transcript .read_commitments(vp.num_chunks()) .map(MultilinearHyraxCommitment) }) .take(num_polys) - .try_collect()?; - Ok(comms) + .collect() } fn verify( @@ -297,7 +295,6 @@ where variable_base_msm(&scalars, &comm.0).into() }) }; - MultilinearIpa::verify(&vp.ipa, &comm, &lo.to_vec(), eval, transcript) } @@ -316,8 +313,8 @@ where #[cfg(test)] mod test { use crate::{ - pcs::multilinear::{ - hyrax::MultilinearHyrax, + pcs::{ + multilinear::hyrax::MultilinearHyrax, test::{run_batch_commit_open_verify, run_commit_open_verify}, }, util::transcript::Keccak256Transcript, diff --git a/plonkish_backend/src/pcs/multilinear/ipa.rs b/plonkish_backend/src/pcs/multilinear/ipa.rs index d9437c64..a62c58b6 100644 --- a/plonkish_backend/src/pcs/multilinear/ipa.rs +++ b/plonkish_backend/src/pcs/multilinear/ipa.rs @@ -1,36 +1,34 @@ use crate::{ pcs::{ multilinear::{additive, err_too_many_variates, validate_input}, - AdditiveCommitment, Evaluation, Point, PolynomialCommitmentScheme, + univariate::ipa::{prove_bulletproof_reduction, verify_bulletproof_reduction}, + Additive, Evaluation, Point, PolynomialCommitmentScheme, }, poly::multilinear::MultilinearPolynomial, util::{ arithmetic::{ - batch_projective_to_affine, inner_product, variable_base_msm, Curve, CurveAffine, - CurveExt, Field, Group, + batch_projective_to_affine, variable_base_msm, Curve, CurveAffine, CurveExt, Group, }, - chain, parallel::parallelize, transcript::{TranscriptRead, TranscriptWrite}, - Deserialize, DeserializeOwned, Itertools, Serialize, + Deserialize, DeserializeOwned, Either, Itertools, Serialize, }, Error, }; -use halo2_curves::group::ff::BatchInvert; use rand::RngCore; -use std::{iter, marker::PhantomData, slice}; +use std::{marker::PhantomData, slice}; #[derive(Clone, Debug)] pub struct MultilinearIpa(PhantomData); #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct MultilinearIpaParams { +pub struct MultilinearIpaParam { num_vars: usize, g: Vec, h: C, } -impl MultilinearIpaParams { +impl MultilinearIpaParam { pub fn num_vars(&self) -> usize { self.num_vars } @@ -71,15 +69,13 @@ impl From for MultilinearIpaCommitment { } } -impl AdditiveCommitment for MultilinearIpaCommitment { - fn sum_with_scalar<'a>( - scalars: impl IntoIterator + 'a, - bases: impl IntoIterator + 'a, +impl Additive for MultilinearIpaCommitment { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, ) -> Self { let scalars = scalars.into_iter().collect_vec(); let bases = bases.into_iter().map(|base| &base.0).collect_vec(); - assert_eq!(scalars.len(), bases.len()); - MultilinearIpaCommitment(variable_base_msm(scalars, bases).to_affine()) } } @@ -89,9 +85,9 @@ where C: CurveAffine + Serialize + DeserializeOwned, C::ScalarExt: Serialize + DeserializeOwned, { - type Param = MultilinearIpaParams; - type ProverParam = MultilinearIpaParams; - type VerifierParam = MultilinearIpaParams; + type Param = MultilinearIpaParam; + type ProverParam = MultilinearIpaParam; + type VerifierParam = MultilinearIpaParam; type Polynomial = MultilinearPolynomial; type Commitment = MultilinearIpaCommitment; type CommitmentChunk = C; @@ -175,62 +171,10 @@ where assert_eq!(poly.evaluate(point), *eval); } - let xi_0 = transcript.squeeze_challenge(); - let h_prime = (pp.h * xi_0).to_affine(); - - let mut bases = pp.g().to_vec(); - let mut coeffs = poly.evals().to_vec(); - let mut zs = MultilinearPolynomial::eq_xy(point).into_evals(); - - for i in 0..pp.num_vars() { - let mid = 1 << (pp.num_vars() - i - 1); - - let (bases_l, bases_r) = bases.split_at(mid); - let (coeffs_l, coeffs_r) = coeffs.split_at(mid); - let (zs_l, zs_r) = zs.split_at(mid); - let (c_l, c_r) = (inner_product(coeffs_r, zs_l), inner_product(coeffs_l, zs_r)); - let l_i = variable_base_msm( - chain![coeffs_r, Some(&c_l)], - chain![bases_l, Some(&h_prime)], - ); - let r_i = variable_base_msm( - chain![coeffs_l, Some(&c_r)], - chain![bases_r, Some(&h_prime)], - ); - transcript.write_commitment(&l_i.to_affine())?; - transcript.write_commitment(&r_i.to_affine())?; - - let xi_i = transcript.squeeze_challenge(); - let xi_i_inv = xi_i.invert().unwrap(); - - let (bases_l, bases_r) = bases.split_at_mut(mid); - let (coeffs_l, coeffs_r) = coeffs.split_at_mut(mid); - let (zs_l, zs_r) = zs.split_at_mut(mid); - parallelize(bases_l, |(bases_l, start)| { - let mut tmp = Vec::with_capacity(bases_l.len()); - for (lhs, rhs) in bases_l.iter().zip(bases_r[start..].iter()) { - tmp.push(lhs.to_curve() + *rhs * xi_i); - } - C::Curve::batch_normalize(&tmp, bases_l); - }); - parallelize(coeffs_l, |(coeffs_l, start)| { - for (lhs, rhs) in coeffs_l.iter_mut().zip(coeffs_r[start..].iter()) { - *lhs += xi_i_inv * rhs; - } - }); - parallelize(zs_l, |(zs_l, start)| { - for (lhs, rhs) in zs_l.iter_mut().zip(zs_r[start..].iter()) { - *lhs += xi_i * rhs; - } - }); - bases.truncate(mid); - coeffs.truncate(mid); - zs.truncate(mid); - } - - transcript.write_field_element(&coeffs[0])?; - - Ok(()) + let bases = pp.g(); + let coeffs = poly.evals(); + let zs = MultilinearPolynomial::eq_xy(point).into_evals(); + prove_bulletproof_reduction(bases, pp.h(), coeffs, zs, transcript) } fn batch_open<'a>( @@ -266,35 +210,9 @@ where eval: &C::Scalar, transcript: &mut impl TranscriptRead, ) -> Result<(), Error> { - validate_input("verify", vp.num_vars(), [], [point])?; - - let xi_0 = transcript.squeeze_challenge(); - - let (ls, rs, xis) = iter::repeat_with(|| { - Ok(( - transcript.read_commitment()?, - transcript.read_commitment()?, - transcript.squeeze_challenge(), - )) - }) - .take(vp.num_vars()) - .collect::, _>>()? - .into_iter() - .multiunzip::<(Vec<_>, Vec<_>, Vec<_>)>(); - let neg_c = -transcript.read_field_element()?; - - let xi_invs = { - let mut xi_invs = xis.clone(); - xi_invs.iter_mut().batch_invert(); - xi_invs - }; - let neg_c_h = MultilinearPolynomial::new(h_coeffs(neg_c, &xis)); - let u = &(xi_0 * (neg_c_h.evaluate(point) + eval)); - let scalars = chain![&xi_invs, &xis, neg_c_h.evals(), Some(u)]; - let bases = chain![&ls, &rs, vp.g(), Some(vp.h())]; - bool::from((variable_base_msm(scalars, bases) + comm.0).is_identity()) - .then_some(()) - .ok_or_else(|| Error::InvalidPcsOpen("Invalid multilinear IPA open".to_string())) + let bases = vp.g(); + let point = Either::Right(point.as_slice()); + verify_bulletproof_reduction(bases, vp.h(), comm, point, eval, transcript) } fn batch_verify<'a>( @@ -309,31 +227,11 @@ where } } -fn h_coeffs(scalar: F, xi: &[F]) -> Vec { - assert!(!xi.is_empty()); - - let mut coeffs = vec![F::ZERO; 1 << xi.len()]; - coeffs[0] = scalar; - - for (len, xi) in xi.iter().rev().enumerate().map(|(i, xi)| (1 << i, xi)) { - let (left, right) = coeffs.split_at_mut(len); - let right = &mut right[0..len]; - right.copy_from_slice(left); - parallelize(right, |(right, _)| { - for coeff in right { - *coeff *= xi; - } - }); - } - - coeffs -} - #[cfg(test)] mod test { use crate::{ - pcs::multilinear::{ - ipa::MultilinearIpa, + pcs::{ + multilinear::ipa::MultilinearIpa, test::{run_batch_commit_open_verify, run_commit_open_verify}, }, util::transcript::Keccak256Transcript, diff --git a/plonkish_backend/src/pcs/multilinear/kzg.rs b/plonkish_backend/src/pcs/multilinear/kzg.rs index 71348fde..3c375b4c 100644 --- a/plonkish_backend/src/pcs/multilinear/kzg.rs +++ b/plonkish_backend/src/pcs/multilinear/kzg.rs @@ -1,7 +1,7 @@ use crate::{ pcs::{ multilinear::{additive, err_too_many_variates, quotients, validate_input}, - AdditiveCommitment, Evaluation, Point, PolynomialCommitmentScheme, + Additive, Evaluation, Point, PolynomialCommitmentScheme, }, poly::multilinear::MultilinearPolynomial, util::{ @@ -9,7 +9,7 @@ use crate::{ batch_projective_to_affine, fixed_base_msm, variable_base_msm, window_size, window_table, Curve, CurveAffine, Field, MultiMillerLoop, PrimeCurveAffine, }, - izip, + chain, izip, parallel::parallelize, transcript::{TranscriptRead, TranscriptWrite}, Deserialize, DeserializeOwned, Itertools, Serialize, @@ -23,14 +23,14 @@ use std::{iter, marker::PhantomData, ops::Neg, slice}; pub struct MultilinearKzg(PhantomData); #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct MultilinearKzgParams { +pub struct MultilinearKzgParam { g1: M::G1Affine, eqs: Vec>, g2: M::G2Affine, ss: Vec, } -impl MultilinearKzgParams { +impl MultilinearKzgParam { pub fn num_vars(&self) -> usize { self.eqs.len() } @@ -53,12 +53,12 @@ impl MultilinearKzgParams { } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct MultilinearKzgProverParams { +pub struct MultilinearKzgProverParam { g1: M::G1Affine, eqs: Vec>, } -impl MultilinearKzgProverParams { +impl MultilinearKzgProverParam { pub fn num_vars(&self) -> usize { self.eqs.len() - 1 } @@ -77,13 +77,13 @@ impl MultilinearKzgProverParams { } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct MultilinearKzgVerifierParams { +pub struct MultilinearKzgVerifierParam { g1: M::G1Affine, g2: M::G2Affine, ss: Vec, } -impl MultilinearKzgVerifierParams { +impl MultilinearKzgVerifierParam { pub fn num_vars(&self) -> usize { self.ss.len() } @@ -136,15 +136,13 @@ impl From for MultilinearKzgCommitment { } } -impl AdditiveCommitment for MultilinearKzgCommitment { - fn sum_with_scalar<'a>( - scalars: impl IntoIterator + 'a, - bases: impl IntoIterator + 'a, +impl Additive for MultilinearKzgCommitment { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, ) -> Self { let scalars = scalars.into_iter().collect_vec(); let bases = bases.into_iter().map(|base| &base.0).collect_vec(); - assert_eq!(scalars.len(), bases.len()); - MultilinearKzgCommitment(variable_base_msm(scalars, bases).to_affine()) } } @@ -156,15 +154,16 @@ where M::G1Affine: Serialize + DeserializeOwned, M::G2Affine: Serialize + DeserializeOwned, { - type Param = MultilinearKzgParams; - type ProverParam = MultilinearKzgProverParams; - type VerifierParam = MultilinearKzgVerifierParams; + type Param = MultilinearKzgParam; + type ProverParam = MultilinearKzgProverParam; + type VerifierParam = MultilinearKzgVerifierParam; type Polynomial = MultilinearPolynomial; type Commitment = MultilinearKzgCommitment; type CommitmentChunk = M::G1Affine; fn setup(poly_size: usize, _: usize, mut rng: impl RngCore) -> Result { assert!(poly_size.is_power_of_two()); + let num_vars = poly_size.ilog2() as usize; let ss = iter::repeat_with(|| M::Scalar::random(&mut rng)) .take(num_vars) @@ -330,21 +329,21 @@ where let window_size = window_size(point.len()); let window_table = window_table(window_size, vp.g2); - let rhs = iter::empty() - .chain(Some(vp.g2.neg())) - .chain( - vp.ss(point.len()) - .iter() - .cloned() - .zip_eq(fixed_base_msm(window_size, &window_table, point)) - .map(|(s_i, x_i)| (s_i - x_i.into()).into()), - ) - .map_into() - .collect_vec(); - let lhs = iter::empty() - .chain(Some((comm.0.to_curve() - vp.g1 * eval).into())) - .chain(quotients.iter().cloned()) - .collect_vec(); + let rhs = chain![ + [vp.g2.neg()], + vp.ss(point.len()) + .iter() + .cloned() + .zip_eq(fixed_base_msm(window_size, &window_table, point)) + .map(|(s_i, x_i)| (s_i - x_i.into()).into()), + ] + .map_into() + .collect_vec(); + let lhs = chain![ + [(comm.0.to_curve() - vp.g1 * eval).into()], + quotients.iter().cloned() + ] + .collect_vec(); M::pairings_product_is_identity(&lhs.iter().zip_eq(rhs.iter()).collect_vec()) .then_some(()) .ok_or_else(|| Error::InvalidPcsOpen("Invalid multilinear KZG open".to_string())) @@ -365,8 +364,8 @@ where #[cfg(test)] mod test { use crate::{ - pcs::multilinear::{ - kzg::MultilinearKzg, + pcs::{ + multilinear::kzg::MultilinearKzg, test::{run_batch_commit_open_verify, run_commit_open_verify}, }, util::transcript::Keccak256Transcript, diff --git a/plonkish_backend/src/pcs/multilinear/zeromorph.rs b/plonkish_backend/src/pcs/multilinear/zeromorph.rs index aa553caa..ffd18abc 100644 --- a/plonkish_backend/src/pcs/multilinear/zeromorph.rs +++ b/plonkish_backend/src/pcs/multilinear/zeromorph.rs @@ -1,14 +1,13 @@ use crate::{ pcs::{ multilinear::{additive, quotients}, - univariate::{UnivariateKzg, UnivariateKzgProverParam, UnivariateKzgVerifierParam}, + univariate::{ + err_too_large_deree, UnivariateKzg, UnivariateKzgProverParam, + UnivariateKzgVerifierParam, + }, Evaluation, Point, PolynomialCommitmentScheme, }, - poly::{ - multilinear::MultilinearPolynomial, - univariate::{UnivariateBasis::Monomial, UnivariatePolynomial}, - Polynomial, - }, + poly::{multilinear::MultilinearPolynomial, univariate::UnivariatePolynomial}, util::{ arithmetic::{ powers, squares, variable_base_msm, BatchInvert, Curve, Field, MultiMillerLoop, @@ -82,6 +81,8 @@ where as PolynomialCommitmentScheme>::CommitmentChunk; fn setup(poly_size: usize, batch_size: usize, rng: impl RngCore) -> Result { + assert!(poly_size.is_power_of_two()); + UnivariateKzg::::setup(poly_size, batch_size, rng) } @@ -90,11 +91,13 @@ where poly_size: usize, batch_size: usize, ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + assert!(poly_size.is_power_of_two()); + let (commit_pp, vp) = UnivariateKzg::::trim(param, poly_size, batch_size)?; let offset = param.monomial_g1().len() - poly_size; let open_pp = { let monomial_g1 = param.monomial_g1()[offset..].to_vec(); - UnivariateKzgProverParam::new(monomial_g1, Vec::new()) + UnivariateKzgProverParam::new(poly_size.ilog2() as usize, monomial_g1, Vec::new()) }; let s_offset_g2 = param.powers_of_s_g2()[offset]; @@ -106,11 +109,8 @@ where fn commit(pp: &Self::ProverParam, poly: &Self::Polynomial) -> Result { if pp.degree() + 1 < poly.evals().len() { - return Err(Error::InvalidPcsParam(format!( - "Too large degree of poly to commit (param supports degree up to {} but got {})", - pp.degree(), - poly.evals().len() - ))); + let got = poly.evals().len() - 1; + return Err(err_too_large_deree("commit", pp.degree(), got)); } Ok(UnivariateKzg::commit_monomial(&pp.commit_pp, poly.evals())) @@ -136,11 +136,8 @@ where ) -> Result<(), Error> { let num_vars = poly.num_vars(); if pp.degree() + 1 < poly.evals().len() { - return Err(Error::InvalidPcsParam(format!( - "Too large degree of poly to open (param supports degree up to {} but got {})", - pp.degree(), - poly.evals().len() - ))); + let got = poly.evals().len() - 1; + return Err(err_too_large_deree("open", pp.degree(), got)); } if cfg!(feature = "sanity-check") { @@ -149,7 +146,7 @@ where } let (quotients, remainder) = - quotients(poly, point, |_, q| UnivariatePolynomial::new(Monomial, q)); + quotients(poly, point, |_, q| UnivariatePolynomial::monomial(q)); UnivariateKzg::batch_commit_and_write(&pp.commit_pp, "ients, transcript)?; if cfg!(feature = "sanity-check") { @@ -167,7 +164,7 @@ where .for_each(|(q_hat, q)| *q_hat += power_of_y * q) }); } - UnivariatePolynomial::new(Monomial, q_hat) + UnivariatePolynomial::monomial(q_hat) }; UnivariateKzg::commit_and_write(&pp.commit_pp, &q_hat, transcript)?; @@ -176,7 +173,7 @@ where let (eval_scalar, q_scalars) = eval_and_quotient_scalars(y, x, z, point); - let mut f = UnivariatePolynomial::new(Monomial, poly.evals().to_vec()); + let mut f = UnivariatePolynomial::monomial(poly.evals().to_vec()); f *= &z; f += &q_hat; f[0] += eval_scalar * eval; @@ -288,7 +285,7 @@ fn eval_and_quotient_scalars(y: F, x: F, z: F, u: &[F]) -> (F, Vec) .iter() .map(|square_of_x| *square_of_x - F::ONE) .collect_vec(); - v_denoms.iter_mut().batch_invert(); + v_denoms.batch_invert(); v_denoms .iter() .map(|v_denom| v_numer * v_denom) @@ -307,10 +304,8 @@ fn eval_and_quotient_scalars(y: F, x: F, z: F, u: &[F]) -> (F, Vec) mod test { use crate::{ pcs::{ - multilinear::{ - test::{run_batch_commit_open_verify, run_commit_open_verify}, - zeromorph::Zeromorph, - }, + multilinear::zeromorph::Zeromorph, + test::{run_batch_commit_open_verify, run_commit_open_verify}, univariate::UnivariateKzg, }, util::transcript::Keccak256Transcript, diff --git a/plonkish_backend/src/pcs/univariate.rs b/plonkish_backend/src/pcs/univariate.rs index 09310f1c..c58db3ba 100644 --- a/plonkish_backend/src/pcs/univariate.rs +++ b/plonkish_backend/src/pcs/univariate.rs @@ -1,30 +1,345 @@ -use crate::util::{ - arithmetic::{ - batch_projective_to_affine, fft, root_of_unity_inv, CurveAffine, Field, PrimeField, +use crate::{ + poly::univariate::{UnivariateBasis::*, UnivariatePolynomial}, + util::{ + arithmetic::{ + batch_projective_to_affine, radix2_fft, root_of_unity_inv, squares, CurveAffine, Field, + PrimeField, + }, + parallel::parallelize, + Itertools, }, - parallel::parallelize, - Itertools, + Error, }; +mod hyrax; +pub(super) mod ipa; mod kzg; +pub use hyrax::{ + UnivariateHyrax, UnivariateHyraxCommitment, UnivariateHyraxParam, UnivariateHyraxVerifierParam, +}; +pub use ipa::{ + UnivariateIpa, UnivariateIpaCommitment, UnivariateIpaParam, UnivariateIpaVerifierParam, +}; pub use kzg::{ UnivariateKzg, UnivariateKzgCommitment, UnivariateKzgParam, UnivariateKzgProverParam, UnivariateKzgVerifierParam, }; -fn monomial_g1_to_lagrange_g1(monomial_g1: &[C]) -> Vec { - assert!(monomial_g1.len().is_power_of_two()); +fn monomial_g_to_lagrange_g(monomial_g: &[C]) -> Vec { + assert!(monomial_g.len().is_power_of_two()); - let k = monomial_g1.len().ilog2() as usize; - let n_inv = C::Scalar::TWO_INV.pow_vartime([k as u64]); + let k = monomial_g.len().ilog2() as usize; + let n_inv = squares(C::Scalar::TWO_INV).nth(k).unwrap(); let omega_inv = root_of_unity_inv(k); - let mut lagrange = monomial_g1.iter().map(C::to_curve).collect_vec(); - fft(&mut lagrange, omega_inv, k); + let mut lagrange = monomial_g.iter().map(C::to_curve).collect_vec(); + radix2_fft(&mut lagrange, omega_inv, k); parallelize(&mut lagrange, |(g, _)| { g.iter_mut().for_each(|g| *g *= n_inv) }); batch_projective_to_affine(&lagrange) } + +fn validate_input<'a, F: Field>( + function: &str, + param_degree: usize, + polys: impl IntoIterator>, +) -> Result<(), Error> { + let polys = polys.into_iter().collect_vec(); + for poly in polys.iter() { + match poly.basis() { + Monomial => { + if param_degree < poly.degree() { + return Err(err_too_large_deree(function, param_degree, poly.degree())); + } + } + Lagrange => { + if param_degree + 1 != poly.coeffs().len() { + return Err(err_invalid_evals_len(param_degree, poly.coeffs().len() - 1)); + } + } + } + } + Ok(()) +} + +pub(super) fn err_too_large_deree(function: &str, upto: usize, got: usize) -> Error { + Error::InvalidPcsParam(if function == "trim" { + format!("Too large degree to {function} (param supports degree up to {upto} but got {got})") + } else { + format!( + "Too large degree of poly to {function} (param supports degree up to {upto} but got {got})" + ) + }) +} + +fn err_invalid_evals_len(expected: usize, got: usize) -> Error { + Error::InvalidPcsParam(format!( + "Invalid number of poly evaluations to commit (param needs {expected} evaluations but got {got})" + )) +} + +mod additive { + use crate::{ + pcs::{Additive, Evaluation, Point, PolynomialCommitmentScheme}, + poly::univariate::UnivariatePolynomial, + util::{ + arithmetic::{ + barycentric_interpolate, barycentric_weights, fe_to_bytes, inner_product, powers, + Field, PrimeField, + }, + chain, izip, izip_eq, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, + }; + use std::collections::BTreeSet; + + pub fn batch_open( + pp: &Pcs::ProverParam, + polys: Vec<&Pcs::Polynomial>, + comms: Vec<&Pcs::Commitment>, + points: &[Point], + evals: &[Evaluation], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> + where + F: PrimeField, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: Additive, + { + if cfg!(feature = "sanity-check") { + assert_eq!( + points.iter().map(fe_to_bytes::).unique().count(), + points.len() + ); + for eval in evals { + let (poly, point) = (&polys[eval.poly()], &points[eval.point()]); + assert_eq!(poly.evaluate(point), *eval.value()); + } + } + + let (sets, superset) = eval_sets(evals); + + let beta = transcript.squeeze_challenge(); + let gamma = transcript.squeeze_challenge(); + + let max_set_len = sets.iter().map(|set| set.polys.len()).max().unwrap(); + let powers_of_beta = powers(beta).take(max_set_len).collect_vec(); + let powers_of_gamma = powers(gamma).take(sets.len()).collect_vec(); + let (fs, (qs, rs)) = sets + .iter() + .map(|set| { + let vanishing_poly = set.vanishing_poly(points); + let f = izip!(&powers_of_beta, set.polys.iter().map(|poly| polys[*poly])) + .sum::>(); + let (q, r) = f.div_rem(&vanishing_poly); + (f, (q, r)) + }) + .unzip::<_, _, Vec<_>, (Vec<_>, Vec<_>)>(); + let q = izip_eq!(&powers_of_gamma, qs.iter()).sum::>(); + + let q_comm = Pcs::commit_and_write(pp, &q, transcript)?; + + let z = transcript.squeeze_challenge(); + + let (normalized_scalars, normalizer) = set_scalars(&sets, &powers_of_gamma, points, &z); + let superset_eval = vanishing_eval(superset.iter().map(|idx| &points[*idx]), &z); + let q_scalar = -superset_eval * normalizer; + let f = { + let mut f = izip_eq!(&normalized_scalars, &fs).sum::>(); + f += (&q_scalar, &q); + f + }; + let (comm, eval) = if cfg!(feature = "sanity-check") { + let scalars = comm_scalars(comms.len(), &sets, &powers_of_beta, &normalized_scalars); + let comm = + Pcs::Commitment::msm(chain![&scalars, [&q_scalar]], chain![comms, [&q_comm]]); + let r_evals = rs.iter().map(|r| r.evaluate(&z)).collect_vec(); + (comm, inner_product(&normalized_scalars, &r_evals)) + } else { + (Pcs::Commitment::default(), F::ZERO) + }; + Pcs::open(pp, &f, &comm, &z, &eval, transcript) + } + + pub fn batch_verify( + vp: &Pcs::VerifierParam, + comms: Vec<&Pcs::Commitment>, + points: &[Point], + evals: &[Evaluation], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> + where + F: PrimeField, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: Additive, + { + let (sets, superset) = eval_sets(evals); + + let beta = transcript.squeeze_challenge(); + let gamma = transcript.squeeze_challenge(); + + let q_comm = Pcs::read_commitment(vp, transcript)?; + + let z = transcript.squeeze_challenge(); + + let max_set_len = sets.iter().map(|set| set.polys.len()).max().unwrap(); + let powers_of_beta = powers(beta).take(max_set_len).collect_vec(); + let powers_of_gamma = powers(gamma).take(sets.len()).collect_vec(); + + let (normalized_scalars, normalizer) = set_scalars(&sets, &powers_of_gamma, points, &z); + let f = { + let scalars = comm_scalars(comms.len(), &sets, &powers_of_beta, &normalized_scalars); + let superset_eval = vanishing_eval(superset.iter().map(|idx| &points[*idx]), &z); + let q_scalar = -superset_eval * normalizer; + Pcs::Commitment::msm(chain![&scalars, [&q_scalar]], chain![comms, [&q_comm]]) + }; + let eval = inner_product( + &normalized_scalars, + &sets + .iter() + .map(|set| set.r_eval(points, &z, &powers_of_beta)) + .collect_vec(), + ); + Pcs::verify(vp, &f, &z, &eval, transcript) + } + + #[derive(Debug)] + struct EvaluationSet { + polys: Vec, + points: Vec, + diffs: Vec, + evals: Vec>, + } + + impl EvaluationSet { + fn vanishing_diff_eval(&self, points: &[F], z: &F) -> F { + self.diffs + .iter() + .map(|idx| points[*idx]) + .fold(F::ONE, |eval, point| eval * (*z - point)) + } + + fn vanishing_poly(&self, points: &[F]) -> UnivariatePolynomial { + UnivariatePolynomial::vanishing(self.points.iter().map(|point| &points[*point]), F::ONE) + } + + fn r_eval(&self, points: &[F], z: &F, powers_of_beta: &[F]) -> F { + let points = self.points.iter().map(|idx| points[*idx]).collect_vec(); + let weights = barycentric_weights(&points); + let r_evals = self + .evals + .iter() + .map(|evals| barycentric_interpolate(&weights, &points, evals, z)) + .collect_vec(); + inner_product(&powers_of_beta[..r_evals.len()], &r_evals) + } + } + + fn eval_sets(evals: &[Evaluation]) -> (Vec>, BTreeSet) { + let (poly_shifts, superset) = evals.iter().fold( + (Vec::<(usize, Vec, Vec)>::new(), BTreeSet::new()), + |(mut poly_shifts, mut superset), eval| { + if let Some(pos) = poly_shifts + .iter() + .position(|(poly, _, _)| *poly == eval.poly) + { + let (_, points, evals) = &mut poly_shifts[pos]; + if !points.contains(&eval.point) { + points.push(eval.point); + evals.push(*eval.value()); + } + } else { + poly_shifts.push((eval.poly, vec![eval.point], vec![*eval.value()])); + } + superset.insert(eval.point()); + (poly_shifts, superset) + }, + ); + + let sets = poly_shifts.into_iter().fold( + Vec::>::new(), + |mut sets, (poly, points, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.points.iter()) == BTreeSet::from_iter(points.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.points + .iter() + .map(|lhs| { + let idx = points.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx] + }) + .collect(), + ); + } + } else { + let diffs = superset + .iter() + .filter(|idx| !points.contains(idx)) + .copied() + .collect(); + sets.push(EvaluationSet { + polys: vec![poly], + points, + diffs, + evals: vec![evals], + }); + } + sets + }, + ); + + (sets, superset) + } + + fn set_scalars( + sets: &[EvaluationSet], + powers_of_gamma: &[F], + points: &[F], + z: &F, + ) -> (Vec, F) { + let vanishing_diff_evals = sets + .iter() + .map(|set| set.vanishing_diff_eval(points, z)) + .collect_vec(); + // Adopt fflonk's trick to normalize the set scalars by the one of first set, + // to save 1 EC scalar multiplication for verifier. + let normalizer = vanishing_diff_evals[0].invert().unwrap_or(F::ONE); + let normalized_scalars = izip_eq!(powers_of_gamma, &vanishing_diff_evals) + .map(|(power_of_gamma, vanishing_diff_eval)| { + normalizer * vanishing_diff_eval * power_of_gamma + }) + .collect_vec(); + (normalized_scalars, normalizer) + } + + fn vanishing_eval<'a, F: Field>(points: impl IntoIterator, z: &F) -> F { + points + .into_iter() + .fold(F::ONE, |eval, point| eval * (*z - point)) + } + + fn comm_scalars( + num_polys: usize, + sets: &[EvaluationSet], + powers_of_beta: &[F], + normalized_scalars: &[F], + ) -> Vec { + sets.iter().zip(normalized_scalars).fold( + vec![F::ZERO; num_polys], + |mut scalars, (set, coeff)| { + izip!(&set.polys, powers_of_beta) + .for_each(|(poly, power_of_beta)| scalars[*poly] = *coeff * power_of_beta); + scalars + }, + ) + } +} diff --git a/plonkish_backend/src/pcs/univariate/hyrax.rs b/plonkish_backend/src/pcs/univariate/hyrax.rs new file mode 100644 index 00000000..ac94473d --- /dev/null +++ b/plonkish_backend/src/pcs/univariate/hyrax.rs @@ -0,0 +1,402 @@ +use crate::{ + pcs::{ + univariate::{ + additive, err_too_large_deree, + ipa::{ + UnivariateIpa, UnivariateIpaCommitment, UnivariateIpaParam, + UnivariateIpaVerifierParam, + }, + validate_input, + }, + Additive, Evaluation, Point, PolynomialCommitmentScheme, + }, + poly::univariate::{UnivariateBasis::*, UnivariatePolynomial}, + util::{ + arithmetic::{ + batch_projective_to_affine, div_ceil, powers, squares, variable_base_msm, CurveAffine, + Field, Group, + }, + chain, izip, + parallel::parallelize, + transcript::{TranscriptRead, TranscriptWrite}, + Deserialize, DeserializeOwned, Itertools, Serialize, + }, + Error, +}; +use rand::RngCore; +use std::{borrow::Cow, iter, marker::PhantomData}; + +#[derive(Clone, Debug)] +pub struct UnivariateHyrax(PhantomData); + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnivariateHyraxParam { + k: usize, + batch_k: usize, + row_k: usize, + ipa: UnivariateIpaParam, +} + +impl UnivariateHyraxParam { + pub fn k(&self) -> usize { + self.k + } + + pub fn degree(&self) -> usize { + (1 << self.k) - 1 + } + + pub fn batch_k(&self) -> usize { + self.batch_k + } + + pub fn row_k(&self) -> usize { + self.row_k + } + + pub fn row_len(&self) -> usize { + 1 << self.row_k + } + + pub fn num_chunks(&self) -> usize { + 1 << (self.k - self.row_k) + } + + pub fn monomial(&self) -> &[C] { + self.ipa.monomial() + } + + pub fn lagrange(&self) -> &[C] { + self.ipa.lagrange() + } + + pub fn h(&self) -> &C { + self.ipa.h() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnivariateHyraxVerifierParam { + k: usize, + batch_k: usize, + row_k: usize, + ipa: UnivariateIpaVerifierParam, +} + +impl UnivariateHyraxVerifierParam { + pub fn k(&self) -> usize { + self.k + } + + pub fn row_k(&self) -> usize { + self.row_k + } + + pub fn num_chunks(&self) -> usize { + 1 << (self.k - self.row_k) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct UnivariateHyraxCommitment(pub Vec); + +impl Default for UnivariateHyraxCommitment { + fn default() -> Self { + Self(Vec::new()) + } +} + +impl AsRef<[C]> for UnivariateHyraxCommitment { + fn as_ref(&self) -> &[C] { + &self.0 + } +} + +// TODO: Batch all MSMs into one +impl Additive for UnivariateHyraxCommitment { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, + ) -> Self { + let (scalars, bases) = scalars + .into_iter() + .zip_eq(bases) + .filter_map(|(scalar, bases)| (bases != &Self::default()).then_some((scalar, bases))) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let num_chunks = bases[0].0.len(); + for bases in bases.iter() { + assert_eq!(bases.0.len(), num_chunks); + } + + let mut output = vec![C::CurveExt::identity(); num_chunks]; + parallelize(&mut output, |(output, start)| { + for (output, idx) in output.iter_mut().zip(start..) { + *output = variable_base_msm(scalars.clone(), bases.iter().map(|base| &base.0[idx])) + } + }); + UnivariateHyraxCommitment(batch_projective_to_affine(&output)) + } +} + +impl PolynomialCommitmentScheme for UnivariateHyrax +where + C: CurveAffine + Serialize + DeserializeOwned, + C::ScalarExt: Serialize + DeserializeOwned, +{ + type Param = UnivariateHyraxParam; + type ProverParam = UnivariateHyraxParam; + type VerifierParam = UnivariateHyraxVerifierParam; + type Polynomial = UnivariatePolynomial; + type Commitment = UnivariateHyraxCommitment; + type CommitmentChunk = C; + + fn setup(poly_size: usize, batch_size: usize, rng: impl RngCore) -> Result { + // TODO: Support arbitrary degree. + assert!(poly_size.is_power_of_two()); + assert!(batch_size > 0 && batch_size <= poly_size); + + let k = poly_size.ilog2() as usize; + let batch_k = (poly_size * batch_size).next_power_of_two().ilog2() as usize; + let row_k = div_ceil(batch_k, 2); + + let ipa = UnivariateIpa::setup(1 << row_k, 0, rng)?; + + Ok(Self::Param { + k, + batch_k, + row_k, + ipa, + }) + } + + fn trim( + param: &Self::Param, + poly_size: usize, + batch_size: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + assert!(poly_size.is_power_of_two()); + assert!(batch_size > 0 && batch_size <= poly_size); + + let k = poly_size.ilog2() as usize; + let batch_k = (poly_size * batch_size).next_power_of_two().ilog2() as usize; + let row_k = div_ceil(batch_k, 2); + if param.row_k() < row_k { + return Err(err_too_large_deree("trim", param.degree(), poly_size - 1)); + } + + let (ipa_pp, ipa_vp) = UnivariateIpa::trim(¶m.ipa, 1 << row_k, 0)?; + + let pp = Self::ProverParam { + k, + batch_k, + row_k, + ipa: ipa_pp, + }; + let vp = Self::VerifierParam { + k, + batch_k, + row_k, + ipa: ipa_vp, + }; + Ok((pp, vp)) + } + + fn commit(pp: &Self::ProverParam, poly: &Self::Polynomial) -> Result { + validate_input("commit", pp.degree(), [poly])?; + + let bases = match poly.basis() { + Monomial => pp.monomial(), + Lagrange => pp.lagrange(), + }; + + let row_len = pp.row_len(); + let scalars = poly.coeffs(); + let comm = { + let mut comm = vec![C::CurveExt::identity(); pp.num_chunks()]; + parallelize(&mut comm, |(comm, start)| { + for (comm, offset) in comm.iter_mut().zip((start * row_len..).step_by(row_len)) { + let row = &scalars[offset..(offset + row_len).min(scalars.len())]; + *comm = variable_base_msm(row, &bases[..row.len()]); + } + }); + batch_projective_to_affine(&comm) + }; + + Ok(UnivariateHyraxCommitment(comm)) + } + + fn batch_commit<'a>( + pp: &Self::ProverParam, + polys: impl IntoIterator, + ) -> Result, Error> { + let polys = polys.into_iter().collect_vec(); + if polys.is_empty() { + return Ok(Vec::new()); + } + validate_input("batch commit", pp.degree(), polys.iter().copied())?; + + let row_len = pp.row_len(); + let scalars = polys + .iter() + .flat_map(|poly| { + chain![poly.coeffs().chunks(row_len), iter::repeat([].as_slice())] + .take(pp.num_chunks()) + }) + .collect_vec(); + let comms = { + let mut comms = vec![C::CurveExt::identity(); scalars.len()]; + parallelize(&mut comms, |(comms, start)| { + for (comm, row) in comms.iter_mut().zip(&scalars[start..]) { + *comm = variable_base_msm(*row, &pp.monomial()[..row.len()]); + } + }); + batch_projective_to_affine(&comms) + }; + + Ok(comms + .into_iter() + .chunks(pp.num_chunks()) + .into_iter() + .map(|comm| UnivariateHyraxCommitment(comm.collect_vec())) + .collect_vec()) + } + + // TODO: Batch all MSMs into one + fn open( + pp: &Self::ProverParam, + poly: &Self::Polynomial, + comm: &Self::Commitment, + point: &Point, + eval: &C::Scalar, + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> { + assert_eq!(poly.basis(), Monomial); + + validate_input("open", pp.degree(), [poly])?; + + if cfg!(feature = "sanity-check") { + assert_eq!(comm.0.len(), pp.num_chunks()); + assert_eq!(Self::commit(pp, poly).unwrap().0, comm.0); + assert_eq!(poly.evaluate(point), *eval); + } + + let row_len = pp.row_len(); + let scalars = powers(squares(*point).nth(pp.row_k()).unwrap()) + .take(pp.num_chunks()) + .collect_vec(); + let poly = if pp.num_chunks() == 1 { + Cow::Borrowed(poly) + } else { + let mut coeffs = vec![C::Scalar::ZERO; row_len]; + if let Some(row) = poly.coeffs().chunks(row_len).next() { + coeffs[..row.len()].copy_from_slice(row); + } + izip!(&scalars, poly.coeffs().chunks(row_len)) + .skip(1) + .for_each(|(scalar, row)| { + parallelize(&mut coeffs, |(coeffs, start)| { + let scalar = *scalar; + izip!(coeffs, &row[start..]).for_each(|(lhs, rhs)| *lhs += scalar * rhs) + }); + }); + Cow::Owned(UnivariatePolynomial::monomial(coeffs)) + }; + let comm = if cfg!(feature = "sanity-check") { + UnivariateIpaCommitment(if pp.num_chunks() == 1 { + comm.0[0] + } else { + variable_base_msm(&scalars, &comm.0).into() + }) + } else { + UnivariateIpaCommitment::default() + }; + + UnivariateIpa::open(&pp.ipa, &poly, &comm, point, eval, transcript) + } + + fn batch_open<'a>( + pp: &Self::ProverParam, + polys: impl IntoIterator, + comms: impl IntoIterator, + points: &[Point], + evals: &[Evaluation], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> { + let polys = polys.into_iter().collect_vec(); + let comms = comms.into_iter().collect_vec(); + additive::batch_open::<_, Self>(pp, polys, comms, points, evals, transcript) + } + + fn read_commitments( + vp: &Self::VerifierParam, + num_polys: usize, + transcript: &mut impl TranscriptRead, + ) -> Result, Error> { + iter::repeat_with(|| { + transcript + .read_commitments(vp.num_chunks()) + .map(UnivariateHyraxCommitment) + }) + .take(num_polys) + .collect() + } + + fn verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &Point, + eval: &C::Scalar, + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> { + assert_eq!(comm.0.len(), vp.num_chunks()); + + let comm = { + UnivariateIpaCommitment(if vp.num_chunks() == 1 { + comm.0[0] + } else { + let scalars = powers(squares(*point).nth(vp.row_k()).unwrap()) + .take(vp.num_chunks()) + .collect_vec(); + variable_base_msm(&scalars, &comm.0).into() + }) + }; + + UnivariateIpa::verify(&vp.ipa, &comm, point, eval, transcript) + } + + fn batch_verify<'a>( + vp: &Self::VerifierParam, + comms: impl IntoIterator, + points: &[Point], + evals: &[Evaluation], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> { + let comms = comms.into_iter().collect_vec(); + additive::batch_verify::<_, Self>(vp, comms, points, evals, transcript) + } +} + +#[cfg(test)] +mod test { + use crate::{ + pcs::{ + test::{run_batch_commit_open_verify, run_commit_open_verify}, + univariate::hyrax::UnivariateHyrax, + }, + util::transcript::Keccak256Transcript, + }; + use halo2_curves::pasta::pallas::Affine; + + type Pcs = UnivariateHyrax; + + #[test] + fn commit_open_verify() { + run_commit_open_verify::<_, Pcs, Keccak256Transcript<_>>(); + } + + #[test] + fn batch_commit_open_verify() { + run_batch_commit_open_verify::<_, Pcs, Keccak256Transcript<_>>(); + } +} diff --git a/plonkish_backend/src/pcs/univariate/ipa.rs b/plonkish_backend/src/pcs/univariate/ipa.rs new file mode 100644 index 00000000..c810c12d --- /dev/null +++ b/plonkish_backend/src/pcs/univariate/ipa.rs @@ -0,0 +1,449 @@ +use crate::{ + pcs::{ + univariate::{additive, err_too_large_deree, monomial_g_to_lagrange_g, validate_input}, + Additive, Evaluation, Point, PolynomialCommitmentScheme, + }, + poly::{ + multilinear, + univariate::{UnivariateBasis::*, UnivariatePolynomial}, + }, + util::{ + arithmetic::{ + batch_projective_to_affine, inner_product, powers, squares, variable_base_msm, Curve, + CurveAffine, CurveExt, Field, Group, PrimeField, + }, + chain, izip, + parallel::parallelize, + transcript::{TranscriptRead, TranscriptWrite}, + Deserialize, DeserializeOwned, Either, Itertools, Serialize, + }, + Error, +}; +use halo2_curves::group::ff::BatchInvert; +use rand::RngCore; +use std::{borrow::Cow, iter, marker::PhantomData, slice}; + +#[derive(Clone, Debug)] +pub struct UnivariateIpa(PhantomData); + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnivariateIpaParam { + k: usize, + monomial: Vec, + lagrange: Vec, + h: C, +} + +impl UnivariateIpaParam { + pub fn k(&self) -> usize { + self.k + } + + pub fn degree(&self) -> usize { + self.monomial.len() - 1 + } + + pub fn monomial(&self) -> &[C] { + &self.monomial + } + + pub fn lagrange(&self) -> &[C] { + &self.lagrange + } + + pub fn h(&self) -> &C { + &self.h + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct UnivariateIpaVerifierParam { + k: usize, + monomial: Vec, + h: C, +} + +impl UnivariateIpaVerifierParam { + pub fn k(&self) -> usize { + self.k + } + + pub fn h(&self) -> &C { + &self.h + } + + pub fn monomial(&self) -> &[C] { + &self.monomial + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct UnivariateIpaCommitment(pub C); + +impl Default for UnivariateIpaCommitment { + fn default() -> Self { + Self(C::identity()) + } +} + +impl AsRef<[C]> for UnivariateIpaCommitment { + fn as_ref(&self) -> &[C] { + slice::from_ref(&self.0) + } +} + +impl AsRef for UnivariateIpaCommitment { + fn as_ref(&self) -> &C { + &self.0 + } +} + +impl From for UnivariateIpaCommitment { + fn from(comm: C) -> Self { + Self(comm) + } +} + +impl Additive for UnivariateIpaCommitment { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, + ) -> Self { + let scalars = scalars.into_iter().collect_vec(); + let bases = bases.into_iter().map(|base| &base.0).collect_vec(); + UnivariateIpaCommitment(variable_base_msm(scalars, bases).to_affine()) + } +} + +impl PolynomialCommitmentScheme for UnivariateIpa +where + C: CurveAffine + Serialize + DeserializeOwned, + C::ScalarExt: Serialize + DeserializeOwned, +{ + type Param = UnivariateIpaParam; + type ProverParam = UnivariateIpaParam; + type VerifierParam = UnivariateIpaVerifierParam; + type Polynomial = UnivariatePolynomial; + type Commitment = UnivariateIpaCommitment; + type CommitmentChunk = C; + + fn setup(poly_size: usize, _: usize, _: impl RngCore) -> Result { + // TODO: Support arbitrary degree. + assert!(poly_size.is_power_of_two()); + assert!(poly_size.ilog2() <= C::Scalar::S); + + let k = poly_size.ilog2() as usize; + + let monomial = { + let mut g = vec![C::Curve::identity(); poly_size]; + parallelize(&mut g, |(g, start)| { + let hasher = C::CurveExt::hash_to_curve("UnivariateIpa::setup"); + for (g, idx) in g.iter_mut().zip(start as u32..) { + let mut message = [0u8; 5]; + message[1..5].copy_from_slice(&idx.to_le_bytes()); + *g = hasher(&message); + } + }); + batch_projective_to_affine(&g) + }; + + let lagrange = monomial_g_to_lagrange_g(&monomial); + + let hasher = C::CurveExt::hash_to_curve("UnivariateIpa::setup"); + let h = hasher(&[1]).to_affine(); + + Ok(Self::Param { + k, + monomial, + lagrange, + h, + }) + } + + fn trim( + param: &Self::Param, + poly_size: usize, + _: usize, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + assert!(poly_size.is_power_of_two()); + + let k = poly_size.ilog2() as usize; + + if param.monomial.len() < poly_size { + return Err(err_too_large_deree("trim", param.degree(), poly_size - 1)); + } + + let monomial = param.monomial[..poly_size].to_vec(); + let lagrange = if param.lagrange.len() == poly_size { + param.lagrange.clone() + } else { + monomial_g_to_lagrange_g(&monomial) + }; + + let pp = Self::ProverParam { + k, + monomial: monomial.clone(), + lagrange, + h: param.h, + }; + let vp = Self::VerifierParam { + k, + monomial, + h: param.h, + }; + Ok((pp, vp)) + } + + fn commit(pp: &Self::ProverParam, poly: &Self::Polynomial) -> Result { + validate_input("commit", pp.degree(), [poly])?; + + let coeffs = poly.coeffs(); + let bases = match poly.basis() { + Monomial => pp.monomial(), + Lagrange => pp.lagrange(), + }; + Ok(variable_base_msm(coeffs, &bases[..coeffs.len()]).into()).map(UnivariateIpaCommitment) + } + + fn batch_commit<'a>( + pp: &Self::ProverParam, + polys: impl IntoIterator, + ) -> Result, Error> { + polys + .into_iter() + .map(|poly| Self::commit(pp, poly)) + .collect() + } + + fn open( + pp: &Self::ProverParam, + poly: &Self::Polynomial, + comm: &Self::Commitment, + point: &Point, + eval: &C::Scalar, + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> { + assert_eq!(poly.basis(), Monomial); + + validate_input("open", pp.degree(), [poly])?; + + if cfg!(feature = "sanity-check") { + assert_eq!(Self::commit(pp, poly).unwrap().0, comm.0); + assert_eq!(poly.evaluate(point), *eval); + } + + let bases = pp.monomial(); + let coeffs = chain![poly.coeffs().iter().cloned(), iter::repeat(C::Scalar::ZERO)] + .take(bases.len()) + .collect_vec(); + let zs = powers(*point).take(bases.len()).collect_vec(); + prove_bulletproof_reduction(bases, pp.h(), coeffs, zs, transcript) + } + + fn batch_open<'a>( + pp: &Self::ProverParam, + polys: impl IntoIterator, + comms: impl IntoIterator, + points: &[Point], + evals: &[Evaluation], + transcript: &mut impl TranscriptWrite, + ) -> Result<(), Error> { + let polys = polys.into_iter().collect_vec(); + let comms = comms.into_iter().collect_vec(); + additive::batch_open::<_, Self>(pp, polys, comms, points, evals, transcript) + } + + fn read_commitments( + _: &Self::VerifierParam, + num_polys: usize, + transcript: &mut impl TranscriptRead, + ) -> Result, Error> { + let comms = transcript.read_commitments(num_polys)?; + Ok(comms.into_iter().map(UnivariateIpaCommitment).collect()) + } + + fn verify( + vp: &Self::VerifierParam, + comm: &Self::Commitment, + point: &Point, + eval: &C::Scalar, + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> { + let bases = vp.monomial(); + let point = Either::Left(point); + verify_bulletproof_reduction(bases, vp.h(), comm, point, eval, transcript) + } + + fn batch_verify<'a>( + vp: &Self::VerifierParam, + comms: impl IntoIterator, + points: &[Point], + evals: &[Evaluation], + transcript: &mut impl TranscriptRead, + ) -> Result<(), Error> { + let comms = comms.into_iter().collect_vec(); + additive::batch_verify::<_, Self>(vp, comms, points, evals, transcript) + } +} + +pub(crate) fn prove_bulletproof_reduction<'a, C: CurveAffine>( + bases: impl Into>, + h: &C, + coeffs: impl Into>, + zs: impl Into>, + transcript: &mut impl TranscriptWrite, +) -> Result<(), Error> { + let mut bases = bases.into().into_owned(); + let mut coeffs = coeffs.into().into_owned(); + let mut zs = zs.into().into_owned(); + + assert_eq!(bases.len(), coeffs.len()); + assert_eq!(bases.len(), zs.len()); + assert!(bases.len().is_power_of_two()); + + let xi_0 = transcript.squeeze_challenge(); + let h_prime = (*h * xi_0).to_affine(); + + let k = bases.len().ilog2() as usize; + for i in 0..k { + let mid = 1 << (k - i - 1); + + let (bases_l, bases_r) = bases.split_at(mid); + let (coeffs_l, coeffs_r) = coeffs.split_at(mid); + let (zs_l, zs_r) = zs.split_at(mid); + let (c_l, c_r) = (inner_product(coeffs_r, zs_l), inner_product(coeffs_l, zs_r)); + let l_i = variable_base_msm(chain![coeffs_r, [&c_l]], chain![bases_l, [&h_prime]]); + let r_i = variable_base_msm(chain![coeffs_l, [&c_r]], chain![bases_r, [&h_prime]]); + transcript.write_commitment(&l_i.to_affine())?; + transcript.write_commitment(&r_i.to_affine())?; + + let xi_i = transcript.squeeze_challenge(); + let xi_i_inv = xi_i.invert().unwrap(); + + let (bases_l, bases_r) = bases.split_at_mut(mid); + let (coeffs_l, coeffs_r) = coeffs.split_at_mut(mid); + let (zs_l, zs_r) = zs.split_at_mut(mid); + parallelize(bases_l, |(bases_l, start)| { + let mut tmp = Vec::with_capacity(bases_l.len()); + for (lhs, rhs) in bases_l.iter().zip(bases_r[start..].iter()) { + tmp.push(lhs.to_curve() + *rhs * xi_i); + } + C::Curve::batch_normalize(&tmp, bases_l); + }); + parallelize(coeffs_l, |(coeffs_l, start)| { + for (lhs, rhs) in coeffs_l.iter_mut().zip(coeffs_r[start..].iter()) { + *lhs += xi_i_inv * rhs; + } + }); + parallelize(zs_l, |(zs_l, start)| { + for (lhs, rhs) in zs_l.iter_mut().zip(zs_r[start..].iter()) { + *lhs += xi_i * rhs; + } + }); + bases.truncate(mid); + coeffs.truncate(mid); + zs.truncate(mid); + } + + transcript.write_field_element(&coeffs[0])?; + + Ok(()) +} + +pub(crate) fn verify_bulletproof_reduction( + bases: &[C], + h: &C, + comm: impl AsRef, + point: Either<&C::Scalar, &[C::Scalar]>, + eval: &C::Scalar, + transcript: &mut impl TranscriptRead, +) -> Result<(), Error> { + assert!(bases.len().is_power_of_two()); + if let Either::Right(point) = point { + assert_eq!(1 << point.len(), bases.len()); + } + + let k = bases.len().ilog2() as usize; + + let xi_0 = transcript.squeeze_challenge(); + + let (ls, rs, xis) = iter::repeat_with(|| { + Ok(( + transcript.read_commitment()?, + transcript.read_commitment()?, + transcript.squeeze_challenge(), + )) + }) + .take(k) + .collect::, _>>()? + .into_iter() + .multiunzip::<(Vec<_>, Vec<_>, Vec<_>)>(); + let neg_c = -transcript.read_field_element()?; + + let xi_invs = { + let mut xi_invs = xis.clone(); + xi_invs.batch_invert(); + xi_invs + }; + let neg_c_h = h_coeffs(neg_c, &xis); + let (kind, neg_c_h_eval) = match point { + Either::Left(point) => ("univariate", h_eval(neg_c, &xis, point)), + Either::Right(point) => ("multivariate", multilinear::evaluate(&neg_c_h, point)), + }; + let u = xi_0 * (neg_c_h_eval + eval); + let scalars = chain![&xi_invs, &xis, &neg_c_h, [&u]]; + let bases = chain![&ls, &rs, bases, [h]]; + bool::from((variable_base_msm(scalars, bases) + comm.as_ref()).is_identity()) + .then_some(()) + .ok_or_else(|| Error::InvalidPcsOpen(format!("Invalid {kind} IPA open"))) +} + +pub(crate) fn h_coeffs(init: F, xi: &[F]) -> Vec { + assert!(!xi.is_empty()); + + let mut coeffs = vec![F::ZERO; 1 << xi.len()]; + coeffs[0] = init; + + for (len, xi) in xi.iter().rev().enumerate().map(|(i, xi)| (1 << i, xi)) { + let (left, right) = coeffs.split_at_mut(len); + let right = &mut right[0..len]; + right.copy_from_slice(left); + parallelize(right, |(right, _)| { + for coeff in right { + *coeff *= xi; + } + }); + } + + coeffs +} + +fn h_eval(init: F, xis: &[F], x: &F) -> F { + izip!(squares(*x), xis.iter().rev()) + .map(|(square_of_x, xi)| F::ONE + square_of_x * xi) + .fold(init, |acc, item| acc * item) +} + +#[cfg(test)] +mod test { + use crate::{ + pcs::{ + test::{run_batch_commit_open_verify, run_commit_open_verify}, + univariate::ipa::UnivariateIpa, + }, + util::transcript::Keccak256Transcript, + }; + use halo2_curves::pasta::pallas::Affine; + + type Pcs = UnivariateIpa; + + #[test] + fn commit_open_verify() { + run_commit_open_verify::<_, Pcs, Keccak256Transcript<_>>(); + } + + #[test] + fn batch_commit_open_verify() { + run_batch_commit_open_verify::<_, Pcs, Keccak256Transcript<_>>(); + } +} diff --git a/plonkish_backend/src/pcs/univariate/kzg.rs b/plonkish_backend/src/pcs/univariate/kzg.rs index ab1e5a8f..646cca43 100644 --- a/plonkish_backend/src/pcs/univariate/kzg.rs +++ b/plonkish_backend/src/pcs/univariate/kzg.rs @@ -1,27 +1,22 @@ use crate::{ pcs::{ - univariate::monomial_g1_to_lagrange_g1, AdditiveCommitment, Evaluation, Point, - PolynomialCommitmentScheme, - }, - poly::{ - univariate::{UnivariateBasis::*, UnivariatePolynomial}, - Polynomial, + univariate::{additive, err_too_large_deree, monomial_g_to_lagrange_g, validate_input}, + Additive, Evaluation, Point, PolynomialCommitmentScheme, }, + poly::univariate::{UnivariateBasis::*, UnivariatePolynomial}, util::{ arithmetic::{ - barycentric_interpolate, barycentric_weights, batch_projective_to_affine, fft, - fixed_base_msm, inner_product, powers, root_of_unity_inv, variable_base_msm, - window_size, window_table, Curve, CurveAffine, Field, MultiMillerLoop, - PrimeCurveAffine, PrimeField, + batch_projective_to_affine, fixed_base_msm, powers, radix2_fft, root_of_unity_inv, + variable_base_msm, window_size, window_table, Curve, CurveAffine, Field, + MultiMillerLoop, PrimeCurveAffine, PrimeField, }, - chain, izip, izip_eq, transcript::{TranscriptRead, TranscriptWrite}, Deserialize, DeserializeOwned, Itertools, Serialize, }, Error, }; use rand::RngCore; -use std::{collections::BTreeSet, marker::PhantomData, ops::Neg, slice}; +use std::{marker::PhantomData, ops::Neg, slice}; #[derive(Clone, Debug)] pub struct UnivariateKzg(PhantomData); @@ -50,12 +45,17 @@ impl UnivariateKzg { deserialize = "M::G1Affine: DeserializeOwned, M::G2Affine: DeserializeOwned", ))] pub struct UnivariateKzgParam { + k: usize, monomial_g1: Vec, lagrange_g1: Vec, powers_of_s_g2: Vec, } impl UnivariateKzgParam { + pub fn k(&self) -> usize { + self.k + } + pub fn degree(&self) -> usize { self.monomial_g1.len() - 1 } @@ -83,18 +83,28 @@ impl UnivariateKzgParam { deserialize = "M::G1Affine: DeserializeOwned", ))] pub struct UnivariateKzgProverParam { + k: usize, monomial_g1: Vec, lagrange_g1: Vec, } impl UnivariateKzgProverParam { - pub(crate) fn new(monomial_g1: Vec, lagrange_g1: Vec) -> Self { + pub(crate) fn new( + k: usize, + monomial_g1: Vec, + lagrange_g1: Vec, + ) -> Self { Self { + k, monomial_g1, lagrange_g1, } } + pub fn k(&self) -> usize { + self.k + } + pub fn degree(&self) -> usize { self.monomial_g1.len() - 1 } @@ -168,15 +178,13 @@ impl From for UnivariateKzgCommitment { } } -impl AdditiveCommitment for UnivariateKzgCommitment { - fn sum_with_scalar<'a>( - scalars: impl IntoIterator + 'a, - bases: impl IntoIterator + 'a, +impl Additive for UnivariateKzgCommitment { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, ) -> Self { let scalars = scalars.into_iter().collect_vec(); let bases = bases.into_iter().map(|base| &base.0).collect_vec(); - assert_eq!(scalars.len(), bases.len()); - UnivariateKzgCommitment(variable_base_msm(scalars, bases).to_affine()) } } @@ -213,7 +221,7 @@ where let k = poly_size.ilog2() as usize; let n_inv = M::Scalar::TWO_INV.pow_vartime([k as u64]); let mut lagrange = monomial; - fft(&mut lagrange, root_of_unity_inv(k), k); + radix2_fft(&mut lagrange, root_of_unity_inv(k), k); lagrange.iter_mut().for_each(|v| *v *= n_inv); batch_projective_to_affine(&fixed_base_msm(window_size, &window_table, &lagrange)) }; @@ -229,6 +237,7 @@ where }; Ok(Self::Param { + k: poly_size.ilog2() as usize, monomial_g1, lagrange_g1, powers_of_s_g2, @@ -243,19 +252,17 @@ where assert!(poly_size.is_power_of_two()); if param.monomial_g1.len() < poly_size { - return Err(Error::InvalidPcsParam(format!( - "Too large poly_size to trim to (param supports poly_size up to {} but got {poly_size})", - param.monomial_g1.len(), - ))); + return Err(err_too_large_deree("trim", param.degree(), poly_size - 1)); } let monomial_g1 = param.monomial_g1[..poly_size].to_vec(); let lagrange_g1 = if param.lagrange_g1.len() == poly_size { param.lagrange_g1.clone() } else { - monomial_g1_to_lagrange_g1(&monomial_g1) + monomial_g_to_lagrange_g(&monomial_g1) }; - let pp = Self::ProverParam::new(monomial_g1, lagrange_g1); + + let pp = Self::ProverParam::new(poly_size.ilog2() as usize, monomial_g1, lagrange_g1); let vp = Self::VerifierParam { g1: param.g1(), g2: param.g2(), @@ -265,27 +272,11 @@ where } fn commit(pp: &Self::ProverParam, poly: &Self::Polynomial) -> Result { + validate_input("commit", pp.degree(), [poly])?; + match poly.basis() { - Monomial => { - if pp.degree() < poly.degree() { - return Err(Error::InvalidPcsParam(format!( - "Too large degree of poly to commit (param supports degree up to {} but got {})", - pp.degree(), - poly.degree() - ))); - } - Ok(Self::commit_monomial(pp, poly.coeffs())) - } - Lagrange => { - if pp.lagrange_g1().len() != poly.coeffs().len() { - return Err(Error::InvalidPcsParam(format!( - "Invalid number of poly evaluations to commit (param needs {} evaluations but got {})", - pp.lagrange_g1().len(), - poly.coeffs().len() - ))); - } - Ok(Self::commit_lagrange(pp, poly.coeffs())) - } + Monomial => Ok(Self::commit_monomial(pp, poly.coeffs())), + Lagrange => Ok(Self::commit_lagrange(pp, poly.coeffs())), } } @@ -307,20 +298,16 @@ where eval: &M::Scalar, transcript: &mut impl TranscriptWrite, ) -> Result<(), Error> { - if pp.degree() < poly.degree() { - return Err(Error::InvalidPcsParam(format!( - "Too large degree of poly to open (param supports degree up to {} but got {})", - pp.degree(), - poly.degree() - ))); - } + assert_eq!(poly.basis(), Monomial); + + validate_input("open", pp.degree(), [poly])?; if cfg!(feature = "sanity-check") { assert_eq!(Self::commit(pp, poly).unwrap().0, comm.0); assert_eq!(poly.evaluate(point), *eval); } - let divisor = Self::Polynomial::new(Monomial, vec![point.neg(), M::Scalar::ONE]); + let divisor = Self::Polynomial::monomial(vec![point.neg(), M::Scalar::ONE]); let (quotient, remainder) = poly.div_rem(&divisor); if cfg!(feature = "sanity-check") { @@ -345,50 +332,9 @@ where transcript: &mut impl TranscriptWrite, ) -> Result<(), Error> { let polys = polys.into_iter().collect_vec(); - let (sets, superset) = eval_sets(evals); - - let beta = transcript.squeeze_challenge(); - let gamma = transcript.squeeze_challenge(); - - let max_set_len = sets.iter().map(|set| set.polys.len()).max().unwrap(); - let powers_of_beta = powers(beta).take(max_set_len).collect_vec(); - let powers_of_gamma = powers(gamma).take(sets.len()).collect_vec(); - let (fs, (qs, rs)) = sets - .iter() - .map(|set| { - let vanishing_poly = set.vanishing_poly(points); - let f = izip!(&powers_of_beta, set.polys.iter().map(|poly| polys[*poly])) - .sum::>(); - let (q, r) = f.div_rem(&vanishing_poly); - (f, (q, r)) - }) - .unzip::<_, _, Vec<_>, (Vec<_>, Vec<_>)>(); - let q = izip_eq!(&powers_of_gamma, qs.iter()).sum::>(); - - let q_comm = Self::commit_and_write(pp, &q, transcript)?; - - let z = transcript.squeeze_challenge(); - - let (normalized_scalars, normalizer) = set_scalars(&sets, &powers_of_gamma, points, &z); - let superset_eval = vanishing_eval(superset.iter().map(|idx| &points[*idx]), &z); - let q_scalar = -superset_eval * normalizer; - let f = { - let mut f = izip_eq!(&normalized_scalars, &fs).sum::>(); - f += (&q_scalar, &q); - f - }; - let (comm, eval) = if cfg!(feature = "sanity-check") { - let comms = comms.into_iter().map(|comm| &comm.0).collect_vec(); - let scalars = comm_scalars(comms.len(), &sets, &powers_of_beta, &normalized_scalars); - let comm = UnivariateKzgCommitment( - variable_base_msm(chain![&scalars, [&q_scalar]], chain![comms, [&q_comm.0]]).into(), - ); - let r_evals = rs.iter().map(|r| r.evaluate(&z)).collect_vec(); - (comm, inner_product(&normalized_scalars, &r_evals)) - } else { - (UnivariateKzgCommitment::default(), M::Scalar::ZERO) - }; - Self::open(pp, &f, &comm, &z, &eval, transcript) + let comms = comms.into_iter().collect_vec(); + validate_input("batch open", pp.degree(), polys.clone())?; + additive::batch_open::<_, Self>(pp, polys, comms, points, evals, transcript) } fn read_commitments( @@ -396,9 +342,8 @@ where num_polys: usize, transcript: &mut impl TranscriptRead, ) -> Result, Error> { - transcript - .read_commitments(num_polys) - .map(|comms| comms.into_iter().map(UnivariateKzgCommitment).collect_vec()) + let comms = transcript.read_commitments(num_polys)?; + Ok(comms.into_iter().map(UnivariateKzgCommitment).collect()) } fn verify( @@ -423,291 +368,30 @@ where transcript: &mut impl TranscriptRead, ) -> Result<(), Error> { let comms = comms.into_iter().collect_vec(); - let (sets, superset) = eval_sets(evals); - - let beta = transcript.squeeze_challenge(); - let gamma = transcript.squeeze_challenge(); - - let q_comm = transcript.read_commitment()?; - - let z = transcript.squeeze_challenge(); - - let max_set_len = sets.iter().map(|set| set.polys.len()).max().unwrap(); - let powers_of_beta = powers(beta).take(max_set_len).collect_vec(); - let powers_of_gamma = powers(gamma).take(sets.len()).collect_vec(); - - let (normalized_scalars, normalizer) = set_scalars(&sets, &powers_of_gamma, points, &z); - let f = { - let comms = comms.iter().map(|comm| &comm.0).collect_vec(); - let scalars = comm_scalars(comms.len(), &sets, &powers_of_beta, &normalized_scalars); - let superset_eval = vanishing_eval(superset.iter().map(|idx| &points[*idx]), &z); - let q_scalar = -superset_eval * normalizer; - UnivariateKzgCommitment( - variable_base_msm(chain![&scalars, [&q_scalar]], chain![comms, [&q_comm]]).into(), - ) - }; - let eval = inner_product( - &normalized_scalars, - &sets - .iter() - .map(|set| set.r_eval(points, &z, &powers_of_beta)) - .collect_vec(), - ); - Self::verify(vp, &f, &z, &eval, transcript) + additive::batch_verify::<_, Self>(vp, comms, points, evals, transcript) } } -#[derive(Debug)] -struct EvaluationSet { - polys: Vec, - points: Vec, - diffs: Vec, - evals: Vec>, -} - -impl EvaluationSet { - fn vanishing_diff_eval(&self, points: &[F], z: &F) -> F { - self.diffs - .iter() - .map(|idx| points[*idx]) - .fold(F::ONE, |eval, point| eval * (*z - point)) - } - - fn vanishing_poly(&self, points: &[F]) -> UnivariatePolynomial { - UnivariatePolynomial::vanishing(self.points.iter().map(|point| &points[*point]), F::ONE) - } - - fn r_eval(&self, points: &[F], z: &F, powers_of_beta: &[F]) -> F { - let points = self.points.iter().map(|idx| points[*idx]).collect_vec(); - let weights = barycentric_weights(&points); - let r_evals = self - .evals - .iter() - .map(|evals| barycentric_interpolate(&weights, &points, evals, z)) - .collect_vec(); - inner_product(&powers_of_beta[..r_evals.len()], &r_evals) - } -} - -fn eval_sets(evals: &[Evaluation]) -> (Vec>, BTreeSet) { - let (poly_shifts, superset) = evals.iter().fold( - (Vec::<(usize, Vec, Vec)>::new(), BTreeSet::new()), - |(mut poly_shifts, mut superset), eval| { - if let Some(pos) = poly_shifts - .iter() - .position(|(poly, _, _)| *poly == eval.poly) - { - let (_, points, evals) = &mut poly_shifts[pos]; - if !points.contains(&eval.point) { - points.push(eval.point); - evals.push(*eval.value()); - } - } else { - poly_shifts.push((eval.poly, vec![eval.point], vec![*eval.value()])); - } - superset.insert(eval.point()); - (poly_shifts, superset) - }, - ); - - let sets = poly_shifts.into_iter().fold( - Vec::>::new(), - |mut sets, (poly, points, evals)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.points.iter()) == BTreeSet::from_iter(points.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - set.evals.push( - set.points - .iter() - .map(|lhs| { - let idx = points.iter().position(|rhs| lhs == rhs).unwrap(); - evals[idx] - }) - .collect(), - ); - } - } else { - let diffs = superset - .iter() - .filter(|idx| !points.contains(idx)) - .copied() - .collect(); - sets.push(EvaluationSet { - polys: vec![poly], - points, - diffs, - evals: vec![evals], - }); - } - sets - }, - ); - - (sets, superset) -} - -fn set_scalars( - sets: &[EvaluationSet], - powers_of_gamma: &[F], - points: &[F], - z: &F, -) -> (Vec, F) { - let vanishing_diff_evals = sets - .iter() - .map(|set| set.vanishing_diff_eval(points, z)) - .collect_vec(); - // Adopt fflonk's trick to normalize the set scalars by the one of first set, - // to save 1 EC scalar multiplication for verifier. - let normalizer = vanishing_diff_evals[0].invert().unwrap_or(F::ONE); - let normalized_scalars = izip_eq!(powers_of_gamma, &vanishing_diff_evals) - .map(|(power_of_gamma, vanishing_diff_eval)| { - normalizer * vanishing_diff_eval * power_of_gamma - }) - .collect_vec(); - (normalized_scalars, normalizer) -} - -fn vanishing_eval<'a, F: Field>(points: impl IntoIterator, z: &F) -> F { - points - .into_iter() - .fold(F::ONE, |eval, point| eval * (*z - point)) -} - -fn comm_scalars( - num_polys: usize, - sets: &[EvaluationSet], - powers_of_beta: &[F], - normalized_scalars: &[F], -) -> Vec { - sets.iter().zip(normalized_scalars).fold( - vec![F::ZERO; num_polys], - |mut scalars, (set, coeff)| { - izip!(&set.polys, powers_of_beta) - .for_each(|(poly, power_of_beta)| scalars[*poly] = *coeff * power_of_beta); - scalars - }, - ) -} - #[cfg(test)] mod test { use crate::{ - pcs::{univariate::kzg::UnivariateKzg, Evaluation, PolynomialCommitmentScheme}, - poly::{univariate::UnivariatePolynomial, Polynomial}, - util::{ - chain, - transcript::{ - FieldTranscript, FieldTranscriptRead, FieldTranscriptWrite, InMemoryTranscript, - Keccak256Transcript, - }, - Itertools, + pcs::{ + test::{run_batch_commit_open_verify, run_commit_open_verify}, + univariate::kzg::UnivariateKzg, }, + util::transcript::Keccak256Transcript, }; use halo2_curves::bn256::Bn256; - use rand::{rngs::OsRng, Rng}; - use std::iter; type Pcs = UnivariateKzg; #[test] fn commit_open_verify() { - for k in 3..16 { - // Setup - let (pp, vp) = { - let mut rng = OsRng; - let poly_size = 1 << k; - let param = Pcs::setup(poly_size, 1, &mut rng).unwrap(); - Pcs::trim(¶m, poly_size, 1).unwrap() - }; - // Commit and open - let proof = { - let mut transcript = Keccak256Transcript::default(); - let poly = UnivariatePolynomial::rand(pp.degree(), OsRng); - let comm = Pcs::commit_and_write(&pp, &poly, &mut transcript).unwrap(); - let point = transcript.squeeze_challenge(); - let eval = poly.evaluate(&point); - transcript.write_field_element(&eval).unwrap(); - Pcs::open(&pp, &poly, &comm, &point, &eval, &mut transcript).unwrap(); - transcript.into_proof() - }; - // Verify - let result = { - let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); - Pcs::verify( - &vp, - &Pcs::read_commitment(&vp, &mut transcript).unwrap(), - &transcript.squeeze_challenge(), - &transcript.read_field_element().unwrap(), - &mut transcript, - ) - }; - assert_eq!(result, Ok(())); - } + run_commit_open_verify::<_, Pcs, Keccak256Transcript<_>>(); } #[test] fn batch_commit_open_verify() { - for k in 3..16 { - let batch_size = 8; - let num_points = batch_size >> 1; - let mut rng = OsRng; - // Setup - let (pp, vp) = { - let poly_size = 1 << k; - let param = Pcs::setup(poly_size, batch_size, &mut rng).unwrap(); - Pcs::trim(¶m, poly_size, batch_size).unwrap() - }; - // Batch commit and open - let evals = chain![ - (0..num_points).map(|point| (0, point)), - (1..batch_size).map(|poly| (poly, 0)), - iter::repeat_with(|| (rng.gen_range(0..batch_size), rng.gen_range(0..num_points))) - .take(batch_size) - ] - .unique() - .collect_vec(); - let proof = { - let mut transcript = Keccak256Transcript::default(); - let polys = iter::repeat_with(|| Polynomial::rand(pp.degree(), OsRng)) - .take(batch_size) - .collect_vec(); - let comms = Pcs::batch_commit_and_write(&pp, &polys, &mut transcript).unwrap(); - let points = transcript.squeeze_challenges(num_points); - let evals = evals - .iter() - .copied() - .map(|(poly, point)| Evaluation { - poly, - point, - value: polys[poly].evaluate(&points[point]), - }) - .collect_vec(); - transcript - .write_field_elements(evals.iter().map(Evaluation::value)) - .unwrap(); - Pcs::batch_open(&pp, &polys, &comms, &points, &evals, &mut transcript).unwrap(); - transcript.into_proof() - }; - // Batch verify - let result = { - let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); - Pcs::batch_verify( - &vp, - &Pcs::read_commitments(&vp, batch_size, &mut transcript).unwrap(), - &transcript.squeeze_challenges(num_points), - &evals - .iter() - .copied() - .zip(transcript.read_field_elements(evals.len()).unwrap()) - .map(|((poly, point), eval)| Evaluation::new(poly, point, eval)) - .collect_vec(), - &mut transcript, - ) - }; - assert_eq!(result, Ok(())); - } + run_batch_commit_open_verify::<_, Pcs, Keccak256Transcript<_>>(); } } diff --git a/plonkish_backend/src/piop.rs b/plonkish_backend/src/piop.rs index 5739e324..7549d81f 100644 --- a/plonkish_backend/src/piop.rs +++ b/plonkish_backend/src/piop.rs @@ -1,2 +1,3 @@ pub mod gkr; +pub mod multilinear_eval; pub mod sum_check; diff --git a/plonkish_backend/src/piop/gkr/fractional_sum_check.rs b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs index 3efc8cd3..f7d4fa0f 100644 --- a/plonkish_backend/src/piop/gkr/fractional_sum_check.rs +++ b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs @@ -20,7 +20,7 @@ use crate::{ }, Error, }; -use std::{array, collections::HashMap, iter}; +use std::{array, iter}; type SumCheck = ClassicSumCheck>; @@ -143,10 +143,10 @@ pub fn prove_fractional_sum_check<'a, F: PrimeField>( let expression = sum_check_expression(num_batching); - let (p_xs, q_xs, x) = layers.iter().rev().fold( - Ok((claimed_p_0s, claimed_q_0s, Vec::new())), + let (p_xs, q_xs, x) = layers.iter().rev().try_fold( + (claimed_p_0s, claimed_q_0s, Vec::new()), |result, layers| { - let (claimed_p_ys, claimed_q_ys, y) = result?; + let (claimed_p_ys, claimed_q_ys, y) = result; let num_vars = layers[0].num_vars(); let polys = layers.iter().flat_map(|layer| layer.polys()); @@ -167,7 +167,7 @@ pub fn prove_fractional_sum_check<'a, F: PrimeField>( )? }; - (x, evals) + (x, evals.into_values().collect_vec()) }; transcript.write_field_elements(&evals)?; @@ -222,10 +222,10 @@ pub fn verify_fractional_sum_check( let expression = sum_check_expression(num_batching); - let (p_xs, q_xs, x) = (0..num_vars).fold( - Ok((claimed_p_0s, claimed_q_0s, Vec::new())), + let (p_xs, q_xs, x) = (0..num_vars).try_fold( + (claimed_p_0s, claimed_q_0s, Vec::new()), |result, num_vars| { - let (claimed_p_ys, claimed_q_ys, y) = result?; + let (claimed_p_ys, claimed_q_ys, y) = result; let (mut x, evals) = if num_vars == 0 { let evals = transcript.read_field_elements(4 * num_batching)?; @@ -249,8 +249,12 @@ pub fn verify_fractional_sum_check( let evals = transcript.read_field_elements(4 * num_batching)?; - let eval_by_query = eval_by_query(&evals); - if x_eval != evaluate(&expression, num_vars, &eval_by_query, &[gamma], &[&y], &x) { + let query_eval = { + let queries = (0..).map(|idx| Query::new(idx, Rotation::cur())); + let evals = izip!(queries, evals.iter().cloned()).collect(); + evaluate::<_, usize>(&expression, num_vars, &evals, &[gamma], &[&y], &x) + }; + if x_eval != query_eval { return Err(err_unmatched_sum_check_output()); } @@ -295,14 +299,6 @@ fn layer_down_claim(evals: &[F], mu: F) -> (Vec, Vec) { .unzip() } -fn eval_by_query(evals: &[F]) -> HashMap { - izip!( - (0..).map(|idx| Query::new(idx, Rotation::cur())), - evals.iter().cloned() - ) - .collect() -} - fn err_unmatched_sum_check_output() -> Error { Error::InvalidSumcheck("Unmatched between sum_check output and query evaluation".to_string()) } diff --git a/plonkish_backend/src/piop/multilinear_eval.rs b/plonkish_backend/src/piop/multilinear_eval.rs new file mode 100644 index 00000000..335755d0 --- /dev/null +++ b/plonkish_backend/src/piop/multilinear_eval.rs @@ -0,0 +1 @@ +pub mod ph23; diff --git a/plonkish_backend/src/piop/multilinear_eval/ph23.rs b/plonkish_backend/src/piop/multilinear_eval/ph23.rs new file mode 100644 index 00000000..d4fe461c --- /dev/null +++ b/plonkish_backend/src/piop/multilinear_eval/ph23.rs @@ -0,0 +1,66 @@ +//! Implementation of section 5.1 of [PH23]. +//! +//! [PH23]: https://eprint.iacr.org/2023/1284.pdf + +use crate::util::{ + arithmetic::{powers, squares, BatchInvert, WithSmallOrderMulGroup}, + expression::{evaluator::quotient::Radix2Domain, Query}, + izip, + parallel::parallelize, + Itertools, +}; + +pub mod additive; + +pub fn s_polys>(num_vars: usize) -> Vec> { + let domain = Radix2Domain::::new(num_vars, 2); + let vanishing = { + let coset_scalar = match domain.n() % 3 { + 1 => domain.zeta(), + 2 => domain.zeta_inv(), + _ => unreachable!(), + }; + powers(domain.extended_omega().pow([domain.n() as u64])) + .map(|value| coset_scalar * value - F::ONE) + .take(1 << (domain.extended_k() - domain.k())) + .collect_vec() + }; + let omegas = powers(domain.extended_omega()) + .take(domain.extended_n()) + .collect_vec(); + let mut s_polys = vec![vec![F::ZERO; domain.extended_n()]; domain.k()]; + parallelize(&mut s_polys, |(s_polys, start)| { + izip!(s_polys, start..).for_each(|(s_polys, idx)| { + let exponent = 1 << idx; + let offset = match exponent % 3 { + 1 => domain.zeta(), + 2 => domain.zeta_inv(), + _ => unreachable!(), + }; + izip!((0..).step_by(exponent), s_polys.iter_mut()).for_each(|(idx, value)| { + *value = offset * omegas[idx % domain.extended_n()] - F::ONE + }); + s_polys.batch_invert(); + izip!(s_polys.iter_mut(), vanishing.iter().cycle()) + .for_each(|(denom, numer)| *denom *= numer); + }) + }); + s_polys +} + +fn s_evals>( + domain: &Radix2Domain, + poly: usize, + x: F, +) -> Vec<(Query, F)> { + let iter = &mut squares(x).map(|square_of_x| square_of_x - F::ONE); + let mut s_denom_evals = iter.take(domain.k()).collect_vec(); + let vanishing_eval = iter.next().unwrap(); + s_denom_evals.batch_invert(); + let s_evals = s_denom_evals.iter().map(|denom| vanishing_eval * denom); + izip!((poly..).map(|poly| (poly, 0).into()), s_evals).collect() +} + +fn vanishing_eval>(domain: &Radix2Domain, x: F) -> F { + x.pow([domain.n() as u64]) - F::ONE +} diff --git a/plonkish_backend/src/piop/multilinear_eval/ph23/additive.rs b/plonkish_backend/src/piop/multilinear_eval/ph23/additive.rs new file mode 100644 index 00000000..2b9ea116 --- /dev/null +++ b/plonkish_backend/src/piop/multilinear_eval/ph23/additive.rs @@ -0,0 +1,580 @@ +use crate::{ + pcs::{Additive, Evaluation, PolynomialCommitmentScheme}, + piop::multilinear_eval::ph23::{additive::QueryGroup::*, s_evals, vanishing_eval}, + poly::{multilinear::MultilinearPolynomial, univariate::UnivariatePolynomial}, + util::{ + arithmetic::{ + inner_product, powers, product, BatchInvert, Msm, PrimeField, WithSmallOrderMulGroup, + }, + chain, end_timer, + expression::{ + evaluator::quotient::{QuotientEvaluator, Radix2Domain}, + rotate::{Lexical, Rotatable}, + Expression, Query, Rotation, + }, + izip, izip_eq, + parallel::parallelize, + start_timer, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use std::{borrow::Cow, collections::BTreeMap, mem}; + +#[allow(clippy::too_many_arguments)] +pub fn prove_multilinear_eval<'a, F, Pcs>( + pp: &Pcs::ProverParam, + num_vars: usize, + s_polys: &[Vec], + polys: impl IntoIterator>, + comms: impl IntoIterator, + point: &[F], + evals: &[(Query, F)], + transcript: &mut impl TranscriptWrite, +) -> Result<(), Error> +where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: 'a + Additive, +{ + let domain = &Radix2Domain::::new(num_vars, 2); + let polys = polys.into_iter().collect_vec(); + let comms = comms.into_iter().collect_vec(); + let num_polys = polys.len(); + assert_eq!(comms.len(), num_polys); + + let (queries, evals) = evals.iter().cloned().unzip::<_, _, Vec<_>, Vec<_>>(); + + let gamma = transcript.squeeze_challenge(); + let powers_of_gamma = powers(gamma).take(queries.len()).collect_vec(); + + let query_groups = query_groups(num_vars, &queries, &powers_of_gamma); + + let u_step = -domain.n_inv() * inner_product(&powers_of_gamma, &evals); + let (eq_0, eq_u_fs) = { + let fs = chain![&query_groups] + .map(|group| group.poly(&polys)) + .collect_vec(); + let (eq, u) = eq_u(point, &query_groups, &fs, u_step); + + let eq_0 = eq[0]; + let eq_u_fs = chain![[eq, u].map(Cow::Owned), fs] + .map(|buf| domain.lagrange_to_monomial(buf)) + .map(UnivariatePolynomial::monomial) + .collect_vec(); + (eq_0, eq_u_fs) + }; + + let eq_u_comm = Pcs::batch_commit_and_write(pp, &eq_u_fs[..2], transcript)?; + + let alpha = transcript.squeeze_challenge(); + + let expression = expression(&query_groups, point, u_step, eq_0, alpha); + + let q = { + let eq_u_fs = chain![&eq_u_fs] + .map(|poly| domain.monomial_to_extended_lagrange(poly.coeffs().into())) + .collect_vec(); + let polys = chain![&eq_u_fs, s_polys].map(Vec::as_slice); + + let timer = start_timer(|| "quotient"); + let ev = QuotientEvaluator::new(domain, &expression, Default::default(), polys); + let mut q = vec![F::ZERO; domain.extended_n()]; + parallelize(&mut q, |(q, start)| { + let mut cache = ev.cache(); + izip!(q, start..).for_each(|(q, row)| ev.evaluate(q, &mut cache, row)); + }); + end_timer(timer); + + UnivariatePolynomial::monomial(domain.extended_lagrange_to_monomial(q.into())) + }; + + let q_comm = Pcs::commit_and_write(pp, &q, transcript)?; + + let x = transcript.squeeze_challenge(); + + let evals = eq_u_queries(&expression) + .map(|query| { + let point = domain.rotate_point(x, query.rotation()); + (query, eq_u_fs[query.poly()].evaluate(&point)) + }) + .collect_vec(); + + transcript.write_field_elements(evals.iter().map(|(_, eval)| eval))?; + + let (lin, lin_comm, lin_eval) = { + let evals = chain![evals.iter().cloned(), s_evals(domain, eq_u_fs.len(), x)].collect(); + let vanishing_eval = vanishing_eval(domain, x); + let (constant, poly) = { + let q = Msm::term(vanishing_eval, &q); + linearization(&expression, &eq_u_fs, &evals, q) + }; + let comm = if cfg!(feature = "sanity-check") { + let comms = { + let f_comms = query_groups.iter().map(|group| group.comm(&comms)); + chain![eq_u_comm.iter().map(Msm::base), f_comms].collect_vec() + }; + let q = Msm::term(vanishing_eval, &q_comm); + let (_, comm) = linearization(&expression, &comms, &evals, Msm::base(&q)); + let (_, comm) = comm.evaluate(); + comm + } else { + Default::default() + }; + (poly, comm, -constant) + }; + + if cfg!(feature = "sanity-check") { + assert_eq!(lin.evaluate(&x), lin_eval); + assert_eq!(&Pcs::commit(pp, &lin).unwrap(), &lin_comm); + } + + let polys = chain![&eq_u_fs[..2], [&lin]]; + let comms = chain![&eq_u_comm, [&lin_comm]]; + let (points, evals) = points_evals(domain, x, &evals, lin_eval); + let _timer = start_timer(|| format!("pcs_batch_open-{}", evals.len())); + Pcs::batch_open(pp, polys, comms, &points, &evals, transcript) +} + +pub fn verify_multilinear_eval<'a, F, Pcs>( + vp: &Pcs::VerifierParam, + num_vars: usize, + comms: impl IntoIterator, + point: &[F], + evals: &[(Query, F)], + transcript: &mut impl TranscriptRead, +) -> Result<(), Error> +where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: 'a + Additive, +{ + let domain = &Radix2Domain::::new(num_vars, 2); + let comms = comms.into_iter().collect_vec(); + + let (queries, evals) = evals.iter().cloned().unzip::<_, _, Vec<_>, Vec<_>>(); + + let gamma = transcript.squeeze_challenge(); + let powers_of_gamma = powers(gamma).take(evals.len()).collect_vec(); + + let query_groups = query_groups(num_vars, &queries, &powers_of_gamma); + + let u_step = -domain.n_inv() * inner_product(&powers_of_gamma, &evals); + let eq_0 = product(point.iter().map(|point_i| F::ONE - point_i)); + + let eq_u_comm = Pcs::read_commitments(vp, 2, transcript)?; + + let alpha = transcript.squeeze_challenge(); + + let expression = expression(&query_groups, point, u_step, eq_0, alpha); + + let q_comm = Pcs::read_commitment(vp, transcript)?; + + let x = transcript.squeeze_challenge(); + + let evals = { + let queries = eq_u_queries(&expression).collect_vec(); + let evals = transcript.read_field_elements(queries.len())?; + izip_eq!(queries, evals).collect_vec() + }; + + let (lin_comm, lin_eval) = { + let comms = { + let f_comms = query_groups.iter().map(|group| group.comm(&comms)); + chain![eq_u_comm.iter().map(Msm::base), f_comms].collect_vec() + }; + let evals = chain![evals.iter().cloned(), s_evals(domain, comms.len(), x)].collect(); + let vanishing_eval = vanishing_eval(domain, x); + let q = Msm::term(vanishing_eval, &q_comm); + let (constant, comm) = linearization(&expression, &comms, &evals, Msm::base(&q)); + let (_, comm) = comm.evaluate(); + (comm, -constant) + }; + + let comms = chain![&eq_u_comm, [&lin_comm]]; + let (points, evals) = points_evals(domain, x, &evals, lin_eval); + Pcs::batch_verify(vp, comms, &points, &evals, transcript) +} + +#[derive(Clone, Debug)] +enum QueryGroup { + ByPoly { + poly: usize, + rotations: Vec, + scalars: Vec, + }, + ByRotation { + rotation: Rotation, + polys: Vec, + scalars: Vec, + }, +} + +impl QueryGroup { + fn eq(&self) -> Expression { + match self { + ByPoly { + rotations, scalars, .. + } => { + let eq_rots = + chain![rotations].map(|rotation| Expression::Polynomial((0, *rotation).into())); + izip!(eq_rots, scalars) + .map(|(eq_rot, scalar)| eq_rot * *scalar) + .sum::>() + } + ByRotation { rotation, .. } => Expression::::Polynomial((0, *rotation).into()), + } + } + + fn poly<'a>(&self, polys: &[&'a UnivariatePolynomial]) -> Cow<'a, [F]> { + match self { + ByPoly { poly, .. } => polys[*poly].coeffs().into(), + ByRotation { + polys: ps, scalars, .. + } => izip!(scalars, ps) + .map(|(scalar, poly)| (scalar, polys[*poly])) + .sum::>() + .into_coeffs() + .into(), + } + } + + fn comm<'a, T: Additive>(&self, comms: &[&'a T]) -> Msm<'a, F, T> { + match self { + ByPoly { poly, .. } => Msm::base(comms[*poly]), + ByRotation { polys, scalars, .. } => izip!(polys, scalars) + .map(|(poly, scalar)| Msm::term(*scalar, comms[*poly])) + .sum(), + } + } +} + +fn query_groups( + num_vars: usize, + queries: &[Query], + powers_of_gamma: &[F], +) -> Vec> { + let n = 1 << num_vars; + let repeated_rotations = chain![queries.iter().map(|query| (-query.rotation()).positive(n))] + .counts() + .into_iter() + .filter(|(_, count)| *count > 1) + .sorted_by(|a, b| b.1.cmp(&a.1)) + .map(|(rotation, _)| rotation); + let mut by_polys = izip!(queries, powers_of_gamma) + .fold(BTreeMap::new(), |mut polys, (query, scalar)| { + polys + .entry(query.poly()) + .and_modify(|poly| match poly { + ByPoly { + rotations, scalars, .. + } => { + rotations.push((-query.rotation()).positive(n)); + scalars.push(*scalar); + } + _ => unreachable!(), + }) + .or_insert_with(|| ByPoly { + poly: query.poly(), + rotations: vec![(-query.rotation()).positive(n)], + scalars: vec![*scalar], + }); + polys + }) + .into_values() + .collect_vec(); + let mut by_rotations = Vec::new(); + let mut output = by_polys.clone(); + for rotation in repeated_rotations { + let mut by_rotation = (Vec::new(), Vec::new()); + by_polys.retain_mut(|poly| match poly { + ByPoly { + poly, + rotations, + scalars, + } => { + if let Some(idx) = rotations.iter().position(|value| *value == rotation) { + rotations.remove(idx); + by_rotation.0.push(*poly); + by_rotation.1.push(scalars.remove(idx)); + !rotations.is_empty() + } else { + true + } + } + _ => unreachable!(), + }); + by_rotations.push(ByRotation { + rotation, + polys: by_rotation.0, + scalars: by_rotation.1, + }); + if by_polys.len() + by_rotations.len() <= output.len() { + output = chain![&by_polys, &by_rotations].cloned().collect_vec(); + } + } + output +} + +fn eq_u( + point: &[F], + query_groups: &[QueryGroup], + polys: &[Cow<[F]>], + u_step: F, +) -> (Vec, Vec) { + let _timer = start_timer(|| "u"); + + let lexical = Lexical::new(point.len()); + let eq = MultilinearPolynomial::eq_xy(point).into_evals(); + let sums = { + let mut coeffs = vec![F::ZERO; lexical.n()]; + izip!(query_groups, polys).for_each(|(group, poly)| match group { + ByPoly { + rotations, scalars, .. + } => { + parallelize(&mut coeffs, |(coeffs, start)| { + izip!(start.., coeffs, &poly[start..]).for_each(|(idx, coeffs, poly)| { + let eq_rot = + chain![rotations].map(|rotation| &eq[lexical.rotate(idx, *rotation)]); + *coeffs += inner_product(eq_rot, scalars) * poly; + }); + }); + } + ByRotation { rotation, .. } => { + parallelize(&mut coeffs, |(coeffs, start)| { + let skip = lexical.rotate(start, *rotation); + izip!(coeffs, eq.iter().cycle().skip(skip), &poly[start..]) + .for_each(|(coeffs, eq, poly)| *coeffs += *eq * poly); + }); + } + }); + coeffs + }; + let u = chain![&sums] + .scan(F::ZERO, |u, sum| mem::replace(u, *u + sum + u_step).into()) + .collect_vec(); + + if cfg!(feature = "sanity-check") { + assert_eq!(F::ZERO, u[lexical.nth(-1)] + sums[lexical.nth(-1)] + u_step); + } + + (eq, u) +} + +fn expression( + query_groups: &[QueryGroup], + point: &[F], + u_step: F, + eq_0: F, + alpha: F, +) -> Expression { + let num_vars = point.len(); + let [u_step, eq_0, alpha] = &[u_step, eq_0, alpha].map(Expression::Constant); + let eq_ratios = { + let mut denoms = point.iter().map(|point_i| F::ONE - point_i).collect_vec(); + denoms.batch_invert(); + izip!(point, denoms) + .map(|(numer, denom)| denom * numer) + .rev() + .collect_vec() + }; + let eq = &Expression::Polynomial(Query::new(0, Rotation::cur())); + let eq_rots = (0..num_vars) + .rev() + .map(|rotation| Expression::Polynomial(Query::new(0, Rotation(1 << rotation)))) + .collect_vec(); + let [u, u_next] = &[Rotation::cur(), Rotation::next()] + .map(|rotation| Expression::Polynomial(Query::new(1, rotation))); + let f = izip!(2.., query_groups) + .map(|(idx, set)| set.eq() * Expression::Polynomial((idx, 0).into())) + .sum::>(); + let s = (2 + query_groups.len()..) + .take(num_vars) + .map(|poly| Expression::::Polynomial(Query::new(poly, Rotation::cur()))) + .collect_vec(); + let constraints = chain![ + [u_next - u - f - u_step], + [&s[0] * (eq - eq_0)], + izip!(&s, &eq_rots, &eq_ratios).map(|(s, eq_rot, eq_ratio)| s * (eq * eq_ratio - eq_rot)) + ] + .collect_vec(); + Expression::distribute_powers(&constraints, alpha) + .simplified(None) + .unwrap() +} + +fn eq_u_queries(expression: &Expression) -> impl Iterator { + chain![ + chain![expression.used_query()].filter(|query| query.poly() == 0), + [(1, Rotation::next()).into()] + ] +} + +fn linearization<'a, F: PrimeField, T: Additive + 'a>( + expression: &Expression, + bases: impl IntoIterator, + evals: &BTreeMap, + vanishing_q: Msm, +) -> (F, T) { + let bases = bases.into_iter().collect_vec(); + (expression.evaluate( + &|scalar| Msm::scalar(scalar), + &|_| unreachable!(), + &|query| { + if let Some(eval) = evals.get(&query) { + Msm::scalar(*eval) + } else { + assert_eq!(query.rotation(), Rotation::cur()); + Msm::base(bases[query.poly()]) + } + }, + &|_| unreachable!(), + &|scalar| -scalar, + &|lhs, rhs| lhs + rhs, + &|lhs, rhs| lhs * rhs, + &|value, scalar| value * Msm::scalar(scalar), + ) - vanishing_q) + .evaluate() +} + +fn points_evals>( + domain: &Radix2Domain, + x: F, + evals: &[(Query, F)], + lin_eval: F, +) -> (Vec, Vec>) { + let point_index = evals + .iter() + .fold(BTreeMap::new(), |mut point_index, (query, _)| { + let rotation = query.rotation().positive(domain.n()); + let idx = point_index.len(); + point_index.entry(rotation).or_insert(idx); + point_index + }); + let points = point_index + .iter() + .sorted_by(|a, b| a.1.cmp(b.1)) + .map(|(rotation, _)| domain.rotate_point(x, *rotation)) + .collect_vec(); + let evals = chain![ + evals.iter().map(|(query, eval)| { + let point = point_index[&query.rotation().positive(domain.n())]; + Evaluation::new(query.poly(), point, *eval) + }), + [Evaluation::new(2, point_index[&Rotation::cur()], lin_eval)] + ] + .collect_vec(); + (points, evals) +} + +#[cfg(test)] +mod test { + use crate::{ + pcs::{univariate::UnivariateKzg, Additive, PolynomialCommitmentScheme}, + piop::multilinear_eval::ph23::{ + additive::{prove_multilinear_eval, verify_multilinear_eval}, + s_polys, + }, + poly::{multilinear::MultilinearPolynomial, univariate::UnivariatePolynomial}, + util::{ + arithmetic::WithSmallOrderMulGroup, + expression::{Query, Rotation}, + izip, + test::{rand_vec, seeded_std_rng}, + transcript::{ + InMemoryTranscript, Keccak256Transcript, TranscriptRead, TranscriptWrite, + }, + Itertools, + }, + }; + use halo2_curves::bn256::{Bn256, Fr}; + use rand::Rng; + use std::{io::Cursor, iter}; + + fn run_prove_verify(num_vars: usize) + where + F: WithSmallOrderMulGroup<3>, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: Additive, + Keccak256Transcript>>: TranscriptRead + + TranscriptWrite + + InMemoryTranscript, + { + let mut rng = seeded_std_rng(); + + let n = 1 << num_vars; + let param = Pcs::setup(n, 0, &mut rng).unwrap(); + let (pp, vp) = Pcs::trim(¶m, n, 0).unwrap(); + + let s_polys = s_polys(num_vars); + let polys = iter::repeat_with(|| UnivariatePolynomial::lagrange(rand_vec(n, &mut rng))) + .take(10) + .collect_vec(); + let comms = polys + .iter() + .map(|poly| Pcs::commit(&pp, poly).unwrap()) + .collect_vec(); + let point = rand_vec(num_vars, &mut rng); + let evals = izip!(0.., &polys) + .flat_map(|(idx, poly)| { + let point = &point; + let max_rotation = 1 << (num_vars - 1); + let num_rotations = rng.gen_range(1..3.min(max_rotation)); + let rotation_range = -(5.min(max_rotation) as i32)..=5.min(max_rotation) as i32; + iter::repeat_with(|| rng.gen_range(rotation_range.clone())) + .unique() + .take(num_rotations) + .map(move |rotation| { + let mut poly = poly.coeffs().to_vec(); + if rotation < 0 { + poly.rotate_right(rotation.unsigned_abs() as usize) + } else { + poly.rotate_left(rotation.unsigned_abs() as usize) + } + let eval = MultilinearPolynomial::new(poly).evaluate(point); + (Query::new(idx, Rotation(rotation)), eval) + }) + .collect_vec() + }) + .collect_vec(); + + let proof = { + let mut transcript = Keccak256Transcript::default(); + prove_multilinear_eval::( + &pp, + num_vars, + &s_polys, + &polys, + &comms, + &point, + &evals, + &mut transcript, + ) + .unwrap(); + transcript.into_proof() + }; + + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + verify_multilinear_eval::( + &vp, + num_vars, + &comms, + &point, + &evals, + &mut transcript, + ) + }; + assert_eq!(result, Ok(())); + } + + #[test] + fn prove_verify() { + type Pcs = UnivariateKzg; + + for num_vars in 2..16 { + run_prove_verify::(num_vars); + } + } +} diff --git a/plonkish_backend/src/piop/sum_check.rs b/plonkish_backend/src/piop/sum_check.rs index f335e78a..8855d296 100644 --- a/plonkish_backend/src/piop/sum_check.rs +++ b/plonkish_backend/src/piop/sum_check.rs @@ -1,14 +1,14 @@ use crate::{ poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{inner_product, powers, product, BooleanHypercube, Field, PrimeField}, - expression::{CommonPolynomial, Expression, Query}, + arithmetic::{inner_product, powers, product, Field, PrimeField}, + expression::{rotate::Rotatable, CommonPolynomial, Expression, Query}, transcript::{FieldTranscriptRead, FieldTranscriptWrite}, BitIndex, Itertools, }, Error, }; -use std::{collections::HashMap, fmt::Debug}; +use std::{collections::BTreeMap, fmt::Debug}; pub mod classic; @@ -40,13 +40,14 @@ pub trait SumCheck: Clone + Debug { type ProverParam: Clone + Debug; type VerifierParam: Clone + Debug; + #[allow(clippy::type_complexity)] fn prove( pp: &Self::ProverParam, num_vars: usize, virtual_poly: VirtualPolynomial, sum: F, transcript: &mut impl FieldTranscriptWrite, - ) -> Result<(F, Vec, Vec), Error>; + ) -> Result<(F, Vec, BTreeMap), Error>; fn verify( vp: &Self::VerifierParam, @@ -57,26 +58,23 @@ pub trait SumCheck: Clone + Debug { ) -> Result<(F, Vec), Error>; } -pub fn evaluate( +pub fn evaluate>( expression: &Expression, num_vars: usize, - evals: &HashMap, + evals: &BTreeMap, challenges: &[F], ys: &[&[F]], x: &[F], ) -> F { - assert!(num_vars > 0 && expression.max_used_rotation_distance() <= num_vars); + let rotatable = R::from(num_vars); + let identity = identity_eval(x); let lagranges = { - let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); expression .used_langrange() .into_iter() - .map(|i| { - let b = bh[i.rem_euclid(1 << num_vars as i32) as usize]; - (i, lagrange_eval(x, b)) - }) - .collect::>() + .map(|i| (i, lagrange_eval(x, rotatable.nth(i)))) + .collect::>() }; let eq_xys = ys.iter().map(|y| eq_xy_eval(x, y)).collect_vec(); expression.evaluate( @@ -128,225 +126,224 @@ fn identity_eval(x: &[F]) -> F { pub(super) mod test { use crate::{ piop::sum_check::{evaluate, SumCheck, VirtualPolynomial}, - poly::multilinear::{rotation_eval, MultilinearPolynomial}, + poly::multilinear::MultilinearPolynomial, util::{ - expression::Expression, + arithmetic::Field, + expression::{rotate::Rotatable, Expression}, transcript::{InMemoryTranscript, Keccak256Transcript}, }, }; use halo2_curves::bn256::Fr; + use itertools::Itertools; use std::ops::Range; - pub fn run_sum_check>( + pub fn run_sum_check, R: Rotatable + From>( num_vars_range: Range, expression_fn: impl Fn(usize) -> Expression, param_fn: impl Fn(usize) -> (S::ProverParam, S::VerifierParam), - assignment_fn: impl Fn(usize) -> (Vec>, Vec, Vec), + assignment_fn: impl Fn(usize) -> (Vec>, Vec, Vec>), sum: Fr, ) { for num_vars in num_vars_range { let expression = expression_fn(num_vars); let degree = expression.degree(); let (pp, vp) = param_fn(expression.degree()); - let (polys, challenges, y) = assignment_fn(num_vars); - let ys = [y]; - let proof = { + let (polys, challenges, ys) = assignment_fn(num_vars); + let (evals, proof) = { let virtual_poly = VirtualPolynomial::new(&expression, &polys, &challenges, &ys); let mut transcript = Keccak256Transcript::default(); - S::prove(&pp, num_vars, virtual_poly, sum, &mut transcript).unwrap(); - transcript.into_proof() + let (_, _, evals) = + S::prove(&pp, num_vars, virtual_poly, sum, &mut transcript).unwrap(); + (evals, transcript.into_proof()) }; let accept = { let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); let (x_eval, x) = - S::verify(&vp, num_vars, degree, Fr::zero(), &mut transcript).unwrap(); - let evals = expression - .used_query() - .into_iter() - .map(|query| { - let evaluate_for_rotation = - polys[query.poly()].evaluate_for_rotation(&x, query.rotation()); - let eval = rotation_eval(&x, query.rotation(), &evaluate_for_rotation); - (query, eval) - }) - .collect(); - x_eval == evaluate(&expression, num_vars, &evals, &challenges, &[&ys[0]], &x) + S::verify(&vp, num_vars, degree, Fr::ZERO, &mut transcript).unwrap(); + let ys = ys.iter().map(Vec::as_slice).collect_vec(); + x_eval == evaluate::<_, R>(&expression, num_vars, &evals, &challenges, &ys, &x) }; assert!(accept); } } - pub fn run_zero_check>( + pub fn run_zero_check, R: Rotatable + From>( num_vars_range: Range, expression_fn: impl Fn(usize) -> Expression, param_fn: impl Fn(usize) -> (S::ProverParam, S::VerifierParam), - assignment_fn: impl Fn(usize) -> (Vec>, Vec, Vec), + assignment_fn: impl Fn(usize) -> (Vec>, Vec, Vec>), ) { - run_sum_check::( + run_sum_check::( num_vars_range, expression_fn, param_fn, assignment_fn, - Fr::zero(), + Fr::ZERO, ) } macro_rules! tests { - ($impl:ty) => { - #[test] - fn sum_check_lagrange() { - use halo2_curves::bn256::Fr; - use $crate::{ - piop::sum_check::test::run_zero_check, - poly::multilinear::MultilinearPolynomial, - util::{ - arithmetic::{BooleanHypercube, Field}, - expression::{CommonPolynomial, Expression, Query, Rotation}, - test::{rand_vec, seeded_std_rng}, - Itertools, - }, - }; + ($suffix:ident, $impl:ty, $rotatable:ident) => { + paste::paste! { + #[test] + fn []() { + use halo2_curves::bn256::Fr; + use $crate::{ + piop::sum_check::test::run_zero_check, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::Field, + expression::{ + rotate::Rotatable, CommonPolynomial, Expression, Query, Rotation, + }, + test::{rand_vec, seeded_std_rng}, + Itertools, + }, + }; - run_zero_check::<$impl>( - 2..4, - |num_vars| { - let polys = (0..1 << num_vars) - .map(|idx| { - Expression::::Polynomial(Query::new(idx, Rotation::cur())) - }) - .collect_vec(); - let gates = polys - .iter() - .enumerate() - .map(|(i, poly)| { - Expression::CommonPolynomial(CommonPolynomial::Lagrange(i as i32)) - - poly - }) - .collect_vec(); - let alpha = Expression::Challenge(0); - let eq = Expression::eq_xy(0); - Expression::distribute_powers(&gates, &alpha) * eq - }, - |_| ((), ()), - |num_vars| { - let polys = BooleanHypercube::new(num_vars) - .iter() - .map(|idx| { - let mut polys = - MultilinearPolynomial::new(vec![Fr::zero(); 1 << num_vars]); - polys[idx] = Fr::one(); - polys - }) - .collect_vec(); - let alpha = Fr::random(seeded_std_rng()); - (polys, vec![alpha], rand_vec(num_vars, seeded_std_rng())) - }, - ); - } + run_zero_check::<$impl, $rotatable>( + 2..4, + |num_vars| { + let polys = (0..$rotatable::new(num_vars).usable_indices().len()) + .map(|idx| { + Expression::::Polynomial(Query::new(idx, Rotation::cur())) + }) + .collect_vec(); + let gates = polys + .iter() + .enumerate() + .map(|(i, poly)| { + Expression::CommonPolynomial(CommonPolynomial::Lagrange(i as i32)) + - poly + }) + .collect_vec(); + let alpha = Expression::Challenge(0); + let eq = Expression::eq_xy(0); + Expression::distribute_powers(&gates, &alpha) * eq + }, + |_| ((), ()), + |num_vars| { + let polys = $rotatable::new(num_vars) + .usable_indices() + .into_iter() + .map(|idx| { + let mut polys = + MultilinearPolynomial::new(vec![Fr::ZERO; 1 << num_vars]); + polys[idx] = Fr::ONE; + polys + }) + .collect_vec(); + let alpha = Fr::random(seeded_std_rng()); + (polys, vec![alpha], vec![rand_vec(num_vars, seeded_std_rng())]) + }, + ); + } - #[test] - fn sum_check_rotation() { - use halo2_curves::bn256::Fr; - use std::iter; - use $crate::{ - piop::sum_check::test::run_zero_check, - poly::multilinear::MultilinearPolynomial, - util::{ - arithmetic::{BooleanHypercube, Field}, - expression::{Expression, Query, Rotation}, - test::{rand_vec, seeded_std_rng}, - Itertools, - }, - }; + #[test] + fn []() { + use halo2_curves::bn256::Fr; + use std::iter; + use $crate::{ + piop::sum_check::test::run_zero_check, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::Field, + expression::{rotate::Rotatable, Expression, Query, Rotation}, + test::{rand_vec, seeded_std_rng}, + Itertools, + }, + }; - run_zero_check::<$impl>( - 2..16, - |num_vars| { - let polys = (-(num_vars as i32) + 1..num_vars as i32) - .rev() - .enumerate() - .map(|(idx, rotation)| { - Expression::::Polynomial(Query::new(idx, rotation.into())) - }) - .collect_vec(); - let gates = polys - .windows(2) - .map(|polys| &polys[1] - &polys[0]) - .collect_vec(); - let alpha = Expression::Challenge(0); - let eq = Expression::eq_xy(0); - Expression::distribute_powers(&gates, &alpha) * eq - }, - |_| ((), ()), - |num_vars| { - let bh = BooleanHypercube::new(num_vars); - let rotate = |f: &Vec| { - (0..1 << num_vars) - .map(|idx| f[bh.rotate(idx, Rotation::next())]) - .collect_vec() - }; - let poly = rand_vec(1 << num_vars, seeded_std_rng()); - let polys = iter::successors(Some(poly), |poly| Some(rotate(poly))) - .map(MultilinearPolynomial::new) - .take(2 * num_vars - 1) - .collect_vec(); - let alpha = Fr::random(seeded_std_rng()); - (polys, vec![alpha], rand_vec(num_vars, seeded_std_rng())) - }, - ); - } + run_zero_check::<$impl, $rotatable>( + 2..16, + |num_vars| { + let polys = (-(num_vars as i32) + 1..num_vars as i32) + .rev() + .enumerate() + .map(|(idx, rotation)| { + Expression::::Polynomial(Query::new(idx, rotation)) + }) + .collect_vec(); + let gates = polys + .windows(2) + .map(|polys| &polys[1] - &polys[0]) + .collect_vec(); + let alpha = Expression::Challenge(0); + let eq = Expression::eq_xy(0); + Expression::distribute_powers(&gates, &alpha) * eq + }, + |_| ((), ()), + |num_vars| { + let rotatable = $rotatable::from(num_vars); + let rotate = |f: &Vec| { + (0..1 << num_vars) + .map(|idx| f[rotatable.rotate(idx, Rotation::next())]) + .collect_vec() + }; + let poly = rand_vec(1 << num_vars, seeded_std_rng()); + let polys = iter::successors(Some(poly), |poly| Some(rotate(poly))) + .map(MultilinearPolynomial::new) + .take(2 * num_vars - 1) + .collect_vec(); + let alpha = Fr::random(seeded_std_rng()); + (polys, vec![alpha], vec![rand_vec(num_vars, seeded_std_rng())]) + }, + ); + } - #[test] - fn sum_check_vanilla_plonk() { - use halo2_curves::bn256::Fr; - use $crate::{ - backend::hyperplonk::util::{ - rand_vanilla_plonk_assignment, vanilla_plonk_expression, - }, - piop::sum_check::test::run_zero_check, - util::test::{rand_vec, seeded_std_rng}, - }; + #[test] + fn []() { + use halo2_curves::bn256::Fr; + use $crate::{ + backend::hyperplonk::util::{ + rand_vanilla_plonk_assignment, vanilla_plonk_expression, + }, + piop::sum_check::test::run_zero_check, + util::test::{rand_vec, seeded_std_rng}, + }; - run_zero_check::<$impl>( - 2..16, - |num_vars| vanilla_plonk_expression(num_vars), - |_| ((), ()), - |num_vars| { - let (polys, challenges) = rand_vanilla_plonk_assignment( - num_vars, - seeded_std_rng(), - seeded_std_rng(), - ); - (polys, challenges, rand_vec(num_vars, seeded_std_rng())) - }, - ); - } + run_zero_check::<$impl, $rotatable>( + 2..16, + |num_vars| vanilla_plonk_expression(num_vars), + |_| ((), ()), + |num_vars| { + let (polys, challenges) = rand_vanilla_plonk_assignment::<_, $rotatable>( + num_vars, + seeded_std_rng(), + seeded_std_rng(), + ); + (polys, challenges, vec![rand_vec(num_vars, seeded_std_rng())]) + }, + ); + } - #[test] - fn sum_check_vanilla_plonk_with_lookup() { - use halo2_curves::bn256::Fr; - use $crate::{ - backend::hyperplonk::util::{ - rand_vanilla_plonk_with_lookup_assignment, - vanilla_plonk_with_lookup_expression, - }, - piop::sum_check::test::run_zero_check, - util::test::{rand_vec, seeded_std_rng}, - }; + #[test] + fn []() { + use halo2_curves::bn256::Fr; + use $crate::{ + backend::hyperplonk::util::{ + rand_vanilla_plonk_w_lookup_assignment, + vanilla_plonk_w_lookup_expression, + }, + piop::sum_check::test::run_zero_check, + util::test::{rand_vec, seeded_std_rng}, + }; - run_zero_check::<$impl>( - 2..16, - |num_vars| vanilla_plonk_with_lookup_expression(num_vars), - |_| ((), ()), - |num_vars| { - let (polys, challenges) = rand_vanilla_plonk_with_lookup_assignment( - num_vars, - seeded_std_rng(), - seeded_std_rng(), - ); - (polys, challenges, rand_vec(num_vars, seeded_std_rng())) - }, - ); + run_zero_check::<$impl, $rotatable>( + 2..16, + |num_vars| vanilla_plonk_w_lookup_expression(num_vars), + |_| ((), ()), + |num_vars| { + let (polys, challenges) = rand_vanilla_plonk_w_lookup_assignment::< + _, + $rotatable, + >( + num_vars, seeded_std_rng(), seeded_std_rng() + ); + (polys, challenges, vec![rand_vec(num_vars, seeded_std_rng())]) + }, + ); + } } }; } diff --git a/plonkish_backend/src/piop/sum_check/classic.rs b/plonkish_backend/src/piop/sum_check/classic.rs index 1e9a0ab4..b82846fd 100644 --- a/plonkish_backend/src/piop/sum_check/classic.rs +++ b/plonkish_backend/src/piop/sum_check/classic.rs @@ -2,9 +2,10 @@ use crate::{ piop::sum_check::{SumCheck, VirtualPolynomial}, poly::multilinear::MultilinearPolynomial, util::{ - arithmetic::{BooleanHypercube, Field, PrimeField}, + arithmetic::{Field, PrimeField}, end_timer, - expression::{Expression, Rotation}, + expression::{rotate::Rotatable, Expression, Query, Rotation}, + izip, parallel::par_map_collect, start_timer, transcript::{FieldTranscriptRead, FieldTranscriptWrite}, @@ -13,7 +14,12 @@ use crate::{ Error, }; use num_integer::Integer; -use std::{borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData}; +use std::{ + borrow::Cow, + collections::{BTreeMap, HashMap}, + fmt::Debug, + marker::PhantomData, +}; mod coeff; mod eval; @@ -30,27 +36,28 @@ pub struct ProverState<'a, F: Field> { lagranges: HashMap, identity: F, eq_xys: Vec>, - polys: Vec>>>, + polys: HashMap>>, challenges: &'a [F], buf: MultilinearPolynomial, round: usize, - bh: BooleanHypercube, + rotatable: Box, } impl<'a, F: PrimeField> ProverState<'a, F> { - fn new(num_vars: usize, sum: F, virtual_poly: VirtualPolynomial<'a, F>) -> Self { - assert!(num_vars > 0 && virtual_poly.expression.max_used_rotation_distance() <= num_vars); - let bh = BooleanHypercube::new(num_vars); + fn new>( + num_vars: usize, + sum: F, + virtual_poly: VirtualPolynomial<'a, F>, + ) -> Self { + let rotatable = Box::new(R::from(num_vars)); + assert!(virtual_poly.expression.max_used_rotation_distance() <= rotatable.max_rotation()); + let lagranges = { - let bh = bh.iter().collect_vec(); virtual_poly .expression .used_langrange() .into_iter() - .map(|i| { - let b = bh[i.rem_euclid(1 << num_vars) as usize]; - (i, (b, F::ONE)) - }) + .map(|i| (i, (rotatable.nth(i), F::ONE))) .collect() }; let eq_xys = virtual_poly @@ -58,15 +65,9 @@ impl<'a, F: PrimeField> ProverState<'a, F> { .iter() .map(|y| MultilinearPolynomial::eq_xy(y)) .collect_vec(); - let polys = virtual_poly - .polys - .iter() - .map(|poly| { - let mut polys = vec![Cow::Owned(MultilinearPolynomial::zero()); 2 * num_vars]; - polys[num_vars] = Cow::Borrowed(*poly); - polys - }) - .collect_vec(); + let polys = izip!(0.., virtual_poly.polys) + .map(|(idx, poly)| ((idx, 0).into(), Cow::Borrowed(poly))) + .collect(); Self { num_vars, expression: virtual_poly.expression, @@ -79,7 +80,7 @@ impl<'a, F: PrimeField> ProverState<'a, F> { challenges: virtual_poly.challenges, buf: MultilinearPolynomial::new(vec![F::ZERO; 1 << (num_vars - 1)]), round: 0, - bh, + rotatable, } } @@ -87,7 +88,7 @@ impl<'a, F: PrimeField> ProverState<'a, F> { 1 << (self.num_vars - self.round - 1) } - fn next_round(&mut self, sum: F, challenge: &F) { + fn next_round>(&mut self, sum: F, challenge: &F) { self.sum = sum; self.identity += F::from(1 << self.round) * challenge; self.lagranges.values_mut().for_each(|(b, value)| { @@ -106,45 +107,45 @@ impl<'a, F: PrimeField> ProverState<'a, F> { .expression .used_rotation() .into_iter() - .filter_map(|rotation| { - (rotation != Rotation::cur()) - .then(|| (rotation, self.bh.rotation_map(rotation))) - }) + .filter(|rotation| rotation != &Rotation::cur()) + .map(|rotation| (rotation, self.rotatable.rotation_map(rotation))) .collect::>(); - for query in self.expression.used_query() { - if query.rotation() != Rotation::cur() { - let poly = &self.polys[query.poly()][self.num_vars]; + let rotated_polys = self + .expression + .used_query() + .into_iter() + .filter(|query| query.rotation() != Rotation::cur()) + .map(|query| { + let poly = &self.polys[&(query.poly(), 0).into()]; let mut rotated = MultilinearPolynomial::new(par_map_collect( &rotation_maps[&query.rotation()], |b| poly[*b], )); rotated.fix_var_in_place(challenge, &mut self.buf); - self.polys[query.poly()] - [(query.rotation().0 + self.num_vars as i32) as usize] = - Cow::Owned(rotated); - } - } - self.polys.iter_mut().for_each(|polys| { - polys[self.num_vars] = Cow::Owned(polys[self.num_vars].fix_var(challenge)); + (query, Cow::Owned(rotated)) + }) + .collect::>(); + self.polys.iter_mut().for_each(|(_, poly)| { + *poly = Cow::Owned(poly.fix_var(challenge)); }); + self.polys.extend(rotated_polys); } else { - self.polys.iter_mut().for_each(|polys| { - polys.iter_mut().for_each(|poly| { - if !poly.is_empty() { - poly.to_mut().fix_var_in_place(challenge, &mut self.buf); - } - }); - }); + self.polys + .iter_mut() + .for_each(|(_, poly)| poly.to_mut().fix_var_in_place(challenge, &mut self.buf)); } self.round += 1; - self.bh = BooleanHypercube::new(self.num_vars - self.round); + if self.round != self.num_vars { + self.rotatable = Box::new(R::from(self.num_vars - self.round)); + } } - fn into_evals(self) -> Vec { + fn into_evals(self) -> BTreeMap { assert_eq!(self.round, self.num_vars); - self.polys - .iter() - .map(|polys| polys[self.num_vars][0]) + self.expression + .used_query() + .into_iter() + .map(|query| (query, self.polys[&query][0])) .collect() } } @@ -194,13 +195,20 @@ pub trait ClassicSumCheckRoundMessage: Sized + Debug { } } -#[derive(Clone, Debug)] -pub struct ClassicSumCheck

(PhantomData

); +#[derive(Debug)] +pub struct ClassicSumCheck(PhantomData<(P, R)>); + +impl Clone for ClassicSumCheck { + fn clone(&self) -> Self { + Self(PhantomData) + } +} -impl SumCheck for ClassicSumCheck

+impl SumCheck for ClassicSumCheck where F: PrimeField, P: ClassicSumCheckProver, + R: Rotatable + From, { type ProverParam = (); type VerifierParam = (); @@ -211,13 +219,13 @@ where virtual_poly: VirtualPolynomial, sum: F, transcript: &mut impl FieldTranscriptWrite, - ) -> Result<(F, Vec, Vec), Error> { + ) -> Result<(F, Vec, BTreeMap), Error> { let _timer = start_timer(|| { let degree = virtual_poly.expression.degree(); format!("sum_check_prove-{num_vars}-{degree}") }); - let mut state = ProverState::new(num_vars, sum, virtual_poly); + let mut state = ProverState::new::(num_vars, sum, virtual_poly); let mut challenges = Vec::with_capacity(num_vars); let prover = P::new(&state); let aux = P::RoundMessage::auxiliary(state.degree); @@ -232,7 +240,7 @@ where challenges.push(challenge); let timer = start_timer(|| format!("sum_check_next_round-{round}")); - state.next_round(msg.evaluate(&aux, &challenge), &challenge); + state.next_round::(msg.evaluate(&aux, &challenge), &challenge); end_timer(timer); } diff --git a/plonkish_backend/src/piop/sum_check/classic/coeff.rs b/plonkish_backend/src/piop/sum_check/classic/coeff.rs index 78c39c76..091aa2a1 100644 --- a/plonkish_backend/src/piop/sum_check/classic/coeff.rs +++ b/plonkish_backend/src/piop/sum_check/classic/coeff.rs @@ -3,6 +3,7 @@ use crate::{ poly::multilinear::{zip_self, MultilinearPolynomial}, util::{ arithmetic::{div_ceil, horner, PrimeField}, + chain, expression::{CommonPolynomial, Expression, Rotation}, impl_index, izip_eq, parallel::{num_threads, parallelize_iter}, @@ -11,7 +12,7 @@ use crate::{ }, Error, }; -use std::{array, fmt::Debug, iter, ops::AddAssign}; +use std::{array, fmt::Debug, ops::AddAssign}; #[derive(Debug)] pub struct Coefficients(Vec); @@ -114,11 +115,7 @@ where { outputs.push(( *lhs_scalar * rhs_scalar, - iter::empty() - .chain(lhs_polys) - .chain(rhs_polys) - .cloned() - .collect_vec(), + chain![lhs_polys, rhs_polys].cloned().collect_vec(), )); } (lhs_constant * rhs_constant, outputs) @@ -209,7 +206,7 @@ fn poly<'a, F: PrimeField>( match expr { Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)) => &state.eq_xys[*idx], Expression::Polynomial(query) if query.rotation() == Rotation::cur() => { - &state.polys[query.poly()][state.num_vars] + &state.polys[&query] } _ => unimplemented!(), } diff --git a/plonkish_backend/src/piop/sum_check/classic/eval.rs b/plonkish_backend/src/piop/sum_check/classic/eval.rs index fe1f3410..429c16a1 100644 --- a/plonkish_backend/src/piop/sum_check/classic/eval.rs +++ b/plonkish_backend/src/piop/sum_check/classic/eval.rs @@ -1,10 +1,8 @@ use crate::{ piop::sum_check::classic::{ClassicSumCheckProver, ClassicSumCheckRoundMessage, ProverState}, util::{ - arithmetic::{ - barycentric_interpolate, barycentric_weights, div_ceil, steps, BooleanHypercube, - PrimeField, - }, + arithmetic::{barycentric_interpolate, barycentric_weights, div_ceil, steps, PrimeField}, + chain, expression::{ evaluator::{ExpressionRegistry, Offsets}, CommonPolynomial, Expression, @@ -79,11 +77,9 @@ where fn new(state: &ProverState) -> Self { let (dense, sparse) = split_sparse(state); Self( - iter::empty() - .chain(Some((&dense, false))) - .chain(sparse.iter().zip(iter::repeat(true))) + chain![[(&dense, false)], sparse.iter().zip(iter::repeat(true))] .filter_map(|(expression, is_sparse)| { - SumCheckEvaluator::new(state.num_vars, state.challenges, expression, is_sparse) + SumCheckEvaluator::new(state.challenges, expression, is_sparse) }) .collect(), ) @@ -133,29 +129,19 @@ impl EvaluationsProver { #[derive(Clone, Debug, Default)] struct SumCheckEvaluator { - num_vars: usize, reg: ExpressionRegistry, sparse: Option>, } impl SumCheckEvaluator { - fn new( - num_vars: usize, - challenges: &[F], - expression: &Expression, - is_sparse: bool, - ) -> Option { + fn new(challenges: &[F], expression: &Expression, is_sparse: bool) -> Option { let expression = expression.simplified(Some(challenges))?; let mut reg = ExpressionRegistry::new(); reg.register(&expression); let sparse = is_sparse.then_some(expression); - Some(Self { - num_vars, - reg, - sparse, - }) + Some(Self { reg, sparse }) } fn sparse_bs(&self, state: &ProverState) -> Option> { @@ -214,13 +200,13 @@ impl SumCheckEvaluator { b: usize, ) { if IS_FIRST_ROUND && IS_FIRST_POINT { - let bh = BooleanHypercube::new(self.num_vars); cache .bs .iter_mut() .zip(self.reg.rotations()) .for_each(|(bs, rotation)| { - let [b_0, b_1] = [b << 1, (b << 1) + 1].map(|b| bh.rotate(b, *rotation)); + let [b_0, b_1] = + [b << 1, (b << 1) + 1].map(|b| state.rotatable.rotate(b, *rotation)); *bs = (b_0, b_1); }); } @@ -261,12 +247,11 @@ impl SumCheckEvaluator { |(((eval, step), bs), (query, rotation))| { if IS_FIRST_ROUND { let (b_0, b_1) = bs[*rotation]; - let poly = &state.polys[query.poly()][self.num_vars]; + let poly = &state.polys[&(query.poly(), 0).into()]; *eval = poly[b_1]; *step = poly[b_1] - &poly[b_0]; } else { - let rotation = (self.num_vars as i32 + query.rotation().0) as usize; - let poly = &state.polys[query.poly()][rotation]; + let poly = &state.polys[query]; *eval = poly[b_1]; *step = poly[b_1] - &poly[b_0]; } @@ -397,10 +382,16 @@ fn split_sparse(state: &ProverState) -> (Expression, Vec>); + type ClassicSumCheck = classic::ClassicSumCheck, R>; + + tests!(binary_field, ClassicSumCheck, BinaryField); + tests!(lexical, ClassicSumCheck, Lexical); } diff --git a/plonkish_backend/src/poly.rs b/plonkish_backend/src/poly.rs index 57041cde..f674fa8d 100644 --- a/plonkish_backend/src/poly.rs +++ b/plonkish_backend/src/poly.rs @@ -7,13 +7,8 @@ pub mod univariate; pub trait Polynomial: Clone + Debug + Default + for<'a> AddAssign<(&'a F, &'a Self)> { - type Basis: Copy + Debug; type Point: Clone + Debug; - fn new(basis: Self::Basis, coeffs: Vec) -> Self; - - fn basis(&self) -> Self::Basis; - fn coeffs(&self) -> &[F]; fn evaluate(&self, point: &Self::Point) -> F; @@ -23,4 +18,10 @@ pub trait Polynomial: #[cfg(any(test, feature = "benchmark"))] fn rand_point(k: usize, rng: impl rand::RngCore) -> Self::Point; + + #[cfg(any(test, feature = "benchmark"))] + fn squeeze_point( + k: usize, + transcript: &mut impl crate::util::transcript::FieldTranscript, + ) -> Self::Point; } diff --git a/plonkish_backend/src/poly/multilinear.rs b/plonkish_backend/src/poly/multilinear.rs index 539ffff4..a7b24337 100644 --- a/plonkish_backend/src/poly/multilinear.rs +++ b/plonkish_backend/src/poly/multilinear.rs @@ -1,9 +1,11 @@ use crate::{ + pcs::Additive, poly::Polynomial, util::{ - arithmetic::{div_ceil, usize_from_bits_le, BooleanHypercube, Field}, - expression::Rotation, - impl_index, + arithmetic::{div_ceil, usize_from_bits_le, Field}, + chain, + expression::{rotate::BinaryField, Rotation}, + impl_index, izip_eq, parallel::{num_threads, parallelize, parallelize_iter}, BitIndex, Deserialize, Itertools, Serialize, }, @@ -11,13 +13,13 @@ use crate::{ use num_integer::Integer; use rand::RngCore; use std::{ - borrow::Cow, + borrow::{Borrow, Cow}, iter::{self, Sum}, mem, ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct MultilinearPolynomial { evals: Vec, num_vars: usize, @@ -29,6 +31,18 @@ impl Default for MultilinearPolynomial { } } +impl Additive for MultilinearPolynomial { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, + ) -> Self + where + Self: 'b, + { + izip_eq!(scalars, bases).sum() + } +} + impl MultilinearPolynomial { pub fn new(evals: Vec) -> Self { let num_vars = if evals.is_empty() { @@ -71,21 +85,14 @@ impl MultilinearPolynomial { } impl Polynomial for MultilinearPolynomial { - type Basis = (); type Point = Vec; - fn new(_: (), evals: Vec) -> Self { - Self::new(evals) - } - - fn basis(&self) {} - fn coeffs(&self) -> &[F] { self.evals.as_slice() } fn evaluate(&self, point: &Self::Point) -> F { - MultilinearPolynomial::evaluate(self, point.as_slice()) + MultilinearPolynomial::evaluate(self, point) } #[cfg(any(test, feature = "benchmark"))] @@ -97,6 +104,16 @@ impl Polynomial for MultilinearPolynomial { fn rand_point(k: usize, rng: impl rand::RngCore) -> Self::Point { crate::util::test::rand_vec(k, rng) } + + #[cfg(any(test, feature = "benchmark"))] + fn squeeze_point( + k: usize, + transcript: &mut impl crate::util::transcript::FieldTranscript, + ) -> Self::Point { + iter::repeat_with(|| transcript.squeeze_challenge()) + .take(k) + .collect() + } } impl MultilinearPolynomial { @@ -148,23 +165,7 @@ impl MultilinearPolynomial { pub fn evaluate(&self, x: &[F]) -> F { assert_eq!(x.len(), self.num_vars); - - let mut evals = Cow::Borrowed(self.evals()); - let mut bits = Vec::new(); - let mut buf = Vec::with_capacity(self.evals.len() >> 1); - for x_i in x.iter() { - if x_i == &F::ZERO || x_i == &F::ONE { - bits.push(x_i == &F::ONE); - continue; - } - - let distance = bits.len() + 1; - let skip = usize_from_bits_le(&bits); - merge_in_place(&mut evals, x_i, distance, skip, &mut buf); - bits.clear(); - } - - evals[usize_from_bits_le(&bits)] + evaluate(&self.evals, x) } pub fn fix_last_vars(&self, x: &[F]) -> Self { @@ -275,18 +276,19 @@ impl MultilinearPolynomial { } } -impl<'lhs, 'rhs, F: Field> Add<&'rhs MultilinearPolynomial> for &'lhs MultilinearPolynomial { +impl>> Add

for &MultilinearPolynomial { type Output = MultilinearPolynomial; - fn add(self, rhs: &'rhs MultilinearPolynomial) -> MultilinearPolynomial { + fn add(self, rhs: P) -> MultilinearPolynomial { let mut output = self.clone(); output += rhs; output } } -impl<'rhs, F: Field> AddAssign<&'rhs MultilinearPolynomial> for MultilinearPolynomial { - fn add_assign(&mut self, rhs: &'rhs MultilinearPolynomial) { +impl>> AddAssign

for MultilinearPolynomial { + fn add_assign(&mut self, rhs: P) { + let rhs = rhs.borrow(); match (self.is_empty(), rhs.is_empty()) { (_, true) => {} (true, false) => *self = rhs.clone(), @@ -303,10 +305,11 @@ impl<'rhs, F: Field> AddAssign<&'rhs MultilinearPolynomial> for MultilinearPo } } -impl<'rhs, F: Field> AddAssign<(&'rhs F, &'rhs MultilinearPolynomial)> +impl, P: Borrow>> AddAssign<(BF, P)> for MultilinearPolynomial { - fn add_assign(&mut self, (scalar, rhs): (&'rhs F, &'rhs MultilinearPolynomial)) { + fn add_assign(&mut self, (scalar, rhs): (BF, P)) { + let (scalar, rhs) = (scalar.borrow(), rhs.borrow()); match (self.is_empty(), rhs.is_empty() | (scalar == &F::ZERO)) { (_, true) => {} (true, false) => { @@ -332,18 +335,19 @@ impl<'rhs, F: Field> AddAssign<(&'rhs F, &'rhs MultilinearPolynomial)> } } -impl<'lhs, 'rhs, F: Field> Sub<&'rhs MultilinearPolynomial> for &'lhs MultilinearPolynomial { +impl>> Sub

for &MultilinearPolynomial { type Output = MultilinearPolynomial; - fn sub(self, rhs: &'rhs MultilinearPolynomial) -> MultilinearPolynomial { + fn sub(self, rhs: P) -> MultilinearPolynomial { let mut output = self.clone(); output -= rhs; output } } -impl<'rhs, F: Field> SubAssign<&'rhs MultilinearPolynomial> for MultilinearPolynomial { - fn sub_assign(&mut self, rhs: &'rhs MultilinearPolynomial) { +impl>> SubAssign

for MultilinearPolynomial { + fn sub_assign(&mut self, rhs: P) { + let rhs = rhs.borrow(); match (self.is_empty(), rhs.is_empty()) { (_, true) => {} (true, false) => { @@ -363,26 +367,27 @@ impl<'rhs, F: Field> SubAssign<&'rhs MultilinearPolynomial> for MultilinearPo } } -impl<'rhs, F: Field> SubAssign<(&'rhs F, &'rhs MultilinearPolynomial)> +impl, P: Borrow>> SubAssign<(BF, P)> for MultilinearPolynomial { - fn sub_assign(&mut self, (scalar, rhs): (&'rhs F, &'rhs MultilinearPolynomial)) { - *self += (&-*scalar, rhs); + fn sub_assign(&mut self, (scalar, rhs): (BF, P)) { + *self += (-*scalar.borrow(), rhs); } } -impl<'lhs, 'rhs, F: Field> Mul<&'rhs F> for &'lhs MultilinearPolynomial { +impl> Mul for &MultilinearPolynomial { type Output = MultilinearPolynomial; - fn mul(self, rhs: &'rhs F) -> MultilinearPolynomial { + fn mul(self, rhs: BF) -> MultilinearPolynomial { let mut output = self.clone(); output *= rhs; output } } -impl<'rhs, F: Field> MulAssign<&'rhs F> for MultilinearPolynomial { - fn mul_assign(&mut self, rhs: &'rhs F) { +impl> MulAssign for MultilinearPolynomial { + fn mul_assign(&mut self, rhs: BF) { + let rhs = rhs.borrow(); if rhs == &F::ZERO { self.evals = vec![F::ZERO; self.evals.len()] } else if rhs == &-F::ONE { @@ -401,46 +406,61 @@ impl<'rhs, F: Field> MulAssign<&'rhs F> for MultilinearPolynomial { } } -impl<'a, F: Field> Sum<&'a MultilinearPolynomial> for MultilinearPolynomial { - fn sum>>( - mut iter: I, - ) -> MultilinearPolynomial { +impl>> Sum

for MultilinearPolynomial { + fn sum>(mut iter: I) -> MultilinearPolynomial { let init = match (iter.next(), iter.next()) { - (Some(lhs), Some(rhs)) => lhs + rhs, - (Some(lhs), None) => return lhs.clone(), - _ => unreachable!(), + (Some(lhs), Some(rhs)) => lhs.borrow() + rhs.borrow(), + (Some(lhs), None) => return lhs.borrow().clone(), + _ => return Self::zero(), }; iter.fold(init, |mut acc, poly| { - acc += poly; + acc += poly.borrow(); acc }) } } -impl Sum> for MultilinearPolynomial { - fn sum>>(iter: I) -> MultilinearPolynomial { - iter.reduce(|mut acc, poly| { - acc += &poly; +impl, P: Borrow>> Sum<(BF, P)> + for MultilinearPolynomial +{ + fn sum>(mut iter: I) -> MultilinearPolynomial { + let init = match iter.next() { + Some((scalar, poly)) => { + let mut poly = poly.borrow().clone(); + poly *= scalar.borrow(); + poly + } + _ => return Self::zero(), + }; + iter.fold(init, |mut acc, (scalar, poly)| { + acc += (scalar.borrow(), poly.borrow()); acc }) - .unwrap() } } -impl Sum<(F, MultilinearPolynomial)> for MultilinearPolynomial { - fn sum)>>( - mut iter: I, - ) -> MultilinearPolynomial { - let (scalar, mut poly) = iter.next().unwrap(); - poly *= &scalar; - iter.fold(poly, |mut acc, (scalar, poly)| { - acc += (&scalar, &poly); - acc - }) +impl_index!(MultilinearPolynomial, evals); + +pub(crate) fn evaluate(evals: &[F], x: &[F]) -> F { + assert_eq!(1 << x.len(), evals.len()); + + let mut evals = Cow::Borrowed(evals); + let mut bits = Vec::new(); + let mut buf = Vec::with_capacity(evals.len() >> 1); + for x_i in x.iter() { + if x_i == &F::ZERO || x_i == &F::ONE { + bits.push(x_i == &F::ONE); + continue; + } + + let distance = bits.len() + 1; + let skip = usize_from_bits_le(&bits); + merge_in_place(&mut evals, x_i, distance, skip, &mut buf); + bits.clear(); } -} -impl_index!(MultilinearPolynomial, evals); + evals[usize_from_bits_le(&bits)] +} pub fn rotation_eval(x: &[F], rotation: Rotation, evals_for_rotation: &[F]) -> F { if rotation == Rotation::cur() { @@ -501,16 +521,17 @@ pub fn rotation_eval_points(x: &[F], rotation: Rotation) -> Vec pattern .iter() .map(|pat| { - iter::empty() - .chain((0..num_x).map(|idx| { + chain![ + (0..num_x).map(|idx| { if pat.nth_bit(idx) { flipped_x[idx] } else { x[idx] } - })) - .chain((0..distance).map(|idx| bit_to_field(pat.nth_bit(idx + num_x)))) - .collect_vec() + }), + (0..distance).map(|idx| bit_to_field(pat.nth_bit(idx + num_x))) + ] + .collect_vec() }) .collect() } else { @@ -520,16 +541,17 @@ pub fn rotation_eval_points(x: &[F], rotation: Rotation) -> Vec pattern .iter() .map(|pat| { - iter::empty() - .chain((0..distance).map(|idx| bit_to_field(pat.nth_bit(idx)))) - .chain((0..num_x).map(|idx| { + chain![ + (0..distance).map(|idx| bit_to_field(pat.nth_bit(idx))), + (0..num_x).map(|idx| { if pat.nth_bit(idx + distance) { flipped_x[idx] } else { x[idx] } - })) - .collect_vec() + }) + ] + .collect_vec() }) .collect() } @@ -539,8 +561,8 @@ pub(crate) fn rotation_eval_point_pattern( num_vars: usize, distance: usize, ) -> Vec { - let bh = BooleanHypercube::new(num_vars); - let remainder = if NEXT { bh.primitive() } else { bh.x_inv() }; + let bf = BinaryField::new(num_vars); + let remainder = if NEXT { bf.primitive() } else { bf.x_inv() }; let mut pattern = vec![0; 1 << distance]; for depth in 0..distance { for (e, o) in zip_self!(0..pattern.len(), 1 << (distance - depth)) { @@ -560,11 +582,11 @@ pub(crate) fn rotation_eval_coeff_pattern( num_vars: usize, distance: usize, ) -> Vec { - let bh = BooleanHypercube::new(num_vars); + let bf = BinaryField::new(num_vars); let remainder = if NEXT { - bh.primitive() - (1 << num_vars) + bf.primitive() - (1 << num_vars) } else { - bh.x_inv() << distance + bf.x_inv() << distance }; let mut pattern = vec![0; 1 << (distance - 1)]; for depth in 0..distance - 1 { @@ -651,8 +673,11 @@ mod test { use crate::{ poly::multilinear::{rotation_eval, zip_self, MultilinearPolynomial}, util::{ - arithmetic::{BooleanHypercube, Field}, - expression::Rotation, + arithmetic::Field, + expression::{ + rotate::{BinaryField, Rotatable}, + Rotation, + }, test::rand_vec, Itertools, }, @@ -672,8 +697,8 @@ mod test { #[test] fn fix_var() { let rand_x_i = || match OsRng.next_u32() % 3 { - 0 => Fr::zero(), - 1 => Fr::one(), + 0 => Fr::ZERO, + 1 => Fr::ONE, 2 => Fr::random(OsRng), _ => unreachable!(), }; @@ -691,11 +716,11 @@ mod test { #[test] fn evaluate_for_rotation() { let mut rng = OsRng; - for num_vars in 0..16 { - let bh = BooleanHypercube::new(num_vars); + for num_vars in 1..16 { + let bf = BinaryField::new(num_vars); let rotate = |f: &Vec| { (0..1 << num_vars) - .map(|idx| f[bh.rotate(idx, Rotation::next())]) + .map(|idx| f[bf.rotate(idx, Rotation::next())]) .collect_vec() }; let f = rand_vec(1 << num_vars, &mut rng); diff --git a/plonkish_backend/src/poly/univariate.rs b/plonkish_backend/src/poly/univariate.rs index 23322b64..c03728f5 100644 --- a/plonkish_backend/src/poly/univariate.rs +++ b/plonkish_backend/src/poly/univariate.rs @@ -1,8 +1,9 @@ use crate::{ + pcs::Additive, poly::{univariate::UnivariateBasis::*, Polynomial}, util::{ arithmetic::{div_ceil, horner, powers, Field}, - impl_index, + impl_index, izip_eq, parallel::{num_threads, parallelize, parallelize_iter}, Deserialize, Itertools, Serialize, }, @@ -34,6 +35,18 @@ impl Default for UnivariatePolynomial { } } +impl Additive for UnivariatePolynomial { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, + ) -> Self + where + Self: 'b, + { + izip_eq!(scalars, bases).sum() + } +} + impl UnivariatePolynomial { pub const fn zero() -> Self { Self { @@ -70,21 +83,8 @@ impl UnivariatePolynomial { } impl Polynomial for UnivariatePolynomial { - type Basis = UnivariateBasis; type Point = F; - fn new(basis: UnivariateBasis, coeffs: Vec) -> Self { - let mut poly = Self { basis, coeffs }; - if let Monomial = basis { - poly.truncate_leading_zeros() - } - poly - } - - fn basis(&self) -> Self::Basis { - self.basis - } - fn coeffs(&self) -> &[F] { &self.coeffs } @@ -95,31 +95,61 @@ impl Polynomial for UnivariatePolynomial { #[cfg(any(test, feature = "benchmark"))] fn rand(n: usize, rng: impl rand::RngCore) -> Self { - Self::new(Monomial, crate::util::test::rand_vec(n, rng)) + Self::monomial(crate::util::test::rand_vec(n, rng)) } #[cfg(any(test, feature = "benchmark"))] fn rand_point(_: usize, rng: impl rand::RngCore) -> F { F::random(rng) } + + #[cfg(any(test, feature = "benchmark"))] + fn squeeze_point( + _: usize, + transcript: &mut impl crate::util::transcript::FieldTranscript, + ) -> Self::Point { + transcript.squeeze_challenge() + } } impl UnivariatePolynomial { + pub fn monomial(coeffs: Vec) -> Self { + let mut poly = Self { + basis: Monomial, + coeffs, + }; + poly.truncate_leading_zeros(); + poly + } + + pub fn lagrange(coeffs: Vec) -> Self { + assert!(coeffs.len().is_power_of_two()); + + Self { + basis: Lagrange, + coeffs, + } + } + pub fn vanishing<'a>(points: impl IntoIterator, scalar: F) -> Self { let points = points.into_iter().collect_vec(); assert!(!points.is_empty()); let mut buf; - let mut basis = vec![F::ZERO; points.len() + 1]; - *basis.last_mut().unwrap() = scalar; + let mut coeffs = vec![F::ZERO; points.len() + 1]; + *coeffs.last_mut().unwrap() = scalar; for (point, len) in points.into_iter().zip(2..) { buf = scalar; - for idx in (0..basis.len() - 1).rev().take(len) { - buf = basis[idx] - buf * point; - mem::swap(&mut buf, &mut basis[idx]); + for idx in (0..coeffs.len() - 1).rev().take(len) { + buf = coeffs[idx] - buf * point; + mem::swap(&mut buf, &mut coeffs[idx]); } } - Self::new(Monomial, basis) + Self::monomial(coeffs) + } + + pub fn basis(&self) -> UnivariateBasis { + self.basis } pub fn evaluate(&self, x: &F) -> F { @@ -165,7 +195,7 @@ impl UnivariatePolynomial { } remainder.truncate_leading_zeros(); } - (Self::new(Monomial, quotient), remainder) + (Self::monomial(quotient), remainder) } } } @@ -193,144 +223,183 @@ impl Neg for UnivariatePolynomial { } } -impl<'lhs, 'rhs, F: Field> Add<&'rhs UnivariatePolynomial> for &'lhs UnivariatePolynomial { +impl>> Add

for &UnivariatePolynomial { type Output = UnivariatePolynomial; - fn add(self, rhs: &'rhs UnivariatePolynomial) -> UnivariatePolynomial { - assert_eq!(self.basis, rhs.basis); - + fn add(self, rhs: P) -> UnivariatePolynomial { let mut output = self.clone(); output += rhs; output } } -impl<'rhs, F: Field> AddAssign<&'rhs UnivariatePolynomial> for UnivariatePolynomial { - fn add_assign(&mut self, rhs: &'rhs UnivariatePolynomial) { - assert_eq!(self.basis, Monomial); - assert_eq!(rhs.basis, Monomial); +impl>> AddAssign

for UnivariatePolynomial { + fn add_assign(&mut self, rhs: P) { + let rhs = rhs.borrow(); + assert_eq!(self.basis, rhs.basis); - match self.degree().cmp(&rhs.degree()) { - Less => { - parallelize(&mut self.coeffs, |(lhs, start)| { - for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { - *lhs += rhs; + match self.basis { + Monomial => match self.degree().cmp(&rhs.degree()) { + Less => { + parallelize(&mut self.coeffs, |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { + *lhs += rhs; + } + }); + self.coeffs + .extend(rhs[self.coeffs().len()..].iter().cloned()); + } + ord @ (Greater | Equal) => { + parallelize(&mut self[..rhs.coeffs().len()], |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { + *lhs += rhs; + } + }); + if matches!(ord, Equal) { + self.truncate_leading_zeros(); } - }); - self.coeffs - .extend(rhs[self.coeffs().len()..].iter().cloned()); - } - ord @ (Greater | Equal) => { - parallelize(&mut self[..rhs.coeffs().len()], |(lhs, start)| { + } + }, + Lagrange => { + assert_eq!(self.coeffs.len(), rhs.coeffs.len()); + + parallelize(&mut self.coeffs, |(lhs, start)| { for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { *lhs += rhs; } }); - if matches!(ord, Equal) { - self.truncate_leading_zeros(); - } } } } } -impl<'rhs, F: Field> AddAssign<(&'rhs F, &'rhs UnivariatePolynomial)> +impl, P: Borrow>> AddAssign<(BF, P)> for UnivariatePolynomial { - fn add_assign(&mut self, (scalar, rhs): (&'rhs F, &'rhs UnivariatePolynomial)) { - assert_eq!(self.basis, Monomial); - assert_eq!(rhs.basis, Monomial); + fn add_assign(&mut self, (scalar, rhs): (BF, P)) { + let (scalar, rhs) = (scalar.borrow(), rhs.borrow()); + assert_eq!(self.basis, rhs.basis); if scalar == &F::ONE { *self += rhs; } else if scalar == &-F::ONE { *self -= rhs; } else if scalar != &F::ZERO { - match self.degree().cmp(&rhs.degree()) { - Less => { - parallelize(&mut self.coeffs, |(lhs, start)| { - for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { - *lhs += *rhs * scalar; + match self.basis { + Monomial => match self.degree().cmp(&rhs.degree()) { + Less => { + parallelize(&mut self.coeffs, |(lhs, start)| { + let scalar = *scalar; + for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { + *lhs += scalar * rhs; + } + }); + self.coeffs + .extend(rhs[self.coeffs().len()..].iter().map(|rhs| *rhs * scalar)); + } + ord @ (Greater | Equal) => { + parallelize(&mut self[..rhs.coeffs().len()], |(lhs, start)| { + let scalar = *scalar; + for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { + *lhs += scalar * rhs; + } + }); + if matches!(ord, Equal) { + self.truncate_leading_zeros(); } - }); - self.coeffs - .extend(rhs[self.coeffs().len()..].iter().map(|rhs| *rhs * scalar)); - } - ord @ (Greater | Equal) => { - parallelize(&mut self[..rhs.coeffs().len()], |(lhs, start)| { + } + }, + Lagrange => { + assert_eq!(self.coeffs.len(), rhs.coeffs.len()); + + parallelize(&mut self.coeffs, |(lhs, start)| { + let scalar = *scalar; for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { - *lhs += *rhs * scalar; + *lhs += scalar * rhs; } }); - if matches!(ord, Equal) { - self.truncate_leading_zeros(); - } } } } } } -impl<'lhs, 'rhs, F: Field> Sub<&'rhs UnivariatePolynomial> for &'lhs UnivariatePolynomial { +impl>> Sub

for &UnivariatePolynomial { type Output = UnivariatePolynomial; - fn sub(self, rhs: &'rhs UnivariatePolynomial) -> UnivariatePolynomial { - assert_eq!(self.basis, Monomial); - assert_eq!(rhs.basis, Monomial); - + fn sub(self, rhs: P) -> UnivariatePolynomial { let mut output = self.clone(); output -= rhs; output } } -impl<'rhs, F: Field> SubAssign<&'rhs UnivariatePolynomial> for UnivariatePolynomial { - fn sub_assign(&mut self, rhs: &'rhs UnivariatePolynomial) { - assert_eq!(self.basis, Monomial); - assert_eq!(rhs.basis, Monomial); +impl>> SubAssign

for UnivariatePolynomial { + fn sub_assign(&mut self, rhs: P) { + let rhs = rhs.borrow(); + assert_eq!(self.basis, rhs.basis); + + match self.basis { + Monomial => match self.degree().cmp(&rhs.degree()) { + Less => { + parallelize(&mut self.coeffs, |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { + *lhs -= rhs; + } + }); + self.coeffs + .extend(rhs[self.coeffs().len()..].iter().cloned().map(Neg::neg)); + } + ord @ (Greater | Equal) => { + parallelize(&mut self[..rhs.coeffs().len()], |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { + *lhs -= rhs; + } + }); + if matches!(ord, Equal) { + self.truncate_leading_zeros(); + } + } + }, + Lagrange => { + assert_eq!(self.coeffs.len(), rhs.coeffs.len()); - match self.degree().cmp(&rhs.degree()) { - Less => { parallelize(&mut self.coeffs, |(lhs, start)| { for (lhs, rhs) in lhs.iter_mut().zip(rhs[start..].iter()) { *lhs -= rhs; } }); - self.coeffs - .extend(rhs[self.coeffs().len()..].iter().cloned().map(Neg::neg)); - } - ord @ (Greater | Equal) => { - parallelize(&mut self[..rhs.coeffs().len()], |(lhs, start)| { - for (lhs, rhs) in lhs.iter_mut().zip_eq(rhs[start..].iter()) { - *lhs -= rhs; - } - }); - if matches!(ord, Equal) { - self.truncate_leading_zeros(); - } } } } } -impl<'lhs, 'rhs, F: Field> Mul<&'rhs F> for &'lhs UnivariatePolynomial { - type Output = UnivariatePolynomial; +impl, P: Borrow>> SubAssign<(BF, P)> + for UnivariatePolynomial +{ + fn sub_assign(&mut self, (scalar, rhs): (BF, P)) { + *self += (-*scalar.borrow(), rhs); + } +} - fn mul(self, rhs: &'rhs F) -> UnivariatePolynomial { - assert_eq!(self.basis, Monomial); +impl> Mul for &UnivariatePolynomial { + type Output = UnivariatePolynomial; + fn mul(self, rhs: BF) -> UnivariatePolynomial { let mut output = self.clone(); output *= rhs; output } } -impl<'rhs, F: Field> MulAssign<&'rhs F> for UnivariatePolynomial { - fn mul_assign(&mut self, rhs: &'rhs F) { - assert_eq!(self.basis, Monomial); - +impl> MulAssign for UnivariatePolynomial { + fn mul_assign(&mut self, rhs: BF) { + let rhs = rhs.borrow(); if rhs == &F::ZERO { - self.coeffs.clear(); + match self.basis { + Monomial => self.coeffs.clear(), + Lagrange => self.coeffs.fill(F::ZERO), + } } else if rhs != &F::ONE { parallelize(&mut self.coeffs, |(lhs, _)| { for lhs in lhs.iter_mut() { @@ -341,44 +410,34 @@ impl<'rhs, F: Field> MulAssign<&'rhs F> for UnivariatePolynomial { } } -impl<'a, F: Field> Sum<&'a UnivariatePolynomial> for UnivariatePolynomial { - fn sum>>( - mut iter: I, - ) -> UnivariatePolynomial { +impl>> Sum

for UnivariatePolynomial { + fn sum>(mut iter: I) -> UnivariatePolynomial { let init = match (iter.next(), iter.next()) { - (Some(lhs), Some(rhs)) => lhs + rhs, - (Some(lhs), None) => return lhs.clone(), + (Some(lhs), Some(rhs)) => lhs.borrow() + rhs.borrow(), + (Some(lhs), None) => return lhs.borrow().clone(), _ => return Self::zero(), }; iter.fold(init, |mut acc, poly| { - acc += poly; + acc += poly.borrow(); acc }) } } -impl Sum> for UnivariatePolynomial { - fn sum>>(iter: I) -> UnivariatePolynomial { - iter.reduce(|mut acc, poly| { - acc += &poly; - acc - }) - .unwrap_or_else(Self::zero) - } -} - -impl<'a, F: Field, P: Borrow>> Sum<(&'a F, P)> for UnivariatePolynomial { - fn sum>(mut iter: I) -> UnivariatePolynomial { +impl, P: Borrow>> Sum<(BF, P)> + for UnivariatePolynomial +{ + fn sum>(mut iter: I) -> UnivariatePolynomial { let init = match iter.next() { Some((scalar, poly)) => { let mut poly = poly.borrow().clone(); - poly *= scalar; + poly *= scalar.borrow(); poly } _ => return Self::zero(), }; iter.fold(init, |mut acc, (scalar, poly)| { - acc += (scalar, poly.borrow()); + acc += (scalar.borrow(), poly.borrow()); acc }) } diff --git a/plonkish_backend/src/util.rs b/plonkish_backend/src/util.rs index 30df4dcb..328b0d0b 100644 --- a/plonkish_backend/src/util.rs +++ b/plonkish_backend/src/util.rs @@ -6,7 +6,7 @@ pub mod parallel; mod timer; pub mod transcript; -pub use itertools::{chain, izip, Itertools}; +pub use itertools::{chain, izip, Either, Itertools}; pub use num_bigint::BigUint; pub use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer}; pub use timer::{end_timer, start_timer, start_unit_timer}; diff --git a/plonkish_backend/src/util/arithmetic.rs b/plonkish_backend/src/util/arithmetic.rs index 0f38f8dd..6bea93d2 100644 --- a/plonkish_backend/src/util/arithmetic.rs +++ b/plonkish_backend/src/util/arithmetic.rs @@ -1,4 +1,4 @@ -use crate::util::{parallel::parallelize, BigUint, Itertools}; +use crate::util::{izip_eq, parallel::parallelize, BigUint, Itertools}; use halo2_curves::{ bn256, grumpkin, pairing::{self, MillerLoopResult}, @@ -7,13 +7,11 @@ use halo2_curves::{ use num_integer::Integer; use std::{borrow::Borrow, fmt::Debug, iter}; -mod bh; mod fft; mod msm; -pub use bh::BooleanHypercube; pub use bitvec::field::BitField; -pub use fft::fft; +pub use fft::radix2_fft; pub use halo2_curves::{ group::{ ff::{ @@ -25,7 +23,7 @@ pub use halo2_curves::{ }, Coordinates, CurveAffine, CurveExt, }; -pub use msm::{fixed_base_msm, variable_base_msm, window_size, window_table}; +pub use msm::{fixed_base_msm, variable_base_msm, window_size, window_table, Msm}; pub trait MultiMillerLoop: pairing::MultiMillerLoop + Debug + Sync { fn pairings_product_is_identity(terms: &[(&Self::G1Affine, &Self::G2Prepared)]) -> bool { @@ -103,8 +101,7 @@ pub fn inner_product<'a, 'b, F: Field>( lhs: impl IntoIterator, rhs: impl IntoIterator, ) -> F { - lhs.into_iter() - .zip_eq(rhs.into_iter()) + izip_eq!(lhs, rhs) .map(|(lhs, rhs)| *lhs * rhs) .reduce(|acc, product| acc + product) .unwrap_or_default() @@ -118,19 +115,20 @@ pub fn barycentric_weights(points: &[F]) -> Vec { points .iter() .enumerate() - .filter_map(|(i, point_i)| (i != j).then(|| *point_j - point_i)) + .filter(|(i, _)| i != &j) + .map(|(_, point_i)| *point_j - point_i) .reduce(|acc, value| acc * &value) .unwrap_or(F::ONE) }) .collect_vec(); - weights.iter_mut().batch_invert(); + weights.batch_invert(); weights } pub fn barycentric_interpolate(weights: &[F], points: &[F], evals: &[F], x: &F) -> F { let (coeffs, sum_inv) = { let mut coeffs = points.iter().map(|point| *x - point).collect_vec(); - coeffs.iter_mut().batch_invert(); + coeffs.batch_invert(); coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { *coeff *= weight; }); @@ -203,6 +201,10 @@ pub fn fe_truncated(fe: F, num_bits: usize) -> F { F::from_repr(repr).unwrap() } +pub fn fe_to_bytes(fe: impl Borrow) -> Vec { + fe.borrow().to_repr().as_ref().to_vec() +} + pub fn usize_from_bits_le(bits: &[bool]) -> usize { bits.iter() .rev() diff --git a/plonkish_backend/src/util/arithmetic/fft.rs b/plonkish_backend/src/util/arithmetic/fft.rs index 7ae99506..68583650 100644 --- a/plonkish_backend/src/util/arithmetic/fft.rs +++ b/plonkish_backend/src/util/arithmetic/fft.rs @@ -3,6 +3,7 @@ use crate::util::{ arithmetic::{Field, GroupOpsOwned, ScalarMulOwned}, parallel::{join, num_threads}, + start_timer, }; pub trait FftGroup: @@ -17,7 +18,9 @@ where { } -pub fn fft>(a: &mut [G], omega: Scalar, log_n: usize) { +pub fn radix2_fft>(a: &mut [G], omega: Scalar, log2_n: usize) { + let _timer = start_timer(|| "fft"); + fn bitreverse(mut n: usize, l: usize) -> usize { let mut r = 0; for _ in 0..l { @@ -29,10 +32,10 @@ pub fn fft>(a: &mut [G], omega: Scalar, log_n let log_num_threads = num_threads().ilog2() as usize; let n = a.len(); - assert_eq!(n, 1 << log_n); + assert_eq!(n, 1 << log2_n); for k in 0..n { - let rk = bitreverse(k, log_n); + let rk = bitreverse(k, log2_n); if k < rk { a.swap(rk, k); } @@ -46,10 +49,10 @@ pub fn fft>(a: &mut [G], omega: Scalar, log_n }) .collect(); - if log_n <= log_num_threads { + if log2_n <= log_num_threads { let mut chunk = 2; let mut twiddle_chunk = n / 2; - for _ in 0..log_n { + for _ in 0..log2_n { a.chunks_mut(chunk).for_each(|coeffs| { let (left, right) = coeffs.split_at_mut(chunk / 2); diff --git a/plonkish_backend/src/util/arithmetic/msm.rs b/plonkish_backend/src/util/arithmetic/msm.rs index 728910b7..3ca3d2a2 100644 --- a/plonkish_backend/src/util/arithmetic/msm.rs +++ b/plonkish_backend/src/util/arithmetic/msm.rs @@ -1,9 +1,17 @@ -use crate::util::{ - arithmetic::{div_ceil, field_size, CurveAffine, Field, Group, PrimeField}, - parallel::{num_threads, parallelize, parallelize_iter}, - start_timer, Itertools, +use crate::{ + pcs::Additive, + util::{ + arithmetic::{div_ceil, field_size, CurveAffine, Field, Group, PrimeField}, + chain, izip_eq, + parallel::{num_threads, parallelize, parallelize_iter}, + start_timer, Itertools, + }, +}; +use std::{ + iter::Sum, + mem::size_of, + ops::{Add, Mul, Neg, Sub}, }; -use std::mem::size_of; pub fn window_size(num_scalars: usize) -> usize { if num_scalars < 32 { @@ -179,3 +187,118 @@ fn variable_base_msm_serial( } } } + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Msm<'a, F: Field, T: Additive> { + Scalar(F), + Terms(F, Vec<(F, &'a T)>), +} + +impl<'a, F: Field, T: Additive> Msm<'a, F, T> { + pub fn scalar(scalar: F) -> Self { + Self::Scalar(scalar) + } + + pub fn base(base: &'a T) -> Self { + Self::term(F::ONE, base) + } + + pub fn term(scalar: F, base: &'a T) -> Self { + Self::Terms(F::ZERO, vec![(scalar, base)]) + } + + pub fn evaluate(self) -> (F, T) { + match self { + Msm::Scalar(constant) => (constant, T::default()), + Msm::Terms(constant, terms) => { + let (scalars, bases) = terms.into_iter().unzip::<_, _, Vec<_>, Vec<_>>(); + (constant, T::msm(&scalars, bases)) + } + } + } +} + +impl<'a, F: Field, T: Additive> Default for Msm<'a, F, T> { + fn default() -> Self { + Msm::Terms(F::ZERO, Vec::new()) + } +} + +impl<'a, F: Field, T: Additive> Neg for Msm<'a, F, T> { + type Output = Self; + + fn neg(mut self) -> Self::Output { + match &mut self { + Msm::Scalar(constant) => *constant = -*constant, + Msm::Terms(constant, terms) => { + *constant = -*constant; + terms.iter_mut().for_each(|(scalar, _)| *scalar = -*scalar); + } + } + self + } +} + +impl<'a, F: Field, T: Additive> Add for Msm<'a, F, T> { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Msm::Scalar(lhs), Msm::Scalar(rhs)) => Msm::Scalar(lhs + rhs), + (Msm::Scalar(scalar), Msm::Terms(constant, terms)) + | (Msm::Terms(constant, terms), Msm::Scalar(scalar)) => { + Msm::Terms(constant + scalar, terms) + } + (Msm::Terms(lhs_constant, lhs_terms), Msm::Terms(rhs_constant, rhs_terms)) => { + Msm::Terms( + lhs_constant + rhs_constant, + chain![lhs_terms, rhs_terms].collect(), + ) + } + } + } +} + +impl<'a, F: Field, T: Additive> Sub for Msm<'a, F, T> { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + self + (-rhs) + } +} + +impl<'a, F: Field, T: Additive> Mul for Msm<'a, F, T> { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Msm::Scalar(lhs), Msm::Scalar(rhs)) => Msm::Scalar(lhs * rhs), + (Msm::Scalar(rhs), Msm::Terms(constant, terms)) + | (Msm::Terms(constant, terms), Msm::Scalar(rhs)) => Msm::Terms( + constant * rhs, + chain![terms].map(|(lhs, base)| (lhs * rhs, base)).collect(), + ), + (Msm::Terms(_, _), Msm::Terms(_, _)) => unreachable!(), + } + } +} + +impl<'a, F: Field, T: Additive> Sum for Msm<'a, F, T> { + fn sum>(iter: I) -> Self { + iter.reduce(|acc, item| acc + item).unwrap() + } +} + +impl<'c, F: Field, T: Additive> Additive for Msm<'c, F, T> { + fn msm<'a, 'b>( + scalars: impl IntoIterator, + bases: impl IntoIterator, + ) -> Self + where + Self: 'b, + { + izip_eq!(scalars, bases) + .map(|(scalar, base)| Msm::scalar(*scalar) * base.clone()) + .sum() + } +} diff --git a/plonkish_backend/src/util/code/brakedown.rs b/plonkish_backend/src/util/code/brakedown.rs index 5eeafa00..9e47b434 100644 --- a/plonkish_backend/src/util/code/brakedown.rs +++ b/plonkish_backend/src/util/code/brakedown.rs @@ -6,6 +6,7 @@ use crate::util::{ arithmetic::{horner, steps, Field, PrimeField}, + chain, code::LinearCodes, Deserialize, Itertools, Serialize, }; @@ -207,12 +208,13 @@ pub trait BrakedownSpec: Debug { fn codeword_len(log2_q: usize, n: usize, n_0: usize) -> usize { let (a, b) = Self::dimensions(log2_q, n, n_0); - iter::empty() - .chain(Some(a[0].n)) - .chain(a[..a.len() - 1].iter().map(|a| a.m)) - .chain(Some(b.last().unwrap().n)) - .chain(b.iter().map(|b| b.m)) - .sum() + chain![ + [a[0].n], + a[..a.len() - 1].iter().map(|a| a.m), + [b.last().unwrap().n], + b.iter().map(|b| b.m), + ] + .sum() } fn matrices( diff --git a/plonkish_backend/src/util/expression.rs b/plonkish_backend/src/util/expression.rs index 102cbe94..775b70be 100644 --- a/plonkish_backend/src/util/expression.rs +++ b/plonkish_backend/src/util/expression.rs @@ -1,5 +1,6 @@ use crate::util::{arithmetic::Field, izip, Deserialize, Itertools, Serialize}; use std::{ + borrow::Borrow, collections::BTreeSet, fmt::Debug, io::{self, Cursor}, @@ -7,35 +8,11 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; -pub(crate) mod evaluator; +pub mod evaluator; pub mod relaxed; +pub mod rotate; -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] -pub struct Rotation(pub i32); - -impl Rotation { - pub const fn cur() -> Self { - Rotation(0) - } - - pub const fn prev() -> Self { - Rotation(-1) - } - - pub const fn next() -> Self { - Rotation(1) - } - - pub const fn distance(&self) -> usize { - self.0.unsigned_abs() as usize - } -} - -impl From for Rotation { - fn from(rotation: i32) -> Self { - Self(rotation) - } -} +pub use rotate::Rotation; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct Query { @@ -44,8 +21,11 @@ pub struct Query { } impl Query { - pub fn new(poly: usize, rotation: Rotation) -> Self { - Self { poly, rotation } + pub fn new(poly: usize, rotation: impl Into) -> Self { + Self { + poly, + rotation: rotation.into(), + } } pub fn poly(&self) -> usize { @@ -57,6 +37,12 @@ impl Query { } } +impl> From<(usize, T)> for Query { + fn from((poly, rotation): (usize, T)) -> Self { + Self::new(poly, rotation) + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum CommonPolynomial { Identity, @@ -90,18 +76,18 @@ impl Expression { Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)) } - pub fn distribute_powers<'a>( - exprs: impl IntoIterator + 'a, - base: &Self, - ) -> Self - where - F: 'a, - { - let mut exprs = exprs.into_iter().cloned().collect_vec(); + pub fn distribute_powers( + exprs: impl IntoIterator>, + base: impl Borrow, + ) -> Self { + let exprs = exprs + .into_iter() + .map(|expr| expr.borrow().clone()) + .collect_vec(); match exprs.len() { 0 => unreachable!(), - 1 => exprs.pop().unwrap(), - _ => Expression::DistributePowers(exprs, base.clone().into()), + 1 => exprs.into_iter().next().unwrap(), + _ => Expression::DistributePowers(exprs, base.borrow().clone().into()), } } diff --git a/plonkish_backend/src/util/expression/evaluator.rs b/plonkish_backend/src/util/expression/evaluator.rs index 35318918..d5cf3f60 100644 --- a/plonkish_backend/src/util/expression/evaluator.rs +++ b/plonkish_backend/src/util/expression/evaluator.rs @@ -4,7 +4,8 @@ use crate::util::{ }; use std::{fmt::Debug, ops::Deref}; -pub(crate) mod quotient; +pub mod hadamard; +pub mod quotient; #[derive(Clone, Debug, Default)] pub(crate) struct ExpressionRegistry { diff --git a/plonkish_backend/src/util/expression/evaluator/hadamard.rs b/plonkish_backend/src/util/expression/evaluator/hadamard.rs new file mode 100644 index 00000000..03185f2c --- /dev/null +++ b/plonkish_backend/src/util/expression/evaluator/hadamard.rs @@ -0,0 +1,79 @@ +use crate::util::{ + arithmetic::PrimeField, + expression::{evaluator::ExpressionRegistry, rotate::Rotatable, Expression}, + izip_eq, Itertools, +}; +use std::borrow::Cow; + +#[derive(Clone, Debug)] +pub(crate) struct HadamardEvaluator<'a, F: PrimeField, R: Rotatable + From> { + pub(crate) num_vars: usize, + pub(crate) reg: ExpressionRegistry, + lagranges: Vec, + polys: Vec>, + rotatable: R, +} + +impl<'a, F: PrimeField, R: Rotatable + From> HadamardEvaluator<'a, F, R> { + pub(crate) fn new( + num_vars: usize, + expressions: &[Expression], + polys: impl IntoIterator>, + ) -> Self { + let mut reg = ExpressionRegistry::new(); + for expression in expressions.iter() { + reg.register(expression); + } + assert!(reg.eq_xys().is_empty()); + + let rotatable = R::from(num_vars); + let lagranges = reg + .lagranges() + .iter() + .map(|i| rotatable.nth(*i)) + .collect_vec(); + + Self { + num_vars, + reg, + lagranges, + polys: polys.into_iter().collect(), + rotatable, + } + } + + pub(crate) fn cache(&self) -> Vec { + self.reg.cache() + } + + pub(crate) fn evaluate(&self, evals: &mut [F], cache: &mut [F], b: usize) { + self.evaluate_calculations(cache, b); + izip_eq!(evals, self.reg.indexed_outputs()).for_each(|(eval, idx)| *eval = cache[*idx]) + } + + pub(crate) fn evaluate_and_sum(&self, sums: &mut [F], cache: &mut [F], b: usize) { + self.evaluate_calculations(cache, b); + izip_eq!(sums, self.reg.indexed_outputs()).for_each(|(sum, idx)| *sum += cache[*idx]) + } + + fn evaluate_calculations(&self, cache: &mut [F], b: usize) { + if self.reg.has_identity() { + cache[self.reg.offsets().identity()] = F::from(b as u64); + } + cache[self.reg.offsets().lagranges()..] + .iter_mut() + .zip(&self.lagranges) + .for_each(|(value, i)| *value = if &b == i { F::ONE } else { F::ZERO }); + cache[self.reg.offsets().polys()..] + .iter_mut() + .zip(self.reg.polys()) + .for_each(|(value, (query, _))| { + *value = self.polys[query.poly()][self.rotatable.rotate(b, query.rotation())] + }); + self.reg + .indexed_calculations() + .iter() + .zip(self.reg.offsets().calculations()..) + .for_each(|(calculation, idx)| calculation.calculate(cache, idx)); + } +} diff --git a/plonkish_backend/src/util/expression/evaluator/quotient.rs b/plonkish_backend/src/util/expression/evaluator/quotient.rs index e9ef9d2d..fe050613 100644 --- a/plonkish_backend/src/util/expression/evaluator/quotient.rs +++ b/plonkish_backend/src/util/expression/evaluator/quotient.rs @@ -1,26 +1,31 @@ -#![allow(dead_code)] - use crate::util::{ arithmetic::{ - fft, root_of_unity, root_of_unity_inv, BatchInvert, PrimeField, WithSmallOrderMulGroup, + radix2_fft, root_of_unity, root_of_unity_inv, BatchInvert, WithSmallOrderMulGroup, + }, + chain, + expression::{ + evaluator::ExpressionRegistry, + rotate::{Lexical, Rotatable}, + CommonPolynomial, Expression, Query, Rotation, }, - expression::{evaluator::ExpressionRegistry, CommonPolynomial, Expression, Query, Rotation}, izip, parallel::parallelize, - Itertools, + BitIndex, Itertools, }; use std::{ borrow::Cow, + cmp::Ordering, collections::{BTreeMap, HashMap}, iter, }; #[derive(Clone, Debug)] -pub(crate) struct Domain { +pub struct Radix2Domain { k: usize, n: usize, extended_k: usize, extended_n: usize, + magnification: usize, omega: F, omega_inv: F, extended_omega: F, @@ -35,9 +40,8 @@ pub(crate) struct Domain { extended_n_inv_zeta_inv: F, } -impl> Domain { - pub(crate) fn new(k: usize, degree: usize) -> Self { - let n = 1 << k; +impl> Radix2Domain { + pub fn new(k: usize, degree: usize) -> Self { let quotient_degree = degree.checked_sub(1).unwrap_or_default(); let extended_k = k + quotient_degree.next_power_of_two().ilog2() as usize; let extended_n = 1 << extended_k; @@ -61,9 +65,10 @@ impl> Domain { Self { k, - n, + n: 1 << k, extended_k, extended_n, + magnification: 1 << (extended_k - k), omega, omega_inv, extended_omega, @@ -79,55 +84,104 @@ impl> Domain { } } - pub(crate) fn k(&self) -> usize { + pub fn k(&self) -> usize { self.k } - pub(crate) fn n(&self) -> usize { + pub fn n(&self) -> usize { self.n } - pub(crate) fn n_inv(&self) -> F { + pub fn n_inv(&self) -> F { self.n_inv } - pub(crate) fn extended_k(&self) -> usize { + pub fn extended_k(&self) -> usize { self.extended_k } - pub(crate) fn extended_n(&self) -> usize { + pub fn extended_n(&self) -> usize { self.extended_n } - pub(crate) fn omega(&self) -> F { - self.omega - } - - pub(crate) fn extended_omega(&self) -> F { + pub fn extended_omega(&self) -> F { self.extended_omega } - pub(crate) fn zeta(&self) -> F { + pub fn zeta(&self) -> F { self.zeta } - pub(crate) fn zeta_inv(&self) -> F { + pub fn zeta_inv(&self) -> F { self.zeta_inv } - pub(crate) fn identity(&self) -> Vec { - iter::successors(Some(F::ZETA), move |state| { - Some(self.extended_omega * state) - }) - .take(self.extended_n()) - .collect_vec() + pub fn rotate_point(&self, x: F, rotation: Rotation) -> F { + let rotation = Lexical::new(self.k).rotate(0, rotation); + let rotation = if rotation > self.n >> 1 { + rotation as i32 - self.n as i32 + } else { + rotation as i32 + }; + let omega = match rotation.cmp(&0) { + Ordering::Less => self.omega_inv, + Ordering::Greater => self.omega, + Ordering::Equal => return x, + }; + let exponent = rotation.unsigned_abs() as usize; + let mut scalar = F::ONE; + for nth in (1..=(usize::BITS - exponent.leading_zeros()) as usize).rev() { + if exponent.nth_bit(nth) { + scalar *= omega; + } + scalar = scalar.square(); + } + if exponent.nth_bit(0) { + scalar *= omega; + } + scalar * x } - pub(crate) fn lagrange_to_monomial(&self, buf: Cow<[F]>) -> Vec { + pub fn evaluate( + &self, + expression: &Expression, + evals: &HashMap, + challenges: &[F], + x: F, + ) -> F { + let lagrange = { + let common = (x.pow_vartime([self.n as u64]) - F::ONE) * self.n_inv; + let used_lagrange = expression.used_langrange(); + let mut denoms = chain![&used_lagrange] + .map(|i| x - self.rotate_point(F::ONE, Rotation(*i))) + .collect_vec(); + denoms.batch_invert(); + izip!(used_lagrange, denoms) + .map(|(i, denom)| (i, self.rotate_point(common * denom, Rotation(i)))) + .collect::>() + }; + + expression.evaluate( + &|scalar| scalar, + &|poly| match poly { + CommonPolynomial::Identity => x, + CommonPolynomial::Lagrange(i) => lagrange[&i], + CommonPolynomial::EqXY(_) => unreachable!(), + }, + &|query| evals[&query], + &|idx| challenges[idx], + &|scalar| -scalar, + &|lhs, rhs| lhs + &rhs, + &|lhs, rhs| lhs * &rhs, + &|value, scalar| scalar * value, + ) + } + + pub fn lagrange_to_monomial(&self, buf: Cow<[F]>) -> Vec { assert_eq!(buf.len(), self.n); let mut buf = buf.into_owned(); - fft(&mut buf, self.omega_inv, self.k); + radix2_fft(&mut buf, self.omega_inv, self.k); parallelize(&mut buf, |(buf, _)| { buf.iter_mut().for_each(|buf| *buf *= self.n_inv) @@ -136,11 +190,11 @@ impl> Domain { buf } - pub(crate) fn lagrange_to_extended_lagrange(&self, buf: Cow<[F]>) -> Vec { + pub fn lagrange_to_extended_lagrange(&self, buf: Cow<[F]>) -> Vec { assert_eq!(buf.len(), self.n); let mut buf = buf.into_owned(); - fft(&mut buf, self.omega_inv, self.k); + radix2_fft(&mut buf, self.omega_inv, self.k); let scalars = [self.n_inv, self.n_inv_zeta, self.n_inv_zeta_inv]; parallelize(&mut buf, |(buf, start)| { @@ -149,12 +203,12 @@ impl> Domain { }); buf.resize(self.extended_n, F::ZERO); - fft(&mut buf, self.extended_omega, self.extended_k); + radix2_fft(&mut buf, self.extended_omega, self.extended_k); buf } - pub(crate) fn monomial_to_extended_lagrange(&self, buf: Cow<[F]>) -> Vec { + pub fn monomial_to_extended_lagrange(&self, buf: Cow<[F]>) -> Vec { assert!(buf.len() <= self.n); let mut buf = buf.into_owned(); @@ -168,16 +222,16 @@ impl> Domain { }); buf.resize(self.extended_n, F::ZERO); - fft(&mut buf, self.extended_omega, self.extended_k); + radix2_fft(&mut buf, self.extended_omega, self.extended_k); buf } - pub(crate) fn extended_lagrange_to_monomial(&self, buf: Cow<[F]>) -> Vec { + pub fn extended_lagrange_to_monomial(&self, buf: Cow<[F]>) -> Vec { assert_eq!(buf.len(), self.extended_n); let mut buf = buf.into_owned(); - fft(&mut buf, self.extended_omega_inv, self.extended_k); + radix2_fft(&mut buf, self.extended_omega_inv, self.extended_k); let scalars = [ self.extended_n_inv, @@ -194,8 +248,9 @@ impl> Domain { } #[derive(Clone, Debug)] -pub(crate) struct QuotientEvaluator<'a, F: WithSmallOrderMulGroup<3>> { - domain: &'a Domain, +pub struct QuotientEvaluator<'a, F: WithSmallOrderMulGroup<3>> { + magnification: i32, + extended_n: i32, reg: ExpressionRegistry, eval_idx: usize, identity: Vec, @@ -205,8 +260,8 @@ pub(crate) struct QuotientEvaluator<'a, F: WithSmallOrderMulGroup<3>> { } impl<'a, F: WithSmallOrderMulGroup<3>> QuotientEvaluator<'a, F> { - pub(crate) fn new( - domain: &'a Domain, + pub fn new( + domain: &'a Radix2Domain, expression: &Expression, lagranges: BTreeMap, polys: impl IntoIterator, @@ -219,14 +274,20 @@ impl<'a, F: WithSmallOrderMulGroup<3>> QuotientEvaluator<'a, F> { let identity = reg .has_identity() - .then(|| domain.identity()) + .then(|| { + iter::successors(Some(F::ZETA), move |state| { + Some(domain.extended_omega * state) + }) + .take(domain.extended_n) + .collect_vec() + }) .unwrap_or_default(); let lagranges = reg.lagranges().iter().map(|i| lagranges[i]).collect_vec(); let polys = polys.into_iter().collect_vec(); let vanishing_invs = { - let step = domain.extended_omega.pow([domain.n as u64]); + let step = domain.extended_omega.pow([domain.n() as u64]); let mut vanishing_invs = iter::successors( - Some(match domain.n % 3 { + Some(match domain.n() % 3 { 1 => domain.zeta, 2 => domain.zeta_inv, _ => unreachable!(), @@ -234,14 +295,15 @@ impl<'a, F: WithSmallOrderMulGroup<3>> QuotientEvaluator<'a, F> { |value| Some(step * value), ) .map(|value| value - F::ONE) - .take(1 << (domain.extended_k - domain.k)) + .take(domain.magnification) .collect_vec(); vanishing_invs.batch_invert(); vanishing_invs }; Self { - domain, + magnification: domain.magnification as i32, + extended_n: domain.extended_n as i32, reg, eval_idx, identity, @@ -251,11 +313,11 @@ impl<'a, F: WithSmallOrderMulGroup<3>> QuotientEvaluator<'a, F> { } } - pub(crate) fn cache(&self) -> Vec { + pub fn cache(&self) -> Vec { self.reg.cache() } - pub(crate) fn evaluate(&self, eval: &mut F, cache: &mut [F], row: usize) { + pub fn evaluate(&self, eval: &mut F, cache: &mut [F], row: usize) { if self.reg.has_identity() { cache[self.reg.offsets().identity()] = self.identity[row]; } @@ -282,38 +344,15 @@ impl<'a, F: WithSmallOrderMulGroup<3>> QuotientEvaluator<'a, F> { } fn rotated_row(&self, row: usize, rotation: Rotation) -> usize { - ((row as i32 + rotation.0).rem_euclid(self.domain.extended_n() as i32)) as usize + ((row as i32 + self.magnification * rotation.0).rem_euclid(self.extended_n)) as usize } } -pub(crate) fn evaluate( - expression: &Expression, - _k: usize, - evals: &HashMap, - challenges: &[F], - x: F, -) -> F { - expression.evaluate( - &|scalar| scalar, - &|poly| match poly { - CommonPolynomial::Identity => x, - CommonPolynomial::Lagrange(_i) => unimplemented!(), - CommonPolynomial::EqXY(_) => unreachable!(), - }, - &|query| evals[&query], - &|idx| challenges[idx], - &|scalar| -scalar, - &|lhs, rhs| lhs + &rhs, - &|lhs, rhs| lhs * &rhs, - &|value, scalar| scalar * value, - ) -} - #[cfg(test)] mod test { use crate::util::{ arithmetic::Field, - expression::evaluator::quotient::Domain, + expression::evaluator::quotient::Radix2Domain, test::{rand_vec, seeded_std_rng}, Itertools, }; @@ -325,7 +364,7 @@ mod test { let lagrange = rand_vec::(1 << 16, &mut rng); for (k, degree) in (1..16).cartesian_product(1..9) { - let domain = Domain::new(k, degree); + let domain = Radix2Domain::new(k, degree); let extended = domain.lagrange_to_extended_lagrange((&lagrange[..domain.n()]).into()); let monomial = domain.extended_lagrange_to_monomial(extended.into()); assert_eq!( diff --git a/plonkish_backend/src/util/expression/relaxed.rs b/plonkish_backend/src/util/expression/relaxed.rs index 8642521e..7b3b7a7c 100644 --- a/plonkish_backend/src/util/expression/relaxed.rs +++ b/plonkish_backend/src/util/expression/relaxed.rs @@ -1,5 +1,6 @@ use crate::util::{ arithmetic::PrimeField, + chain, expression::{CommonPolynomial, Expression, Query}, BitIndex, Itertools, }; @@ -69,41 +70,40 @@ pub(crate) fn cross_term_expressions( &|(lhs, expr), rhs| (lhs * rhs, expr), ); for idx in 1usize..(1 << folding_degree) - 1 { - let (scalar, mut polys) = iter::empty() - .chain(iter::repeat(None).take(folding_degree - product.folding_degree())) - .chain(product.foldees.iter().map(Some)) - .enumerate() - .fold( - (Expression::Constant(common_scalar), common_poly.clone()), - |(mut scalar, mut polys), (nth, foldee)| { - let (poly_offset, challenge_offset) = if idx.nth_bit(nth) { - ( - preprocess_poly_indices.len() + folding_poly_indices.len(), - num_challenges + 1, - ) - } else { - (preprocess_poly_indices.len(), 0) - }; - match foldee { - None => { - scalar = - &scalar * Expression::Challenge(challenge_offset + u) - } - Some(Expression::Challenge(challenge)) => { - scalar = &scalar - * Expression::Challenge(challenge_offset + challenge) - } - Some(Expression::Polynomial(query)) => { - let poly = - poly_offset + folding_poly_indices[&query.poly()]; - let query = Query::new(poly, query.rotation()); - polys.push(ExpressionPolynomial::Polynomial(query)); - } - _ => unreachable!(), + let (scalar, mut polys) = chain![ + iter::repeat(None).take(folding_degree - product.folding_degree()), + product.foldees.iter().map(Some), + ] + .enumerate() + .fold( + (Expression::Constant(common_scalar), common_poly.clone()), + |(mut scalar, mut polys), (nth, foldee)| { + let (poly_offset, challenge_offset) = if idx.nth_bit(nth) { + ( + preprocess_poly_indices.len() + folding_poly_indices.len(), + num_challenges + 1, + ) + } else { + (preprocess_poly_indices.len(), 0) + }; + match foldee { + None => { + scalar = &scalar * Expression::Challenge(challenge_offset + u) } - (scalar, polys) - }, - ); + Some(Expression::Challenge(challenge)) => { + scalar = &scalar + * Expression::Challenge(challenge_offset + challenge) + } + Some(Expression::Polynomial(query)) => { + let poly = poly_offset + folding_poly_indices[&query.poly()]; + let query = Query::new(poly, query.rotation()); + polys.push(ExpressionPolynomial::Polynomial(query)); + } + _ => unreachable!(), + } + (scalar, polys) + }, + ); polys.sort_unstable(); scalars[idx.count_ones() as usize - 1] .entry(polys) @@ -186,11 +186,7 @@ pub(crate) fn products( .map(|(lhs, rhs)| { Product::new( &lhs.preprocess * &rhs.preprocess, - iter::empty() - .chain(&lhs.foldees) - .chain(&rhs.foldees) - .cloned() - .collect(), + chain![&lhs.foldees, &rhs.foldees].cloned().collect(), ) }) .collect_vec() diff --git a/plonkish_backend/src/util/expression/rotate.rs b/plonkish_backend/src/util/expression/rotate.rs new file mode 100644 index 00000000..e5ff739f --- /dev/null +++ b/plonkish_backend/src/util/expression/rotate.rs @@ -0,0 +1,106 @@ +use crate::util::{Deserialize, Serialize}; +use std::{fmt::Debug, ops::Neg}; + +mod binary_field; +mod lexical; + +pub use binary_field::BinaryField; +pub use lexical::Lexical; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct Rotation(pub i32); + +impl Rotation { + pub const fn cur() -> Self { + Rotation(0) + } + + pub const fn prev() -> Self { + Rotation(-1) + } + + pub const fn next() -> Self { + Rotation(1) + } + + pub const fn distance(&self) -> usize { + self.0.unsigned_abs() as usize + } + + pub fn positive(&self, n: usize) -> Rotation { + Rotation(self.0.rem_euclid(n as i32)) + } +} + +impl From for Rotation { + fn from(rotation: i32) -> Self { + Self(rotation) + } +} + +impl Neg for Rotation { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(-self.0) + } +} + +pub trait Rotatable: 'static + Debug + Send + Sync { + /// Return `self.n().ilog2()` + fn k(&self) -> usize; + + /// Return `self.usable_indices().next_power_of_two()` + fn n(&self) -> usize; + + /// Return usable indices that are cyclic + fn usable_indices(&self) -> Vec; + + /// Return maximum rotation the implementation supports + fn max_rotation(&self) -> usize; + + /// Rotate `idx` by `rotation`` + fn rotate(&self, idx: usize, rotation: Rotation) -> usize; + + /// Return a map from `idx` to `self.rotate(idx, rotation)` + fn rotation_map(&self, rotation: Rotation) -> Vec; + + /// Return `self.usable_indices()[nth]` + fn nth(&self, nth: i32) -> usize; +} + +impl Rotatable for usize { + fn k(&self) -> usize { + *self + } + + fn n(&self) -> usize { + 1 << self + } + + fn usable_indices(&self) -> Vec { + (0..1 << self).collect() + } + + fn max_rotation(&self) -> usize { + 0 + } + + fn rotate(&self, idx: usize, rotation: Rotation) -> usize { + if rotation.0 == 0 { + return idx; + } + unreachable!() + } + + fn rotation_map(&self, rotation: Rotation) -> Vec { + if rotation.0 == 0 { + return self.usable_indices(); + } + unreachable!() + } + + fn nth(&self, nth: i32) -> usize { + nth.rem_euclid(1 << self) as usize + } +} diff --git a/plonkish_backend/src/util/arithmetic/bh.rs b/plonkish_backend/src/util/expression/rotate/binary_field.rs similarity index 75% rename from plonkish_backend/src/util/arithmetic/bh.rs rename to plonkish_backend/src/util/expression/rotate/binary_field.rs index 9988af66..da7001f2 100644 --- a/plonkish_backend/src/util/arithmetic/bh.rs +++ b/plonkish_backend/src/util/expression/rotate/binary_field.rs @@ -1,4 +1,7 @@ -use crate::util::{expression::Rotation, parallel::par_map_collect}; +use crate::util::{ + expression::{rotate::Rotatable, Rotation}, + parallel::par_map_collect, +}; use std::{cmp::Ordering, iter}; /// Integer representation of primitive polynomial in GF(2). @@ -74,14 +77,15 @@ const X_INVS: [usize; 32] = [ ]; #[derive(Debug, Clone, Copy)] -pub struct BooleanHypercube { +pub struct BinaryField { num_vars: usize, primitive: usize, x_inv: usize, } -impl BooleanHypercube { +impl BinaryField { pub const fn new(num_vars: usize) -> Self { + assert!(num_vars > 0); assert!(num_vars < 32); Self { num_vars, @@ -102,7 +106,47 @@ impl BooleanHypercube { self.x_inv } - pub fn rotate(&self, mut b: usize, Rotation(rotation): Rotation) -> usize { + pub fn iter(&self) -> impl Iterator + '_ { + iter::once(0) + .chain(iter::successors(Some(1), |b| { + next(*b, self.num_vars, self.primitive).into() + })) + .take(1 << self.num_vars) + } + + pub fn nth_map(&self) -> Vec { + let mut nth_map = vec![0; 1 << self.num_vars]; + for (nth, b) in self.iter().enumerate() { + nth_map[b] = nth; + } + nth_map + } +} + +impl From for BinaryField { + fn from(k: usize) -> Self { + Self::new(k) + } +} + +impl Rotatable for BinaryField { + fn k(&self) -> usize { + self.num_vars + } + + fn n(&self) -> usize { + 1 << self.num_vars + } + + fn usable_indices(&self) -> Vec { + self.iter().skip(1).collect() + } + + fn max_rotation(&self) -> usize { + self.num_vars + } + + fn rotate(&self, mut b: usize, Rotation(rotation): Rotation) -> usize { match rotation.cmp(&0) { Ordering::Equal => {} Ordering::Less => { @@ -119,24 +163,13 @@ impl BooleanHypercube { b } - pub fn iter(&self) -> impl Iterator + '_ { - iter::once(0) - .chain(iter::successors(Some(1), |b| { - next(*b, self.num_vars, self.primitive).into() - })) - .take(1 << self.num_vars) - } - - pub fn nth_map(&self) -> Vec { - let mut nth_map = vec![0; 1 << self.num_vars]; - for (nth, b) in self.iter().enumerate() { - nth_map[b] = nth; - } - nth_map + fn rotation_map(&self, rotation: Rotation) -> Vec { + par_map_collect(0..1 << self.num_vars, |b| self.rotate(b, rotation)) } - pub fn rotation_map(&self, rotation: Rotation) -> Vec { - par_map_collect(0..1 << self.num_vars, |b| self.rotate(b, rotation)) + fn nth(&self, nth: i32) -> usize { + let usable_indices = self.usable_indices(); + usable_indices[nth.rem_euclid(usable_indices.len() as i32) as usize] } } @@ -154,15 +187,18 @@ fn prev(b: usize, x_inv: usize) -> usize { #[cfg(test)] mod test { - use crate::util::{arithmetic::BooleanHypercube, expression::Rotation}; + use crate::util::expression::{ + rotate::{binary_field::BinaryField, Rotatable}, + Rotation, + }; #[test] #[ignore = "cause it takes some minutes to run with release profile"] - fn boolean_hypercube_iter() { - for num_vars in 0..32 { - let bh = BooleanHypercube::new(num_vars); + fn iter() { + for num_vars in 1..32 { + let bf = BinaryField::new(num_vars); let mut set = vec![false; 1 << num_vars]; - for i in bh.iter() { + for i in bf.iter() { assert!(!set[i]); set[i] = true; } @@ -171,11 +207,11 @@ mod test { #[test] #[ignore = "cause it takes some minutes to run with release profile"] - fn boolean_hypercube_prev() { - for num_vars in 0..32 { - let bh = BooleanHypercube::new(num_vars); - for (b, b_next) in bh.iter().skip(1).zip(bh.iter().skip(2).chain(Some(1))) { - assert_eq!(b, bh.rotate(b_next, Rotation::prev())) + fn prev() { + for num_vars in 1..32 { + let bf = BinaryField::new(num_vars); + for (b, b_next) in bf.iter().skip(1).zip(bf.iter().skip(2).chain([1])) { + assert_eq!(b, bf.rotate(b_next, Rotation::prev())) } } } diff --git a/plonkish_backend/src/util/expression/rotate/lexical.rs b/plonkish_backend/src/util/expression/rotate/lexical.rs new file mode 100644 index 00000000..cb7bc7f9 --- /dev/null +++ b/plonkish_backend/src/util/expression/rotate/lexical.rs @@ -0,0 +1,54 @@ +use crate::util::expression::{rotate::Rotatable, Rotation}; + +#[derive(Clone, Copy, Debug)] +pub struct Lexical { + k: usize, + n: usize, +} + +impl Lexical { + pub const fn new(k: usize) -> Self { + assert!(k > 0); + Self { k, n: 1 << k } + } +} + +impl From for Lexical { + fn from(k: usize) -> Self { + Self::new(k) + } +} + +impl Rotatable for Lexical { + fn k(&self) -> usize { + self.k + } + + fn n(&self) -> usize { + self.n + } + + fn usable_indices(&self) -> Vec { + (0..self.n).collect() + } + + fn max_rotation(&self) -> usize { + self.n + } + + fn rotate(&self, idx: usize, rotation: Rotation) -> usize { + (idx as i32 + rotation.0).rem_euclid(self.n as i32) as usize + } + + fn rotation_map(&self, rotation: Rotation) -> Vec { + (0..self.n) + .cycle() + .skip(self.rotate(0, rotation)) + .take(self.n) + .collect() + } + + fn nth(&self, nth: i32) -> usize { + self.rotate(0, Rotation(nth)) + } +} diff --git a/rust-toolchain b/rust-toolchain index 77c582d8..837f16a7 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.67.0 \ No newline at end of file +1.73.0 \ No newline at end of file