diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 1799fc43..ee924a3c 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -35,7 +35,7 @@ jobs: - name: Compare gates reports id: gates_diff - uses: noir-lang/noir-gates-diff@1931aaaa848a1a009363d6115293f7b7fc72bb87 + uses: noir-lang/noir-gates-diff@dbe920a8dcc3370af4be4f702ca9cef29317bec1 with: report: gates_report.json summaryQuantile: 0.9 # only display the 10% most significant circuit size diffs in the summary (defaults to 20%) diff --git a/src/bignum.nr b/src/bignum.nr index 872572b2..6258fb90 100644 --- a/src/bignum.nr +++ b/src/bignum.nr @@ -25,40 +25,40 @@ pub struct BigNum { pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq { // TODO: this crashes the compiler? v0.32 // fn default() -> Self { std::default::Default::default () } - pub fn new() -> Self; + fn new() -> Self; fn zero() -> Self; - pub fn one() -> Self; - pub fn derive_from_seed(seed: [u8; SeedBytes]) -> Self; - pub unconstrained fn __derive_from_seed(seed: [u8; SeedBytes]) -> Self; - pub fn from_slice(limbs: [Field]) -> Self; - pub fn from_be_bytes(x: [u8; NBytes]) -> Self; - pub fn to_le_bytes(self) -> [u8; NBytes]; - - pub fn modulus() -> Self; - pub fn modulus_bits(self) -> u32; - pub fn num_limbs(self) -> u32; - pub fn get_limbs_slice(self) -> [Field]; - pub fn get_limb(self, idx: u32) -> Field; - pub fn set_limb(&mut self, idx: u32, value: Field); - - pub unconstrained fn __eq(self, other: Self) -> bool; - pub unconstrained fn __is_zero(self) -> bool; - - pub unconstrained fn __neg(self) -> Self; - pub unconstrained fn __add(self, other: Self) -> Self; - pub unconstrained fn __sub(self, other: Self) -> Self; - pub unconstrained fn __mul(self, other: Self) -> Self; - pub unconstrained fn __div(self, other: Self) -> Self; - pub unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self); - pub unconstrained fn __invmod(self) -> Self; - pub unconstrained fn __pow(self, exponent: Self) -> Self; - - pub unconstrained fn __batch_invert(to_invert: [Self; M]) -> [Self; M]; - pub unconstrained fn __batch_invert_slice(to_invert: [Self]) -> [Self]; - - pub unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option; - - pub unconstrained fn __compute_quadratic_expression( + fn one() -> Self; + fn derive_from_seed(seed: [u8; SeedBytes]) -> Self; + unconstrained fn __derive_from_seed(seed: [u8; SeedBytes]) -> Self; + fn from_slice(limbs: [Field]) -> Self; + fn from_be_bytes(x: [u8; NBytes]) -> Self; + fn to_le_bytes(self) -> [u8; NBytes]; + + fn modulus() -> Self; + fn modulus_bits(self) -> u32; + fn num_limbs(self) -> u32; + fn get_limbs_slice(self) -> [Field]; + fn get_limb(self, idx: u32) -> Field; + fn set_limb(&mut self, idx: u32, value: Field); + + unconstrained fn __eq(self, other: Self) -> bool; + unconstrained fn __is_zero(self) -> bool; + + unconstrained fn __neg(self) -> Self; + unconstrained fn __add(self, other: Self) -> Self; + unconstrained fn __sub(self, other: Self) -> Self; + unconstrained fn __mul(self, other: Self) -> Self; + unconstrained fn __div(self, other: Self) -> Self; + unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self); + unconstrained fn __invmod(self) -> Self; + unconstrained fn __pow(self, exponent: Self) -> Self; + + unconstrained fn __batch_invert(to_invert: [Self; M]) -> [Self; M]; + unconstrained fn __batch_invert_slice(to_invert: [Self]) -> [Self]; + + unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option; + + unconstrained fn __compute_quadratic_expression( lhs: [[Self; LHS_N]; NUM_PRODUCTS], lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], rhs: [[Self; RHS_N]; NUM_PRODUCTS], @@ -67,7 +67,7 @@ pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq { add_flags: [bool; ADD_N], ) -> (Self, Self); - pub fn evaluate_quadratic_expression( + fn evaluate_quadratic_expression( lhs: [[Self; LHS_N]; NUM_PRODUCTS], lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], rhs: [[Self; RHS_N]; NUM_PRODUCTS], @@ -76,16 +76,16 @@ pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq { add_flags: [bool; ADD_N], ); - pub fn assert_is_not_equal(self, other: Self); - pub fn validate_in_range(self); - pub fn validate_in_field(self); + fn assert_is_not_equal(self, other: Self); + fn validate_in_range(self); + fn validate_in_field(self); - pub fn udiv_mod(self, divisor: Self) -> (Self, Self); - pub fn udiv(self, divisor: Self) -> Self; - pub fn umod(self, divisor: Self) -> Self; + fn udiv_mod(self, divisor: Self) -> (Self, Self); + fn udiv(self, divisor: Self) -> Self; + fn umod(self, divisor: Self) -> Self; - pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self; - pub fn is_zero(self) -> bool; + fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self; + fn is_zero(self) -> bool; } impl std::convert::From for BigNum diff --git a/src/fns/unconstrained_ops.nr b/src/fns/unconstrained_ops.nr index cc7b6647..a9de60df 100644 --- a/src/fns/unconstrained_ops.nr +++ b/src/fns/unconstrained_ops.nr @@ -237,21 +237,20 @@ pub(crate) unconstrained fn __batch_invert [[Field; N]; M] { // TODO: ugly! Will fail if input slice is empty let mut accumulator: [Field; N] = __one::(); - let mut result: [[Field; N]; M] = [[0; N]; M]; let mut temporaries: [[Field; N]; N] = std::mem::zeroed(); - for i in 0..N { + for i in 0..M { temporaries[i] = accumulator; if (!__is_zero(x[i])) { accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[i]); } } + let mut result: [[Field; N]; M] = [[0; N]; M]; accumulator = __invmod::<_, MOD_BITS>(params, accumulator); - let mut T0: [Field; N] = [0; N]; - for i in 0..N { - let idx = N - 1 - i; + for i in 0..M { + let idx = M - 1 - i; if (!__is_zero(x[idx])) { - T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]); + let T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]); accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[idx]); result[idx] = T0; } @@ -265,26 +264,27 @@ pub(crate) unconstrained fn __batch_invert_slice( ) -> [[Field; N]] { // TODO: ugly! Will fail if input slice is empty let mut accumulator: [Field; N] = __one::(); - let mut result: [[Field; N]] = [[0; N]]; - let mut temporaries: [[Field; N]; N] = std::mem::zeroed(); - for i in 0..N { - temporaries[i] = accumulator; + let mut temporaries: [[Field; N]] = &[]; + for i in 0..x.len() { + temporaries = temporaries.push_back(accumulator); if (!__is_zero(x[i])) { accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[i]); } - result = result.push_back([0; N]); } + let mut result: [[Field; N]] = []; accumulator = __invmod::<_, MOD_BITS>(params, accumulator); - let mut T0: [Field; N] = [0; N]; for i in 0..x.len() { let idx = x.len() - 1 - i; - if (__is_zero(x[idx]) == false) { - T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]); + if (!__is_zero(x[idx])) { + let T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]); accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[idx]); - result[idx] = T0; - } + result = result.push_front(T0); + } else { + result = result.push_front([0; N]); + }; } + result } diff --git a/src/params.nr b/src/params.nr index f5beb504..aab9d9f2 100644 --- a/src/params.nr +++ b/src/params.nr @@ -19,7 +19,7 @@ pub struct BigNumParams { // To be implemented by the user for any BigNum they define, or within the predefined BigNums in the `fields/` dir. pub trait BigNumParamsGetter { - pub fn get_params() -> BigNumParams; + fn get_params() -> BigNumParams; } impl BigNumParams { diff --git a/src/runtime_bignum.nr b/src/runtime_bignum.nr index e2743072..ff14eeac 100644 --- a/src/runtime_bignum.nr +++ b/src/runtime_bignum.nr @@ -25,61 +25,58 @@ impl RuntimeBigNum {} // All functions prefixed `__` are unconstrained! // They're not actually decorated as `unconstrained` because to return the `params` (as part of Self) from an `unconstrained` fn would cause range constraints. Instead, each `__` fn wraps a call to an unconstrained fn, so that the already-range-constrained `params` can be inserted into Self after the unconstrained call. pub(crate) trait RuntimeBigNumTrait: Neg + Add + Sub + Mul + Div + Eq { - pub fn new(params: BigNumParams) -> Self; - pub fn one(params: BigNumParams) -> Self; - pub fn derive_from_seed( + fn new(params: BigNumParams) -> Self; + fn one(params: BigNumParams) -> Self; + fn derive_from_seed( params: BigNumParams, seed: [u8; SeedBytes], ) -> Self; - pub unconstrained fn __derive_from_seed( + unconstrained fn __derive_from_seed( params: BigNumParams, seed: [u8; SeedBytes], ) -> Self; - pub fn from_slice(params: BigNumParams, limbs: [Field]) -> Self; - pub fn from_array(params: BigNumParams, limbs: [Field; N]) -> Self; - pub fn from_be_bytes( - params: BigNumParams, - x: [u8; NBytes], - ) -> Self; + fn from_slice(params: BigNumParams, limbs: [Field]) -> Self; + fn from_array(params: BigNumParams, limbs: [Field; N]) -> Self; + fn from_be_bytes(params: BigNumParams, x: [u8; NBytes]) -> Self; - pub fn to_le_bytes(self) -> [u8; NBytes]; + fn to_le_bytes(self) -> [u8; NBytes]; - pub fn modulus(self) -> Self; - pub fn modulus_bits() -> u32; - pub fn num_limbs() -> u32; + fn modulus(self) -> Self; + fn modulus_bits() -> u32; + fn num_limbs() -> u32; // pub fn get(self) -> [Field]; - pub fn get_limbs(self) -> [Field; N]; - pub fn get_limb(self, idx: u32) -> Field; - pub fn set_limb(&mut self, idx: u32, value: Field); + fn get_limbs(self) -> [Field; N]; + fn get_limb(self, idx: u32) -> Field; + fn set_limb(&mut self, idx: u32, value: Field); unconstrained fn __eq(self, other: Self) -> bool; unconstrained fn __is_zero(self) -> bool; // unconstrained - pub fn __neg(self) -> Self; + fn __neg(self) -> Self; // unconstrained - pub fn __add(self, other: Self) -> Self; + fn __add(self, other: Self) -> Self; // unconstrained - pub fn __sub(self, other: Self) -> Self; + fn __sub(self, other: Self) -> Self; // unconstrained - pub fn __mul(self, other: Self) -> Self; + fn __mul(self, other: Self) -> Self; // unconstrained - pub fn __div(self, other: Self) -> Self; + fn __div(self, other: Self) -> Self; // unconstrained - pub fn __udiv_mod(self, divisor: Self) -> (Self, Self); + fn __udiv_mod(self, divisor: Self) -> (Self, Self); // unconstrained - pub fn __invmod(self) -> Self; + fn __invmod(self) -> Self; // unconstrained - pub fn __pow(self, exponent: Self) -> Self; + fn __pow(self, exponent: Self) -> Self; // unconstrained - pub fn __batch_invert(x: [Self; M]) -> [Self; M]; + fn __batch_invert(x: [Self; M]) -> [Self; M]; unconstrained fn __batch_invert_slice(to_invert: [Self]) -> [Self]; - pub fn __tonelli_shanks_sqrt(self) -> std::option::Option; + fn __tonelli_shanks_sqrt(self) -> std::option::Option; // unconstrained - pub fn __compute_quadratic_expression( + fn __compute_quadratic_expression( params: BigNumParams, lhs_terms: [[Self; LHS_N]; NUM_PRODUCTS], lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], @@ -89,7 +86,7 @@ pub(crate) trait RuntimeBigNumTrait: Neg + Add + linear_flags: [bool; ADD_N], ) -> (Self, Self); - pub fn evaluate_quadratic_expression( + fn evaluate_quadratic_expression( params: BigNumParams, lhs_terms: [[Self; LHS_N]; NUM_PRODUCTS], lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS], @@ -99,20 +96,20 @@ pub(crate) trait RuntimeBigNumTrait: Neg + Add + linear_flags: [bool; ADD_N], ); - pub fn eq(lhs: Self, rhs: Self) -> bool { + fn eq(lhs: Self, rhs: Self) -> bool { lhs == rhs } - pub fn assert_is_not_equal(self, other: Self); - pub fn validate_in_field(self); - pub fn validate_in_range(self); + fn assert_is_not_equal(self, other: Self); + fn validate_in_field(self); + fn validate_in_range(self); // pub fn validate_gt(self, lhs: Self, rhs: Self); - pub fn udiv_mod(numerator: Self, divisor: Self) -> (Self, Self); - pub fn udiv(numerator: Self, divisor: Self) -> Self; - pub fn umod(numerator: Self, divisor: Self) -> Self; + fn udiv_mod(numerator: Self, divisor: Self) -> (Self, Self); + fn udiv(numerator: Self, divisor: Self) -> Self; + fn umod(numerator: Self, divisor: Self) -> Self; - pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self; - pub fn is_zero(self) -> bool; + fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self; + fn is_zero(self) -> bool; } impl Neg for RuntimeBigNum { diff --git a/src/tests/bignum_test.nr b/src/tests/bignum_test.nr index 23d29705..e7fec0ce 100644 --- a/src/tests/bignum_test.nr +++ b/src/tests/bignum_test.nr @@ -9,7 +9,6 @@ use crate::params::BigNumParamsGetter; use crate::fields::bls12_381Fq::BLS12_381_Fq_Params; use crate::fields::bn254Fq::BN254_Fq_Params; use crate::fields::U256::U256Params; -use std::bigint::Bn254Fq; struct Test2048Params {} @@ -938,3 +937,39 @@ fn test_to_from_field_2() { assert(a == c); } +unconstrained fn test_batch_inversion(fields: [BN; N]) +where + BN: BigNumTrait, +{ + let inverted_fields = BN::__batch_invert(fields); + for i in 0..N { + assert_eq(fields[i] * inverted_fields[i], BN::one()); + } +} + +#[test] +unconstrained fn test_batch_inversion_BN381(seeds: [[u8; 2]; 3]) { + let fields = seeds.map(|seed| BN381::derive_from_seed(seed)); + unsafe { + test_batch_inversion(fields) + } +} + +unconstrained fn test_batch_inversion_slice(fields: [BN]) +where + BN: BigNumTrait, +{ + let inverted_fields = BN::__batch_invert_slice(fields); + assert_eq(fields.len(), inverted_fields.len()); + for i in 0..fields.len() { + assert_eq(fields[i] * inverted_fields[i], BN::one()); + } +} + +#[test] +unconstrained fn test_batch_inversion_slice_BN381(seeds: [[u8; 2]; 3]) { + let fields = seeds.map(|seed| BN381::derive_from_seed(seed)).as_slice(); + unsafe { + test_batch_inversion_slice(fields) + } +}