diff --git a/src/poseidon.rs b/src/poseidon.rs index 53bae8f..71de7cd 100644 --- a/src/poseidon.rs +++ b/src/poseidon.rs @@ -84,8 +84,8 @@ pub fn test_poseidon() { // Convert to Montgomery form let a = Fr::from_le_bytes_mod_order(random_bytes.as_slice()); - inputs.push((a * r).into_bigint().into()); a_inputs.push(a); + inputs.push((a * r).into_bigint().into()); } let mut constants: Vec = Vec::with_capacity(p_constants.c.len() + 4); @@ -153,11 +153,11 @@ pub fn test_poseidon() { let mut from_mont_results: Vec = Vec::with_capacity(num_inputs); for r in &result { - from_mont_results.push((Fr::from_be_bytes_mod_order(&result[0].to_bytes_be()) * rinv).into_bigint().into()); + from_mont_results.push((Fr::from_be_bytes_mod_order(&r.to_bytes_be()) * rinv).into_bigint().into()); } //println!("{}, {}", Fr::from_be_bytes_mod_order(&result[0].to_bytes_be()) * rinv, expected_hashes[0]); //println!("Input: {:?}", inputs.clone()); - //println!("Result from GPU: {:?}", result.clone()); + //println!("Results from GPU converted to Montgomery form: {:?}", from_mont_results.clone()); //assert_eq!(result[0], expected_final_state[0]); assert_eq!(from_mont_results, expected_hashes); diff --git a/web/app.ts b/web/app.ts index 1a38de2..c760cb1 100644 --- a/web/app.ts +++ b/web/app.ts @@ -32,23 +32,27 @@ async function poseidon(input: BigInt) { const constants_c = constants.default.C const constants_m = constants.default.M - const num_inputs = 256 * 64; + const num_inputs = 256 * 64; const numXWorkgroups = 256; - //let inputs: BigInt[] = [BigInt(1), BigInt(1)] let inputs: BigInt[] = [] + let mont_inputs: BigInt[] = [] const p = BigInt('0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001') + const r = BigInt('0xe0a77c19a07df2f666ea36f7879462e36fc76959f60cd29ac96341c4ffffffb') + const ri = BigInt('0x15ebf95182c5551cc8260de4aeb85d5d090ef5a9e111ec87dc5ba0056db1194e') + for (let i = 0; i < num_inputs; i ++) { const rand = BigInt('0x' + crypto.randomBytes(32).toString('hex')) % p inputs.push(rand) + mont_inputs.push(rand * r % p) } const hasher = await buildPoseidon(); let expectedHashes: BigInt[] = [] let start = Date.now() for (const input of inputs) { - const hash = utils.leBuff2int(hasher.F.fromMontgomery(hasher([input]))) + const hash = hasher([input]) expectedHashes.push(hash) } let elapsed = Date.now() - start @@ -59,37 +63,40 @@ async function poseidon(input: BigInt) { // Append the C constants for (const c_val of constants_c[t - 2]) { //inputs.push(BigInt(c_val)); - constants_flat.push(BigInt(c_val)); + constants_flat.push(BigInt(c_val) * r % p); } // Append the M constants for (const vs of constants_m[t - 2]) { for (const v_val of vs) { - constants_flat.push(BigInt(v_val)) + constants_flat.push(BigInt(v_val) * r % p) } } - const input_bytes = new Uint8Array(bigints_to_limbs(inputs).buffer); + const input_bytes = new Uint8Array(bigints_to_limbs(mont_inputs).buffer); const constants_bytes = new Uint8Array(bigints_to_limbs(constants_flat).buffer); const INPUT_BUFFER_SIZE = input_bytes.length; const CONSTANTS_BUFFER_SIZE = constants_bytes.length; - console.log(inputs.length, INPUT_BUFFER_SIZE) + //console.log(inputs.length, INPUT_BUFFER_SIZE) - console.log(0) + const gpuErrMsg = "Please use a browser that has WebGPU enabled."; + //console.log(0) // 1: request adapter and device // @ts-ignore if (!navigator.gpu) { + codeOutput.innerHTML += gpuErrMsg; throw Error('WebGPU not supported.'); } - console.log(1) + //console.log(1) // @ts-ignore const adapter = await navigator.gpu.requestAdapter({ powerPreference: 'high-performance', }); if (!adapter) { + codeOutput.innerHTML += gpuErrMsg; throw Error('Couldn\'t request WebGPU adapter.'); } @@ -100,9 +107,9 @@ async function poseidon(input: BigInt) { code: shader }); - console.log(2) + //console.log(2) - //3: Create an output buffer to read GPU calculations to, and a staging + // 3: Create an output buffer to read GPU calculations to, and a staging //buffer to be mapped for JavaScript access const storageBuffer = device.createBuffer({ @@ -125,7 +132,7 @@ async function poseidon(input: BigInt) { usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST }); - console.log(3) + //console.log(3) // 4: Create a GPUBindGroupLayout to define the bind group structure, // create a GPUBindGroup from it, then use it to create a @@ -179,18 +186,18 @@ async function poseidon(input: BigInt) { entryPoint: 'main' } }); - console.log(4) + //console.log(4) // 5: Create GPUCommandEncoder to issue commands to the GPU const commandEncoder = device.createCommandEncoder(); - console.log(5) + //console.log(5) start = Date.now() // 6: Initiate render pass const passEncoder = commandEncoder.beginComputePass(); - console.log(6) + //console.log(6) // 7: Issue commands passEncoder.setPipeline(computePipeline); @@ -209,11 +216,11 @@ async function poseidon(input: BigInt) { INPUT_BUFFER_SIZE ); - console.log(7) + //console.log(7) // 8: End frame by passing array of command buffers to command queue for execution device.queue.submit([commandEncoder.finish()]); - console.log(7.1) + //console.log(7.1) // map staging buffer to read results back to JS await stagingBuffer.mapAsync( @@ -222,13 +229,13 @@ async function poseidon(input: BigInt) { 0, // Offset INPUT_BUFFER_SIZE // Length ); - console.log(7.2) + //console.log(7.2) const copyArrayBuffer = stagingBuffer.getMappedRange(0, INPUT_BUFFER_SIZE); const data = copyArrayBuffer.slice(); stagingBuffer.unmap(); - console.log(8) + //console.log(8) const dataBuf = new Uint32Array(data); elapsed = Date.now() - start @@ -237,13 +244,12 @@ async function poseidon(input: BigInt) { const results: BigInt[] = [] for (let i = 0; i < dataBuf.length / 16; i ++) { - const result = uint32ArrayToBigint(dataBuf.slice(i * 16, i * 16 + 16)) - results.push(result) + const result = BigInt(uint32ArrayToBigint(dataBuf.slice(i * 16, i * 16 + 16))) + results.push(result * ri % p) } - console.log(results) - console.log(expectedHashes) for (let i = 0; i < results.length; i ++) { - assert(results[i] === expectedHashes[i]) + let e = utils.leBuff2int(hasher.F.fromMontgomery(expectedHashes[i])); + assert(results[i] === e); } assert(results.length === expectedHashes.length) } @@ -273,7 +279,7 @@ const bytes_to_bigints = (limbs: Uint8Array): BigInt[] => { chunks.push(chunk); } - console.log(chunks); + //console.log(chunks); return [] }