Skip to content

Commit

Permalink
fix: gpu (metal) tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bobbinth committed Oct 31, 2023
1 parent d76da5d commit 770e1c5
Showing 1 changed file with 53 additions and 21 deletions.
74 changes: 53 additions & 21 deletions prover/src/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,36 +561,54 @@ mod tests {
use processor::{crypto::RpoRandomCoin, StackInputs, StackOutputs};
use winter_prover::math::fields::CubeExtension;

type CubeFelt = CubeExtension<Felt>;

#[test]
fn build_trace_commitment_on_gpu_with_padding_matches_cpu() {
let cpu_prover = create_test_prover();
let gpu_prover = MetalRpoExecutionProver(create_test_prover());
let num_rows = 1 << 8;
let trace_info = get_trace_info(1, num_rows);
let trace = gen_random_trace(num_rows, RPO_RATE + 1);
let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR);
let (cpu_lde, cpu_mt, cpu_polys) = cpu_prover.build_trace_commitment(&trace, &domain);

let (gpu_lde, gpu_mt, gpu_polys) = gpu_prover.build_trace_commitment(&trace, &domain);
let (cpu_trace_lde, cpu_polys) =
cpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);
let (gpu_trace_lde, gpu_polys) =
gpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);

assert_eq!(cpu_lde.data(), gpu_lde.data());
assert_eq!(cpu_mt.root(), gpu_mt.root());
assert_eq!(cpu_polys.into_columns(), gpu_polys.into_columns());
assert_eq!(
cpu_trace_lde.get_main_trace_commitment(),
gpu_trace_lde.get_main_trace_commitment()
);
assert_eq!(
cpu_polys.main_trace_polys().collect::<Vec<_>>(),
gpu_polys.main_trace_polys().collect::<Vec<_>>()
);
}

#[test]
fn build_trace_commitment_on_gpu_without_padding_matches_cpu() {
let cpu_prover = create_test_prover();
let gpu_prover = MetalRpoExecutionProver(create_test_prover());
let num_rows = 1 << 8;
let trace_info = get_trace_info(1, num_rows);
let trace = gen_random_trace(num_rows, RPO_RATE);
let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR);
let (cpu_lde, cpu_mt, cpu_polys) = cpu_prover.build_trace_commitment(&trace, &domain);

let (gpu_lde, gpu_mt, gpu_polys) = gpu_prover.build_trace_commitment(&trace, &domain);
let (cpu_trace_lde, cpu_polys) =
cpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);
let (gpu_trace_lde, gpu_polys) =
gpu_prover.new_trace_lde::<CubeFelt>(&trace_info, &trace, &domain);

assert_eq!(cpu_lde.data(), gpu_lde.data());
assert_eq!(cpu_mt.root(), gpu_mt.root());
assert_eq!(cpu_polys.into_columns(), gpu_polys.into_columns());
assert_eq!(
cpu_trace_lde.get_main_trace_commitment(),
gpu_trace_lde.get_main_trace_commitment()
);
assert_eq!(
cpu_polys.main_trace_polys().collect::<Vec<_>>(),
gpu_polys.main_trace_polys().collect::<Vec<_>>()
);
}

#[test]
Expand All @@ -599,15 +617,20 @@ mod tests {
let gpu_prover = MetalRpoExecutionProver(create_test_prover());
let num_rows = 1 << 8;
let ce_blowup_factor = 2;
let coeffs = gen_random_coeffs::<CubeExtension<Felt>>(num_rows * ce_blowup_factor);
let composition_poly = CompositionPoly::new(coeffs, num_rows, 2);
let values = get_random_values::<CubeFelt>(num_rows * ce_blowup_factor);
let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR);
let commitment_cpu = cpu_prover.build_constraint_commitment(&composition_poly, &domain);

let commitment_gpu = gpu_prover.build_constraint_commitment(&composition_poly, &domain);
let (commitment_cpu, composition_poly_cpu) = cpu_prover.build_constraint_commitment(
CompositionPolyTrace::new(values.clone()),
2,
&domain,
);
let (commitment_gpu, composition_poly_gpu) =
gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 2, &domain);

assert_eq!(commitment_cpu.root(), commitment_gpu.root());
assert_ne!(0, composition_poly.data().num_base_cols() % RPO_RATE);
assert_ne!(0, composition_poly_cpu.data().num_base_cols() % RPO_RATE);
assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns());
}

#[test]
Expand All @@ -616,25 +639,34 @@ mod tests {
let gpu_prover = MetalRpoExecutionProver(create_test_prover());
let num_rows = 1 << 8;
let ce_blowup_factor = 8;
let coeffs = gen_random_coeffs::<Felt>(num_rows * ce_blowup_factor);
let composition_poly = CompositionPoly::new(coeffs, num_rows, 8);
let values = get_random_values::<Felt>(num_rows * ce_blowup_factor);
let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR);
let commitment_cpu = cpu_prover.build_constraint_commitment(&composition_poly, &domain);

let commitment_gpu = gpu_prover.build_constraint_commitment(&composition_poly, &domain);
let (commitment_cpu, composition_poly_cpu) = cpu_prover.build_constraint_commitment(
CompositionPolyTrace::new(values.clone()),
8,
&domain,
);
let (commitment_gpu, composition_poly_gpu) =
gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 8, &domain);

assert_eq!(commitment_cpu.root(), commitment_gpu.root());
assert_eq!(0, composition_poly.data().num_base_cols() % RPO_RATE);
assert_eq!(0, composition_poly_cpu.data().num_base_cols() % RPO_RATE);
assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns());
}

fn gen_random_trace(num_rows: usize, num_cols: usize) -> ColMatrix<Felt> {
ColMatrix::new((0..num_cols as u64).map(|col| vec![Felt::new(col); num_rows]).collect())
}

fn gen_random_coeffs<E: FieldElement>(num_rows: usize) -> Vec<E> {
fn get_random_values<E: FieldElement>(num_rows: usize) -> Vec<E> {
(0..num_rows).map(|i| E::from(i as u32)).collect()
}

fn get_trace_info(num_cols: usize, num_rows: usize) -> TraceInfo {
TraceInfo::new(num_cols, num_rows)
}

fn create_test_prover() -> ExecutionProver<Rpo256, RpoRandomCoin> {
ExecutionProver::new(
ProvingOptions::with_128_bit_security(true),
Expand Down

0 comments on commit 770e1c5

Please sign in to comment.