diff --git a/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs b/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs index 72e2ca8cb..4c24b8d15 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs @@ -5,7 +5,7 @@ use super::{ G1Affine, F, }; use crate::{ - base::{commitment::CommittableColumn, database::ColumnType}, + base::{commitment::CommittableColumn, database::ColumnType, if_rayon}, proof_primitive::dory::offset_to_bytes::OffsetToBytes, }; use alloc::{vec, vec::Vec}; @@ -14,6 +14,11 @@ use ark_ff::MontFp; use ark_std::ops::Mul; use core::iter; use itertools::Itertools; +#[cfg(feature = "rayon")] +use rayon::{ + iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}, + prelude::ParallelSliceMut, +}; use tracing::{span, Level}; const BYTE_SIZE: u32 = 8; @@ -72,11 +77,16 @@ pub fn signed_commits( } } - unsigned_sub_commits - .into_iter() - .zip(min_sub_commits.into_iter()) - .map(|(unsigned, min)| (unsigned + min).into()) - .collect() + if_rayon!( + unsigned_sub_commits + .into_par_iter() + .zip(min_sub_commits.into_par_iter()), + unsigned_sub_commits + .into_iter() + .zip(min_sub_commits.into_iter()) + ) + .map(|(unsigned, min)| (unsigned + min).into()) + .collect() } /// Copies the column data to the scalar row slice. @@ -123,7 +133,6 @@ fn copy_column_data_to_slice( } } -#[allow(clippy::cast_possible_truncation)] /// Creates the metadata tables for Blitzar's `vlen_msm` algorithm. /// /// # Arguments @@ -135,6 +144,10 @@ fn copy_column_data_to_slice( /// /// A tuple containing the output bit table, output length table, /// and scalars required to call Blitzar's `vlen_msm` function. +/// +/// # Panics +/// +/// Panics if the row of a column exceeds `u32::MAX`. #[tracing::instrument(name = "create_blitzar_metadata_tables", level = "debug", skip_all)] pub fn create_blitzar_metadata_tables( committable_columns: &[CommittableColumn], @@ -185,7 +198,8 @@ pub fn create_blitzar_metadata_tables( / single_entry_in_blitzar_output_bit_table.len()) .flat_map(|i| { itertools::repeat_n( - full_width_of_row(i + offset_row) as u32, + u32::try_from(full_width_of_row(i + offset_row)) + .expect("row lengths should never exceed u32::MAX"), single_entry_in_blitzar_output_bit_table.len(), ) }) @@ -208,54 +222,52 @@ pub fn create_blitzar_metadata_tables( // Populate the scalars array. let span = span!(Level::INFO, "pack_blitzar_scalars").entered(); if !blitzar_scalars.is_empty() { - blitzar_scalars - .chunks_exact_mut(num_scalar_columns) - .enumerate() - .for_each(|(scalar_row, scalar_row_slice)| { - // Iterate over the columns and populate the scalars array. - for scalar_col in 0..offset_height { - // Find index in the committable columns. Note, the scalar is in - // column major order, that is why the (row, col) arguments are flipped. - if let Some(index) = - index_from_row_and_column(scalar_col + offset_row, scalar_row).and_then( - |committable_column_idx| committable_column_idx.checked_sub(offset), - ) + if_rayon!( + blitzar_scalars.par_chunks_exact_mut(num_scalar_columns), + blitzar_scalars.chunks_exact_mut(num_scalar_columns) + ) + .enumerate() + .for_each(|(scalar_row, scalar_row_slice)| { + for scalar_col in 0..offset_height { + // Find index in the committable columns. Note, the scalar is in + // column major order, that is why the (row, col) arguments are flipped. + if let Some(index) = index_from_row_and_column(scalar_col + offset_row, scalar_row) + .and_then(|committable_column_idx| committable_column_idx.checked_sub(offset)) + { + for (i, committable_column) in committable_columns + .iter() + .enumerate() + .filter(|(_, committable_column)| index < committable_column.len()) { - for (i, committable_column) in committable_columns - .iter() - .enumerate() - .filter(|(_, committable_column)| index < committable_column.len()) - { - let start = cumulative_byte_length_table - [i + scalar_col * single_entry_in_blitzar_output_bit_table.len()]; - let end = start - + (single_entry_in_blitzar_output_bit_table[i] / BYTE_SIZE) - as usize; - - copy_column_data_to_slice( - committable_column, - scalar_row_slice, - start, - end, - index, - ); - } - - ones_columns_lengths - .iter() - .positions(|ones_columns_length| index < *ones_columns_length) - .for_each(|i| { - let ones_index = i - + scalar_col - * (num_of_bytes_in_committable_columns - + ones_columns_lengths.len()) - + num_of_bytes_in_committable_columns; - - scalar_row_slice[ones_index] = 1_u8; - }); + let start = cumulative_byte_length_table + [i + scalar_col * single_entry_in_blitzar_output_bit_table.len()]; + let end = start + + (single_entry_in_blitzar_output_bit_table[i] / BYTE_SIZE) as usize; + + copy_column_data_to_slice( + committable_column, + scalar_row_slice, + start, + end, + index, + ); } + + ones_columns_lengths + .iter() + .positions(|ones_columns_length| index < *ones_columns_length) + .for_each(|i| { + let ones_index = i + + scalar_col + * (num_of_bytes_in_committable_columns + + ones_columns_lengths.len()) + + num_of_bytes_in_committable_columns; + + scalar_row_slice[ones_index] = 1_u8; + }); } - }); + } + }); } span.exit(); diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_gpu.rs b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_gpu.rs index 7a5c8c022..f3206f7c6 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_gpu.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_commitment_helper_gpu.rs @@ -3,8 +3,10 @@ use super::{ dynamic_dory_structure::row_and_column_from_index, pairings, DynamicDoryCommitment, G1Affine, ProverSetup, }; -use crate::base::{commitment::CommittableColumn, slice_ops::slice_cast}; +use crate::base::{commitment::CommittableColumn, if_rayon, slice_ops::slice_cast}; use blitzar::compute::ElementP2; +#[cfg(feature = "rayon")] +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use tracing::{span, Level}; /// Computes the dynamic Dory commitment using the GPU implementation of the `vlen_msm` algorithm. @@ -69,18 +71,21 @@ pub(super) fn compute_dynamic_dory_commitments( committable_columns.len() ]) .unwrap_or_else(|| { - (0..committable_columns.len()) - .map(|i| { - let sub_slice = signed_sub_commits[i..] - .iter() - .step_by(committable_columns.len()) - .take(num_commits); - DynamicDoryCommitment(pairings::multi_pairing( - sub_slice, - &Gamma_2[gamma_2_offset..gamma_2_offset + num_commits], - )) - }) - .collect() + if_rayon!( + (0..committable_columns.len()).into_par_iter(), + (0..committable_columns.len()) + ) + .map(|i| { + let sub_slice = signed_sub_commits[i..] + .iter() + .step_by(committable_columns.len()) + .take(num_commits); + DynamicDoryCommitment(pairings::multi_pairing( + sub_slice, + &Gamma_2[gamma_2_offset..gamma_2_offset + num_commits], + )) + }) + .collect() }); span.exit();