Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct batch inversion function #117

Merged
merged 6 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:

- name: Compare gates reports
id: gates_diff
uses: noir-lang/noir-gates-diff@1931aaaa848a1a009363d6115293f7b7fc72bb87
uses: noir-lang/noir-gates-diff@dbe920a8dcc3370af4be4f702ca9cef29317bec1
with:
report: gates_report.json
summaryQuantile: 0.9 # only display the 10% most significant circuit size diffs in the summary (defaults to 20%)
Expand Down
84 changes: 42 additions & 42 deletions src/bignum.nr
Original file line number Diff line number Diff line change
Expand Up @@ -25,40 +25,40 @@ pub struct BigNum<let N: u32, let MOD_BITS: u32, Params> {
pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq {
// TODO: this crashes the compiler? v0.32
// fn default() -> Self { std::default::Default::default () }
pub fn new() -> Self;
fn new() -> Self;
fn zero() -> Self;
pub fn one() -> Self;
pub fn derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub unconstrained fn __derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub fn from_slice(limbs: [Field]) -> Self;
pub fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self;
pub fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

pub fn modulus() -> Self;
pub fn modulus_bits(self) -> u32;
pub fn num_limbs(self) -> u32;
pub fn get_limbs_slice(self) -> [Field];
pub fn get_limb(self, idx: u32) -> Field;
pub fn set_limb(&mut self, idx: u32, value: Field);

pub unconstrained fn __eq(self, other: Self) -> bool;
pub unconstrained fn __is_zero(self) -> bool;

pub unconstrained fn __neg(self) -> Self;
pub unconstrained fn __add(self, other: Self) -> Self;
pub unconstrained fn __sub(self, other: Self) -> Self;
pub unconstrained fn __mul(self, other: Self) -> Self;
pub unconstrained fn __div(self, other: Self) -> Self;
pub unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self);
pub unconstrained fn __invmod(self) -> Self;
pub unconstrained fn __pow(self, exponent: Self) -> Self;

pub unconstrained fn __batch_invert<let M: u32>(to_invert: [Self; M]) -> [Self; M];
pub unconstrained fn __batch_invert_slice<let M: u32>(to_invert: [Self]) -> [Self];

pub unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;

pub unconstrained fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn one() -> Self;
fn derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
unconstrained fn __derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
fn from_slice(limbs: [Field]) -> Self;
fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self;
fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

fn modulus() -> Self;
fn modulus_bits(self) -> u32;
fn num_limbs(self) -> u32;
fn get_limbs_slice(self) -> [Field];
fn get_limb(self, idx: u32) -> Field;
fn set_limb(&mut self, idx: u32, value: Field);

unconstrained fn __eq(self, other: Self) -> bool;
unconstrained fn __is_zero(self) -> bool;

unconstrained fn __neg(self) -> Self;
unconstrained fn __add(self, other: Self) -> Self;
unconstrained fn __sub(self, other: Self) -> Self;
unconstrained fn __mul(self, other: Self) -> Self;
unconstrained fn __div(self, other: Self) -> Self;
unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self);
unconstrained fn __invmod(self) -> Self;
unconstrained fn __pow(self, exponent: Self) -> Self;

unconstrained fn __batch_invert<let M: u32>(to_invert: [Self; M]) -> [Self; M];
unconstrained fn __batch_invert_slice<let M: u32>(to_invert: [Self]) -> [Self];

unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;

unconstrained fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
lhs: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
rhs: [[Self; RHS_N]; NUM_PRODUCTS],
Expand All @@ -67,7 +67,7 @@ pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq {
add_flags: [bool; ADD_N],
) -> (Self, Self);

pub fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
lhs: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
rhs: [[Self; RHS_N]; NUM_PRODUCTS],
Expand All @@ -76,16 +76,16 @@ pub trait BigNumTrait: Neg + Add + Sub + Mul + Div + Eq {
add_flags: [bool; ADD_N],
);

pub fn assert_is_not_equal(self, other: Self);
pub fn validate_in_range(self);
pub fn validate_in_field(self);
fn assert_is_not_equal(self, other: Self);
fn validate_in_range(self);
fn validate_in_field(self);

pub fn udiv_mod(self, divisor: Self) -> (Self, Self);
pub fn udiv(self, divisor: Self) -> Self;
pub fn umod(self, divisor: Self) -> Self;
fn udiv_mod(self, divisor: Self) -> (Self, Self);
fn udiv(self, divisor: Self) -> Self;
fn umod(self, divisor: Self) -> Self;

pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
pub fn is_zero(self) -> bool;
fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
fn is_zero(self) -> bool;
}

impl<let N: u32, let MOD_BITS: u32, Params> std::convert::From<Field> for BigNum<N, MOD_BITS, Params>
Expand Down
33 changes: 17 additions & 16 deletions src/fns/unconstrained_ops.nr
Original file line number Diff line number Diff line change
Expand Up @@ -237,21 +237,20 @@ pub(crate) unconstrained fn __batch_invert<let N: u32, let MOD_BITS: u32, let M:
) -> [[Field; N]; M] {
// TODO: ugly! Will fail if input slice is empty
let mut accumulator: [Field; N] = __one::<N>();
let mut result: [[Field; N]; M] = [[0; N]; M];
let mut temporaries: [[Field; N]; N] = std::mem::zeroed();
for i in 0..N {
for i in 0..M {
temporaries[i] = accumulator;
if (!__is_zero(x[i])) {
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[i]);
}
}

let mut result: [[Field; N]; M] = [[0; N]; M];
accumulator = __invmod::<_, MOD_BITS>(params, accumulator);
let mut T0: [Field; N] = [0; N];
for i in 0..N {
let idx = N - 1 - i;
for i in 0..M {
let idx = M - 1 - i;
if (!__is_zero(x[idx])) {
T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
let T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[idx]);
result[idx] = T0;
}
Expand All @@ -265,26 +264,28 @@ pub(crate) unconstrained fn __batch_invert_slice<let N: u32, let MOD_BITS: u32>(
) -> [[Field; N]] {
// TODO: ugly! Will fail if input slice is empty
let mut accumulator: [Field; N] = __one::<N>();
let mut result: [[Field; N]] = [[0; N]];
let mut temporaries: [[Field; N]; N] = std::mem::zeroed();
for i in 0..N {
temporaries[i] = accumulator;
let mut temporaries: [[Field; N]] = &[];
for i in 0..x.len() {
temporaries = temporaries.push_back(accumulator);
if (!__is_zero(x[i])) {
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[i]);
}
result = result.push_back([0; N]);
}

let mut result: [[Field; N]] = [];
accumulator = __invmod::<_, MOD_BITS>(params, accumulator);
let mut T0: [Field; N] = [0; N];
for i in 0..x.len() {
let idx = x.len() - 1 - i;
if (__is_zero(x[idx]) == false) {
T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
if (!__is_zero(x[idx])) {
let T0 = __mul::<_, MOD_BITS>(params, accumulator, temporaries[idx]);
accumulator = __mul::<_, MOD_BITS>(params, accumulator, x[idx]);
result[idx] = T0;
}
assert(__mul::<_, MOD_BITS>(params, T0, x[idx]) == __one::<N>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added an assertion for easier debugging later if something goes wrong, it wasn't showing the call stack properly.

result = result.push_front(T0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order was reversed, so we should push things to the front instead of the back

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed it in a later commit

} else {
result = result.push_front([0; N]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result.push_front does not update the slice, and we should rather do result = result.push_front()

};
}

result
}

Expand Down
2 changes: 1 addition & 1 deletion src/params.nr
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct BigNumParams<let N: u32, let MOD_BITS: u32> {

// To be implemented by the user for any BigNum they define, or within the predefined BigNums in the `fields/` dir.
pub trait BigNumParamsGetter<let N: u32, let MOD_BITS: u32> {
pub fn get_params() -> BigNumParams<N, MOD_BITS>;
fn get_params() -> BigNumParams<N, MOD_BITS>;
}

impl<let N: u32, let MOD_BITS: u32> BigNumParams<N, MOD_BITS> {
Expand Down
73 changes: 35 additions & 38 deletions src/runtime_bignum.nr
Original file line number Diff line number Diff line change
Expand Up @@ -25,61 +25,58 @@ impl<let N: u32, let MOD_BITS: u32> RuntimeBigNum<N, MOD_BITS> {}
// All functions prefixed `__` are unconstrained!
// They're not actually decorated as `unconstrained` because to return the `params` (as part of Self) from an `unconstrained` fn would cause range constraints. Instead, each `__` fn wraps a call to an unconstrained fn, so that the already-range-constrained `params` can be inserted into Self after the unconstrained call.
pub(crate) trait RuntimeBigNumTrait<let N: u32, let MOD_BITS: u32>: Neg + Add + Sub + Mul + Div + Eq {
pub fn new(params: BigNumParams<N, MOD_BITS>) -> Self;
pub fn one(params: BigNumParams<N, MOD_BITS>) -> Self;
pub fn derive_from_seed<let SeedBytes: u32>(
fn new(params: BigNumParams<N, MOD_BITS>) -> Self;
fn one(params: BigNumParams<N, MOD_BITS>) -> Self;
fn derive_from_seed<let SeedBytes: u32>(
params: BigNumParams<N, MOD_BITS>,
seed: [u8; SeedBytes],
) -> Self;
pub unconstrained fn __derive_from_seed<let SeedBytes: u32>(
unconstrained fn __derive_from_seed<let SeedBytes: u32>(
params: BigNumParams<N, MOD_BITS>,
seed: [u8; SeedBytes],
) -> Self;
pub fn from_slice(params: BigNumParams<N, MOD_BITS>, limbs: [Field]) -> Self;
pub fn from_array(params: BigNumParams<N, MOD_BITS>, limbs: [Field; N]) -> Self;
pub fn from_be_bytes<let NBytes: u32>(
params: BigNumParams<N, MOD_BITS>,
x: [u8; NBytes],
) -> Self;
fn from_slice(params: BigNumParams<N, MOD_BITS>, limbs: [Field]) -> Self;
fn from_array(params: BigNumParams<N, MOD_BITS>, limbs: [Field; N]) -> Self;
fn from_be_bytes<let NBytes: u32>(params: BigNumParams<N, MOD_BITS>, x: [u8; NBytes]) -> Self;

pub fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];
fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

pub fn modulus(self) -> Self;
pub fn modulus_bits() -> u32;
pub fn num_limbs() -> u32;
fn modulus(self) -> Self;
fn modulus_bits() -> u32;
fn num_limbs() -> u32;
// pub fn get(self) -> [Field];
pub fn get_limbs(self) -> [Field; N];
pub fn get_limb(self, idx: u32) -> Field;
pub fn set_limb(&mut self, idx: u32, value: Field);
fn get_limbs(self) -> [Field; N];
fn get_limb(self, idx: u32) -> Field;
fn set_limb(&mut self, idx: u32, value: Field);

unconstrained fn __eq(self, other: Self) -> bool;
unconstrained fn __is_zero(self) -> bool;

// unconstrained
pub fn __neg(self) -> Self;
fn __neg(self) -> Self;
// unconstrained
pub fn __add(self, other: Self) -> Self;
fn __add(self, other: Self) -> Self;
// unconstrained
pub fn __sub(self, other: Self) -> Self;
fn __sub(self, other: Self) -> Self;
// unconstrained
pub fn __mul(self, other: Self) -> Self;
fn __mul(self, other: Self) -> Self;
// unconstrained
pub fn __div(self, other: Self) -> Self;
fn __div(self, other: Self) -> Self;
// unconstrained
pub fn __udiv_mod(self, divisor: Self) -> (Self, Self);
fn __udiv_mod(self, divisor: Self) -> (Self, Self);
// unconstrained
pub fn __invmod(self) -> Self;
fn __invmod(self) -> Self;
// unconstrained
pub fn __pow(self, exponent: Self) -> Self;
fn __pow(self, exponent: Self) -> Self;

// unconstrained
pub fn __batch_invert<let M: u32>(x: [Self; M]) -> [Self; M];
fn __batch_invert<let M: u32>(x: [Self; M]) -> [Self; M];
unconstrained fn __batch_invert_slice<let M: u32>(to_invert: [Self]) -> [Self];

pub fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;
fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self>;

// unconstrained
pub fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn __compute_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
params: BigNumParams<N, MOD_BITS>,
lhs_terms: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
Expand All @@ -89,7 +86,7 @@ pub(crate) trait RuntimeBigNumTrait<let N: u32, let MOD_BITS: u32>: Neg + Add +
linear_flags: [bool; ADD_N],
) -> (Self, Self);

pub fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
fn evaluate_quadratic_expression<let LHS_N: u32, let RHS_N: u32, let NUM_PRODUCTS: u32, let ADD_N: u32>(
params: BigNumParams<N, MOD_BITS>,
lhs_terms: [[Self; LHS_N]; NUM_PRODUCTS],
lhs_flags: [[bool; LHS_N]; NUM_PRODUCTS],
Expand All @@ -99,20 +96,20 @@ pub(crate) trait RuntimeBigNumTrait<let N: u32, let MOD_BITS: u32>: Neg + Add +
linear_flags: [bool; ADD_N],
);

pub fn eq(lhs: Self, rhs: Self) -> bool {
fn eq(lhs: Self, rhs: Self) -> bool {
lhs == rhs
}
pub fn assert_is_not_equal(self, other: Self);
pub fn validate_in_field(self);
pub fn validate_in_range(self);
fn assert_is_not_equal(self, other: Self);
fn validate_in_field(self);
fn validate_in_range(self);
// pub fn validate_gt(self, lhs: Self, rhs: Self);

pub fn udiv_mod(numerator: Self, divisor: Self) -> (Self, Self);
pub fn udiv(numerator: Self, divisor: Self) -> Self;
pub fn umod(numerator: Self, divisor: Self) -> Self;
fn udiv_mod(numerator: Self, divisor: Self) -> (Self, Self);
fn udiv(numerator: Self, divisor: Self) -> Self;
fn umod(numerator: Self, divisor: Self) -> Self;

pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
pub fn is_zero(self) -> bool;
fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
fn is_zero(self) -> bool;
}

impl<let N: u32, let MOD_BITS: u32> Neg for RuntimeBigNum<N, MOD_BITS> {
Expand Down
37 changes: 36 additions & 1 deletion src/tests/bignum_test.nr
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use crate::params::BigNumParamsGetter;
use crate::fields::bls12_381Fq::BLS12_381_Fq_Params;
use crate::fields::bn254Fq::BN254_Fq_Params;
use crate::fields::U256::U256Params;
use std::bigint::Bn254Fq;

struct Test2048Params {}

Expand Down Expand Up @@ -938,3 +937,39 @@ fn test_to_from_field_2() {
assert(a == c);
}

unconstrained fn test_batch_inversion<let N: u32, BN>(fields: [BN; N])
where
BN: BigNumTrait,
{
let inverted_fields = BN::__batch_invert(fields);
for i in 0..N {
assert_eq(fields[i] * inverted_fields[i], BN::one());
}
}

#[test]
unconstrained fn test_batch_inversion_BN381(seeds: [[u8; 2]; 3]) {
let fields = seeds.map(|seed| BN381::derive_from_seed(seed));
unsafe {
test_batch_inversion(fields)
}
}

unconstrained fn test_batch_inversion_slice<BN>(fields: [BN])
where
BN: BigNumTrait,
{
let inverted_fields = BN::__batch_invert_slice(fields);
assert_eq(fields.len(), inverted_fields.len());
for i in 0..fields.len() {
assert_eq(fields[i] * inverted_fields[i], BN::one());
}
}

#[test]
unconstrained fn test_batch_inversion_slice_BN381(seeds: [[u8; 2]; 3]) {
let fields = seeds.map(|seed| BN381::derive_from_seed(seed)).as_slice();
unsafe {
test_batch_inversion_slice(fields)
}
}