From 4d38d8046b88dc405a054396851d133234d9c440 Mon Sep 17 00:00:00 2001 From: cf Date: Fri, 15 Mar 2024 19:30:58 +0800 Subject: [PATCH 1/2] fixed page alignment issue in metal prover and added rpo test for fibonacci Signed-off-by: GopherJ --- miden/src/examples/fibonacci.rs | 8 +++- miden/src/examples/mod.rs | 13 ++++++- prover/src/gpu/metal/mod.rs | 66 ++++++++++++++++++++++++++++----- 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/miden/src/examples/fibonacci.rs b/miden/src/examples/fibonacci.rs index f81dcb32ff..eb76bb9e77 100644 --- a/miden/src/examples/fibonacci.rs +++ b/miden/src/examples/fibonacci.rs @@ -1,5 +1,5 @@ use super::{Example, ONE, ZERO}; -use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, StackInputs}; +use miden_vm::{math::Felt, Assembler, DefaultHost, MemAdviceProvider, Program, ProvingOptions, StackInputs}; // EXAMPLE BUILDER // ================================================================================================ @@ -68,3 +68,9 @@ fn test_fib_example_fail() { let example = get_example(16); super::test_example(example, true); } + +#[test] +fn test_fib_example_rpo() { + let example = get_example(16); + super::test_example_with_options(example, false, ProvingOptions::with_96_bit_security(true)); +} diff --git a/miden/src/examples/mod.rs b/miden/src/examples/mod.rs index 7086a1d8f0..4b7b2ce37c 100644 --- a/miden/src/examples/mod.rs +++ b/miden/src/examples/mod.rs @@ -162,7 +162,7 @@ impl ExampleOptions { // ================================================================================================ #[cfg(test)] -pub fn test_example(example: Example, fail: bool) +pub fn test_example_with_options(example: Example, fail: bool, options: ProvingOptions) where H: Host, { @@ -175,7 +175,7 @@ where } = example; let (mut outputs, proof) = - miden_vm::prove(&program, stack_inputs.clone(), host, ProvingOptions::default()).unwrap(); + miden_vm::prove(&program, stack_inputs.clone(), host, options).unwrap(); assert_eq!( expected_result, @@ -193,3 +193,12 @@ where assert!(miden_vm::verify(program_info, stack_inputs, outputs, proof).is_ok()); } } + + +#[cfg(test)] +pub fn test_example(example: Example, fail: bool) +where + H: Host, +{ + test_example_with_options(example, fail, ProvingOptions::default()); +} diff --git a/prover/src/gpu/metal/mod.rs b/prover/src/gpu/metal/mod.rs index 7dc6a3b511..af31bc339d 100644 --- a/prover/src/gpu/metal/mod.rs +++ b/prover/src/gpu/metal/mod.rs @@ -15,15 +15,11 @@ use miden_gpu::{ HashFn, }; use pollster::block_on; -use processor::{ - crypto::{ElementHasher, Hasher}, - ONE, -}; -use std::{boxed::Box, marker::PhantomData, time::Instant, vec::Vec}; -use tracing::{event, Level}; +use processor::{utils::group_vector_elements, ONE}; +use std::time::Instant; use winter_prover::{ crypto::{Digest, MerkleTree}, - matrix::{build_segments, get_evaluation_offsets, ColMatrix, RowMatrix, Segment}, + matrix::{get_evaluation_offsets, ColMatrix, RowMatrix, Segment}, proof::Queries, AuxTraceRandElements, CompositionPoly, CompositionPolyTrace, ConstraintCommitment, ConstraintCompositionCoefficients, DefaultConstraintEvaluator, EvaluationFrame, Prover, @@ -61,13 +57,65 @@ where D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, R: RandomCoin, { - pub fn new(execution_prover: ExecutionProver, hash_fn: HashFn) -> Self { + pub fn new(execution_prover: ExecutionProver, hash_fn: HashFn) -> Self { MetalExecutionProver { execution_prover, metal_hash_fn: hash_fn, phantom_data: PhantomData, } } + + fn build_aligned_segement( + polys: &ColMatrix, + poly_offset: usize, + offsets: &[Felt], + twiddles: &[Felt], + ) -> Segment + where + E: FieldElement, + { + let poly_size = polys.num_rows(); + let domain_size = offsets.len(); + assert!(domain_size.is_power_of_two()); + assert!(domain_size > poly_size); + assert_eq!(poly_size, twiddles.len() * 2); + assert!(poly_offset < polys.num_base_cols()); + + // allocate memory for the segment + let data = if polys.num_base_cols() - poly_offset >= N { + // if we will fill the entire segment, we allocate uninitialized memory + unsafe { page_aligned_uninit_vector(domain_size) } + } else { + // but if some columns in the segment will remain unfilled, we allocate memory initialized + // to zeros to make sure we don't end up with memory with undefined values + group_vector_elements(Felt::zeroed_vector(N * domain_size)) + }; + + Segment::new_with_buffer(data, polys, poly_offset, offsets, twiddles) + } + + fn build_aligned_segements( + polys: &ColMatrix, + twiddles: &[Felt], + offsets: &[Felt], + ) -> Vec> + where + E: FieldElement, + { + assert!(N > 0, "batch size N must be greater than zero"); + debug_assert_eq!(polys.num_rows(), twiddles.len() * 2); + debug_assert_eq!(offsets.len() % polys.num_rows(), 0); + + let num_segments = if polys.num_base_cols() % N == 0 { + polys.num_base_cols() / N + } else { + polys.num_base_cols() / N + 1 + }; + + (0..num_segments) + .map(|i| Self::build_aligned_segement(polys, i * N, offsets, twiddles)) + .collect() + } } impl Prover for MetalExecutionProver @@ -149,7 +197,7 @@ where let blowup = domain.trace_to_lde_blowup(); let offsets = get_evaluation_offsets::(composition_poly.column_len(), blowup, domain.offset()); - let segments = build_segments(composition_poly.data(), domain.trace_twiddles(), &offsets); + let segments = Self::build_aligned_segements(composition_poly.data(), domain.trace_twiddles(), &offsets); event!( Level::INFO, "Evaluated {} composition polynomial columns over LDE domain (2^{} elements) in {} ms", From d7862cdbce7e29a7ce883ea5a51fa93d6810b5ff Mon Sep 17 00:00:00 2001 From: GopherJ Date: Wed, 12 Jun 2024 16:39:34 +0800 Subject: [PATCH 2/2] fix format Signed-off-by: GopherJ --- prover/src/gpu/metal/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prover/src/gpu/metal/mod.rs b/prover/src/gpu/metal/mod.rs index af31bc339d..1b233eae71 100644 --- a/prover/src/gpu/metal/mod.rs +++ b/prover/src/gpu/metal/mod.rs @@ -57,7 +57,7 @@ where D: Digest + for<'a> From<&'a [Felt; DIGEST_SIZE]>, R: RandomCoin, { - pub fn new(execution_prover: ExecutionProver, hash_fn: HashFn) -> Self { + pub fn new(execution_prover: ExecutionProver, hash_fn: HashFn) -> Self { MetalExecutionProver { execution_prover, metal_hash_fn: hash_fn,