diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_gpu.rs b/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_gpu.rs index 58d9b354c..1e32257bb 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_gpu.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dory_commitment_helper_gpu.rs @@ -1,148 +1,12 @@ -use super::{pairings, DoryCommitment, DoryProverPublicSetup, DoryScalar, G1Affine}; +use super::{pairings, transpose, DoryCommitment, DoryProverPublicSetup, DoryScalar, G1Affine}; use crate::base::commitment::CommittableColumn; use ark_bls12_381::Fr; use ark_ec::CurveGroup; use ark_std::ops::Mul; use blitzar::{compute::ElementP2, sequence::Sequence}; -use num_traits::ToBytes; use rayon::prelude::*; use zerocopy::AsBytes; -trait OffsetToBytes { - const IS_SIGNED: bool; - fn min_as_fr() -> Fr; - fn offset_to_bytes(&self) -> Vec; -} - -impl OffsetToBytes for u8 { - const IS_SIGNED: bool = false; - - fn min_as_fr() -> Fr { - Fr::from(0) - } - - fn offset_to_bytes(&self) -> Vec { - vec![*self] - } -} - -impl OffsetToBytes for i16 { - const IS_SIGNED: bool = true; - - fn min_as_fr() -> Fr { - Fr::from(i16::MIN) - } - - fn offset_to_bytes(&self) -> Vec { - let shifted = self.wrapping_sub(i16::MIN); - shifted.to_le_bytes().to_vec() - } -} - -impl OffsetToBytes for i32 { - const IS_SIGNED: bool = true; - - fn min_as_fr() -> Fr { - Fr::from(i32::MIN) - } - - fn offset_to_bytes(&self) -> Vec { - let shifted = self.wrapping_sub(i32::MIN); - shifted.to_le_bytes().to_vec() - } -} - -impl OffsetToBytes for i64 { - const IS_SIGNED: bool = true; - - fn min_as_fr() -> Fr { - Fr::from(i64::MIN) - } - - fn offset_to_bytes(&self) -> Vec { - let shifted = self.wrapping_sub(i64::MIN); - shifted.to_le_bytes().to_vec() - } -} - -impl OffsetToBytes for i128 { - const IS_SIGNED: bool = true; - - fn min_as_fr() -> Fr { - Fr::from(i128::MIN) - } - - fn offset_to_bytes(&self) -> Vec { - let shifted = self.wrapping_sub(i128::MIN); - shifted.to_le_bytes().to_vec() - } -} - -impl OffsetToBytes for bool { - const IS_SIGNED: bool = false; - - fn min_as_fr() -> Fr { - Fr::from(false) - } - - fn offset_to_bytes(&self) -> Vec { - vec![*self as u8] - } -} - -impl OffsetToBytes for u64 { - const IS_SIGNED: bool = false; - - fn min_as_fr() -> Fr { - Fr::from(0) - } - - fn offset_to_bytes(&self) -> Vec { - let bytes = self.to_le_bytes(); - bytes.to_vec() - } -} - -impl OffsetToBytes for [u64; 4] { - const IS_SIGNED: bool = false; - - fn min_as_fr() -> Fr { - Fr::from(0) - } - - fn offset_to_bytes(&self) -> Vec { - let slice = self.as_bytes(); - slice.to_vec() - } -} - -#[tracing::instrument(name = "transpose_column (gpu)", level = "debug", skip_all)] -fn transpose_column( - column: &[T], - offset: usize, - num_columns: usize, - data_size: usize, -) -> Vec { - let column_len_with_offset = column.len() + offset; - let total_length_bytes = - data_size * (((column_len_with_offset + num_columns - 1) / num_columns) * num_columns); - let cols = num_columns; - let rows = total_length_bytes / (data_size * cols); - - let mut transpose = vec![0_u8; total_length_bytes]; - for n in offset..(column.len() + offset) { - let i = n / cols; - let j = n % cols; - let t_idx = (j * rows + i) * data_size; - let p_idx = (i * cols + j) - offset; - - transpose[t_idx..t_idx + data_size] - .copy_from_slice(column[p_idx].offset_to_bytes().as_slice()); - } - - transpose -} - #[tracing::instrument(name = "get_offset_commits (gpu)", level = "debug", skip_all)] fn get_offset_commits( column_len: usize, @@ -166,10 +30,15 @@ fn get_offset_commits( if num_zero_commits < num_of_commits { // Get the commit of the first non-zero row let first_row_offset = offset - (num_zero_commits * num_columns); - let first_row_transpose = - transpose_column(first_row, first_row_offset, num_columns, data_size); + let first_row_transpose = transpose::transpose_for_fixed_msm( + first_row, + first_row_offset, + 1, + num_columns, + data_size, + ); - setup.public_parameters().blitzar_handle.msm( + setup.public_parameters().blitzar_msm( &mut ones_blitzar_commits[num_zero_commits..num_zero_commits + 1], data_size as u32, first_row_transpose.as_slice(), @@ -179,11 +48,12 @@ fn get_offset_commits( let mut chunks = remaining_elements.chunks(num_columns); if chunks.len() > 1 { if let Some(middle_row) = chunks.next() { - let middle_row_transpose = transpose_column(middle_row, 0, num_columns, data_size); + let middle_row_transpose = + transpose::transpose_for_fixed_msm(middle_row, 0, 1, num_columns, data_size); let mut middle_row_blitzar_commit = vec![ElementP2::::default(); 1]; - setup.public_parameters().blitzar_handle.msm( + setup.public_parameters().blitzar_msm( &mut middle_row_blitzar_commit, data_size as u32, middle_row_transpose.as_slice(), @@ -197,9 +67,10 @@ fn get_offset_commits( // Get the commit of the last row to handle an zero padding at the end of the column if let Some(last_row) = remaining_elements.chunks(num_columns).last() { - let last_row_transpose = transpose_column(last_row, 0, num_columns, data_size); + let last_row_transpose = + transpose::transpose_for_fixed_msm(last_row, 0, 1, num_columns, data_size); - setup.public_parameters().blitzar_handle.msm( + setup.public_parameters().blitzar_msm( &mut ones_blitzar_commits[num_of_commits - 1..num_of_commits], data_size as u32, last_row_transpose.as_slice(), @@ -223,20 +94,21 @@ fn compute_dory_commitment_impl<'a, T>( where &'a T: Into, &'a [T]: Into>, - T: AsBytes + Copy + OffsetToBytes, + T: AsBytes + Copy + transpose::OffsetToBytes, { let num_columns = 1 << setup.sigma(); let data_size = std::mem::size_of::(); // Format column to match column major data layout required by blitzar's msm - let column_transpose = transpose_column(column, offset, num_columns, data_size); - let num_of_commits = column_transpose.len() / (data_size * num_columns); + let num_of_commits = ((column.len() + offset) + num_columns - 1) / num_columns; + let column_transpose = + transpose::transpose_for_fixed_msm(column, offset, num_of_commits, num_columns, data_size); let gamma_2_slice = &setup.public_parameters().Gamma_2[0..num_of_commits]; // Compute the commitment for the entire data set let mut blitzar_commits = vec![ElementP2::::default(); num_of_commits]; - setup.public_parameters().blitzar_handle.msm( + setup.public_parameters().blitzar_msm( &mut blitzar_commits, data_size as u32, column_transpose.as_slice(), @@ -293,194 +165,3 @@ pub(super) fn compute_dory_commitments( .map(|column| compute_dory_commitment(column, offset, setup)) .collect() } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn we_can_transpose_empty_column() { - type T = u64; - let column: Vec = vec![]; - let offset = 0; - let num_columns = 2; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - assert!(transpose.is_empty()); - } - - #[test] - fn we_can_transpose_u64_column() { - type T = u64; - let column: Vec = vec![0, 1, 2, 3]; - let offset = 0; - let num_columns = 2; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - - assert_eq!(&transpose[0..data_size], column[0].as_bytes()); - assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes()); - assert_eq!( - &transpose[2 * data_size..3 * data_size], - column[1].as_bytes() - ); - assert_eq!( - &transpose[3 * data_size..4 * data_size], - column[3].as_bytes() - ); - } - - #[test] - fn we_can_transpose_u64_column_with_offset() { - type T = u64; - let column: Vec = vec![1, 2, 3]; - let offset = 2; - let num_columns = 3; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset + 1); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - - assert_eq!(&transpose[0..data_size], 0_u64.as_bytes()); - assert_eq!(&transpose[data_size..2 * data_size], column[1].as_bytes()); - assert_eq!(&transpose[2 * data_size..3 * data_size], 0_u64.as_bytes()); - assert_eq!( - &transpose[3 * data_size..4 * data_size], - column[2].as_bytes() - ); - assert_eq!( - &transpose[4 * data_size..5 * data_size], - column[0].as_bytes() - ); - assert_eq!(&transpose[5 * data_size..6 * data_size], 0_u64.as_bytes()); - } - - #[test] - fn we_can_transpose_boolean_column_with_offset() { - type T = bool; - let column: Vec = vec![true, false, true]; - let offset = 1; - let num_columns = 2; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - - assert_eq!(&transpose[0..data_size], 0_u8.as_bytes()); - assert_eq!(&transpose[data_size..2 * data_size], column[1].as_bytes()); - assert_eq!( - &transpose[2 * data_size..3 * data_size], - column[0].as_bytes() - ); - assert_eq!( - &transpose[3 * data_size..4 * data_size], - column[2].as_bytes() - ); - } - - #[test] - fn we_can_transpose_i64_column() { - type T = i64; - let column: Vec = vec![0, 1, 2, 3]; - let offset = 0; - let num_columns = 2; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - - assert_eq!( - &transpose[0..data_size], - column[0].wrapping_sub(T::MIN).as_bytes() - ); - assert_eq!( - &transpose[data_size..2 * data_size], - column[2].wrapping_sub(T::MIN).as_bytes() - ); - assert_eq!( - &transpose[2 * data_size..3 * data_size], - column[1].wrapping_sub(T::MIN).as_bytes() - ); - assert_eq!( - &transpose[3 * data_size..4 * data_size], - column[3].wrapping_sub(T::MIN).as_bytes() - ); - } - - #[test] - fn we_can_transpose_i128_column() { - type T = i128; - let column: Vec = vec![0, 1, 2, 3]; - let offset = 0; - let num_columns = 2; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - - assert_eq!( - &transpose[0..data_size], - column[0].wrapping_sub(T::MIN).as_bytes() - ); - assert_eq!( - &transpose[data_size..2 * data_size], - column[2].wrapping_sub(T::MIN).as_bytes() - ); - assert_eq!( - &transpose[2 * data_size..3 * data_size], - column[1].wrapping_sub(T::MIN).as_bytes() - ); - assert_eq!( - &transpose[3 * data_size..4 * data_size], - column[3].wrapping_sub(T::MIN).as_bytes() - ); - } - - #[test] - fn we_can_transpose_u64_array_column() { - type T = [u64; 4]; - let column: Vec = vec![[0, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]; - let offset = 0; - let num_columns = 2; - let data_size = std::mem::size_of::(); - - let expected_len = data_size * (column.len() + offset); - - let transpose = transpose_column(&column, offset, num_columns, data_size); - - assert_eq!(transpose.len(), expected_len); - - assert_eq!(&transpose[0..data_size], column[0].as_bytes()); - assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes()); - assert_eq!( - &transpose[2 * data_size..3 * data_size], - column[1].as_bytes() - ); - assert_eq!( - &transpose[3 * data_size..4 * data_size], - column[3].as_bytes() - ); - } -} diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dory_vmv_helper.rs b/crates/proof-of-sql/src/proof_primitive/dory/dory_vmv_helper.rs index b4539ea7b..e9095ef0f 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dory_vmv_helper.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dory_vmv_helper.rs @@ -1,7 +1,15 @@ -use super::{G1Affine, G1Projective, ProverSetup, F}; +#[cfg(not(feature = "blitzar"))] +use super::G1Projective; +use super::{transpose, G1Affine, ProverSetup, F}; use crate::base::polynomial::compute_evaluation_vector; +#[cfg(not(feature = "blitzar"))] use ark_ec::{AffineRepr, VariableBaseMSM}; +use ark_ff::{BigInt, MontBackend}; +#[cfg(feature = "blitzar")] +use blitzar::compute::ElementP2; use num_traits::{One, Zero}; +#[cfg(feature = "blitzar")] +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; /// Compute the evaluations of the columns of the matrix M that is derived from `a`. pub(super) fn compute_v_vec(a: &[F], L_vec: &[F], sigma: usize, nu: usize) -> Vec { @@ -13,8 +21,47 @@ pub(super) fn compute_v_vec(a: &[F], L_vec: &[F], sigma: usize, nu: usize) -> Ve }) } +/// Converts a bls12-381 scalar to a u64 array. +#[cfg(feature = "blitzar")] +fn convert_scalar_to_array( + scalars: &[ark_ff::Fp, 4>], +) -> Vec<[u64; 4]> { + scalars + .iter() + .map(|&element| BigInt::<4>::from(element).0) + .collect() +} + /// Compute the commitments to the rows of the matrix M that is derived from `a`. #[tracing::instrument(level = "debug", skip_all)] +#[cfg(feature = "blitzar")] +pub(super) fn compute_T_vec_prime( + a: &[F], + sigma: usize, + nu: usize, + prover_setup: &ProverSetup, +) -> Vec { + let num_columns = 1 << sigma; + let num_outputs = 1 << nu; + let data_size = std::mem::size_of::(); + + let a_array = convert_scalar_to_array(a); + let a_transpose = + transpose::transpose_for_fixed_msm(&a_array, 0, num_outputs, num_columns, data_size); + + let mut blitzar_commits = vec![ElementP2::::default(); num_outputs]; + + prover_setup.blitzar_msm( + &mut blitzar_commits, + data_size as u32, + a_transpose.as_slice(), + ); + + blitzar_commits.par_iter().map(Into::into).collect() +} + +#[tracing::instrument(level = "debug", skip_all)] +#[cfg(not(feature = "blitzar"))] pub(super) fn compute_T_vec_prime( a: &[F], sigma: usize, diff --git a/crates/proof-of-sql/src/proof_primitive/dory/eval_vmv_re.rs b/crates/proof-of-sql/src/proof_primitive/dory/eval_vmv_re.rs index c379fe087..d42f0edc7 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/eval_vmv_re.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/eval_vmv_re.rs @@ -4,7 +4,6 @@ use super::{ }; use ark_ec::VariableBaseMSM; use merlin::Transcript; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; /// This is the prover side of the Eval-VMV-RE algorithm in section 5 of https://eprint.iacr.org/2020/1274.pdf. /// @@ -35,7 +34,7 @@ pub fn eval_vmv_re_prove( messages.prover_send_G1_message(transcript, E_1); let v2 = state .v_vec - .par_iter() + .iter() .map(|c| (setup.Gamma_2_fin * c).into()) .collect::>(); ExtendedProverState::from_vmv_prover_state(state, v2) diff --git a/crates/proof-of-sql/src/proof_primitive/dory/mod.rs b/crates/proof-of-sql/src/proof_primitive/dory/mod.rs index a35a41dd4..fba5a91d4 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/mod.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/mod.rs @@ -140,3 +140,4 @@ type DeferredG1 = deferred_msm::DeferredMSM; type DeferredG2 = deferred_msm::DeferredMSM; mod pairings; +mod transpose; diff --git a/crates/proof-of-sql/src/proof_primitive/dory/public_parameters.rs b/crates/proof-of-sql/src/proof_primitive/dory/public_parameters.rs index c141ef668..de0f1573d 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/public_parameters.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/public_parameters.rs @@ -72,4 +72,15 @@ impl PublicParameters { blitzar_handle, } } + + #[cfg(feature = "blitzar")] + #[tracing::instrument(name = "PublicParameters::blitzar_msm", level = "debug", skip_all)] + pub(super) fn blitzar_msm( + &self, + res: &mut [ElementP2], + element_num_bytes: u32, + scalars: &[u8], + ) { + self.blitzar_handle.msm(res, element_num_bytes, scalars) + } } diff --git a/crates/proof-of-sql/src/proof_primitive/dory/setup.rs b/crates/proof-of-sql/src/proof_primitive/dory/setup.rs index c17bac798..18718dc3b 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/setup.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/setup.rs @@ -2,6 +2,8 @@ use super::{G1Affine, G2Affine, PublicParameters, GT}; use crate::base::impl_serde_for_ark_serde_unchecked; use ark_ec::pairing::{Pairing, PairingOutput}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +#[cfg(feature = "blitzar")] +use blitzar::compute::{ElementP2, MsmHandle}; use itertools::MultiUnzip; use num_traits::One; @@ -26,6 +28,9 @@ pub struct ProverSetup<'a> { pub(super) Gamma_2_fin: G2Affine, /// `max_nu` is the maximum nu that this setup will work for pub(super) max_nu: usize, + /// The handle to the `blitzar` Gamma_1 instances. + #[cfg(feature = "blitzar")] + blitzar_handle: &'a MsmHandle>, } impl<'a> ProverSetup<'a> { @@ -37,10 +42,13 @@ impl<'a> ProverSetup<'a> { H_2: G2Affine, Gamma_2_fin: G2Affine, max_nu: usize, + #[cfg(feature = "blitzar")] blitzar_handle: &'a MsmHandle< + ElementP2, + >, ) -> Self { assert_eq!(Gamma_1.len(), 1 << max_nu); assert_eq!(Gamma_2.len(), 1 << max_nu); - let (Gamma_1, Gamma_2) = (0..max_nu + 1) + let (Gamma_1, Gamma_2): (Vec<_>, Vec<_>) = (0..max_nu + 1) .map(|k| (&Gamma_1[..1 << k], &Gamma_2[..1 << k])) .unzip(); ProverSetup { @@ -50,8 +58,21 @@ impl<'a> ProverSetup<'a> { H_2, Gamma_2_fin, max_nu, + #[cfg(feature = "blitzar")] + blitzar_handle, } } + + #[cfg(feature = "blitzar")] + #[tracing::instrument(name = "ProverSetup::blitzar_msm", level = "debug", skip_all)] + pub(super) fn blitzar_msm( + &self, + res: &mut [ElementP2], + element_num_bytes: u32, + scalars: &[u8], + ) { + self.blitzar_handle.msm(res, element_num_bytes, scalars) + } } impl<'a> From<&'a PublicParameters> for ProverSetup<'a> { @@ -63,6 +84,8 @@ impl<'a> From<&'a PublicParameters> for ProverSetup<'a> { value.H_2, value.Gamma_2_fin, value.max_nu, + #[cfg(feature = "blitzar")] + &value.blitzar_handle, ) } } diff --git a/crates/proof-of-sql/src/proof_primitive/dory/transpose.rs b/crates/proof-of-sql/src/proof_primitive/dory/transpose.rs new file mode 100644 index 000000000..ee8794e2c --- /dev/null +++ b/crates/proof-of-sql/src/proof_primitive/dory/transpose.rs @@ -0,0 +1,357 @@ +use ark_bls12_381::Fr; +use zerocopy::AsBytes; + +pub trait OffsetToBytes { + const IS_SIGNED: bool; + fn min_as_fr() -> Fr; + fn offset_to_bytes(&self) -> Vec; +} + +impl OffsetToBytes for u8 { + const IS_SIGNED: bool = false; + + fn min_as_fr() -> Fr { + Fr::from(0) + } + + fn offset_to_bytes(&self) -> Vec { + vec![*self] + } +} + +impl OffsetToBytes for i16 { + const IS_SIGNED: bool = true; + + fn min_as_fr() -> Fr { + Fr::from(i16::MIN) + } + + fn offset_to_bytes(&self) -> Vec { + let shifted = self.wrapping_sub(i16::MIN); + shifted.to_le_bytes().to_vec() + } +} + +impl OffsetToBytes for i32 { + const IS_SIGNED: bool = true; + + fn min_as_fr() -> Fr { + Fr::from(i32::MIN) + } + + fn offset_to_bytes(&self) -> Vec { + let shifted = self.wrapping_sub(i32::MIN); + shifted.to_le_bytes().to_vec() + } +} + +impl OffsetToBytes for i64 { + const IS_SIGNED: bool = true; + + fn min_as_fr() -> Fr { + Fr::from(i64::MIN) + } + + fn offset_to_bytes(&self) -> Vec { + let shifted = self.wrapping_sub(i64::MIN); + shifted.to_le_bytes().to_vec() + } +} + +impl OffsetToBytes for i128 { + const IS_SIGNED: bool = true; + + fn min_as_fr() -> Fr { + Fr::from(i128::MIN) + } + + fn offset_to_bytes(&self) -> Vec { + let shifted = self.wrapping_sub(i128::MIN); + shifted.to_le_bytes().to_vec() + } +} + +impl OffsetToBytes for bool { + const IS_SIGNED: bool = false; + + fn min_as_fr() -> Fr { + Fr::from(false) + } + + fn offset_to_bytes(&self) -> Vec { + vec![*self as u8] + } +} + +impl OffsetToBytes for u64 { + const IS_SIGNED: bool = false; + + fn min_as_fr() -> Fr { + Fr::from(0) + } + + fn offset_to_bytes(&self) -> Vec { + let bytes = self.to_le_bytes(); + bytes.to_vec() + } +} + +impl OffsetToBytes for [u64; 4] { + const IS_SIGNED: bool = false; + + fn min_as_fr() -> Fr { + Fr::from(0) + } + + fn offset_to_bytes(&self) -> Vec { + let slice = self.as_bytes(); + slice.to_vec() + } +} + +#[tracing::instrument(name = "transpose_for_fixed_msm (gpu)", level = "debug", skip_all)] +pub fn transpose_for_fixed_msm( + column: &[T], + offset: usize, + rows: usize, + cols: usize, + data_size: usize, +) -> Vec { + let total_length_bytes = data_size * rows * cols; + let mut transpose = vec![0_u8; total_length_bytes]; + for n in offset..(column.len() + offset) { + let i = n / cols; + let j = n % cols; + let t_idx = (j * rows + i) * data_size; + let p_idx = (i * cols + j) - offset; + + transpose[t_idx..t_idx + data_size] + .copy_from_slice(column[p_idx].offset_to_bytes().as_slice()); + } + transpose +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn we_can_transpose_empty_column() { + type T = u64; + let column: Vec = vec![]; + let offset = 0; + let rows = 0; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + assert!(transpose.is_empty()); + } + + #[test] + fn we_can_transpose_u64_column() { + type T = u64; + let column: Vec = vec![0, 1, 2, 3]; + let offset = 0; + let rows = 2; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!(&transpose[0..data_size], column[0].as_bytes()); + assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes()); + assert_eq!( + &transpose[2 * data_size..3 * data_size], + column[1].as_bytes() + ); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[3].as_bytes() + ); + } + + #[test] + fn we_can_transpose_u64_column_with_offset() { + type T = u64; + let column: Vec = vec![1, 2, 3]; + let offset = 2; + let rows = 2; + let cols = 3; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset + 1); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!(&transpose[0..data_size], 0_u64.as_bytes()); + assert_eq!(&transpose[data_size..2 * data_size], column[1].as_bytes()); + assert_eq!(&transpose[2 * data_size..3 * data_size], 0_u64.as_bytes()); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[2].as_bytes() + ); + assert_eq!( + &transpose[4 * data_size..5 * data_size], + column[0].as_bytes() + ); + assert_eq!(&transpose[5 * data_size..6 * data_size], 0_u64.as_bytes()); + } + + #[test] + fn we_can_transpose_boolean_column_with_offset() { + type T = bool; + let column: Vec = vec![true, false, true]; + let offset = 1; + let rows = 2; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!(&transpose[0..data_size], 0_u8.as_bytes()); + assert_eq!(&transpose[data_size..2 * data_size], column[1].as_bytes()); + assert_eq!( + &transpose[2 * data_size..3 * data_size], + column[0].as_bytes() + ); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[2].as_bytes() + ); + } + + #[test] + fn we_can_transpose_i64_column() { + type T = i64; + let column: Vec = vec![0, 1, 2, 3]; + let offset = 0; + let rows = 2; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!( + &transpose[0..data_size], + column[0].wrapping_sub(T::MIN).as_bytes() + ); + assert_eq!( + &transpose[data_size..2 * data_size], + column[2].wrapping_sub(T::MIN).as_bytes() + ); + assert_eq!( + &transpose[2 * data_size..3 * data_size], + column[1].wrapping_sub(T::MIN).as_bytes() + ); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[3].wrapping_sub(T::MIN).as_bytes() + ); + } + + #[test] + fn we_can_transpose_i128_column() { + type T = i128; + let column: Vec = vec![0, 1, 2, 3]; + let offset = 0; + let rows = 2; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!( + &transpose[0..data_size], + column[0].wrapping_sub(T::MIN).as_bytes() + ); + assert_eq!( + &transpose[data_size..2 * data_size], + column[2].wrapping_sub(T::MIN).as_bytes() + ); + assert_eq!( + &transpose[2 * data_size..3 * data_size], + column[1].wrapping_sub(T::MIN).as_bytes() + ); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[3].wrapping_sub(T::MIN).as_bytes() + ); + } + + #[test] + fn we_can_transpose_u64_array_column() { + type T = [u64; 4]; + let column: Vec = vec![[0, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]; + let offset = 0; + let rows = 2; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!(&transpose[0..data_size], column[0].as_bytes()); + assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes()); + assert_eq!( + &transpose[2 * data_size..3 * data_size], + column[1].as_bytes() + ); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[3].as_bytes() + ); + } + + #[test] + fn we_can_transpose_u64_array_column_update() { + type T = [u64; 4]; + let column: Vec = vec![[0, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]; + let offset = 0; + let rows = 2; + let cols = 2; + let data_size = std::mem::size_of::(); + + let expected_len = data_size * (column.len() + offset); + + let transpose = transpose_for_fixed_msm(&column, offset, rows, cols, data_size); + + assert_eq!(transpose.len(), expected_len); + + assert_eq!(&transpose[0..data_size], column[0].as_bytes()); + assert_eq!(&transpose[data_size..2 * data_size], column[2].as_bytes()); + assert_eq!( + &transpose[2 * data_size..3 * data_size], + column[1].as_bytes() + ); + assert_eq!( + &transpose[3 * data_size..4 * data_size], + column[3].as_bytes() + ); + } +}