diff --git a/prover/src/gpu.rs b/prover/src/gpu.rs index ad1ac70d4f..530d18fa74 100644 --- a/prover/src/gpu.rs +++ b/prover/src/gpu.rs @@ -561,20 +561,30 @@ mod tests { use processor::{crypto::RpoRandomCoin, StackInputs, StackOutputs}; use winter_prover::math::fields::CubeExtension; + type CubeFelt = CubeExtension; + #[test] fn build_trace_commitment_on_gpu_with_padding_matches_cpu() { let cpu_prover = create_test_prover(); let gpu_prover = MetalRpoExecutionProver(create_test_prover()); let num_rows = 1 << 8; + let trace_info = get_trace_info(1, num_rows); let trace = gen_random_trace(num_rows, RPO_RATE + 1); let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR); - let (cpu_lde, cpu_mt, cpu_polys) = cpu_prover.build_trace_commitment(&trace, &domain); - let (gpu_lde, gpu_mt, gpu_polys) = gpu_prover.build_trace_commitment(&trace, &domain); + let (cpu_trace_lde, cpu_polys) = + cpu_prover.new_trace_lde::(&trace_info, &trace, &domain); + let (gpu_trace_lde, gpu_polys) = + gpu_prover.new_trace_lde::(&trace_info, &trace, &domain); - assert_eq!(cpu_lde.data(), gpu_lde.data()); - assert_eq!(cpu_mt.root(), gpu_mt.root()); - assert_eq!(cpu_polys.into_columns(), gpu_polys.into_columns()); + assert_eq!( + cpu_trace_lde.get_main_trace_commitment(), + gpu_trace_lde.get_main_trace_commitment() + ); + assert_eq!( + cpu_polys.main_trace_polys().collect::>(), + gpu_polys.main_trace_polys().collect::>() + ); } #[test] @@ -582,15 +592,23 @@ mod tests { let cpu_prover = create_test_prover(); let gpu_prover = MetalRpoExecutionProver(create_test_prover()); let num_rows = 1 << 8; + let trace_info = get_trace_info(1, num_rows); let trace = gen_random_trace(num_rows, RPO_RATE); let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR); - let (cpu_lde, cpu_mt, cpu_polys) = cpu_prover.build_trace_commitment(&trace, &domain); - let (gpu_lde, gpu_mt, gpu_polys) = gpu_prover.build_trace_commitment(&trace, &domain); + let (cpu_trace_lde, cpu_polys) = + cpu_prover.new_trace_lde::(&trace_info, &trace, &domain); + let (gpu_trace_lde, gpu_polys) = + gpu_prover.new_trace_lde::(&trace_info, &trace, &domain); - assert_eq!(cpu_lde.data(), gpu_lde.data()); - assert_eq!(cpu_mt.root(), gpu_mt.root()); - assert_eq!(cpu_polys.into_columns(), gpu_polys.into_columns()); + assert_eq!( + cpu_trace_lde.get_main_trace_commitment(), + gpu_trace_lde.get_main_trace_commitment() + ); + assert_eq!( + cpu_polys.main_trace_polys().collect::>(), + gpu_polys.main_trace_polys().collect::>() + ); } #[test] @@ -599,15 +617,20 @@ mod tests { let gpu_prover = MetalRpoExecutionProver(create_test_prover()); let num_rows = 1 << 8; let ce_blowup_factor = 2; - let coeffs = gen_random_coeffs::>(num_rows * ce_blowup_factor); - let composition_poly = CompositionPoly::new(coeffs, num_rows, 2); + let values = get_random_values::(num_rows * ce_blowup_factor); let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR); - let commitment_cpu = cpu_prover.build_constraint_commitment(&composition_poly, &domain); - let commitment_gpu = gpu_prover.build_constraint_commitment(&composition_poly, &domain); + let (commitment_cpu, composition_poly_cpu) = cpu_prover.build_constraint_commitment( + CompositionPolyTrace::new(values.clone()), + 2, + &domain, + ); + let (commitment_gpu, composition_poly_gpu) = + gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 2, &domain); assert_eq!(commitment_cpu.root(), commitment_gpu.root()); - assert_ne!(0, composition_poly.data().num_base_cols() % RPO_RATE); + assert_ne!(0, composition_poly_cpu.data().num_base_cols() % RPO_RATE); + assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns()); } #[test] @@ -616,25 +639,34 @@ mod tests { let gpu_prover = MetalRpoExecutionProver(create_test_prover()); let num_rows = 1 << 8; let ce_blowup_factor = 8; - let coeffs = gen_random_coeffs::(num_rows * ce_blowup_factor); - let composition_poly = CompositionPoly::new(coeffs, num_rows, 8); + let values = get_random_values::(num_rows * ce_blowup_factor); let domain = StarkDomain::from_twiddles(fft::get_twiddles(num_rows), 8, Felt::GENERATOR); - let commitment_cpu = cpu_prover.build_constraint_commitment(&composition_poly, &domain); - let commitment_gpu = gpu_prover.build_constraint_commitment(&composition_poly, &domain); + let (commitment_cpu, composition_poly_cpu) = cpu_prover.build_constraint_commitment( + CompositionPolyTrace::new(values.clone()), + 8, + &domain, + ); + let (commitment_gpu, composition_poly_gpu) = + gpu_prover.build_constraint_commitment(CompositionPolyTrace::new(values), 8, &domain); assert_eq!(commitment_cpu.root(), commitment_gpu.root()); - assert_eq!(0, composition_poly.data().num_base_cols() % RPO_RATE); + assert_eq!(0, composition_poly_cpu.data().num_base_cols() % RPO_RATE); + assert_eq!(composition_poly_cpu.into_columns(), composition_poly_gpu.into_columns()); } fn gen_random_trace(num_rows: usize, num_cols: usize) -> ColMatrix { ColMatrix::new((0..num_cols as u64).map(|col| vec![Felt::new(col); num_rows]).collect()) } - fn gen_random_coeffs(num_rows: usize) -> Vec { + fn get_random_values(num_rows: usize) -> Vec { (0..num_rows).map(|i| E::from(i as u32)).collect() } + fn get_trace_info(num_cols: usize, num_rows: usize) -> TraceInfo { + TraceInfo::new(num_cols, num_rows) + } + fn create_test_prover() -> ExecutionProver { ExecutionProver::new( ProvingOptions::with_128_bit_security(true),