Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: dynamic Dory commitment computation on the GPU should efficiently handle the offset #291

Merged
merged 7 commits into from
Oct 22, 2024
201 changes: 176 additions & 25 deletions crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::{
dynamic_dory_structure::{full_width_of_row, index_from_row_and_column, matrix_size},
dynamic_dory_structure::{
full_width_of_row, index_from_row_and_column, matrix_size, row_and_column_from_index,
},
G1Affine, F,
};
use crate::{
Expand Down Expand Up @@ -153,6 +155,10 @@ pub fn create_blitzar_metadata_tables(
(0, 0)
};

// We will ignore the rows that are zero from the offsets.
let offset_row = row_and_column_from_index(offset).0;
JayWhite2357 marked this conversation as resolved.
Show resolved Hide resolved
let offset_height = max_height - offset_row;

// Find the single packed byte size of all committable columns.
let num_of_bytes_in_committable_columns: usize = committable_columns
.iter()
Expand All @@ -171,15 +177,15 @@ pub fn create_blitzar_metadata_tables(
.iter()
.copied()
.cycle()
.take(single_entry_in_blitzar_output_bit_table.len() * max_height)
.take(single_entry_in_blitzar_output_bit_table.len() * offset_height)
.collect();

// Create the full length vector to be used by Blitzar's vlen_msm algorithm.
let blitzar_output_length_table: Vec<u32> = (0..blitzar_output_bit_table.len()
/ single_entry_in_blitzar_output_bit_table.len())
.flat_map(|i| {
itertools::repeat_n(
full_width_of_row(i) as u32,
full_width_of_row(i + offset_row) as u32,
single_entry_in_blitzar_output_bit_table.len(),
)
})
Expand All @@ -196,7 +202,7 @@ pub fn create_blitzar_metadata_tables(
// Create scalars array. Note, scalars need to be stored in a column-major order.
let num_scalar_rows = max_width;
let num_scalar_columns =
(num_of_bytes_in_committable_columns + ones_columns_lengths.len()) * max_height;
(num_of_bytes_in_committable_columns + ones_columns_lengths.len()) * offset_height;
let mut blitzar_scalars = vec![0u8; num_scalar_rows * num_scalar_columns];

// Populate the scalars array.
Expand All @@ -207,12 +213,14 @@ pub fn create_blitzar_metadata_tables(
.enumerate()
.for_each(|(scalar_row, scalar_row_slice)| {
// Iterate over the columns and populate the scalars array.
for scalar_col in 0..max_height {
for scalar_col in 0..offset_height {
// Find index in the committable columns. Note, the scalar is in
// column major order, that is why the (row, col) arguments are flipped.
if let Some(index) = index_from_row_and_column(scalar_col, scalar_row).and_then(
|committable_column_idx| committable_column_idx.checked_sub(offset),
) {
if let Some(index) =
index_from_row_and_column(scalar_col + offset_row, scalar_row).and_then(
|committable_column_idx| committable_column_idx.checked_sub(offset),
)
{
for (i, committable_column) in committable_columns
.iter()
.enumerate()
Expand Down Expand Up @@ -265,27 +273,83 @@ mod tests {
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};

#[test]
fn we_can_populate_blitzar_metadata_tables_with_empty_columns() {
fn we_can_populate_blitzar_metadata_tables_with_empty_columns_with_full_column_offsets() {
let committable_columns = [CommittableColumn::BigInt(&[0; 0])];
let offset = 0;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);
let offsets = vec![
0, 1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 56, 64, 80, 96, 112, 128,
];
for &offset in &offsets {
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert!(bit_table.is_empty());
assert!(length_table.is_empty());
assert!(scalars.is_empty());
assert!(bit_table.is_empty());
assert!(length_table.is_empty());
assert!(scalars.is_empty());
}
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_empty_columns_and_an_offset() {
fn we_can_populate_blitzar_metadata_tables_with_empty_columns_and_an_offset_with_partial_column_offsets(
) {
let committable_columns = [CommittableColumn::BigInt(&[0; 0])];
let offset = 1;

let offset = 3;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![1, 1]);
assert_eq!(scalars, vec![0, 0, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(length_table, vec![2, 2]);
assert_eq!(
scalars,
vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
);

let offset = 5;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![4, 4]);
assert_eq!(
scalars,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0
]
);

let offset = 17;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![8, 8]);
assert_eq!(
scalars,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
]
);

let offset = 65;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![16, 16]);
assert_eq!(
scalars,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0
]
);
}

#[test]
Expand All @@ -307,14 +371,11 @@ mod tests {
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8, 64, 8]);
assert_eq!(length_table, vec![1, 1, 2, 2]);
assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![2, 2]);
assert_eq!(
scalars,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 128, 1
]
vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 128, 1]
);
}

Expand Down Expand Up @@ -357,4 +418,94 @@ mod tests {
]
);
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_mixed_columns_and_partial_column_offset() {
let committable_columns = [
CommittableColumn::TinyInt(&[1]),
CommittableColumn::SmallInt(&[2]),
CommittableColumn::Int(&[3]),
CommittableColumn::BigInt(&[4]),
CommittableColumn::Int128(&[5]),
CommittableColumn::Decimal75(Precision::new(1).unwrap(), 0, vec![[6, 0, 0, 0]]),
CommittableColumn::Scalar(vec![[7, 0, 0, 0]]),
CommittableColumn::VarChar(vec![[8, 0, 0, 0]]),
CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::Utc, &[9]),
CommittableColumn::Boolean(&[true]),
];

let offset = 1;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);
assert_eq!(
bit_table,
vec![8, 16, 32, 64, 128, 256, 256, 256, 64, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
);

assert_eq!(
length_table,
vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
);
assert_eq!(
scalars,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 129, 2, 128, 3, 0, 0, 128, 4, 0, 0, 0, 0, 0, 0, 128, 5, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0,
0, 0, 0, 0, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
]
);
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_mixed_columns_and_full_column_offset() {
let committable_columns = [
CommittableColumn::TinyInt(&[1]),
CommittableColumn::SmallInt(&[2]),
CommittableColumn::Int(&[3]),
CommittableColumn::BigInt(&[4]),
CommittableColumn::Int128(&[5]),
CommittableColumn::Decimal75(Precision::new(1).unwrap(), 0, vec![[6, 0, 0, 0]]),
CommittableColumn::Scalar(vec![[7, 0, 0, 0]]),
CommittableColumn::VarChar(vec![[8, 0, 0, 0]]),
CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::Utc, &[9]),
CommittableColumn::Boolean(&[true]),
];

let offset = 2;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);
assert_eq!(
bit_table,
vec![8, 16, 32, 64, 128, 256, 256, 256, 64, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
);

assert_eq!(
length_table,
vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
);
assert_eq!(
scalars,
vec![
129, 2, 128, 3, 0, 0, 128, 4, 0, 0, 0, 0, 0, 0, 128, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 128, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 128,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
]
);
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{
blitzar_metadata_table::{create_blitzar_metadata_tables, signed_commits},
dynamic_dory_structure::row_and_column_from_index,
pairings, DynamicDoryCommitment, G1Affine, ProverSetup,
};
use crate::base::{commitment::CommittableColumn, slice_ops::slice_cast};
Expand Down Expand Up @@ -32,6 +33,7 @@ pub(super) fn compute_dynamic_dory_commitments(
setup: &ProverSetup,
) -> Vec<DynamicDoryCommitment> {
let Gamma_2 = setup.Gamma_2.last().unwrap();
let gamma_2_offset = row_and_column_from_index(offset).0;

// Get metadata tables for Blitzar's vlen_msm algorithm.
let (blitzar_output_bit_table, blitzar_output_length_table, blitzar_scalars) =
Expand Down Expand Up @@ -75,7 +77,7 @@ pub(super) fn compute_dynamic_dory_commitments(
.take(num_commits);
DynamicDoryCommitment(pairings::multi_pairing(
sub_slice,
&Gamma_2[..num_commits],
&Gamma_2[gamma_2_offset..gamma_2_offset + num_commits],
))
})
.collect()
Expand Down
Loading