Skip to content

Commit

Permalink
chore: create record_batch_conversion module within arrow module
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith257 committed Oct 27, 2024
1 parent ff87390 commit d6a1eda
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 153 deletions.
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/base/arrow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub mod owned_and_arrow_conversions;
/// Tests for owned and Arrow conversions.
mod owned_and_arrow_conversions_test;

/// Module for converting record batches.
pub mod record_batch_conversion;

/// Module for record batch error definitions.
pub mod record_batch_errors;

Expand Down
160 changes: 160 additions & 0 deletions crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
use super::{
arrow_array_to_column_conversion::ArrayRefExt,
record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError},
};
use crate::base::{
commitment::{
AppendColumnCommitmentsError, AppendTableCommitmentError, Commitment, TableCommitment,
TableCommitmentFromColumnsError,
},
database::Column,
scalar::Scalar,
};
use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use proof_of_sql_parser::Identifier;

/// This function will return an error if:
/// - The field name cannot be parsed into an [`Identifier`].
/// - The conversion of an Arrow array to a [`Column`] fails.
pub fn batch_to_columns<'a, S: Scalar + 'a>(
batch: &'a RecordBatch,
alloc: &'a Bump,
) -> Result<Vec<(Identifier, Column<'a, S>)>, RecordBatchToColumnsError> {
batch
.schema()
.fields()
.into_iter()
.zip(batch.columns())
.map(|(field, array)| {
let identifier: Identifier = field.name().parse()?;
let column: Column<S> = array.to_column(alloc, &(0..array.len()), None)?;
Ok((identifier, column))
})
.collect()
}

impl<C: Commitment> TableCommitment<C> {
/// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`].
///
/// The row offset is assumed to be the end of the [`TableCommitment`]'s current range.
///
/// Will error on a variety of mismatches, or if the provided columns have mixed length.
#[allow(clippy::missing_panics_doc)]
pub fn try_append_record_batch(
&mut self,
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
) -> Result<(), AppendRecordBatchTableCommitmentError> {
match self.try_append_rows(
batch_to_columns::<C::Scalar>(batch, &Bump::new())?
.iter()
.map(|(a, b)| (a, b)),
setup,
) {
Ok(()) => Ok(()),
Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => {
panic!("RecordBatches cannot have columns of mixed length")
}
Err(AppendTableCommitmentError::AppendColumnCommitments {
source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. },
}) => {
panic!("RecordBatches cannot have duplicate identifiers")
}
Err(AppendTableCommitmentError::AppendColumnCommitments {
source: AppendColumnCommitmentsError::Mismatch { source: e },
}) => Err(e)?,
}
}
/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`].
pub fn try_from_record_batch(
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> {
Self::try_from_record_batch_with_offset(batch, 0, setup)
}

/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset.
#[allow(clippy::missing_panics_doc)]
pub fn try_from_record_batch_with_offset(
batch: &RecordBatch,
offset: usize,
setup: &C::PublicSetup<'_>,
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> {
match Self::try_from_columns_with_offset(
batch_to_columns::<C::Scalar>(batch, &Bump::new())?
.iter()
.map(|(a, b)| (a, b)),
offset,
setup,
) {
Ok(commitment) => Ok(commitment),
Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => {
panic!("RecordBatches cannot have columns of mixed length")
}
Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => {
panic!("RecordBatches cannot have duplicate identifiers")
}
}
}
}

#[cfg(all(test, feature = "blitzar"))]
mod tests {
use super::*;
use crate::{base::scalar::Curve25519Scalar, record_batch};
use curve25519_dalek::RistrettoPoint;

#[test]
fn we_can_create_and_append_table_commitments_with_record_batchs() {
let batch = record_batch!(
"a" => [1i64, 2, 3],
"b" => ["1", "2", "3"],
);

let b_scals = ["1".into(), "2".into(), "3".into()];

let columns = [
(
&"a".parse().unwrap(),
&Column::<Curve25519Scalar>::BigInt(&[1, 2, 3]),
),
(
&"b".parse().unwrap(),
&Column::<Curve25519Scalar>::VarChar((&["1", "2", "3"], &b_scals)),
),
];

let mut expected_commitment =
TableCommitment::<RistrettoPoint>::try_from_columns_with_offset(columns, 0, &())
.unwrap();

let mut commitment =
TableCommitment::<RistrettoPoint>::try_from_record_batch(&batch, &()).unwrap();

assert_eq!(commitment, expected_commitment);

let batch2 = record_batch!(
"a" => [4i64, 5, 6],
"b" => ["4", "5", "6"],
);

let b_scals2 = ["4".into(), "5".into(), "6".into()];

let columns2 = [
(
&"a".parse().unwrap(),
&Column::<Curve25519Scalar>::BigInt(&[4, 5, 6]),
),
(
&"b".parse().unwrap(),
&Column::<Curve25519Scalar>::VarChar((&["4", "5", "6"], &b_scals2)),
),
];

expected_commitment.try_append_rows(columns2, &()).unwrap();
commitment.try_append_record_batch(&batch2, &()).unwrap();

assert_eq!(commitment, expected_commitment);
}
}
158 changes: 5 additions & 153 deletions crates/proof-of-sql/src/base/commitment/table_commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,11 @@ use super::{
committable_column::CommittableColumn, AppendColumnCommitmentsError, ColumnCommitments,
ColumnCommitmentsMismatch, Commitment, DuplicateIdentifiers,
};
#[cfg(feature = "arrow")]
use crate::base::arrow::{
arrow_array_to_column_conversion::ArrayRefExt,
record_batch_errors::{AppendRecordBatchTableCommitmentError, RecordBatchToColumnsError},
};
use crate::base::{
database::{Column, ColumnField, CommitmentAccessor, OwnedTable, TableRef},
database::{ColumnField, CommitmentAccessor, OwnedTable, TableRef},
scalar::Scalar,
};
use alloc::vec::Vec;
#[cfg(feature = "arrow")]
use arrow::record_batch::RecordBatch;
use bumpalo::Bump;
use core::ops::Range;
use proof_of_sql_parser::Identifier;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -365,90 +357,6 @@ impl<C: Commitment> TableCommitment<C> {
range,
})
}

/// Append an arrow [`RecordBatch`] to the existing [`TableCommitment`].
///
/// The row offset is assumed to be the end of the [`TableCommitment`]'s current range.
///
/// Will error on a variety of mismatches, or if the provided columns have mixed length.
#[cfg(feature = "arrow")]
#[allow(clippy::missing_panics_doc)]
pub fn try_append_record_batch(
&mut self,
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
) -> Result<(), AppendRecordBatchTableCommitmentError> {
match self.try_append_rows(
batch_to_columns::<C::Scalar>(batch, &Bump::new())?
.iter()
.map(|(a, b)| (a, b)),
setup,
) {
Ok(()) => Ok(()),
Err(AppendTableCommitmentError::MixedLengthColumns { .. }) => {
panic!("RecordBatches cannot have columns of mixed length")
}
Err(AppendTableCommitmentError::AppendColumnCommitments {
source: AppendColumnCommitmentsError::DuplicateIdentifiers { .. },
}) => {
panic!("RecordBatches cannot have duplicate identifiers")
}
Err(AppendTableCommitmentError::AppendColumnCommitments {
source: AppendColumnCommitmentsError::Mismatch { source: e },
}) => Err(e)?,
}
}
/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`].
#[cfg(feature = "arrow")]
pub fn try_from_record_batch(
batch: &RecordBatch,
setup: &C::PublicSetup<'_>,
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> {
Self::try_from_record_batch_with_offset(batch, 0, setup)
}

/// Returns a [`TableCommitment`] to the provided arrow [`RecordBatch`] with the given row offset.
#[allow(clippy::missing_panics_doc)]
#[cfg(feature = "arrow")]
pub fn try_from_record_batch_with_offset(
batch: &RecordBatch,
offset: usize,
setup: &C::PublicSetup<'_>,
) -> Result<TableCommitment<C>, RecordBatchToColumnsError> {
match Self::try_from_columns_with_offset(
batch_to_columns::<C::Scalar>(batch, &Bump::new())?
.iter()
.map(|(a, b)| (a, b)),
offset,
setup,
) {
Ok(commitment) => Ok(commitment),
Err(TableCommitmentFromColumnsError::MixedLengthColumns { .. }) => {
panic!("RecordBatches cannot have columns of mixed length")
}
Err(TableCommitmentFromColumnsError::DuplicateIdentifiers { .. }) => {
panic!("RecordBatches cannot have duplicate identifiers")
}
}
}
}

#[cfg(feature = "arrow")]
fn batch_to_columns<'a, S: Scalar + 'a>(
batch: &'a RecordBatch,
alloc: &'a Bump,
) -> Result<Vec<(Identifier, Column<'a, S>)>, RecordBatchToColumnsError> {
batch
.schema()
.fields()
.into_iter()
.zip(batch.columns())
.map(|(field, array)| {
let identifier: Identifier = field.name().parse()?;
let column: Column<S> = array.to_column(alloc, &(0..array.len()), None)?;
Ok((identifier, column))
})
.collect()
}

/// Return the number of rows for the provided columns, erroring if they have mixed length.
Expand All @@ -472,13 +380,10 @@ fn num_rows_of_columns<'a>(
#[cfg(all(test, feature = "arrow", feature = "blitzar"))]
mod tests {
use super::*;
use crate::{
base::{
database::{owned_table_utility::*, OwnedColumn},
map::IndexMap,
scalar::Curve25519Scalar,
},
record_batch,
use crate::base::{
database::{owned_table_utility::*, OwnedColumn},
map::IndexMap,
scalar::Curve25519Scalar,
};
use curve25519_dalek::RistrettoPoint;

Expand Down Expand Up @@ -1230,57 +1135,4 @@ mod tests {
Err(TableCommitmentArithmeticError::NegativeRange { .. })
));
}

#[test]
fn we_can_create_and_append_table_commitments_with_record_batchs() {
let batch = record_batch!(
"a" => [1i64, 2, 3],
"b" => ["1", "2", "3"],
);

let b_scals = ["1".into(), "2".into(), "3".into()];

let columns = [
(
&"a".parse().unwrap(),
&Column::<Curve25519Scalar>::BigInt(&[1, 2, 3]),
),
(
&"b".parse().unwrap(),
&Column::<Curve25519Scalar>::VarChar((&["1", "2", "3"], &b_scals)),
),
];

let mut expected_commitment =
TableCommitment::<RistrettoPoint>::try_from_columns_with_offset(columns, 0, &())
.unwrap();

let mut commitment =
TableCommitment::<RistrettoPoint>::try_from_record_batch(&batch, &()).unwrap();

assert_eq!(commitment, expected_commitment);

let batch2 = record_batch!(
"a" => [4i64, 5, 6],
"b" => ["4", "5", "6"],
);

let b_scals2 = ["4".into(), "5".into(), "6".into()];

let columns2 = [
(
&"a".parse().unwrap(),
&Column::<Curve25519Scalar>::BigInt(&[4, 5, 6]),
),
(
&"b".parse().unwrap(),
&Column::<Curve25519Scalar>::VarChar((&["4", "5", "6"], &b_scals2)),
),
];

expected_commitment.try_append_rows(columns2, &()).unwrap();
commitment.try_append_record_batch(&batch2, &()).unwrap();

assert_eq!(commitment, expected_commitment);
}
}

0 comments on commit d6a1eda

Please sign in to comment.