diff --git a/Cargo.lock b/Cargo.lock index 7c9503073..292c0b4d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1414,7 +1414,7 @@ dependencies = [ [[package]] name = "miden-gpu" version = "0.2.0" -source = "git+https://github.com/0xPolygonMiden/miden-gpu?rev=9a55000ff8cc113c1f4acb42115c542c1bb8533c#9a55000ff8cc113c1f4acb42115c542c1bb8533c" +source = "git+https://github.com/0xPolygonMiden/miden-gpu?rev=b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70#b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70" dependencies = [ "bytemuck", "elsa", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 96e859ab4..dbb944ea9 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -18,7 +18,7 @@ crate-type = ["cdylib", "rlib"] [features] concurrent = ["processor/concurrent", "std", "winter-prover/concurrent"] -default = ["webgpu", "async", "std"] +default = ["std"] metal = ["dep:miden-gpu", "dep:pollster", "concurrent", "std"] webgpu = ["dep:miden-gpu", "miden-gpu/webgpu", "async"] async = ["maybe-async/async", "winter-prover/async", "dep:async-trait"] @@ -34,11 +34,11 @@ async-trait = { version = "0.1", optional = true } elsa = { version = "1.9" } [target.'cfg(all(target_arch = "aarch64", target_os = "macos"))'.dependencies] -miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "9a55000ff8cc113c1f4acb42115c542c1bb8533c", default-features = false, optional = true } +miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70", default-features = false, optional = true } pollster = { version = "0.3", optional = true } [target.'cfg(target_family = "wasm")'.dependencies] -miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "9a55000ff8cc113c1f4acb42115c542c1bb8533c", default-features = false, optional = true } +miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70", default-features = false, optional = true } [target.'cfg(target_family = "wasm")'.dev-dependencies] wasm-bindgen-test = "0.3" diff --git a/prover/src/gpu/webgpu/mod.rs b/prover/src/gpu/webgpu/mod.rs index ae44f45c1..f1ef5d4b2 100644 --- a/prover/src/gpu/webgpu/mod.rs +++ b/prover/src/gpu/webgpu/mod.rs @@ -13,15 +13,13 @@ extern crate alloc; #[cfg(not(feature = "std"))] use alloc::boxed::Box; #[cfg(not(feature = "std"))] -use alloc::vec; -#[cfg(not(feature = "std"))] -use alloc::vec::Vec; +use alloc::{vec, vec::Vec}; use air::{AuxRandElements, LagrangeKernelEvaluationFrame}; use elsa::FrozenVec; use maybe_async::maybe_async; use miden_gpu::{ - webgpu::{build_merkle_tree, get_wgpu_helper, RowHasher}, + webgpu::{build_merkle_tree, get_or_init_wgpu_helper, init_wgpu_helper, RowHasher}, HashFn, }; use processor::{ @@ -29,6 +27,7 @@ use processor::{ ONE, }; use tracing::info_span; +use tracing::{event, Level}; use winter_prover::{ crypto::{Digest, MerkleTree}, matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment}, @@ -37,7 +36,6 @@ use winter_prover::{ DefaultConstraintEvaluator, EvaluationFrame, Prover, StarkDomain, TraceInfo, TraceLde, TracePolyTable, }; -use tracing::{event, Level}; use crate::{ crypto::{RandomCoin, Rpo256}, @@ -49,6 +47,13 @@ use crate::{ #[cfg(test)] mod tests; +#[allow(clippy::uninit_vec)] +pub unsafe fn uninit_vector_real(length: usize) -> Vec { + let mut vector = Vec::with_capacity(length); + vector.set_len(length); + vector +} + // CONSTANTS // ================================================================================================ @@ -102,15 +107,15 @@ where assert!(poly_offset < polys.num_base_cols()); // allocate memory for the segment - // let data = if polys.num_base_cols() - poly_offset >= N { - // // if we will fill the entire segment, we allocate uninitialized memory - // unsafe { page_aligned_uninit_vector(domain_size) } - // } else { - // but if some columns in the segment will remain unfilled, we allocate memory - // initialized to zeros to make sure we don't end up with memory with - // undefined values - let data = alloc::vec![[E::BaseField::ZERO; N]; domain_size]; - // }; + let data = if polys.num_base_cols() - poly_offset >= N { + // if we will fill the entire segment, we allocate uninitialized memory + unsafe { uninit_vector_real(domain_size) } + } else { + // but if some columns in the segment will remain unfilled, we allocate memory + // initialized to zeros to make sure we don't end up with memory with + // undefined values + vec![[E::BaseField::ZERO; N]; domain_size] + }; Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles) } @@ -234,7 +239,7 @@ where offsets.len().ilog2(), now.elapsed().as_millis() ); - let helper = get_wgpu_helper().unwrap(); + let helper = get_or_init_wgpu_helper().await; // build constraint evaluation commitment #[cfg(feature = "std")] @@ -246,7 +251,7 @@ where let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE); let mut row_hasher = RowHasher::new(helper, lde_domain_size, rpo_requires_padding, self.webgpu_hash_fn); - let mut rpo_padded_segment: Vec<[Felt; RATE]>; + let rpo_padded_segment: Vec<[Felt; RATE]>; for (segment_idx, segment) in segments.iter().enumerate() { // check if the segment requires padding if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) { @@ -263,11 +268,6 @@ where s }) .collect(); - // For rpx, skip this step - if self.webgpu_hash_fn == HashFn::Rpo256 { - let rpo_pad_column = num_base_columns % RATE; - rpo_padded_segment.iter_mut().for_each(|row| row[rpo_pad_column] = ONE); - } row_hasher.update(helper, &rpo_padded_segment); assert_eq!(segments.len() - 1, segment_idx, "padded segment should be the last"); break; @@ -337,6 +337,7 @@ impl< domain: &StarkDomain, webgpu_hash_fn: HashFn, ) -> (Self, TracePolyTable) { + init_wgpu_helper().await; // extend the main execution trace and build a Merkle tree from the extended trace let (main_segment_lde, main_segment_tree, main_segment_polys) = build_trace_commitment_sync::(main_trace, domain); @@ -602,6 +603,7 @@ async fn build_trace_commitment< fft::interpolate_poly(&mut poly, &inv_twiddles); poly }); + let helper = get_or_init_wgpu_helper().await; // extend the execution trace and generate hashes on the gpu let lde_segments = FrozenVec::new(); @@ -609,9 +611,8 @@ async fn build_trace_commitment< let num_base_columns = trace.num_base_cols(); let rpo_requires_padding = num_base_columns % RATE != 0; let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE); - let mut row_hasher = - RowHasher::new(get_wgpu_helper().unwrap(), lde_domain_size, rpo_requires_padding, hash_fn); - let mut rpo_padded_segment: Vec<[Felt; RATE]>; + let mut row_hasher = RowHasher::new(&helper, lde_domain_size, rpo_requires_padding, hash_fn); + let rpo_padded_segment: Vec<[Felt; RATE]>; let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain); let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate(); for (segment_idx, segment) in &mut lde_segment_iter { @@ -630,21 +631,14 @@ async fn build_trace_commitment< s }) .collect(); - // skip this in case of Rpx - if hash_fn == HashFn::Rpo256 { - let rpo_pad_column = num_base_columns % RATE; - rpo_padded_segment.iter_mut().for_each(|row| row[rpo_pad_column] = ONE); - } - row_hasher.update(get_wgpu_helper().unwrap(), &rpo_padded_segment); + row_hasher.update(&helper, &rpo_padded_segment); assert!(lde_segment_iter.next().is_none(), "padded segment should be the last"); break; } - row_hasher.update(get_wgpu_helper().unwrap(), segment); + row_hasher.update(&helper, segment); } - let row_hashes = row_hasher.finish(get_wgpu_helper().unwrap()).await.unwrap(); - let tree_nodes = build_merkle_tree(get_wgpu_helper().unwrap(), &row_hashes, hash_fn) - .await - .unwrap(); + let row_hashes = row_hasher.finish(&helper).await.unwrap(); + let tree_nodes = build_merkle_tree(&helper, &row_hashes, hash_fn).await.unwrap(); // aggregate segments at the same time as the GPU generates the merkle tree nodes let lde_segments = lde_segments.into_vec().into_iter().map(|p| *p).collect(); let trace_lde = RowMatrix::from_segments(lde_segments, num_base_columns); diff --git a/prover/src/lib.rs b/prover/src/lib.rs index ee0077dd7..69ccc8809 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -9,6 +9,7 @@ use core::marker::PhantomData; use maybe_async::maybe_async; use air::{AuxRandElements, ProcessorAir, PublicInputs}; +#[cfg(any(feature = "metal", feature = "webgpu"))] use miden_gpu::HashFn; use processor::{ crypto::{