diff --git a/crates/proof-of-sql/src/base/arrow/mod.rs b/crates/proof-of-sql/src/base/arrow/mod.rs index 48197e05b..0bcac183d 100644 --- a/crates/proof-of-sql/src/base/arrow/mod.rs +++ b/crates/proof-of-sql/src/base/arrow/mod.rs @@ -10,6 +10,9 @@ pub mod owned_and_arrow_conversions; /// Tests for owned and Arrow conversions. mod owned_and_arrow_conversions_test; +/// Module for converting record batches. +pub mod record_batch_conversion; + /// Module for record batch error definitions. pub mod record_batch_errors; diff --git a/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs b/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs new file mode 100644 index 000000000..6f24457cc --- /dev/null +++ b/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs @@ -0,0 +1,160 @@ +use super::{ + arrow_array_to_column_conversion::ArrayRefExt, + record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError}, +}; +use crate::base::{ + commitment::{ + AppendColumnCommitmentsError, AppendTableCommitmentError, Commitment, TableCommitment, + TableCommitmentFromColumnsError, + }, + database::Column, + scalar::Scalar, +}; +use arrow::record_batch::RecordBatch; +use bumpalo::Bump; +use proof_of_sql_parser::Identifier; + +/// This function will return an error if: +/// - The field name cannot be parsed into an [`Identifier`]. +/// - The conversion of an Arrow array to a [`Column`] fails. +pub fn batch_to_columns<'a, S: Scalar + 'a>( + batch: &'a RecordBatch, + alloc: &'a Bump, +) -> Result)>, RecordBatchToColumnsError> { + batch + .schema() + .fields() + .into_iter() + .zip(batch.columns()) + .map(|(field, array)| { + let identifier: Identifier = field.name().parse()?; + let column: Column = array.to_column(alloc, &(0..array.len()), None)?; + Ok((identifier, column)) + }) + .collect() +} + +impl TableCommitment { + /// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`]. + /// + /// The row offset is assumed to be the end of the [`TableCommitment`]'s current range. + /// + /// Will error on a variety of mismatches, or if the provided columns have mixed length. + #[allow(clippy::missing_panics_doc)] + pub fn try_append_record_batch( + &mut self, + batch: &RecordBatch, + setup: &C::PublicSetup<'_>, + ) -> Result<(), AppendRecordBatchTableCommitmentError> { + match self.try_append_rows( + batch_to_columns::(batch, &Bump::new())? + .iter() + .map(|(a, b)| (a, b)), + setup, + ) { + Ok(()) => Ok(()), + Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => { + panic!("RecordBatches cannot have columns of mixed length") + } + Err(AppendTableCommitmentError::AppendColumnCommitments { + source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. }, + }) => { + panic!("RecordBatches cannot have duplicate identifiers") + } + Err(AppendTableCommitmentError::AppendColumnCommitments { + source: AppendColumnCommitmentsError::Mismatch { source: e }, + }) => Err(e)?, + } + } + /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`]. + pub fn try_from_record_batch( + batch: &RecordBatch, + setup: &C::PublicSetup<'_>, + ) -> Result, RecordBatchToColumnsError> { + Self::try_from_record_batch_with_offset(batch, 0, setup) + } + + /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset. + #[allow(clippy::missing_panics_doc)] + pub fn try_from_record_batch_with_offset( + batch: &RecordBatch, + offset: usize, + setup: &C::PublicSetup<'_>, + ) -> Result, RecordBatchToColumnsError> { + match Self::try_from_columns_with_offset( + batch_to_columns::(batch, &Bump::new())? + .iter() + .map(|(a, b)| (a, b)), + offset, + setup, + ) { + Ok(commitment) => Ok(commitment), + Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => { + panic!("RecordBatches cannot have columns of mixed length") + } + Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => { + panic!("RecordBatches cannot have duplicate identifiers") + } + } + } +} + +#[cfg(all(test, feature = "blitzar"))] +mod tests { + use super::*; + use crate::{base::scalar::Curve25519Scalar, record_batch}; + use curve25519_dalek::RistrettoPoint; + + #[test] + fn we_can_create_and_append_table_commitments_with_record_batchs() { + let batch = record_batch!( + "a" => [1i64, 2, 3], + "b" => ["1", "2", "3"], + ); + + let b_scals = ["1".into(), "2".into(), "3".into()]; + + let columns = [ + ( + &"a".parse().unwrap(), + &Column::::BigInt(&[1, 2, 3]), + ), + ( + &"b".parse().unwrap(), + &Column::::VarChar((&["1", "2", "3"], &b_scals)), + ), + ]; + + let mut expected_commitment = + TableCommitment::::try_from_columns_with_offset(columns, 0, &()) + .unwrap(); + + let mut commitment = + TableCommitment::::try_from_record_batch(&batch, &()).unwrap(); + + assert_eq!(commitment, expected_commitment); + + let batch2 = record_batch!( + "a" => [4i64, 5, 6], + "b" => ["4", "5", "6"], + ); + + let b_scals2 = ["4".into(), "5".into(), "6".into()]; + + let columns2 = [ + ( + &"a".parse().unwrap(), + &Column::::BigInt(&[4, 5, 6]), + ), + ( + &"b".parse().unwrap(), + &Column::::VarChar((&["4", "5", "6"], &b_scals2)), + ), + ]; + + expected_commitment.try_append_rows(columns2, &()).unwrap(); + commitment.try_append_record_batch(&batch2, &()).unwrap(); + + assert_eq!(commitment, expected_commitment); + } +} diff --git a/crates/proof-of-sql/src/base/commitment/table_commitment.rs b/crates/proof-of-sql/src/base/commitment/table_commitment.rs index b4387a765..1a52b7cea 100644 --- a/crates/proof-of-sql/src/base/commitment/table_commitment.rs +++ b/crates/proof-of-sql/src/base/commitment/table_commitment.rs @@ -2,19 +2,11 @@ use super::{ committable_column::CommittableColumn, AppendColumnCommitmentsError, ColumnCommitments, ColumnCommitmentsMismatch, Commitment, DuplicateIdentifiers, }; -#[cfg(feature = "arrow")] -use crate::base::arrow::{ - arrow_array_to_column_conversion::ArrayRefExt, - record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError}, -}; use crate::base::{ - database::{Column, ColumnField, CommitmentAccessor, OwnedTable, TableRef}, + database::{ColumnField, CommitmentAccessor, OwnedTable, TableRef}, scalar::Scalar, }; use alloc::vec::Vec; -#[cfg(feature = "arrow")] -use arrow::record_batch::RecordBatch; -use bumpalo::Bump; use core::ops::Range; use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; @@ -365,90 +357,6 @@ impl TableCommitment { range, }) } - - /// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`]. - /// - /// The row offset is assumed to be the end of the [`TableCommitment`]'s current range. - /// - /// Will error on a variety of mismatches, or if the provided columns have mixed length. - #[cfg(feature = "arrow")] - #[allow(clippy::missing_panics_doc)] - pub fn try_append_record_batch( - &mut self, - batch: &RecordBatch, - setup: &C::PublicSetup<'_>, - ) -> Result<(), AppendRecordBatchTableCommitmentError> { - match self.try_append_rows( - batch_to_columns::(batch, &Bump::new())? - .iter() - .map(|(a, b)| (a, b)), - setup, - ) { - Ok(()) => Ok(()), - Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => { - panic!("RecordBatches cannot have columns of mixed length") - } - Err(AppendTableCommitmentError::AppendColumnCommitments { - source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. }, - }) => { - panic!("RecordBatches cannot have duplicate identifiers") - } - Err(AppendTableCommitmentError::AppendColumnCommitments { - source: AppendColumnCommitmentsError::Mismatch { source: e }, - }) => Err(e)?, - } - } - /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`]. - #[cfg(feature = "arrow")] - pub fn try_from_record_batch( - batch: &RecordBatch, - setup: &C::PublicSetup<'_>, - ) -> Result, RecordBatchToColumnsError> { - Self::try_from_record_batch_with_offset(batch, 0, setup) - } - - /// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset. - #[allow(clippy::missing_panics_doc)] - #[cfg(feature = "arrow")] - pub fn try_from_record_batch_with_offset( - batch: &RecordBatch, - offset: usize, - setup: &C::PublicSetup<'_>, - ) -> Result, RecordBatchToColumnsError> { - match Self::try_from_columns_with_offset( - batch_to_columns::(batch, &Bump::new())? - .iter() - .map(|(a, b)| (a, b)), - offset, - setup, - ) { - Ok(commitment) => Ok(commitment), - Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => { - panic!("RecordBatches cannot have columns of mixed length") - } - Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => { - panic!("RecordBatches cannot have duplicate identifiers") - } - } - } -} - -#[cfg(feature = "arrow")] -fn batch_to_columns<'a, S: Scalar + 'a>( - batch: &'a RecordBatch, - alloc: &'a Bump, -) -> Result)>, RecordBatchToColumnsError> { - batch - .schema() - .fields() - .into_iter() - .zip(batch.columns()) - .map(|(field, array)| { - let identifier: Identifier = field.name().parse()?; - let column: Column = array.to_column(alloc, &(0..array.len()), None)?; - Ok((identifier, column)) - }) - .collect() } /// Return the number of rows for the provided columns, erroring if they have mixed length. @@ -472,13 +380,10 @@ fn num_rows_of_columns<'a>( #[cfg(all(test, feature = "arrow", feature = "blitzar"))] mod tests { use super::*; - use crate::{ - base::{ - database::{owned_table_utility::*, OwnedColumn}, - map::IndexMap, - scalar::Curve25519Scalar, - }, - record_batch, + use crate::base::{ + database::{owned_table_utility::*, OwnedColumn}, + map::IndexMap, + scalar::Curve25519Scalar, }; use curve25519_dalek::RistrettoPoint; @@ -1230,57 +1135,4 @@ mod tests { Err(TableCommitmentArithmeticError::NegativeRange { .. }) )); } - - #[test] - fn we_can_create_and_append_table_commitments_with_record_batchs() { - let batch = record_batch!( - "a" => [1i64, 2, 3], - "b" => ["1", "2", "3"], - ); - - let b_scals = ["1".into(), "2".into(), "3".into()]; - - let columns = [ - ( - &"a".parse().unwrap(), - &Column::::BigInt(&[1, 2, 3]), - ), - ( - &"b".parse().unwrap(), - &Column::::VarChar((&["1", "2", "3"], &b_scals)), - ), - ]; - - let mut expected_commitment = - TableCommitment::::try_from_columns_with_offset(columns, 0, &()) - .unwrap(); - - let mut commitment = - TableCommitment::::try_from_record_batch(&batch, &()).unwrap(); - - assert_eq!(commitment, expected_commitment); - - let batch2 = record_batch!( - "a" => [4i64, 5, 6], - "b" => ["4", "5", "6"], - ); - - let b_scals2 = ["4".into(), "5".into(), "6".into()]; - - let columns2 = [ - ( - &"a".parse().unwrap(), - &Column::::BigInt(&[4, 5, 6]), - ), - ( - &"b".parse().unwrap(), - &Column::::VarChar((&["4", "5", "6"], &b_scals2)), - ), - ]; - - expected_commitment.try_append_rows(columns2, &()).unwrap(); - commitment.try_append_record_batch(&batch2, &()).unwrap(); - - assert_eq!(commitment, expected_commitment); - } }