Skip to content

Commit

Permalink
perf: compute_dynamic_T_vec_prime should have a GPU implementation (#276
Browse files Browse the repository at this point in the history
)

# Rationale for this change
Dynamic Dory commitment computation has a GPU implementation using
Blitzar's `vlen_msm` function. This work extends the GPU implementation
to the `dynamic_dory_helper` module by allowing the
`compute_dynamic_T_vec_prime` function to also use the GPU.

New, non-refactoring, changes are isolated to this commit:
e9699db.

# What changes are included in this PR?
- A GPU implementation is added to the `compute_dynamic_T_vec_prime`
function.
- Code creating Blitzar's metadata tables are refactored into a separate
module to be used by both the `dynamic_dory_commitment_helper_gpu` and
`dynamic_dory_helper` modules.

# Are these changes tested?
Yes
  • Loading branch information
jacobtrombetta authored Oct 19, 2024
1 parent 867cc00 commit 5b4b962
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 359 deletions.
359 changes: 359 additions & 0 deletions crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
use super::{
dynamic_dory_structure::{full_width_of_row, index_from_row_and_column, matrix_size},
G1Affine, F,
};
use crate::{
base::{commitment::CommittableColumn, database::ColumnType},
proof_primitive::dory::offset_to_bytes::OffsetToBytes,
};
use alloc::{vec, vec::Vec};
use ark_ec::CurveGroup;
use ark_ff::MontFp;
use ark_std::ops::Mul;
use core::iter;
use itertools::Itertools;
use tracing::{span, Level};

const BYTE_SIZE: u32 = 8;

/// Returns the minimum value of a column as F.
///
/// # Arguments
///
/// * `column_type` - The type of a committable column.
pub const fn min_as_f(column_type: ColumnType) -> F {
match column_type {
ColumnType::TinyInt => MontFp!("-128"),
ColumnType::SmallInt => MontFp!("-32768"),
ColumnType::Int => MontFp!("-2147483648"),
ColumnType::BigInt | ColumnType::TimestampTZ(_, _) => MontFp!("-9223372036854775808"),
ColumnType::Int128 => MontFp!("-170141183460469231731687303715884105728"),
ColumnType::Decimal75(_, _)
| ColumnType::Scalar
| ColumnType::VarChar
| ColumnType::Boolean => MontFp!("0"),
}
}

/// Modifies the sub commits by adding the minimum commitment of the column type to the signed sub commits.
///
/// # Arguments
///
/// * `all_sub_commits` - A reference to the sub commits.
/// * `committable_columns` - A reference to the committable columns.
///
/// # Returns
///
/// A vector containing the modified sub commits to be used by the dynamic Dory commitment computation.
#[tracing::instrument(name = "signed_commits", level = "debug", skip_all)]
pub fn signed_commits(
all_sub_commits: &Vec<G1Affine>,
committable_columns: &[CommittableColumn],
) -> Vec<G1Affine> {
let mut unsigned_sub_commits: Vec<G1Affine> = Vec::new();
let mut min_sub_commits: Vec<G1Affine> = Vec::new();
let mut counter = 0;

// Every sub_commit has a corresponding offset sub_commit committable_columns.len() away.
// The commits and respective ones commits are interleaved in the all_sub_commits vector.
for commit in all_sub_commits {
if counter < committable_columns.len() {
unsigned_sub_commits.push(*commit);
} else {
let min =
min_as_f(committable_columns[counter - committable_columns.len()].column_type());
min_sub_commits.push(commit.mul(min).into_affine());
}
counter += 1;
if counter == 2 * committable_columns.len() {
counter = 0;
}
}

unsigned_sub_commits
.into_iter()
.zip(min_sub_commits.into_iter())
.map(|(unsigned, min)| (unsigned + min).into())
.collect()
}

/// Copies the column data to the scalar row slice.
///
/// # Arguments
///
/// * `column` - A reference to the committable column.
/// * `scalar_row_slice` - A mutable reference to the scalar row slice.
/// * `start` - The start index of the slice.
/// * `end` - The end index of the slice.
/// * `index` - The index of the column.
fn copy_column_data_to_slice(
column: &CommittableColumn,
scalar_row_slice: &mut [u8],
start: usize,
end: usize,
index: usize,
) {
match column {
CommittableColumn::Boolean(column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::TinyInt(column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::SmallInt(column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::Int(column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::BigInt(column) | CommittableColumn::TimestampTZ(_, _, column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::Int128(column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::Scalar(column)
| CommittableColumn::Decimal75(_, _, column)
| CommittableColumn::VarChar(column) => {
scalar_row_slice[start..end].copy_from_slice(&column[index].offset_to_bytes());
}
CommittableColumn::RangeCheckWord(_) => todo!(),
}
}

/// Creates the metadata tables for Blitzar's `vlen_msm` algorithm.
///
/// # Arguments
///
/// * `committable_columns` - A reference to the committable columns.
/// * `offset` - The offset to the data.
///
/// # Returns
///
/// A tuple containing the output bit table, output length table,
/// and scalars required to call Blitzar's `vlen_msm` function.
#[tracing::instrument(name = "create_blitzar_metadata_tables", level = "debug", skip_all)]
pub fn create_blitzar_metadata_tables(
committable_columns: &[CommittableColumn],
offset: usize,
) -> (Vec<u32>, Vec<u32>, Vec<u8>) {
// Keep track of the lengths of the columns to handled signed data columns.
let ones_columns_lengths = committable_columns
.iter()
.map(CommittableColumn::len)
.collect_vec();

// The maximum matrix size will be used to create the scalars vector.
let (max_height, max_width) = if let Some(max_column_len) =
committable_columns.iter().map(CommittableColumn::len).max()
{
matrix_size(max_column_len, offset)
} else {
(0, 0)
};

// Find the single packed byte size of all committable columns.
let num_of_bytes_in_committable_columns: usize = committable_columns
.iter()
.map(|column| column.column_type().byte_size())
.sum();

// Get a single bit table entry with ones added for all committable columns that are signed.
let single_entry_in_blitzar_output_bit_table: Vec<u32> = committable_columns
.iter()
.map(|column| column.column_type().bit_size())
.chain(iter::repeat(BYTE_SIZE).take(ones_columns_lengths.len()))
.collect();

// Create the full bit table vector to be used by Blitzar's vlen_msm algorithm.
let blitzar_output_bit_table: Vec<u32> = single_entry_in_blitzar_output_bit_table
.iter()
.copied()
.cycle()
.take(single_entry_in_blitzar_output_bit_table.len() * max_height)
.collect();

// Create the full length vector to be used by Blitzar's vlen_msm algorithm.
let blitzar_output_length_table: Vec<u32> = (0..blitzar_output_bit_table.len()
/ single_entry_in_blitzar_output_bit_table.len())
.flat_map(|i| {
itertools::repeat_n(
full_width_of_row(i) as u32,
single_entry_in_blitzar_output_bit_table.len(),
)
})
.collect();

// Create a cumulative length table to be used when packing the scalar vector.
let cumulative_byte_length_table: Vec<usize> = iter::once(0)
.chain(blitzar_output_bit_table.iter().scan(0usize, |acc, &x| {
*acc += (x / BYTE_SIZE) as usize;
Some(*acc)
}))
.collect();

// Create scalars array. Note, scalars need to be stored in a column-major order.
let num_scalar_rows = max_width;
let num_scalar_columns =
(num_of_bytes_in_committable_columns + ones_columns_lengths.len()) * max_height;
let mut blitzar_scalars = vec![0u8; num_scalar_rows * num_scalar_columns];

// Populate the scalars array.
let span = span!(Level::INFO, "pack_blitzar_scalars").entered();
if !blitzar_scalars.is_empty() {
blitzar_scalars
.chunks_exact_mut(num_scalar_columns)
.enumerate()
.for_each(|(scalar_row, scalar_row_slice)| {
// Iterate over the columns and populate the scalars array.
for scalar_col in 0..max_height {
// Find index in the committable columns. Note, the scalar is in
// column major order, that is why the (row, col) arguments are flipped.
if let Some(index) = index_from_row_and_column(scalar_col, scalar_row).and_then(
|committable_column_idx| committable_column_idx.checked_sub(offset),
) {
for (i, committable_column) in committable_columns
.iter()
.enumerate()
.filter(|(_, committable_column)| index < committable_column.len())
{
let start = cumulative_byte_length_table
[i + scalar_col * single_entry_in_blitzar_output_bit_table.len()];
let end = start
+ (single_entry_in_blitzar_output_bit_table[i] / BYTE_SIZE)
as usize;

copy_column_data_to_slice(
committable_column,
scalar_row_slice,
start,
end,
index,
);
}

ones_columns_lengths
.iter()
.positions(|ones_columns_length| index < *ones_columns_length)
.for_each(|i| {
let ones_index = i
+ scalar_col
* (num_of_bytes_in_committable_columns
+ ones_columns_lengths.len())
+ num_of_bytes_in_committable_columns;

scalar_row_slice[ones_index] = 1_u8;
});
}
}
});
}
span.exit();

(
blitzar_output_bit_table,
blitzar_output_length_table,
blitzar_scalars,
)
}

#[cfg(test)]
mod tests {
use super::*;
use crate::base::math::decimal::Precision;
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};

#[test]
fn we_can_populate_blitzar_metadata_tables_with_empty_columns() {
let committable_columns = [CommittableColumn::BigInt(&[0; 0])];
let offset = 0;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert!(bit_table.is_empty());
assert!(length_table.is_empty());
assert!(scalars.is_empty());
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_empty_columns_and_an_offset() {
let committable_columns = [CommittableColumn::BigInt(&[0; 0])];
let offset = 1;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![1, 1]);
assert_eq!(scalars, vec![0, 0, 0, 0, 0, 0, 0, 0, 0]);
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_simple_column() {
let committable_columns = [CommittableColumn::BigInt(&[1])];
let offset = 0;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8]);
assert_eq!(length_table, vec![1, 1]);
assert_eq!(scalars, vec![1, 0, 0, 0, 0, 0, 0, 128, 1]);
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_simple_column_and_offset() {
let committable_columns = [CommittableColumn::BigInt(&[1])];
let offset = 1;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);

assert_eq!(bit_table, vec![64, 8, 64, 8]);
assert_eq!(length_table, vec![1, 1, 2, 2]);
assert_eq!(
scalars,
vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 128, 1
]
);
}

#[test]
fn we_can_populate_blitzar_metadata_tables_with_mixed_columns() {
let committable_columns = [
CommittableColumn::TinyInt(&[1]),
CommittableColumn::SmallInt(&[2]),
CommittableColumn::Int(&[3]),
CommittableColumn::BigInt(&[4]),
CommittableColumn::Int128(&[5]),
CommittableColumn::Decimal75(Precision::new(1).unwrap(), 0, vec![[6, 0, 0, 0]]),
CommittableColumn::Scalar(vec![[7, 0, 0, 0]]),
CommittableColumn::VarChar(vec![[8, 0, 0, 0]]),
CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::Utc, &[9]),
CommittableColumn::Boolean(&[true]),
];

let offset = 0;
let (bit_table, length_table, scalars) =
create_blitzar_metadata_tables(&committable_columns, offset);
assert_eq!(
bit_table,
vec![8, 16, 32, 64, 128, 256, 256, 256, 64, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8]
);

assert_eq!(
length_table,
vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
);
assert_eq!(
scalars,
vec![
129, 2, 128, 3, 0, 0, 128, 4, 0, 0, 0, 0, 0, 0, 128, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 128, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 128,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
]
);
}
}
Loading

0 comments on commit 5b4b962

Please sign in to comment.