Skip to content

Commit

Permalink
montgomery mul for rust impl done
Browse files Browse the repository at this point in the history
  • Loading branch information
weijiekoh committed Aug 21, 2023
1 parent a09bf75 commit 91f5e32
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 104 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ ark-ff = "0.4.0"
ark-bn254 = "0.4.0"
poseidon-ark = { git = "https://github.com/arnaucube/poseidon-ark.git" }
naga = "0.12.3"
hex = "0.4.3"
50 changes: 44 additions & 6 deletions src/poseidon.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use rand::Rng;
use ark_bn254::Fr;
use ark_ff::PrimeField;
use stopwatch::Stopwatch;
use num_bigint::BigUint;
//use std::str::FromStr;
Expand Down Expand Up @@ -45,21 +46,46 @@ pub fn test_poseidon() {
// Number of inputs: 1
// t = 1 + 1 = 2

let r_bytes = hex::decode("010000000000000000000000000000000000000000000000000000000000000000").unwrap();
let r = Fr::from_be_bytes_mod_order(r_bytes.as_slice());
let rinv_bytes = hex::decode("15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e").unwrap();
let rinv = Fr::from_be_bytes_mod_order(rinv_bytes.as_slice());

let poseidon = Poseidon::new();
let p_constants = load_constants();
let mut p_constants = load_constants();

// Convert constants to Montgomery form
for i in 0..p_constants.c.len() {
for j in 0..p_constants.c[i].len() {
p_constants.c[i][j] = p_constants.c[i][j] * r;
}
}
for i in 0..p_constants.m.len() {
for j in 0..p_constants.m[i].len() {
for k in 0..p_constants.m[i][j].len() {
p_constants.m[i][j][k] = p_constants.m[i][j][k] * r;
}
}
}

let num_inputs = 256 * 64;
let num_x_workgroups = 256;

println!("Computing {} Poseidon hashes in Rust / WebGPU", num_inputs);

let mut inputs: Vec<BigUint> = Vec::with_capacity(num_inputs);
let mut a_inputs: Vec<Fr> = Vec::with_capacity(num_inputs);

let mut rng = rand::thread_rng();
for _ in 0..num_inputs {
let random_bytes = rng.gen::<[u8; 32]>();
let a = BigUint::from_bytes_be(random_bytes.as_slice()) % get_fr();
inputs.push(a);
//let random_bytes = rng.gen::<[u8; 32]>();
let mut random_bytes = [1u8];
//let a = BigUint::from_bytes_be(random_bytes.as_slice()) % get_fr();

// Convert to Montgomery form
let a = Fr::from_le_bytes_mod_order(random_bytes.as_slice());
inputs.push((a * r).into_bigint().into());
a_inputs.push(a);
}

let mut constants: Vec<BigUint> = Vec::with_capacity(p_constants.c.len() + 4);
Expand All @@ -78,8 +104,14 @@ pub fn test_poseidon() {
}
}

// Compute the hashes using CPU
let sw = Stopwatch::start_new();
let expected_hashes: Vec<BigUint> = inputs.iter().map(|a| poseidon.hash(vec![a.clone().into()]).unwrap().into()).collect();
let mut expected_hashes: Vec<BigUint> = Vec::with_capacity(num_inputs);
//let mut expected_hashes: Vec<Fr> = Vec::with_capacity(num_inputs);
for i in 0..num_inputs {
let h = poseidon.hash(vec![a_inputs[i].clone().into()]).unwrap();
expected_hashes.push(h.into_bigint().into());
}
println!("CPU took {}ms", sw.elapsed_ms());

//// For debugging:
Expand Down Expand Up @@ -118,9 +150,15 @@ pub fn test_poseidon() {
let result = pollster::block_on(double_buffer_compute(&wgsl, &buf, &constants, num_x_workgroups, 1)).unwrap();

let result = u32s_to_bigints(result);

let mut from_mont_results: Vec<BigUint> = Vec::with_capacity(num_inputs);
for r in &result {
from_mont_results.push((Fr::from_be_bytes_mod_order(&result[0].to_bytes_be()) * rinv).into_bigint().into());
}
//println!("{}, {}", Fr::from_be_bytes_mod_order(&result[0].to_bytes_be()) * rinv, expected_hashes[0]);
//println!("Input: {:?}", inputs.clone());
//println!("Result from GPU: {:?}", result.clone());
//assert_eq!(result[0], expected_final_state[0]);
assert_eq!(result, expected_hashes);
assert_eq!(from_mont_results, expected_hashes);

}
168 changes: 80 additions & 88 deletions src/wgsl/fr.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,11 @@ fn fr_get_p() -> BigInt256 {
p.limbs[13] = 57649u;
p.limbs[14] = 20082u;
p.limbs[15] = 12388u;

return p;
}

fn fr_get_mu() -> BigInt256 {
var p: BigInt256;
p.limbs[0] = 59685u;
p.limbs[1] = 48669u;
p.limbs[2] = 934u;
p.limbs[3] = 25095u;
p.limbs[4] = 32942u;
p.limbs[5] = 2536u;
p.limbs[6] = 34080u;
p.limbs[7] = 28996u;
p.limbs[8] = 12308u;
p.limbs[9] = 26631u;
p.limbs[10] = 19032u;
p.limbs[11] = 43783u;
p.limbs[12] = 1191u;
p.limbs[13] = 25146u;
p.limbs[14] = 29794u;
p.limbs[15] = 21668u;

return p;
}

fn fr_get_p_wide() -> BigInt512 {
var p: BigInt512;
fn gen_p_medium_wide() -> BigInt272 {
var p: BigInt272;
p.limbs[0] = 1u;
p.limbs[1] = 61440u;
p.limbs[2] = 62867u;
Expand All @@ -60,72 +37,9 @@ fn fr_get_p_wide() -> BigInt512 {
p.limbs[13] = 57649u;
p.limbs[14] = 20082u;
p.limbs[15] = 12388u;
p.limbs[16] = 0u;
p.limbs[17] = 0u;
p.limbs[18] = 0u;
p.limbs[19] = 0u;
p.limbs[20] = 0u;
p.limbs[21] = 0u;
p.limbs[22] = 0u;
p.limbs[23] = 0u;
p.limbs[24] = 0u;
p.limbs[25] = 0u;
p.limbs[26] = 0u;
p.limbs[27] = 0u;
p.limbs[28] = 0u;
p.limbs[29] = 0u;
p.limbs[30] = 0u;
p.limbs[31] = 0u;
return p;
}

fn get_higher_with_slack(a: ptr<function, BigInt512>) -> BigInt256 {
var out: BigInt256;
/*var slack = 2u; // 256 minus the bitwidth of the Fr modulus*/
/*var W = 16u;*/
/*var W_mask = 65535u;*/
for (var i = 0u; i < 16u; i ++) {
/*
This loop operates on the most significant bits of the bigint.
It discards the least significant bits.
*/
// mul by 2 ** 1 divide by 2 ** 15
/*out.limbs[i] = (((*a).limbs[i + 16u] << slack) + ((*a).limbs[i + 15u] >> (W - slack))) & W_mask;*/
out.limbs[i] = (((*a).limbs[i + 16u] << 2u) + ((*a).limbs[i + 15u] >> 14u)) & 65535u;
}
return out;
}


fn fr_mul(a: ptr<function, BigInt256>, b: ptr<function, BigInt256>) -> BigInt256 {
var mu = fr_get_mu();
var p = fr_get_p();
var p_wide = fr_get_p_wide();

var xy: BigInt512 = bigint_mul(a, b);
var xy_hi: BigInt256 = get_higher_with_slack(&xy);
var l: BigInt512 = bigint_mul(&xy_hi, &mu);
var l_hi: BigInt256 = get_higher_with_slack(&l);
var lp: BigInt512 = bigint_mul(&l_hi, &p);
var r_wide: BigInt512;
bigint_512_sub(&xy, &lp, &r_wide);

var r_wide_reduced: BigInt512;
var underflow = bigint_512_sub(&r_wide, &p_wide, &r_wide_reduced);
if (underflow == 0u) {
r_wide = r_wide_reduced;
}
var r: BigInt256;
for (var i = 0u; i < 16u; i ++) {
r.limbs[i] = r_wide.limbs[i];
}
return fr_reduce(&r);
}

fn fr_sqr(a: ptr<function, BigInt256>) -> BigInt256 {
return fr_mul(a, a);
}

fn fr_add(a: ptr<function, BigInt256>, b: ptr<function, BigInt256>) -> BigInt256 {
var res: BigInt256;
/*var res = bigint_add(a, b);*/
Expand All @@ -144,3 +58,81 @@ fn fr_reduce(a: ptr<function, BigInt256>) -> BigInt256 {

return res;
}

fn hi(val: u32) -> u32 {
return val >> 16u;
}

fn lo(val: u32) -> u32 {
return val & 65535u;
}

fn cios_mon_pro(a: ptr<function, BigInt256>, b: ptr<function, BigInt256>) -> BigInt256 {
var n = gen_p_medium_wide();
var n0 = 65535u;
var num_words = 16u;

var t: array<u32, 18u>;
var x: BigInt256;

for (var i = 0u; i < num_words; i ++) {
var c = 0u;
for (var j = 0u; j < num_words; j ++) {
var r = t[j] + (*a).limbs[j] * (*b).limbs[i] + c;
c = hi(r);
t[j] = lo(r);
}
var r = t[num_words] + c;
t[num_words + 1u] = hi(r);
t[num_words] = lo(r);

var m = (t[0] * n0) % 65536u;
r = t[0] + m * n.limbs[0];
c = hi(r);

for (var j = 1u; j < num_words; j ++) {
r = t[j] + m * n.limbs[j] + c;
c = hi(r);
t[j - 1u] = lo(r);
}

r = t[num_words] + c;
c = hi(r);
t[num_words - 1u] = lo(r);
t[num_words] = t[num_words + 1u] + c;
}

// Check if t > n. If so, return n - t. Else, return t.
var t_lt_n = false;
for (var idx = 0u; idx < num_words + 1u; idx ++) {
var i = num_words - 1u - idx;
if (t[i] < n.limbs[i]) {
t_lt_n = true;
break;
} else if (t[i] > n.limbs[i]) {
break;
}
}

var r: BigInt256;
if (t_lt_n) {
for (var i = 0u; i < num_words; i ++) {
r.limbs[i] = t[i];
}
return r;
} else {
var borrow = 0u;
var t_minus_n: BigInt272;
for (var i = 0u; i < num_words; i ++) {
t_minus_n.limbs[i] = t[i] - n.limbs[i] - borrow;
if (t[i] < (n.limbs[i] + borrow)) {
t_minus_n.limbs[i] = t_minus_n.limbs[i] + 65536u;
borrow = 1u;
} else {
borrow = 0u;
}
x.limbs[i] = t_minus_n.limbs[i];
}
return x;
}
}
20 changes: 10 additions & 10 deletions src/wgsl/poseidon_t2.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {

// S-Box
var s0 = state_0;
state_0 = fr_mul(&state_0, &state_0);
state_0 = fr_mul(&state_0, &state_0);
state_0 = fr_mul(&s0, &state_0);
state_0 = cios_mon_pro(&state_0, &state_0);
state_0 = cios_mon_pro(&state_0, &state_0);
state_0 = cios_mon_pro(&s0, &state_0);

if (i < 4u || i >= 60u) {
var s1 = state_1;
state_1 = fr_mul(&state_1, &state_1);
state_1 = fr_mul(&state_1, &state_1);
state_1 = fr_mul(&s1, &state_1);
state_1 = cios_mon_pro(&state_1, &state_1);
state_1 = cios_mon_pro(&state_1, &state_1);
state_1 = cios_mon_pro(&s1, &state_1);
}

// Mix
var m00s0 = fr_mul(&m_0_0, &state_0);
var m01s1 = fr_mul(&m_0_1, &state_1);
var m10s0 = fr_mul(&m_1_0, &state_0);
var m11s1 = fr_mul(&m_1_1, &state_1);
var m00s0 = cios_mon_pro(&m_0_0, &state_0);
var m01s1 = cios_mon_pro(&m_0_1, &state_1);
var m10s0 = cios_mon_pro(&m_1_0, &state_0);
var m11s1 = cios_mon_pro(&m_1_1, &state_1);

var new_state_0: BigInt256 = fr_add(&m00s0, &m01s1);
var new_state_1: BigInt256 = fr_add(&m10s0, &m11s1);
Expand Down
4 changes: 4 additions & 0 deletions src/wgsl/structs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ struct BigInt256 {
limbs: array<u32, 16>
}

struct BigInt272 {
limbs: array<u32, 17>
}

struct BigInt512 {
limbs: array<u32, 32>
}

0 comments on commit 91f5e32

Please sign in to comment.