Skip to content

Commit

Permalink
fixed Rust test and made Montgomery mul work in the browser demo
Browse files Browse the repository at this point in the history
  • Loading branch information
weijiekoh committed Aug 21, 2023
1 parent 91f5e32 commit ed7ce4c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ pub fn test_poseidon() {

// Convert to Montgomery form
let a = Fr::from_le_bytes_mod_order(random_bytes.as_slice());
inputs.push((a * r).into_bigint().into());
a_inputs.push(a);
inputs.push((a * r).into_bigint().into());
}

let mut constants: Vec<BigUint> = Vec::with_capacity(p_constants.c.len() + 4);
Expand Down Expand Up @@ -153,11 +153,11 @@ pub fn test_poseidon() {

let mut from_mont_results: Vec<BigUint> = Vec::with_capacity(num_inputs);
for r in &result {
from_mont_results.push((Fr::from_be_bytes_mod_order(&result[0].to_bytes_be()) * rinv).into_bigint().into());
from_mont_results.push((Fr::from_be_bytes_mod_order(&r.to_bytes_be()) * rinv).into_bigint().into());
}
//println!("{}, {}", Fr::from_be_bytes_mod_order(&result[0].to_bytes_be()) * rinv, expected_hashes[0]);
//println!("Input: {:?}", inputs.clone());
//println!("Result from GPU: {:?}", result.clone());
//println!("Results from GPU converted to Montgomery form: {:?}", from_mont_results.clone());
//assert_eq!(result[0], expected_final_state[0]);
assert_eq!(from_mont_results, expected_hashes);

Expand Down
56 changes: 31 additions & 25 deletions web/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,27 @@ async function poseidon(input: BigInt) {
const constants_c = constants.default.C
const constants_m = constants.default.M

const num_inputs = 256 * 64;
const num_inputs = 256 * 64;
const numXWorkgroups = 256;

//let inputs: BigInt[] = [BigInt(1), BigInt(1)]
let inputs: BigInt[] = []
let mont_inputs: BigInt[] = []

const p = BigInt('0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001')
const r = BigInt('0xe0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb')
const ri = BigInt('0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e')

for (let i = 0; i < num_inputs; i ++) {
const rand = BigInt('0x' + crypto.randomBytes(32).toString('hex')) % p
inputs.push(rand)
mont_inputs.push(rand * r % p)
}

const hasher = await buildPoseidon();
let expectedHashes: BigInt[] = []
let start = Date.now()
for (const input of inputs) {
const hash = utils.leBuff2int(hasher.F.fromMontgomery(hasher([input])))
const hash = hasher([input])
expectedHashes.push(hash)
}
let elapsed = Date.now() - start
Expand All @@ -59,37 +63,40 @@ async function poseidon(input: BigInt) {
// Append the C constants
for (const c_val of constants_c[t - 2]) {
//inputs.push(BigInt(c_val));
constants_flat.push(BigInt(c_val));
constants_flat.push(BigInt(c_val) * r % p);
}

// Append the M constants
for (const vs of constants_m[t - 2]) {
for (const v_val of vs) {
constants_flat.push(BigInt(v_val))
constants_flat.push(BigInt(v_val) * r % p)
}
}

const input_bytes = new Uint8Array(bigints_to_limbs(inputs).buffer);
const input_bytes = new Uint8Array(bigints_to_limbs(mont_inputs).buffer);
const constants_bytes = new Uint8Array(bigints_to_limbs(constants_flat).buffer);

const INPUT_BUFFER_SIZE = input_bytes.length;
const CONSTANTS_BUFFER_SIZE = constants_bytes.length;
console.log(inputs.length, INPUT_BUFFER_SIZE)
//console.log(inputs.length, INPUT_BUFFER_SIZE)

console.log(0)
const gpuErrMsg = "Please use a browser that has WebGPU enabled.";
//console.log(0)
// 1: request adapter and device
// @ts-ignore
if (!navigator.gpu) {
codeOutput.innerHTML += gpuErrMsg;
throw Error('WebGPU not supported.');
}

console.log(1)
//console.log(1)

// @ts-ignore
const adapter = await navigator.gpu.requestAdapter({
powerPreference: 'high-performance',
});
if (!adapter) {
codeOutput.innerHTML += gpuErrMsg;
throw Error('Couldn\'t request WebGPU adapter.');
}

Expand All @@ -100,9 +107,9 @@ async function poseidon(input: BigInt) {
code: shader
});

console.log(2)
//console.log(2)

//3: Create an output buffer to read GPU calculations to, and a staging
// 3: Create an output buffer to read GPU calculations to, and a staging
//buffer to be mapped for JavaScript access

const storageBuffer = device.createBuffer({
Expand All @@ -125,7 +132,7 @@ async function poseidon(input: BigInt) {
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
});

console.log(3)
//console.log(3)

// 4: Create a GPUBindGroupLayout to define the bind group structure,
// create a GPUBindGroup from it, then use it to create a
Expand Down Expand Up @@ -179,18 +186,18 @@ async function poseidon(input: BigInt) {
entryPoint: 'main'
}
});
console.log(4)
//console.log(4)

// 5: Create GPUCommandEncoder to issue commands to the GPU
const commandEncoder = device.createCommandEncoder();

console.log(5)
//console.log(5)

start = Date.now()
// 6: Initiate render pass
const passEncoder = commandEncoder.beginComputePass();

console.log(6)
//console.log(6)

// 7: Issue commands
passEncoder.setPipeline(computePipeline);
Expand All @@ -209,11 +216,11 @@ async function poseidon(input: BigInt) {
INPUT_BUFFER_SIZE
);

console.log(7)
//console.log(7)

// 8: End frame by passing array of command buffers to command queue for execution
device.queue.submit([commandEncoder.finish()]);
console.log(7.1)
//console.log(7.1)

// map staging buffer to read results back to JS
await stagingBuffer.mapAsync(
Expand All @@ -222,13 +229,13 @@ async function poseidon(input: BigInt) {
0, // Offset
INPUT_BUFFER_SIZE // Length
);
console.log(7.2)
//console.log(7.2)

const copyArrayBuffer = stagingBuffer.getMappedRange(0, INPUT_BUFFER_SIZE);
const data = copyArrayBuffer.slice();
stagingBuffer.unmap();

console.log(8)
//console.log(8)

const dataBuf = new Uint32Array(data);
elapsed = Date.now() - start
Expand All @@ -237,13 +244,12 @@ async function poseidon(input: BigInt) {

const results: BigInt[] = []
for (let i = 0; i < dataBuf.length / 16; i ++) {
const result = uint32ArrayToBigint(dataBuf.slice(i * 16, i * 16 + 16))
results.push(result)
const result = BigInt(uint32ArrayToBigint(dataBuf.slice(i * 16, i * 16 + 16)))
results.push(result * ri % p)
}
console.log(results)
console.log(expectedHashes)
for (let i = 0; i < results.length; i ++) {
assert(results[i] === expectedHashes[i])
let e = utils.leBuff2int(hasher.F.fromMontgomery(expectedHashes[i]));
assert(results[i] === e);
}
assert(results.length === expectedHashes.length)
}
Expand Down Expand Up @@ -273,7 +279,7 @@ const bytes_to_bigints = (limbs: Uint8Array): BigInt[] => {
chunks.push(chunk);
}

console.log(chunks);
//console.log(chunks);
return []
}

Expand Down

0 comments on commit ed7ce4c

Please sign in to comment.