Skip to content

Commit

Permalink
Merge pull request #29 from zkmopro/test/metal/field
Browse files Browse the repository at this point in the history
Fix ff.metal and refactor rust test
  • Loading branch information
moven0831 authored Jan 3, 2025
2 parents 6394db7 + b14c60e commit a6df775
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 70 deletions.
22 changes: 6 additions & 16 deletions mopro-msm/src/msm/metal_msm/shader/field/ff.metal
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,21 @@ BigInt ff_add(
BigInt b,
BigInt p
) {
// Assign p to p_wide
BigIntWide p_wide;
for (uint i = 0; i < NUM_LIMBS; i ++) {
p_wide.limbs[i] = p.limbs[i];
}

// a + b
BigIntWide sum_wide = bigint_add_wide(a, b);
BigInt sum = bigint_add_unsafe(a, b);

BigInt res;

// if (a + b) >= p
if (bigint_wide_gte(sum_wide, p_wide)) {
if (bigint_gte(sum, p)) {
// s = a + b - p
BigIntWide s = bigint_sub_wide(sum_wide, p_wide);

BigInt s = bigint_sub(sum, p);
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = s.limbs[i];
}
} else {
}
else {
for (uint i = 0; i < NUM_LIMBS; i ++) {
res.limbs[i] = sum_wide.limbs[i];
res.limbs[i] = sum.limbs[i];
}
}

return res;
}

Expand Down
62 changes: 38 additions & 24 deletions mopro-msm/src/msm/metal_msm/tests/field/ff_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,38 @@ use crate::msm::metal_msm::host::gpu::{
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, BigInteger, PrimeField};
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand};
use ark_std::rand;
use metal::*;

#[test]
#[serial_test::serial]
pub fn test_ff_add() {
let log_limb_size = 13;
let num_limbs = 20;
let log_limb_size = 16;
let num_limbs = 16;

// Scalar field modulus for bn254
let p = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A029,
]);
assert!(p == ScalarField::MODULUS);

let a = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A028,
]);
let b = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E7200000000,
]);
let p = BaseField::MODULUS;

let mut rng = rand::thread_rng();
let mut a = BigInt::rand(&mut rng);
let mut b = BigInt::rand(&mut rng);

// Reduce a and b if they are greater than or equal to the prime field modulus
while a >= p {
a.sub_with_borrow(&p);
}

while b >= p {
b.sub_with_borrow(&p);
}

// Ensure a and b are non-negative and less than p
assert!(a >= BigInt::from(0u64), "a must be non-negative");
assert!(b >= BigInt::from(0u64), "b must be non-negative");
assert!(a < p, "a must be less than p");
assert!(b < p, "b must be less than p");

let device = get_default_device();
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
Expand All @@ -51,6 +52,19 @@ pub fn test_ff_add() {
while expected >= p {
expected.sub_with_borrow(&p);
}
// Ensure expected is non-negative and less than p
assert!(
expected >= BigInt::from(0u64),
"expected must be non-negative"
);
assert!(expected < p, "expected must be less than p");

// Ensure the operation is correct using Arkworks
let a_field = BaseField::from(a);
let b_field = BaseField::from(b);
let expected_field = a_field + b_field;
assert!(expected_field == expected.into());

let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);

let command_queue = device.new_command_queue();
Expand Down
78 changes: 48 additions & 30 deletions mopro-msm/src/msm/metal_msm/tests/field/ff_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,70 @@ use crate::msm::metal_msm::host::gpu::{
};
use crate::msm::metal_msm::host::shader::{compile_metal, write_constants};
use crate::msm::metal_msm::utils::limbs_conversion::{FromLimbs, ToLimbs};
use ark_bn254::Fr as ScalarField;
use ark_ff::{BigInt, BigInteger, PrimeField};
use ark_bn254::Fq as BaseField;
use ark_ff::{BigInt, BigInteger, PrimeField, UniformRand};
use ark_std::rand;
use metal::*;

#[test]
#[serial_test::serial]
pub fn test_ff_sub() {
let log_limb_size = 13;
let num_limbs = 20;
let log_limb_size = 16;
let num_limbs = 16;

// Scalar field modulus for bn254
let p = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A029,
]);
assert!(p == ScalarField::MODULUS);

let a = BigInt::new([
0x43E1F593F0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E72E131A028,
]);
let b = BigInt::new([
0xAAAAAAAAF0000001,
0x2833E84879B97091,
0xB85045B68181585D,
0x30644E7200000000,
]);
let p = BaseField::MODULUS;

let mut rng = rand::thread_rng();
let mut a = BigInt::rand(&mut rng);
let mut b = BigInt::rand(&mut rng);

// Reduce a and b if they are greater than or equal to the prime field modulus
while a >= p {
a.sub_with_borrow(&p);
}

while b >= p {
b.sub_with_borrow(&p);
}

// Ensure a and b are non-negative and less than p
assert!(a >= BigInt::from(0u64), "a must be non-negative");
assert!(b >= BigInt::from(0u64), "b must be non-negative");
assert!(a < p, "a must be less than p");
assert!(b < p, "b must be less than p");

let device = get_default_device();
let a_buf = create_buffer(&device, &a.to_limbs(num_limbs, log_limb_size));
let b_buf = create_buffer(&device, &b.to_limbs(num_limbs, log_limb_size));
let p_buf = create_buffer(&device, &p.to_limbs(num_limbs, log_limb_size));
let result_buf = create_empty_buffer(&device, num_limbs);

// Perform (a - b) % p
// (a - b) % p
let mut expected = a.clone();
expected.sub_with_borrow(&b);

// If result is negative, add p until it's positive
while expected < BigInt::zero() {
expected.add_with_carry(&p);
if a >= b {
expected.sub_with_borrow(&b);
}
// p - (b - a)
else {
let mut p_sub_b = p.clone();
p_sub_b.sub_with_borrow(&b);
expected.add_with_carry(&p_sub_b);
}

// Ensure expected is non-negative and less than p
assert!(
expected >= BigInt::from(0u64),
"expected must be non-negative"
);
assert!(expected < p, "expected must be less than p");

// Ensure the operation is correct using Arkworks
let a_field = BaseField::from(a);
let b_field = BaseField::from(b);
let expected_field = a_field - b_field;
assert!(expected_field == expected.into());

let expected_limbs = expected.to_limbs(num_limbs, log_limb_size);

let command_queue = device.new_command_queue();
Expand Down

0 comments on commit a6df775

Please sign in to comment.