Skip to content

Commit

Permalink
fix:segment page alignment issue in metal prover
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbinth authored Jun 20, 2024
2 parents 34fde66 + d7862cd commit 978c142
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
8 changes: 7 additions & 1 deletion miden/src/examples/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{Example, ONE, ZERO};
use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, StackInputs};
use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs};

Check warning on line 2 in miden/src/examples/fibonacci.rs

View workflow job for this annotation

GitHub Actions / Check Rust stable on ubuntu with --all-targets --all-features

unused import: `ProvingOptions`

// EXAMPLE BUILDER
// ================================================================================================
Expand Down Expand Up @@ -68,3 +68,9 @@ fn test_fib_example_fail() {
let example = get_example(16);
super::test_example(example, true);
}

#[test]
fn test_fib_example_rpo() {
let example = get_example(16);
super::test_example_with_options(example, false, ProvingOptions::with_96_bit_security(true));
}
13 changes: 11 additions & 2 deletions miden/src/examples/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl ExampleOptions {
// ================================================================================================

#[cfg(test)]
pub fn test_example<H>(example: Example<H>, fail: bool)
pub fn test_example_with_options<H>(example: Example<H>, fail: bool, options: ProvingOptions)
where
H: Host,
{
Expand All @@ -175,7 +175,7 @@ where
} = example;

let (mut outputs, proof) =
miden_vm::prove(&program, stack_inputs.clone(), host, ProvingOptions::default()).unwrap();
miden_vm::prove(&program, stack_inputs.clone(), host, options).unwrap();

assert_eq!(
expected_result,
Expand All @@ -193,3 +193,12 @@ where
assert!(miden_vm::verify(program_info, stack_inputs, outputs, proof).is_ok());
}
}


#[cfg(test)]
pub fn test_example<H>(example: Example<H>, fail: bool)
where
H: Host,
{
test_example_with_options(example, fail, ProvingOptions::default());
}
64 changes: 56 additions & 8 deletions prover/src/gpu/metal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@ use miden_gpu::{
HashFn,
};
use pollster::block_on;
use processor::{
crypto::{ElementHasher, Hasher},
ONE,
};
use std::{boxed::Box, marker::PhantomData, time::Instant, vec::Vec};
use tracing::{event, Level};
use processor::{utils::group_vector_elements, ONE};
use std::time::Instant;
use winter_prover::{
crypto::{Digest, MerkleTree},
matrix::{build_segments, get_evaluation_offsets, ColMatrix, RowMatrix, Segment},
matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment},
proof::Queries,
AuxTraceRandElements, CompositionPoly, CompositionPolyTrace, ConstraintCommitment,
ConstraintCompositionCoefficients, DefaultConstraintEvaluator, EvaluationFrame, Prover,
Expand Down Expand Up @@ -68,6 +64,58 @@ where
phantom_data: PhantomData,
}
}

fn build_aligned_segement<E, const N: usize>(
polys: &ColMatrix<E>,
poly_offset: usize,
offsets: &[Felt],
twiddles: &[Felt],
) -> Segment<Felt, N>
where
E: FieldElement<BaseField = Felt>,
{
let poly_size = polys.num_rows();
let domain_size = offsets.len();
assert!(domain_size.is_power_of_two());
assert!(domain_size > poly_size);
assert_eq!(poly_size, twiddles.len() * 2);
assert!(poly_offset < polys.num_base_cols());

// allocate memory for the segment
let data = if polys.num_base_cols() - poly_offset >= N {
// if we will fill the entire segment, we allocate uninitialized memory
unsafe { page_aligned_uninit_vector(domain_size) }
} else {
// but if some columns in the segment will remain unfilled, we allocate memory initialized
// to zeros to make sure we don't end up with memory with undefined values
group_vector_elements(Felt::zeroed_vector(N * domain_size))
};

Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles)
}

fn build_aligned_segements<E, const N: usize>(
polys: &ColMatrix<E>,
twiddles: &[Felt],
offsets: &[Felt],
) -> Vec<Segment<Felt, N>>
where
E: FieldElement<BaseField = Felt>,
{
assert!(N > 0, "batch size N must be greater than zero");
debug_assert_eq!(polys.num_rows(), twiddles.len() * 2);
debug_assert_eq!(offsets.len() % polys.num_rows(), 0);

let num_segments = if polys.num_base_cols() % N == 0 {
polys.num_base_cols() / N
} else {
polys.num_base_cols() / N + 1
};

(0..num_segments)
.map(|i| Self::build_aligned_segement(polys, i * N, offsets, twiddles))
.collect()
}
}

impl<H, D, R> Prover for MetalExecutionProver<H, D, R>
Expand Down Expand Up @@ -149,7 +197,7 @@ where
let blowup = domain.trace_to_lde_blowup();
let offsets =
get_evaluation_offsets::<E>(composition_poly.column_len(), blowup, domain.offset());
let segments = build_segments(composition_poly.data(), domain.trace_twiddles(), &offsets);
let segments = Self::build_aligned_segements(composition_poly.data(), domain.trace_twiddles(), &offsets);
event!(
Level::INFO,
"Evaluated {} composition polynomial columns over LDE domain (2^{} elements) in {} ms",
Expand Down

0 comments on commit 978c142

Please sign in to comment.