Skip to content

Commit

Permalink
all test passed with build_trace_commitment_sync
Browse files Browse the repository at this point in the history
Signed-off-by: GopherJ <[email protected]>
  • Loading branch information
GopherJ committed Aug 20, 2024
1 parent cae2b90 commit cbdde53
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 72 deletions.
46 changes: 28 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 0 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,3 @@ inherits = "release"
debug = true
debug-assertions = true
overflow-checks = true

[patch.crates-io]
winter-prover = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-air = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-utils = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-verifier = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-fri = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-crypto = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-math = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-maybe-async = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-rand-utils = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
miden-crypto = { git = "https://github.com/GopherJ/miden-crypto", rev = "657b4922f3abe3577d1b0de1f6ae6ad52400cb11" }
6 changes: 3 additions & 3 deletions prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ crate-type = ["cdylib", "rlib"]

[features]
concurrent = ["processor/concurrent", "std", "winter-prover/concurrent"]
default = ["std"]
default = ["webgpu"]
metal = ["dep:miden-gpu", "dep:pollster", "concurrent", "std"]
webgpu = ["dep:miden-gpu", "miden-gpu/webgpu", "async"]
async = ["maybe-async/async", "winter-prover/async", "dep:async-trait"]
Expand All @@ -34,11 +34,11 @@ async-trait = { version = "0.1", optional = true }
elsa = { version = "1.9" }

[target.'cfg(all(target_arch = "aarch64", target_os = "macos"))'.dependencies]
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70", default-features = false, optional = true }
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "ee2e51bf1ed54a40212d60ddb55a5b97bfdfae5b", default-features = false, optional = true }
pollster = { version = "0.3", optional = true }

[target.'cfg(target_family = "wasm")'.dependencies]
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "b8b3a4273dd76a3af2ba8cb7739c9e20823a8f70", default-features = false, optional = true }
miden-gpu = { git = "https://github.com/0xPolygonMiden/miden-gpu", rev = "ee2e51bf1ed54a40212d60ddb55a5b97bfdfae5b", default-features = false, optional = true }

[target.'cfg(target_family = "wasm")'.dev-dependencies]
wasm-bindgen-test = "0.3"
Expand Down
85 changes: 47 additions & 38 deletions prover/src/gpu/webgpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ use air::{AuxRandElements, LagrangeKernelEvaluationFrame};
use elsa::FrozenVec;
use maybe_async::maybe_async;
use miden_gpu::{
webgpu::{build_merkle_tree, get_or_init_wgpu_helper, init_wgpu_helper, RowHasher},
webgpu::{build_merkle_tree, get_or_init_wgpu_helper, RowHasher},
HashFn,
};
use processor::{
crypto::{ElementHasher, Hasher},
ONE,
};
use tracing::info_span;
#[cfg(feature = "std")]
use tracing::{event, Level};
use winter_prover::{
crypto::{Digest, MerkleTree},
Expand Down Expand Up @@ -248,9 +249,10 @@ where
let num_base_columns =
composition_poly.num_columns() * <E as FieldElement>::EXTENSION_DEGREE;
let rpo_requires_padding = num_base_columns % RATE != 0;
let is_rpo = self.webgpu_hash_fn == HashFn::Rpo256;
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher =
RowHasher::new(helper, lde_domain_size, rpo_requires_padding, self.webgpu_hash_fn);
RowHasher::new(helper, lde_domain_size, num_base_columns, self.webgpu_hash_fn);
let rpo_padded_segment: Vec<[Felt; RATE]>;
for (segment_idx, segment) in segments.iter().enumerate() {
// check if the segment requires padding
Expand All @@ -260,14 +262,18 @@ where
// padded with "0"s we only need to add the "1"s.
let rpo_pad_column = num_base_columns % RATE;

rpo_padded_segment = segment
.iter()
.map(|x| {
let mut s = *x;
s[rpo_pad_column] = ONE;
s
})
.collect();
rpo_padded_segment = if is_rpo {
segment
.iter()
.map(|x| {
let mut s = *x;
s[rpo_pad_column] = ONE;
s
})
.collect()
} else {
segment.iter().map(|x| *x).collect()
};
row_hasher.update(helper, &rpo_padded_segment);
assert_eq!(segments.len() - 1, segment_idx, "padded segment should be the last");
break;
Expand Down Expand Up @@ -337,7 +343,6 @@ impl<
domain: &StarkDomain<Felt>,
webgpu_hash_fn: HashFn,
) -> (Self, TracePolyTable<E>) {
init_wgpu_helper().await;
// extend the main execution trace and build a Merkle tree from the extended trace
let (main_segment_lde, main_segment_tree, main_segment_polys) =
build_trace_commitment_sync::<E, Felt, H>(main_trace, domain);
Expand Down Expand Up @@ -410,28 +415,27 @@ impl<
/// This function will panic if any of the following are true:
/// - the number of rows in the provided `aux_trace` does not match the main trace.
/// - this segment would exceed the number of segments specified by the trace layout.
async fn set_aux_trace(
fn set_aux_trace(
&mut self,
aux_trace: &ColMatrix<E>,
domain: &StarkDomain<Felt>,
) -> (ColMatrix<E>, D) {
todo!()
// // extend the auxiliary trace segment and build a Merkle tree from the extended trace
// let (aux_segment_lde, aux_segment_tree, aux_segment_polys) =
// build_trace_commitment::<E, H, D>(aux_trace, domain, self.webgpu_hash_fn).await;
//
// assert_eq!(
// self.main_segment_lde.num_rows(),
// aux_segment_lde.num_rows(),
// "the number of rows in the auxiliary segment must be the same as in the main segment"
// );
//
// // save the lde and commitment
// self.aux_segment_lde = Some(aux_segment_lde);
// let root_hash = *aux_segment_tree.root();
// self.aux_segment_tree = Some(aux_segment_tree);
//
// (aux_segment_polys, root_hash)
// extend the auxiliary trace segment and build a Merkle tree from the extended trace
let (aux_segment_lde, aux_segment_tree, aux_segment_polys) =
build_trace_commitment_sync::<E, E, H>(aux_trace, domain);

assert_eq!(
self.main_segment_lde.num_rows(),
aux_segment_lde.num_rows(),
"the number of rows in the auxiliary segment must be the same as in the main segment"
);

// save the lde and commitment
self.aux_segment_lde = Some(aux_segment_lde);
let root_hash = *aux_segment_tree.root();
self.aux_segment_tree = Some(aux_segment_tree);

(aux_segment_polys, root_hash)
}

/// Reads current and next rows from the main trace segment into the specified frame.
Expand Down Expand Up @@ -610,8 +614,9 @@ async fn build_trace_commitment<
let lde_domain_size = domain.lde_domain_size();
let num_base_columns = trace.num_base_cols();
let rpo_requires_padding = num_base_columns % RATE != 0;
let is_rpo = hash_fn == HashFn::Rpo256;
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher = RowHasher::new(&helper, lde_domain_size, rpo_requires_padding, hash_fn);
let mut row_hasher = RowHasher::new(&helper, lde_domain_size, num_base_columns, hash_fn);
let rpo_padded_segment: Vec<[Felt; RATE]>;
let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain);
let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate();
Expand All @@ -623,14 +628,18 @@ async fn build_trace_commitment<
// rule ("1" followed by "0"s). Our segments are already
// padded with "0"s we only need to add the "1"s.
let rpo_pad_column = num_base_columns % RATE;
rpo_padded_segment = segment
.iter()
.map(|x| {
let mut s = *x;
s[rpo_pad_column] = ONE;
s
})
.collect();
rpo_padded_segment = if is_rpo {
segment
.iter()
.map(|x| {
let mut s = *x;
s[rpo_pad_column] = ONE;
s
})
.collect()
} else {
segment.iter().map(|x| *x).collect()
};
row_hasher.update(&helper, &rpo_padded_segment);
assert!(lde_segment_iter.next().is_none(), "padded segment should be the last");
break;
Expand Down
1 change: 0 additions & 1 deletion prover/src/gpu/webgpu/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use winter_prover::{crypto::Digest, math::fields::CubeExtension, CompositionPoly

use crate::*;

use wasm_bindgen::prelude::*;
use wasm_bindgen_test::*;

wasm_bindgen_test_configure!(run_in_browser);
Expand Down

0 comments on commit cbdde53

Please sign in to comment.