Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: GopherJ <[email protected]>
  • Loading branch information
GopherJ committed Aug 19, 2024
1 parent a807f8f commit cae2b90
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 39 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ crate-type = ["cdylib", "rlib"]

[features]
concurrent = ["processor/concurrent", "std", "winter-prover/concurrent"]
default = ["webgpu", "async", "std"]
default = ["std"]
metal = ["dep:miden-gpu", "dep:pollster", "concurrent", "std"]
webgpu = ["dep:miden-gpu", "miden-gpu/webgpu", "async"]
async = ["maybe-async/async", "winter-prover/async", "dep:async-trait"]
Expand All @@ -34,11 +34,11 @@ async-trait = { version = "0.1", optional = true }
elsa = { version = "1.9" }

[target.'cfg(all(target_arch = "aarch64", target_os = "macos"))'.dependencies]
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "9a55000ff8cc113c1f4acb42115c542c1bb8533c", default-features = false, optional = true }
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70", default-features = false, optional = true }
pollster = { version = "0.3", optional = true }

[target.'cfg(target_family = "wasm")'.dependencies]
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "9a55000ff8cc113c1f4acb42115c542c1bb8533c", default-features = false, optional = true }
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70", default-features = false, optional = true }

[target.'cfg(target_family = "wasm")'.dev-dependencies]
wasm-bindgen-test = "0.3"
Expand Down
64 changes: 29 additions & 35 deletions prover/src/gpu/webgpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@ extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use alloc::{vec, vec::Vec};

use air::{AuxRandElements, LagrangeKernelEvaluationFrame};
use elsa::FrozenVec;
use maybe_async::maybe_async;
use miden_gpu::{
webgpu::{build_merkle_tree, get_wgpu_helper, RowHasher},
webgpu::{build_merkle_tree, get_or_init_wgpu_helper, init_wgpu_helper, RowHasher},
HashFn,
};
use processor::{
crypto::{ElementHasher, Hasher},
ONE,
};
use tracing::info_span;
use tracing::{event, Level};
use winter_prover::{
crypto::{Digest, MerkleTree},
matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment},
Expand All @@ -37,7 +36,6 @@ use winter_prover::{
DefaultConstraintEvaluator, EvaluationFrame, Prover, StarkDomain, TraceInfo, TraceLde,
TracePolyTable,
};
use tracing::{event, Level};

use crate::{
crypto::{RandomCoin, Rpo256},
Expand All @@ -49,6 +47,13 @@ use crate::{
#[cfg(test)]
mod tests;

#[allow(clippy::uninit_vec)]
pub unsafe fn uninit_vector_real<T>(length: usize) -> Vec<T> {
let mut vector = Vec::with_capacity(length);
vector.set_len(length);
vector
}

// CONSTANTS
// ================================================================================================

Expand Down Expand Up @@ -102,15 +107,15 @@ where
assert!(poly_offset < polys.num_base_cols());

// allocate memory for the segment
// let data = if polys.num_base_cols() - poly_offset >= N {
// // if we will fill the entire segment, we allocate uninitialized memory
// unsafe { page_aligned_uninit_vector(domain_size) }
// } else {
// but if some columns in the segment will remain unfilled, we allocate memory
// initialized to zeros to make sure we don't end up with memory with
// undefined values
let data = alloc::vec![[E::BaseField::ZERO; N]; domain_size];
// };
let data = if polys.num_base_cols() - poly_offset >= N {
// if we will fill the entire segment, we allocate uninitialized memory
unsafe { uninit_vector_real(domain_size) }
} else {
// but if some columns in the segment will remain unfilled, we allocate memory
// initialized to zeros to make sure we don't end up with memory with
// undefined values
vec![[E::BaseField::ZERO; N]; domain_size]
};

Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
}
Expand Down Expand Up @@ -234,7 +239,7 @@ where
offsets.len().ilog2(),
now.elapsed().as_millis()
);
let helper = get_wgpu_helper().unwrap();
let helper = get_or_init_wgpu_helper().await;

// build constraint evaluation commitment
#[cfg(feature = "std")]
Expand All @@ -246,7 +251,7 @@ where
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher =
RowHasher::new(helper, lde_domain_size, rpo_requires_padding, self.webgpu_hash_fn);
let mut rpo_padded_segment: Vec<[Felt; RATE]>;
let rpo_padded_segment: Vec<[Felt; RATE]>;
for (segment_idx, segment) in segments.iter().enumerate() {
// check if the segment requires padding
if rpo_padded_segment_idx.map_or(false, |pad_idx| pad_idx == segment_idx) {
Expand All @@ -263,11 +268,6 @@ where
s
})
.collect();
// For rpx, skip this step
if self.webgpu_hash_fn == HashFn::Rpo256 {
let rpo_pad_column = num_base_columns % RATE;
rpo_padded_segment.iter_mut().for_each(|row| row[rpo_pad_column] = ONE);
}
row_hasher.update(helper, &rpo_padded_segment);
assert_eq!(segments.len() - 1, segment_idx, "padded segment should be the last");
break;
Expand Down Expand Up @@ -337,6 +337,7 @@ impl<
domain: &StarkDomain<Felt>,
webgpu_hash_fn: HashFn,
) -> (Self, TracePolyTable<E>) {
init_wgpu_helper().await;
// extend the main execution trace and build a Merkle tree from the extended trace
let (main_segment_lde, main_segment_tree, main_segment_polys) =
build_trace_commitment_sync::<E, Felt, H>(main_trace, domain);
Expand Down Expand Up @@ -602,16 +603,16 @@ async fn build_trace_commitment<
fft::interpolate_poly(&mut poly, &inv_twiddles);
poly
});
let helper = get_or_init_wgpu_helper().await;

// extend the execution trace and generate hashes on the gpu
let lde_segments = FrozenVec::new();
let lde_domain_size = domain.lde_domain_size();
let num_base_columns = trace.num_base_cols();
let rpo_requires_padding = num_base_columns % RATE != 0;
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher =
RowHasher::new(get_wgpu_helper().unwrap(), lde_domain_size, rpo_requires_padding, hash_fn);
let mut rpo_padded_segment: Vec<[Felt; RATE]>;
let mut row_hasher = RowHasher::new(&helper, lde_domain_size, rpo_requires_padding, hash_fn);
let rpo_padded_segment: Vec<[Felt; RATE]>;
let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain);
let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate();
for (segment_idx, segment) in &mut lde_segment_iter {
Expand All @@ -630,21 +631,14 @@ async fn build_trace_commitment<
s
})
.collect();
// skip this in case of Rpx
if hash_fn == HashFn::Rpo256 {
let rpo_pad_column = num_base_columns % RATE;
rpo_padded_segment.iter_mut().for_each(|row| row[rpo_pad_column] = ONE);
}
row_hasher.update(get_wgpu_helper().unwrap(), &rpo_padded_segment);
row_hasher.update(&helper, &rpo_padded_segment);
assert!(lde_segment_iter.next().is_none(), "padded segment should be the last");
break;
}
row_hasher.update(get_wgpu_helper().unwrap(), segment);
row_hasher.update(&helper, segment);
}
let row_hashes = row_hasher.finish(get_wgpu_helper().unwrap()).await.unwrap();
let tree_nodes = build_merkle_tree(get_wgpu_helper().unwrap(), &row_hashes, hash_fn)
.await
.unwrap();
let row_hashes = row_hasher.finish(&helper).await.unwrap();
let tree_nodes = build_merkle_tree(&helper, &row_hashes, hash_fn).await.unwrap();
// aggregate segments at the same time as the GPU generates the merkle tree nodes
let lde_segments = lde_segments.into_vec().into_iter().map(|p| *p).collect();
let trace_lde = RowMatrix::from_segments(lde_segments, num_base_columns);
Expand Down
1 change: 1 addition & 0 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use core::marker::PhantomData;
use maybe_async::maybe_async;

use air::{AuxRandElements, ProcessorAir, PublicInputs};
#[cfg(any(feature = "metal", feature = "webgpu"))]
use miden_gpu::HashFn;
use processor::{
crypto::{
Expand Down

0 comments on commit cae2b90

Please sign in to comment.