Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: dynamic dory commitment computation on the GPU should have parallel option #288

Merged
merged 6 commits into from
Oct 22, 2024
118 changes: 65 additions & 53 deletions crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::{
G1Affine, F,
};
use crate::{
base::{commitment::CommittableColumn, database::ColumnType},
base::{commitment::CommittableColumn, database::ColumnType, if_rayon},
proof_primitive::dory::offset_to_bytes::OffsetToBytes,
};
use alloc::{vec, vec::Vec};
Expand All @@ -14,6 +14,11 @@ use ark_ff::MontFp;
use ark_std::ops::Mul;
use core::iter;
use itertools::Itertools;
#[cfg(feature = "rayon")]
use rayon::{
iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
prelude::ParallelSliceMut,
};
use tracing::{span, Level};

const BYTE_SIZE: u32 = 8;
Expand Down Expand Up @@ -72,11 +77,16 @@ pub fn signed_commits(
}
}

unsigned_sub_commits
.into_iter()
.zip(min_sub_commits.into_iter())
.map(|(unsigned, min)| (unsigned + min).into())
.collect()
if_rayon!(
unsigned_sub_commits
.into_par_iter()
.zip(min_sub_commits.into_par_iter()),
unsigned_sub_commits
.into_iter()
.zip(min_sub_commits.into_iter())
)
.map(|(unsigned, min)| (unsigned + min).into())
.collect()
}

/// Copies the column data to the scalar row slice.
Expand Down Expand Up @@ -123,7 +133,6 @@ fn copy_column_data_to_slice(
}
}

#[allow(clippy::cast_possible_truncation)]
/// Creates the metadata tables for Blitzar's `vlen_msm` algorithm.
///
/// # Arguments
Expand All @@ -135,6 +144,10 @@ fn copy_column_data_to_slice(
///
/// A tuple containing the output bit table, output length table,
/// and scalars required to call Blitzar's `vlen_msm` function.
///
/// # Panics
///
/// Panics if the row of a column exceeds `u32::MAX`.
#[tracing::instrument(name = "create_blitzar_metadata_tables", level = "debug", skip_all)]
pub fn create_blitzar_metadata_tables(
committable_columns: &[CommittableColumn],
Expand Down Expand Up @@ -185,7 +198,8 @@ pub fn create_blitzar_metadata_tables(
/ single_entry_in_blitzar_output_bit_table.len())
.flat_map(|i| {
itertools::repeat_n(
full_width_of_row(i + offset_row) as u32,
u32::try_from(full_width_of_row(i + offset_row))
.expect("row lengths should never exceed u32::MAX"),
single_entry_in_blitzar_output_bit_table.len(),
)
})
Expand All @@ -208,54 +222,52 @@ pub fn create_blitzar_metadata_tables(
// Populate the scalars array.
let span = span!(Level::INFO, "pack_blitzar_scalars").entered();
if !blitzar_scalars.is_empty() {
blitzar_scalars
.chunks_exact_mut(num_scalar_columns)
.enumerate()
.for_each(|(scalar_row, scalar_row_slice)| {
// Iterate over the columns and populate the scalars array.
for scalar_col in 0..offset_height {
// Find index in the committable columns. Note, the scalar is in
// column major order, that is why the (row, col) arguments are flipped.
if let Some(index) =
index_from_row_and_column(scalar_col + offset_row, scalar_row).and_then(
|committable_column_idx| committable_column_idx.checked_sub(offset),
)
if_rayon!(
blitzar_scalars.par_chunks_exact_mut(num_scalar_columns),
blitzar_scalars.chunks_exact_mut(num_scalar_columns)
)
.enumerate()
.for_each(|(scalar_row, scalar_row_slice)| {
for scalar_col in 0..offset_height {
// Find index in the committable columns. Note, the scalar is in
// column major order, that is why the (row, col) arguments are flipped.
if let Some(index) = index_from_row_and_column(scalar_col + offset_row, scalar_row)
.and_then(|committable_column_idx| committable_column_idx.checked_sub(offset))
{
for (i, committable_column) in committable_columns
.iter()
.enumerate()
.filter(|(_, committable_column)| index < committable_column.len())
{
for (i, committable_column) in committable_columns
.iter()
.enumerate()
.filter(|(_, committable_column)| index < committable_column.len())
{
let start = cumulative_byte_length_table
[i + scalar_col * single_entry_in_blitzar_output_bit_table.len()];
let end = start
+ (single_entry_in_blitzar_output_bit_table[i] / BYTE_SIZE)
as usize;

copy_column_data_to_slice(
committable_column,
scalar_row_slice,
start,
end,
index,
);
}

ones_columns_lengths
.iter()
.positions(|ones_columns_length| index < *ones_columns_length)
.for_each(|i| {
let ones_index = i
+ scalar_col
* (num_of_bytes_in_committable_columns
+ ones_columns_lengths.len())
+ num_of_bytes_in_committable_columns;

scalar_row_slice[ones_index] = 1_u8;
});
let start = cumulative_byte_length_table
[i + scalar_col * single_entry_in_blitzar_output_bit_table.len()];
let end = start
+ (single_entry_in_blitzar_output_bit_table[i] / BYTE_SIZE) as usize;

copy_column_data_to_slice(
committable_column,
scalar_row_slice,
start,
end,
index,
);
}

ones_columns_lengths
.iter()
.positions(|ones_columns_length| index < *ones_columns_length)
.for_each(|i| {
let ones_index = i
+ scalar_col
* (num_of_bytes_in_committable_columns
+ ones_columns_lengths.len())
+ num_of_bytes_in_committable_columns;

scalar_row_slice[ones_index] = 1_u8;
});
}
});
}
});
}
span.exit();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use super::{
dynamic_dory_structure::row_and_column_from_index,
pairings, DynamicDoryCommitment, G1Affine, ProverSetup,
};
use crate::base::{commitment::CommittableColumn, slice_ops::slice_cast};
use crate::base::{commitment::CommittableColumn, if_rayon, slice_ops::slice_cast};
use blitzar::compute::ElementP2;
#[cfg(feature = "rayon")]
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use tracing::{span, Level};

/// Computes the dynamic Dory commitment using the GPU implementation of the `vlen_msm` algorithm.
Expand Down Expand Up @@ -69,18 +71,21 @@ pub(super) fn compute_dynamic_dory_commitments(
committable_columns.len()
])
.unwrap_or_else(|| {
(0..committable_columns.len())
.map(|i| {
let sub_slice = signed_sub_commits[i..]
.iter()
.step_by(committable_columns.len())
.take(num_commits);
DynamicDoryCommitment(pairings::multi_pairing(
sub_slice,
&Gamma_2[gamma_2_offset..gamma_2_offset + num_commits],
))
})
.collect()
if_rayon!(
(0..committable_columns.len()).into_par_iter(),
(0..committable_columns.len())
)
.map(|i| {
let sub_slice = signed_sub_commits[i..]
.iter()
.step_by(committable_columns.len())
.take(num_commits);
DynamicDoryCommitment(pairings::multi_pairing(
sub_slice,
&Gamma_2[gamma_2_offset..gamma_2_offset + num_commits],
))
})
.collect()
});
span.exit();

Expand Down
Loading