Skip to content

Commit

Permalink
fix: correct batch inversion function (#117)
Browse files Browse the repository at this point in the history
Co-authored-by: Khashayar Barooti <[email protected]>
  • Loading branch information
TomAFrench and Khashayar Barooti authored Feb 3, 2025
1 parent 53a4d4a commit 976d3ef
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 99 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:

- name: Compare gates reports
id: gates_diff
uses: noir-lang/noir-gates-diff@1931aaaa848a1a009363d6115293f7b7fc72bb87
uses: noir-lang/noir-gates-diff@dbe920a8dcc3370af4be4f702ca9cef29317bec1
with:
report: gates_report.json
summaryQuantile: 0.9 # only display the 10% most significant circuit size diffs in the summary (defaults to 20%)
Expand Down
84 changes: 42 additions & 42 deletions src/bignum.nr
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,40 @@ pub struct BigNum<let N: u32, let MOD_BITS: u32, Params> {
pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq {
// TODO: this crashes the compiler? v0.32
// fn default() -> Self { std::default::Default::default () }
pub fn new() -> Self;
fn new() -> Self;
fn zero() -> Self;
pub fn one() -> Self;
pub fn derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub unconstrained fn __derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub fn from_slice(limbs: [Field]) -> Self;
pub fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self;
pub fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

pub fn modulus() -> Self;
pub fn modulus_bits(self) -> u32;
pub fn num_limbs(self) -> u32;
pub fn get_limbs_slice(self) -> [Field];
pub fn get_limb(self, idx: u32) -> Field;
pub fn set_limb(&mut self, idx: u32, value: Field);

pub unconstrained fn __eq(self, other: Self) -> bool;
pub unconstrained fn __is_zero(self) -> bool;

pub unconstrained fn __neg(self) -> Self;
pub unconstrained fn __add(self, other: Self) -> Self;
pub unconstrained fn __sub(self, other: Self) -> Self;
pub unconstrained fn __mul(self, other: Self) -> Self;
pub unconstrained fn __div(self, other: Self) -> Self;
pub unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self);
pub unconstrained fn __invmod(self) -> Self;
pub unconstrained fn __pow(self, exponent: Self) -> Self;

pub unconstrained fn __batch_invert<let M: u32>(to_invert: [Self; M]) -> [Self; M];
pub unconstrained fn __batch_invert_slice<let M: u32>(to_invert: [Self]) -> [Self];

pub unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;

pub unconstrained fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn one() -> Self;
fn derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
unconstrained fn __derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
fn from_slice(limbs: [Field]) -> Self;
fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self;
fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

fn modulus() -> Self;
fn modulus_bits(self) -> u32;
fn num_limbs(self) -> u32;
fn get_limbs_slice(self) -> [Field];
fn get_limb(self, idx: u32) -> Field;
fn set_limb(&mut self, idx: u32, value: Field);

unconstrained fn __eq(self, other: Self) -> bool;
unconstrained fn __is_zero(self) -> bool;

unconstrained fn __neg(self) -> Self;
unconstrained fn __add(self, other: Self) -> Self;
unconstrained fn __sub(self, other: Self) -> Self;
unconstrained fn __mul(self, other: Self) -> Self;
unconstrained fn __div(self, other: Self) -> Self;
unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self);
unconstrained fn __invmod(self) -> Self;
unconstrained fn __pow(self, exponent: Self) -> Self;

unconstrained fn __batch_invert<let M: u32>(to_invert: [Self; M]) -> [Self; M];
unconstrained fn __batch_invert_slice<let M: u32>(to_invert: [Self]) -> [Self];

unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;

unconstrained fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
lhs: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
rhs: [[Self; RHS_N]; NUM_PRODUCTS],
Expand All @@ -67,7 +67,7 @@ pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq {
add_flags: [bool; ADD_N],
) -> (Self, Self);

pub fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
lhs: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
rhs: [[Self; RHS_N]; NUM_PRODUCTS],
Expand All @@ -76,16 +76,16 @@ pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq {
add_flags: [bool; ADD_N],
);

pub fn assert_is_not_equal(self, other: Self);
pub fn validate_in_range(self);
pub fn validate_in_field(self);
fn assert_is_not_equal(self, other: Self);
fn validate_in_range(self);
fn validate_in_field(self);

pub fn udiv_mod(self, divisor: Self) -> (Self, Self);
pub fn udiv(self, divisor: Self) -> Self;
pub fn umod(self, divisor: Self) -> Self;
fn udiv_mod(self, divisor: Self) -> (Self, Self);
fn udiv(self, divisor: Self) -> Self;
fn umod(self, divisor: Self) -> Self;

pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
pub fn is_zero(self) -> bool;
fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
fn is_zero(self) -> bool;
}

impl<let N: u32, let MOD_BITS: u32, Params> std::convert::From<Field> for BigNum<N, MOD_BITS, Params>
Expand Down
32 changes: 16 additions & 16 deletions src/fns/unconstrained_ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,20 @@ pub(crate) unconstrained fn __batch_invert<let N: u32, let MOD_BITS: u32, let M:
) -> [[Field; N]; M] {
// TODO: ugly! Will fail if input slice is empty
let mut accumulator: [Field; N] = __one::<N>();
let mut result: [[Field; N]; M] = [[0; N]; M];
let mut temporaries: [[Field; N]; N] = std::mem::zeroed();
for i in 0..N {
for i in 0..M {
temporaries[i] = accumulator;
if (!__is_zero(x[i])) {
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[i]);
}
}

let mut result: [[Field; N]; M] = [[0; N]; M];
accumulator = __invmod::<_, MOD_BITS>(params, accumulator);
let mut T0: [Field; N] = [0; N];
for i in 0..N {
let idx = N - 1 - i;
for i in 0..M {
let idx = M - 1 - i;
if (!__is_zero(x[idx])) {
T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
let T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[idx]);
result[idx] = T0;
}
Expand All @@ -265,26 +264,27 @@ pub(crate) unconstrained fn __batch_invert_slice<let N: u32, let MOD_BITS: u32>(
) -> [[Field; N]] {
// TODO: ugly! Will fail if input slice is empty
let mut accumulator: [Field; N] = __one::<N>();
let mut result: [[Field; N]] = [[0; N]];
let mut temporaries: [[Field; N]; N] = std::mem::zeroed();
for i in 0..N {
temporaries[i] = accumulator;
let mut temporaries: [[Field; N]] = &[];
for i in 0..x.len() {
temporaries = temporaries.push_back(accumulator);
if (!__is_zero(x[i])) {
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[i]);
}
result = result.push_back([0; N]);
}

let mut result: [[Field; N]] = [];
accumulator = __invmod::<_, MOD_BITS>(params, accumulator);
let mut T0: [Field; N] = [0; N];
for i in 0..x.len() {
let idx = x.len() - 1 - i;
if (__is_zero(x[idx]) == false) {
T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
if (!__is_zero(x[idx])) {
let T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[idx]);
result[idx] = T0;
}
result = result.push_front(T0);
} else {
result = result.push_front([0; N]);
};
}

result
}

Expand Down
2 changes: 1 addition & 1 deletion src/params.nr
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct BigNumParams<let N: u32, let MOD_BITS: u32> {

// To be implemented by the user for any BigNum they define, or within the predefined BigNums in the `fields/` dir.
pub trait BigNumParamsGetter<let N: u32, let MOD_BITS: u32> {
pub fn get_params() -> BigNumParams<N, MOD_BITS>;
fn get_params() -> BigNumParams<N, MOD_BITS>;
}

impl<let N: u32, let MOD_BITS: u32> BigNumParams<N, MOD_BITS> {
Expand Down
73 changes: 35 additions & 38 deletions src/runtime_bignum.nr
Original file line number Diff line number Diff line change
Expand Up @@ -25,61 +25,58 @@ impl<let N: u32, let MOD_BITS: u32> RuntimeBigNum<N, MOD_BITS> {}
// All functions prefixed `__` are unconstrained!
// They're not actually decorated as `unconstrained` because to return the `params` (as part of Self) from an `unconstrained` fn would cause range constraints. Instead, each `__` fn wraps a call to an unconstrained fn, so that the already-range-constrained `params` can be inserted into Self after the unconstrained call.
pub(crate) trait RuntimeBigNumTrait<let N: u32, let MOD_BITS: u32>: Neg + Add + Sub + Mul + Div + Eq {
pub fn new(params: BigNumParams<N, MOD_BITS>) -> Self;
pub fn one(params: BigNumParams<N, MOD_BITS>) -> Self;
pub fn derive_from_seed<let SeedBytes: u32>(
fn new(params: BigNumParams<N, MOD_BITS>) -> Self;
fn one(params: BigNumParams<N, MOD_BITS>) -> Self;
fn derive_from_seed<let SeedBytes: u32>(
params: BigNumParams<N, MOD_BITS>,
seed: [u8; SeedBytes],
) -> Self;
pub unconstrained fn __derive_from_seed<let SeedBytes: u32>(
unconstrained fn __derive_from_seed<let SeedBytes: u32>(
params: BigNumParams<N, MOD_BITS>,
seed: [u8; SeedBytes],
) -> Self;
pub fn from_slice(params: BigNumParams<N, MOD_BITS>, limbs: [Field]) -> Self;
pub fn from_array(params: BigNumParams<N, MOD_BITS>, limbs: [Field; N]) -> Self;
pub fn from_be_bytes<let NBytes: u32>(
params: BigNumParams<N, MOD_BITS>,
x: [u8; NBytes],
) -> Self;
fn from_slice(params: BigNumParams<N, MOD_BITS>, limbs: [Field]) -> Self;
fn from_array(params: BigNumParams<N, MOD_BITS>, limbs: [Field; N]) -> Self;
fn from_be_bytes<let NBytes: u32>(params: BigNumParams<N, MOD_BITS>, x: [u8; NBytes]) -> Self;

pub fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];
fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

pub fn modulus(self) -> Self;
pub fn modulus_bits() -> u32;
pub fn num_limbs() -> u32;
fn modulus(self) -> Self;
fn modulus_bits() -> u32;
fn num_limbs() -> u32;
// pub fn get(self) -> [Field];
pub fn get_limbs(self) -> [Field; N];
pub fn get_limb(self, idx: u32) -> Field;
pub fn set_limb(&mut self, idx: u32, value: Field);
fn get_limbs(self) -> [Field; N];
fn get_limb(self, idx: u32) -> Field;
fn set_limb(&mut self, idx: u32, value: Field);

unconstrained fn __eq(self, other: Self) -> bool;
unconstrained fn __is_zero(self) -> bool;

// unconstrained
pub fn __neg(self) -> Self;
fn __neg(self) -> Self;
// unconstrained
pub fn __add(self, other: Self) -> Self;
fn __add(self, other: Self) -> Self;
// unconstrained
pub fn __sub(self, other: Self) -> Self;
fn __sub(self, other: Self) -> Self;
// unconstrained
pub fn __mul(self, other: Self) -> Self;
fn __mul(self, other: Self) -> Self;
// unconstrained
pub fn __div(self, other: Self) -> Self;
fn __div(self, other: Self) -> Self;
// unconstrained
pub fn __udiv_mod(self, divisor: Self) -> (Self, Self);
fn __udiv_mod(self, divisor: Self) -> (Self, Self);
// unconstrained
pub fn __invmod(self) -> Self;
fn __invmod(self) -> Self;
// unconstrained
pub fn __pow(self, exponent: Self) -> Self;
fn __pow(self, exponent: Self) -> Self;

// unconstrained
pub fn __batch_invert<let M: u32>(x: [Self; M]) -> [Self; M];
fn __batch_invert<let M: u32>(x: [Self; M]) -> [Self; M];
unconstrained fn __batch_invert_slice<let M: u32>(to_invert: [Self]) -> [Self];

pub fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;
fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;

// unconstrained
pub fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
params: BigNumParams<N, MOD_BITS>,
lhs_terms: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
Expand All @@ -89,7 +86,7 @@ pub(crate) trait RuntimeBigNumTrait<let N: u32, let MOD_BITS: u32>: Neg + Add +
linear_flags: [bool; ADD_N],
) -> (Self, Self);

pub fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
params: BigNumParams<N, MOD_BITS>,
lhs_terms: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
Expand All @@ -99,20 +96,20 @@ pub(crate) trait RuntimeBigNumTrait<let N: u32, let MOD_BITS: u32>: Neg + Add +
linear_flags: [bool; ADD_N],
);

pub fn eq(lhs: Self, rhs: Self) -> bool {
fn eq(lhs: Self, rhs: Self) -> bool {
lhs == rhs
}
pub fn assert_is_not_equal(self, other: Self);
pub fn validate_in_field(self);
pub fn validate_in_range(self);
fn assert_is_not_equal(self, other: Self);
fn validate_in_field(self);
fn validate_in_range(self);
// pub fn validate_gt(self, lhs: Self, rhs: Self);

pub fn udiv_mod(numerator: Self, divisor: Self) -> (Self, Self);
pub fn udiv(numerator: Self, divisor: Self) -> Self;
pub fn umod(numerator: Self, divisor: Self) -> Self;
fn udiv_mod(numerator: Self, divisor: Self) -> (Self, Self);
fn udiv(numerator: Self, divisor: Self) -> Self;
fn umod(numerator: Self, divisor: Self) -> Self;

pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
pub fn is_zero(self) -> bool;
fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
fn is_zero(self) -> bool;
}

impl<let N: u32, let MOD_BITS: u32> Neg for RuntimeBigNum<N, MOD_BITS> {
Expand Down
37 changes: 36 additions & 1 deletion src/tests/bignum_test.nr
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::params::BigNumParamsGetter;
use crate::fields::bls12_381Fq::BLS12_381_Fq_Params;
use crate::fields::bn254Fq::BN254_Fq_Params;
use crate::fields::U256::U256Params;
use std::bigint::Bn254Fq;

struct Test2048Params {}

Expand Down Expand Up @@ -938,3 +937,39 @@ fn test_to_from_field_2() {
assert(a == c);
}

unconstrained fn test_batch_inversion<let N: u32, BN>(fields: [BN; N])
where
BN: BigNumTrait,
{
let inverted_fields = BN::__batch_invert(fields);
for i in 0..N {
assert_eq(fields[i] * inverted_fields[i], BN::one());
}
}

#[test]
unconstrained fn test_batch_inversion_BN381(seeds: [[u8; 2]; 3]) {
let fields = seeds.map(|seed| BN381::derive_from_seed(seed));
unsafe {
test_batch_inversion(fields)
}
}

unconstrained fn test_batch_inversion_slice<BN>(fields: [BN])
where
BN: BigNumTrait,
{
let inverted_fields = BN::__batch_invert_slice(fields);
assert_eq(fields.len(), inverted_fields.len());
for i in 0..fields.len() {
assert_eq(fields[i] * inverted_fields[i], BN::one());
}
}

#[test]
unconstrained fn test_batch_inversion_slice_BN381(seeds: [[u8; 2]; 3]) {
let fields = seeds.map(|seed| BN381::derive_from_seed(seed)).as_slice();
unsafe {
test_batch_inversion_slice(fields)
}
}

0 comments on commit 976d3ef

Please sign in to comment.