Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
Signed-off-by: GopherJ <[email protected]>
  • Loading branch information
GopherJ committed Aug 19, 2024
1 parent 873a49c commit c604954
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 62 deletions.
19 changes: 4 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@ winter-verifier = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a0
winter-fri = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-crypto = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-math = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-maybe-async = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
winter-rand-utils = { git = "https://github.com/GopherJ/winterfell", rev = "50a9a088d30ff58e88cf19805f4dfa13b4be356e" }
miden-crypto = { git = "https://github.com/GopherJ/miden-crypto", rev = "657b4922f3abe3577d1b0de1f6ae6ad52400cb11" }
4 changes: 2 additions & 2 deletions prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ edition.workspace = true

[features]
concurrent = ["processor/concurrent", "std", "winter-prover/concurrent"]
default = ["webgpu"]
default = ["std"]
metal = ["dep:miden-gpu", "dep:elsa", "dep:pollster", "concurrent", "std"]
webgpu = ["dep:miden-gpu", "dep:elsa", "miden-gpu/webgpu", "async"]
async = ["maybe-async/async", "winter-prover/async", "winter-prover/async-trait", "dep:async-trait"]
async = ["maybe-async/async", "winter-prover/async", "dep:async-trait"]
std = ["air/std", "processor/std", "winter-prover/std", "miden-gpu/std", "tracing/std"]

[dependencies]
Expand Down
101 changes: 68 additions & 33 deletions prover/src/gpu/webgpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,24 @@
//! For now, the logic is limited to GPU accelerating trace and constraint commitments,
//! using the RPO 256 or RPX 256 hash functions.

#[cfg(feature = "std")]
use std::{boxed::Box, marker::PhantomData, time::Instant, vec::Vec};

#[cfg(not(feature = "std"))]
use core::marker::PhantomData;
#[cfg(feature = "std")]
use std::{boxed::Box, marker::PhantomData, time::Instant, vec::Vec};

#[cfg(not(feature = "std"))]
extern crate alloc;

#[cfg(not(feature = "std"))]
use alloc::boxed::Box;
#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use alloc::boxed::Box;

use air::{AuxRandElements, LagrangeKernelEvaluationFrame};
use elsa::FrozenVec;
use maybe_async::maybe_async;

use air::{AuxRandElements, LagrangeKernelEvaluationFrame};
use miden_gpu::{
webgpu::{
build_merkle_tree, get_dispatch_linear, get_wgpu_helper, init_wgpu_helper, RowHasher,
Expand All @@ -33,7 +31,7 @@ use processor::{
crypto::{ElementHasher, Hasher},
ONE,
};
use tracing::{event, Level};
use tracing::{event, info_span, Level};
use winter_prover::{
crypto::{Digest, MerkleTree},
matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment},
Expand Down Expand Up @@ -342,10 +340,9 @@ impl<
domain: &StarkDomain<Felt>,
webgpu_hash_fn: HashFn,
) -> (Self, TracePolyTable<E>) {
init_wgpu_helper().await;
// extend the main execution trace and build a Merkle tree from the extended trace
let (main_segment_lde, main_segment_tree, main_segment_polys) =
build_trace_commitment(main_trace, domain, webgpu_hash_fn).await;
build_trace_commitment_sync::<E, Felt, H>(main_trace, domain);

let trace_poly_table = TracePolyTable::new(main_segment_polys);
let trace_lde = WebGPUTraceLde {
Expand Down Expand Up @@ -420,22 +417,23 @@ impl<
aux_trace: &ColMatrix<E>,
domain: &StarkDomain<Felt>,
) -> (ColMatrix<E>, D) {
// extend the auxiliary trace segment and build a Merkle tree from the extended trace
let (aux_segment_lde, aux_segment_tree, aux_segment_polys) =
build_trace_commitment::<E, H, D>(aux_trace, domain, self.webgpu_hash_fn).await;

assert_eq!(
self.main_segment_lde.num_rows(),
aux_segment_lde.num_rows(),
"the number of rows in the auxiliary segment must be the same as in the main segment"
);

// save the lde and commitment
self.aux_segment_lde = Some(aux_segment_lde);
let root_hash = *aux_segment_tree.root();
self.aux_segment_tree = Some(aux_segment_tree);

(aux_segment_polys, root_hash)
todo!()
// // extend the auxiliary trace segment and build a Merkle tree from the extended trace
// let (aux_segment_lde, aux_segment_tree, aux_segment_polys) =
// build_trace_commitment::<E, H, D>(aux_trace, domain, self.webgpu_hash_fn).await;
//
// assert_eq!(
// self.main_segment_lde.num_rows(),
// aux_segment_lde.num_rows(),
// "the number of rows in the auxiliary segment must be the same as in the main segment"
// );
//
// // save the lde and commitment
// self.aux_segment_lde = Some(aux_segment_lde);
// let root_hash = *aux_segment_tree.root();
// self.aux_segment_tree = Some(aux_segment_tree);
//
// (aux_segment_polys, root_hash)
}

/// Reads current and next rows from the main trace segment into the specified frame.
Expand Down Expand Up @@ -550,6 +548,45 @@ impl<
/// ────┼────────┼────────┼────────┼────────┼────────┼────
/// t=n t=n+1 t=n+2 t=n+3 t=n+4 t=n+5
/// ```
const DEFAULT_SEGMENT_WIDTH: usize = 8;

fn build_trace_commitment_sync<E, F, H>(
trace: &ColMatrix<F>,
domain: &StarkDomain<E::BaseField>,
) -> (RowMatrix<F>, MerkleTree<H>, ColMatrix<F>)
where
E: FieldElement,
F: FieldElement<BaseField = E::BaseField>,
H: ElementHasher<BaseField = E::BaseField>,
{
// extend the execution trace
let (trace_lde, trace_polys) = {
let span = info_span!(
"extend_execution_trace",
num_cols = trace.num_cols(),
blowup = domain.trace_to_lde_blowup()
)
.entered();
let trace_polys = trace.interpolate_columns();
let trace_lde =
RowMatrix::evaluate_polys_over::<DEFAULT_SEGMENT_WIDTH>(&trace_polys, domain);
drop(span);

(trace_lde, trace_polys)
};
assert_eq!(trace_lde.num_cols(), trace.num_cols());
assert_eq!(trace_polys.num_rows(), trace.num_rows());
assert_eq!(trace_lde.num_rows(), domain.lde_domain_size());

// build trace commitment
let tree_depth = trace_lde.num_rows().ilog2() as usize;
let trace_tree = info_span!("compute_execution_trace_commitment", tree_depth)
.in_scope(|| trace_lde.commit_to_rows());
assert_eq!(trace_tree.depth(), tree_depth);

(trace_lde, trace_tree, trace_polys)
}

async fn build_trace_commitment<
E: FieldElement<BaseField = Felt>,
H: Hasher<Digest = D> + ElementHasher<BaseField = E::BaseField>,
Expand All @@ -575,12 +612,8 @@ async fn build_trace_commitment<
let num_base_columns = trace.num_base_cols();
let rpo_requires_padding = num_base_columns % RATE != 0;
let rpo_padded_segment_idx = rpo_requires_padding.then_some(num_base_columns / RATE);
let mut row_hasher = RowHasher::new(
&get_wgpu_helper().unwrap(),
num_base_columns,
rpo_requires_padding,
hash_fn,
);
let mut row_hasher =
RowHasher::new(&get_wgpu_helper().unwrap(), lde_domain_size, rpo_requires_padding, hash_fn);
let mut rpo_padded_segment: Vec<[Felt; RATE]>;
let mut lde_segment_generator = SegmentGenerator::new(trace_polys, domain);
let mut lde_segment_iter = lde_segment_generator.gen_segment_iter().enumerate();
Expand Down Expand Up @@ -612,7 +645,9 @@ async fn build_trace_commitment<
row_hasher.update(&get_wgpu_helper().unwrap(), segment);
}
let row_hashes = row_hasher.finish(&get_wgpu_helper().unwrap()).await.unwrap();
let tree_nodes = build_merkle_tree(&get_wgpu_helper().unwrap(), &row_hashes, hash_fn).await.unwrap();
let tree_nodes = build_merkle_tree(&get_wgpu_helper().unwrap(), &row_hashes, hash_fn)
.await
.unwrap();
// aggregate segments at the same time as the GPU generates the merkle tree nodes
let lde_segments = lde_segments.into_vec().into_iter().map(|p| *p).collect();
let trace_lde = RowMatrix::from_segments(lde_segments, num_base_columns);
Expand Down
24 changes: 12 additions & 12 deletions prover/src/gpu/webgpu/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::*;

type CubeFelt = CubeExtension<Felt>;

fn build_trace_commitment_on_gpu_with_padding_matches_cpu<
async fn build_trace_commitment_on_gpu_with_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
Expand All @@ -30,9 +30,9 @@ fn build_trace_commitment_on_gpu_with_padding_matches_cpu<
let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR);

let (cpu_trace_lde, cpu_polys) =
cpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);
cpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain).await;
let (gpu_trace_lde, gpu_polys) =
gpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);
gpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain).await;

assert_eq!(
cpu_trace_lde.get_main_trace_commitment(),
Expand All @@ -44,7 +44,7 @@ fn build_trace_commitment_on_gpu_with_padding_matches_cpu<
);
}

fn build_trace_commitment_on_gpu_without_padding_matches_cpu<
async fn build_trace_commitment_on_gpu_without_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
Expand All @@ -61,9 +61,9 @@ fn build_trace_commitment_on_gpu_without_padding_matches_cpu<
let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR);

let (cpu_trace_lde, cpu_polys) =
cpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);
cpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain).await;
let (gpu_trace_lde, gpu_polys) =
gpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);
gpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain).await;

assert_eq!(
cpu_trace_lde.get_main_trace_commitment(),
Expand All @@ -75,7 +75,7 @@ fn build_trace_commitment_on_gpu_without_padding_matches_cpu<
);
}

fn build_constraint_commitment_on_gpu_with_padding_matches_cpu<
async fn build_constraint_commitment_on_gpu_with_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
Expand All @@ -95,16 +95,16 @@ fn build_constraint_commitment_on_gpu_with_padding_matches_cpu<
CompositionPolyTrace::new(values.clone()),
2,
&domain,
);
).await;
let (commitment_gpu, composition_poly_gpu) =
gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 2, &domain);
gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 2, &domain).await;

assert_eq!(commitment_cpu.root(), commitment_gpu.root());
assert_ne!(0, composition_poly_cpu.data().num_base_cols() % RATE);
assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns());
}

fn build_constraint_commitment_on_gpu_without_padding_matches_cpu<
async fn build_constraint_commitment_on_gpu_without_padding_matches_cpu<
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
H: ElementHasher<BaseField = Felt> + Hasher<Digest = D>,
D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>,
Expand All @@ -124,9 +124,9 @@ fn build_constraint_commitment_on_gpu_without_padding_matches_cpu<
CompositionPolyTrace::new(values.clone()),
8,
&domain,
);
).await;
let (commitment_gpu, composition_poly_gpu) =
gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 8, &domain);
gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 8, &domain).await;

assert_eq!(commitment_cpu.root(), commitment_gpu.root());
assert_eq!(0, composition_poly_cpu.data().num_base_cols() % RATE);
Expand Down

0 comments on commit c604954

Please sign in to comment.