Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PrimeField as generic bound across the codebase #67

Merged
merged 3 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 64 additions & 64 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions benches/groth16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};
use ark_circom::{read_zkey, CircomReduction, WitnessCalculator};
use ark_std::rand::thread_rng;

use ark_bn254::Bn254;
use ark_bn254::{Bn254, Fr};
use ark_groth16::Groth16;
use wasmer::Store;

Expand Down Expand Up @@ -39,7 +39,7 @@ fn bench_groth(c: &mut Criterion, num_validators: u32, num_constraints: u32) {
)
.unwrap();
let full_assignment = wtns
.calculate_witness_element::<Bn254, _>(&mut store, inputs, false)
.calculate_witness_element::<Fr, _>(&mut store, inputs, false)
.unwrap();

let mut rng = thread_rng();
Expand Down
25 changes: 13 additions & 12 deletions src/circom/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ark_ec::pairing::Pairing;
use std::{fs::File, path::Path};
use wasmer::Store;

use ark_ff::PrimeField;

use super::{CircomCircuit, R1CS};

use num_bigint::BigInt;
Expand All @@ -14,21 +15,21 @@ use crate::{
use color_eyre::Result;

#[derive(Debug)]
pub struct CircomBuilder<E: Pairing> {
pub cfg: CircomConfig<E>,
pub struct CircomBuilder<F: PrimeField> {
pub cfg: CircomConfig<F>,
pub inputs: HashMap<String, Vec<BigInt>>,
}

// Add utils for creating this from files / directly from bytes
#[derive(Debug)]
pub struct CircomConfig<E: Pairing> {
pub r1cs: R1CS<E>,
pub struct CircomConfig<F: PrimeField> {
pub r1cs: R1CS<F>,
pub wtns: WitnessCalculator,
pub store: Store,
pub sanity_check: bool,
}

impl<E: Pairing> CircomConfig<E> {
impl<F: PrimeField> CircomConfig<F> {
pub fn new(wtns: impl AsRef<Path>, r1cs: impl AsRef<Path>) -> Result<Self> {
let mut store = Store::default();
let wtns = WitnessCalculator::new(&mut store, wtns).unwrap();
Expand Down Expand Up @@ -56,10 +57,10 @@ impl<E: Pairing> CircomConfig<E> {
}
}

impl<E: Pairing> CircomBuilder<E> {
impl<F: PrimeField> CircomBuilder<F> {
/// Instantiates a new builder using the provided WitnessGenerator and R1CS files
/// for your circuit
pub fn new(cfg: CircomConfig<E>) -> Self {
pub fn new(cfg: CircomConfig<F>) -> Self {
Self {
cfg,
inputs: HashMap::new(),
Expand All @@ -74,7 +75,7 @@ impl<E: Pairing> CircomBuilder<E> {

/// Generates an empty circom circuit with no witness set, to be used for
/// generation of the trusted setup parameters
pub fn setup(&self) -> CircomCircuit<E> {
pub fn setup(&self) -> CircomCircuit<F> {
let mut circom = CircomCircuit {
r1cs: self.cfg.r1cs.clone(),
witness: None,
Expand All @@ -88,11 +89,11 @@ impl<E: Pairing> CircomBuilder<E> {

/// Creates the circuit populated with the witness corresponding to the previously
/// provided inputs
pub fn build(mut self) -> Result<CircomCircuit<E>> {
pub fn build(mut self) -> Result<CircomCircuit<F>> {
let mut circom = self.setup();

// calculate the witness
let witness = self.cfg.wtns.calculate_witness_element::<E, _>(
let witness = self.cfg.wtns.calculate_witness_element::<F, _>(
&mut self.cfg.store,
self.inputs,
self.cfg.sanity_check,
Expand All @@ -102,7 +103,7 @@ impl<E: Pairing> CircomBuilder<E> {
// sanity check
debug_assert!({
use ark_relations::r1cs::{ConstraintSynthesizer, ConstraintSystem};
let cs = ConstraintSystem::<E::ScalarField>::new_ref();
let cs = ConstraintSystem::<F>::new_ref();
circom.clone().generate_constraints(cs.clone()).unwrap();
let is_satisfied = cs.is_satisfied().unwrap();
if !is_satisfied {
Expand Down
36 changes: 16 additions & 20 deletions src/circom/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use ark_ec::pairing::Pairing;
use ark_relations::r1cs::{
ConstraintSynthesizer, ConstraintSystemRef, LinearCombination, SynthesisError, Variable,
};

use ark_ff::PrimeField;

use super::R1CS;

use color_eyre::Result;

#[derive(Clone, Debug)]
pub struct CircomCircuit<E: Pairing> {
pub r1cs: R1CS<E>,
pub witness: Option<Vec<E::ScalarField>>,
pub struct CircomCircuit<F: PrimeField> {
pub r1cs: R1CS<F>,
pub witness: Option<Vec<F>>,
}

impl<E: Pairing> CircomCircuit<E> {
pub fn get_public_inputs(&self) -> Option<Vec<E::ScalarField>> {
impl<F: PrimeField> CircomCircuit<F> {
pub fn get_public_inputs(&self) -> Option<Vec<F>> {
match &self.witness {
None => None,
Some(w) => match &self.r1cs.wire_mapping {
Expand All @@ -25,19 +26,16 @@ impl<E: Pairing> CircomCircuit<E> {
}
}

impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
fn generate_constraints(
self,
cs: ConstraintSystemRef<E::ScalarField>,
) -> Result<(), SynthesisError> {
impl<F: PrimeField> ConstraintSynthesizer<F> for CircomCircuit<F> {
fn generate_constraints(self, cs: ConstraintSystemRef<F>) -> Result<(), SynthesisError> {
let witness = &self.witness;
let wire_mapping = &self.r1cs.wire_mapping;

// Start from 1 because Arkworks implicitly allocates One for the first input
for i in 1..self.r1cs.num_inputs {
cs.new_input_variable(|| {
Ok(match witness {
None => E::ScalarField::from(1u32),
None => F::from(1u32),
Some(w) => match wire_mapping {
Some(m) => w[m[i]],
None => w[i],
Expand All @@ -49,7 +47,7 @@ impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
for i in 0..self.r1cs.num_aux {
cs.new_witness_variable(|| {
Ok(match witness {
None => E::ScalarField::from(1u32),
None => F::from(1u32),
Some(w) => match wire_mapping {
Some(m) => w[m[i + self.r1cs.num_inputs]],
None => w[i + self.r1cs.num_inputs],
Expand All @@ -65,12 +63,10 @@ impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
Variable::Witness(index - self.r1cs.num_inputs)
}
};
let make_lc = |lc_data: &[(usize, E::ScalarField)]| {
let make_lc = |lc_data: &[(usize, F)]| {
lc_data.iter().fold(
LinearCombination::<E::ScalarField>::zero(),
|lc: LinearCombination<E::ScalarField>, (index, coeff)| {
lc + (*coeff, make_index(*index))
},
LinearCombination::<F>::zero(),
|lc: LinearCombination<F>, (index, coeff)| lc + (*coeff, make_index(*index)),
)
};

Expand All @@ -90,12 +86,12 @@ impl<E: Pairing> ConstraintSynthesizer<E::ScalarField> for CircomCircuit<E> {
mod tests {
use super::*;
use crate::{CircomBuilder, CircomConfig};
use ark_bn254::{Bn254, Fr};
use ark_bn254::Fr;
use ark_relations::r1cs::ConstraintSystem;

#[tokio::test]
async fn satisfied() {
let cfg = CircomConfig::<Bn254>::new(
let cfg = CircomConfig::<Fr>::new(
"./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs",
)
Expand Down
6 changes: 2 additions & 4 deletions src/circom/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use ark_ec::pairing::Pairing;

pub mod r1cs_reader;
pub use r1cs_reader::{R1CSFile, R1CS};

Expand All @@ -12,5 +10,5 @@ pub use builder::{CircomBuilder, CircomConfig};
mod qap;
pub use qap::CircomReduction;

pub type Constraints<E> = (ConstraintVec<E>, ConstraintVec<E>, ConstraintVec<E>);
pub type ConstraintVec<E> = Vec<(usize, <E as Pairing>::ScalarField)>;
pub type Constraints<F> = (ConstraintVec<F>, ConstraintVec<F>, ConstraintVec<F>);
pub type ConstraintVec<F> = Vec<(usize, F)>;
40 changes: 20 additions & 20 deletions src/circom/r1cs_reader.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! R1CS circom file reader
//! Copied from <https://github.com/poma/zkutil>
//! Spec: <https://github.com/iden3/r1csfile/blob/master/doc/r1cs_bin_format.md>
use ark_ff::PrimeField;
use byteorder::{LittleEndian, ReadBytesExt};
use std::io::{Error, ErrorKind};

use ark_ec::pairing::Pairing;
use ark_serialize::{CanonicalDeserialize, SerializationError, SerializationError::IoError};
use ark_serialize::{SerializationError, SerializationError::IoError};
use ark_std::io::{Read, Seek, SeekFrom};

use std::collections::HashMap;
Expand All @@ -15,16 +15,16 @@ type IoResult<T> = Result<T, SerializationError>;
use super::{ConstraintVec, Constraints};

#[derive(Clone, Debug)]
pub struct R1CS<E: Pairing> {
pub struct R1CS<F> {
pub num_inputs: usize,
pub num_aux: usize,
pub num_variables: usize,
pub constraints: Vec<Constraints<E>>,
pub constraints: Vec<Constraints<F>>,
pub wire_mapping: Option<Vec<usize>>,
}

impl<E: Pairing> From<R1CSFile<E>> for R1CS<E> {
fn from(file: R1CSFile<E>) -> Self {
impl<F: PrimeField> From<R1CSFile<F>> for R1CS<F> {
fn from(file: R1CSFile<F>) -> Self {
let num_inputs = (1 + file.header.n_pub_in + file.header.n_pub_out) as usize;
let num_variables = file.header.n_wires as usize;
let num_aux = num_variables - num_inputs;
Expand All @@ -38,20 +38,20 @@ impl<E: Pairing> From<R1CSFile<E>> for R1CS<E> {
}
}

pub struct R1CSFile<E: Pairing> {
pub struct R1CSFile<F: PrimeField> {
pub version: u32,
pub header: Header,
pub constraints: Vec<Constraints<E>>,
pub constraints: Vec<Constraints<F>>,
pub wire_mapping: Vec<u64>,
}

impl<E: Pairing> R1CSFile<E> {
impl<F: PrimeField> R1CSFile<F> {
/// reader must implement the Seek trait, for example with a Cursor
///
/// ```rust,ignore
/// let reader = BufReader::new(Cursor::new(&data[..]));
/// ```
pub fn new<R: Read + Seek>(mut reader: R) -> IoResult<R1CSFile<E>> {
pub fn new<R: Read + Seek>(mut reader: R) -> IoResult<R1CSFile<F>> {
let mut magic = [0u8; 4];
reader.read_exact(&mut magic)?;
if magic != [0x72, 0x31, 0x63, 0x73] {
Expand Down Expand Up @@ -117,7 +117,7 @@ impl<E: Pairing> R1CSFile<E> {

reader.seek(SeekFrom::Start(*constraint_offset?))?;

let constraints = read_constraints::<&mut R, E>(&mut reader, &header)?;
let constraints = read_constraints::<&mut R, F>(&mut reader, &header)?;

let wire2label_offset = sec_offsets.get(&wire2label_type).ok_or_else(|| {
Error::new(
Expand Down Expand Up @@ -200,29 +200,29 @@ impl Header {
}
}

fn read_constraint_vec<R: Read, E: Pairing>(mut reader: R) -> IoResult<ConstraintVec<E>> {
fn read_constraint_vec<R: Read, F: PrimeField>(mut reader: R) -> IoResult<ConstraintVec<F>> {
let n_vec = reader.read_u32::<LittleEndian>()? as usize;
let mut vec = Vec::with_capacity(n_vec);
for _ in 0..n_vec {
vec.push((
reader.read_u32::<LittleEndian>()? as usize,
E::ScalarField::deserialize_uncompressed(&mut reader)?,
F::deserialize_uncompressed(&mut reader)?,
));
}
Ok(vec)
}

fn read_constraints<R: Read, E: Pairing>(
fn read_constraints<R: Read, F: PrimeField>(
mut reader: R,
header: &Header,
) -> IoResult<Vec<Constraints<E>>> {
) -> IoResult<Vec<Constraints<F>>> {
// todo check section size
let mut vec = Vec::with_capacity(header.n_constraints as usize);
for _ in 0..header.n_constraints {
vec.push((
read_constraint_vec::<&mut R, E>(&mut reader)?,
read_constraint_vec::<&mut R, E>(&mut reader)?,
read_constraint_vec::<&mut R, E>(&mut reader)?,
read_constraint_vec::<&mut R, F>(&mut reader)?,
read_constraint_vec::<&mut R, F>(&mut reader)?,
read_constraint_vec::<&mut R, F>(&mut reader)?,
));
}
Ok(vec)
Expand Down Expand Up @@ -251,7 +251,7 @@ fn read_map<R: Read>(mut reader: R, size: u64, header: &Header) -> IoResult<Vec<
#[cfg(test)]
mod tests {
use super::*;
use ark_bn254::{Bn254, Fr};
use ark_bn254::Fr;
use ark_std::io::{BufReader, Cursor};

#[test]
Expand Down Expand Up @@ -309,7 +309,7 @@ mod tests {
);

let reader = BufReader::new(Cursor::new(&data[..]));
let file = R1CSFile::<Bn254>::new(reader).unwrap();
let file = R1CSFile::<Fr>::new(reader).unwrap();
assert_eq!(file.version, 1);

assert_eq!(file.header.field_size, 32);
Expand Down
10 changes: 5 additions & 5 deletions src/witness/witness_calculator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{fnv, CircomBase, SafeMemory, Wasm};
use ark_ff::PrimeField;
use color_eyre::Result;
use num_bigint::BigInt;
use num_traits::Zero;
Expand Down Expand Up @@ -284,17 +285,16 @@ impl WitnessCalculator {
}

pub fn calculate_witness_element<
E: ark_ec::pairing::Pairing,
F: PrimeField,
I: IntoIterator<Item = (String, Vec<BigInt>)>,
>(
&mut self,
store: &mut Store,
inputs: I,
sanity_check: bool,
) -> Result<Vec<E::ScalarField>> {
use ark_ff::PrimeField;
) -> Result<Vec<F>> {
let modulus = F::MODULUS;
let witness = self.calculate_witness(store, inputs, sanity_check)?;
let modulus = <E::ScalarField as PrimeField>::MODULUS;

// convert it to field elements
use num_traits::Signed;
Expand All @@ -307,7 +307,7 @@ impl WitnessCalculator {
} else {
w.to_biguint().unwrap()
};
E::ScalarField::from(w)
F::from(w)
})
.collect::<Vec<_>>();

Expand Down
4 changes: 2 additions & 2 deletions src/zkey.rs
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ mod tests {
let mut file = File::open(path).unwrap();
let (params, _matrices) = read_zkey(&mut file).unwrap(); // binfile.proving_key().unwrap();

let cfg = CircomConfig::<Bn254>::new(
let cfg = CircomConfig::<Fr>::new(
"./test-vectors/mycircuit.wasm",
"./test-vectors/mycircuit.r1cs",
)
Expand Down Expand Up @@ -896,7 +896,7 @@ mod tests {
let s = ark_bn254::Fr::rand(rng);

let full_assignment = wtns
.calculate_witness_element::<Bn254, _>(&mut store, inputs, false)
.calculate_witness_element::<Fr, _>(&mut store, inputs, false)
.unwrap();
let proof = Groth16::<Bn254, CircomReduction>::create_proof_with_reduction_and_matrices(
&params,
Expand Down
Loading
Loading