Skip to content

Commit

Permalink
refactor: update and expr tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JayWhite2357 committed Jun 20, 2024
1 parent b1a01ae commit d6bcc65
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 53 deletions.
5 changes: 4 additions & 1 deletion crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ mod record_batch_test_accessor_test;
#[cfg(any(test, feature = "test"))]
mod test_accessor_utility;
#[cfg(any(test, feature = "test"))]
pub use test_accessor_utility::{make_random_test_accessor_data, RandomTestAccessorDescriptor};
pub use test_accessor_utility::{
make_random_test_accessor_data, make_random_test_accessor_owned_table,
RandomTestAccessorDescriptor,
};

mod owned_column;
pub use owned_column::OwnedColumn;
Expand Down
66 changes: 66 additions & 0 deletions crates/proof-of-sql/src/base/database/owned_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,70 @@ impl<S: Scalar> OwnedColumn<S> {
OwnedColumn::TimestampTZ(tu, tz, _) => ColumnType::TimestampTZ(*tu, *tz),
}
}

#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is [i16], panicking if it is not.
pub fn i16_iter(&self) -> impl Iterator<Item = &i16> {
match self {
OwnedColumn::SmallInt(col) => col.iter(),
_ => panic!("Expected SmallInt column"),
}
}
#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is [i32], panicking if it is not.
pub fn i32_iter(&self) -> impl Iterator<Item = &i32> {
match self {
OwnedColumn::Int(col) => col.iter(),
_ => panic!("Expected Int column"),
}
}
#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is [i64], panicking if it is not.
pub fn i64_iter(&self) -> impl Iterator<Item = &i64> {
match self {
OwnedColumn::BigInt(col) => col.iter(),
OwnedColumn::TimestampTZ(_, _, col) => col.iter(),
_ => panic!("Expected TimestampTZ or BigInt column"),
}
}
#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is [i128], panicking if it is not.
pub fn i128_iter(&self) -> impl Iterator<Item = &i128> {
match self {
OwnedColumn::Int128(col) => col.iter(),
_ => panic!("Expected Int128 column"),
}
}
#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is [bool], panicking if it is not.
pub fn bool_iter(&self) -> impl Iterator<Item = &bool> {
match self {
OwnedColumn::Boolean(col) => col.iter(),
_ => panic!("Expected Boolean column"),
}
}
#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is a [Scalar], panicking if it is not.
pub fn scalar_iter(&self) -> impl Iterator<Item = &S> {
match self {
OwnedColumn::Scalar(col) => col.iter(),
OwnedColumn::Decimal75(_, _, col) => col.iter(),
_ => panic!("Expected Scalar or Decimal75 column"),
}
}
#[cfg(test)]
/// Returns an iterator over the raw data of the column
/// assuming the underlying type is [String], panicking if it is not.
pub fn string_iter(&self) -> impl Iterator<Item = &String> {
match self {
OwnedColumn::VarChar(col) => col.iter(),
_ => panic!("Expected VarChar column"),
}
}
}
10 changes: 10 additions & 0 deletions crates/proof-of-sql/src/base/database/owned_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,13 @@ impl<S: Scalar> PartialEq for OwnedTable<S> {
.all(|(a, b)| a == b)
}
}

#[cfg(test)]
impl<S: Scalar> core::ops::Index<&str> for OwnedTable<S> {
type Output = OwnedColumn<S>;
fn index(&self, index: &str) -> &Self::Output {
self.table
.get(&index.parse::<Identifier>().unwrap())
.unwrap()
}
}
12 changes: 12 additions & 0 deletions crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,16 @@ where
res.setup = Some(setup);
res
}

/// Create a new test accessor containing the provided table.
pub fn new_from_table(
table_ref: TableRef,
owned_table: OwnedTable<CP::Scalar>,
offset: usize,
setup: CP::ProverPublicSetup,
) -> Self {
let mut res = Self::new_empty_with_setup(setup);
res.add_table(table_ref, owned_table, offset);
res
}
}
54 changes: 53 additions & 1 deletion crates/proof-of-sql/src/base/database/test_accessor_utility.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::base::{database::ColumnType, time::timestamp::PoSQLTimeUnit};
use super::{OwnedColumn, OwnedTable};
use crate::base::{database::ColumnType, scalar::Scalar, time::timestamp::PoSQLTimeUnit};
use arrow::{
array::{
Array, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, Int64Array,
Expand Down Expand Up @@ -144,6 +145,57 @@ pub fn make_random_test_accessor_data(
RecordBatch::try_new(schema, columns).unwrap()
}

/// Generate a OwnedTable with random data
///
/// Currently, this mirrors [make_random_test_accessor_data] and is intended to replace it.
pub fn make_random_test_accessor_owned_table<S: Scalar>(
rng: &mut StdRng,
cols: &[(&str, ColumnType)],
descriptor: &RandomTestAccessorDescriptor,
) -> OwnedTable<S> {
let n = Uniform::new(descriptor.min_rows, descriptor.max_rows + 1).sample(rng);
let dist = Uniform::new(descriptor.min_value, descriptor.max_value + 1);

OwnedTable::try_from_iter(cols.iter().map(|(col_name, col_type)| {
let values = dist.sample_iter(&mut *rng).take(n);
(
col_name.parse().unwrap(),
match col_type {
ColumnType::Boolean => OwnedColumn::Boolean(values.map(|x| x % 2 != 0).collect()),
ColumnType::SmallInt => {
OwnedColumn::SmallInt(
values
.map(|x| ((x >> 48) as i16)) // Shift right to align the lower 16 bits
.collect(),
)
}
ColumnType::Int => {
OwnedColumn::Int(
values
.map(|x| ((x >> 32) as i32)) // Shift right to align the lower 32 bits
.collect(),
)
}
ColumnType::BigInt => OwnedColumn::BigInt(values.collect()),
ColumnType::Int128 => OwnedColumn::Int128(values.map(|x| x as i128).collect()),
ColumnType::Decimal75(precision, scale) => {
OwnedColumn::Decimal75(*precision, *scale, values.map(Into::into).collect())
}
ColumnType::VarChar => OwnedColumn::VarChar(
values
.map(|v| "s".to_owned() + &v.to_string()[..])
.collect(),
),
ColumnType::Scalar => OwnedColumn::Scalar(values.map(Into::into).collect()),
ColumnType::TimestampTZ(tu, tz) => {
OwnedColumn::TimestampTZ(*tu, *tz, values.collect())
}
},
)
}))
.unwrap()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
119 changes: 70 additions & 49 deletions crates/proof-of-sql/src/sql/ast/and_expr_test.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use super::{test_utility::*, FilterExpr, ProvableExpr};
use super::{test_utility::*, ProvableExpr};
use crate::{
base::{
commitment::InnerProductProof,
database::{
make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, OwnedTable,
make_random_test_accessor_owned_table, owned_table_utility::*, Column, ColumnType,
OwnedTableTestAccessor, RandomTestAccessorDescriptor, TestAccessor,
},
scalar::Curve25519Scalar,
},
sql::{
ast::{
Expand All @@ -18,46 +17,12 @@ use crate::{
};
use bumpalo::Bump;
use curve25519_dalek::ristretto::RistrettoPoint;
use polars::prelude::*;
use itertools::{multizip, MultiUnzip};
use rand::{
distributions::{Distribution, Uniform},
rngs::StdRng,
};
use rand_core::SeedableRng;
/// This function creates a TestAccessor, adds a table, and then creates a FilterExpr with the given parameters.
/// It then executes the query, verifies the result, and returns the table.
fn create_and_verify_test_and_expr(
table_ref: &str,
results: &[&str],
lhs: (&str, impl Into<Curve25519Scalar>),
rhs: (&str, impl Into<Curve25519Scalar>),
data: OwnedTable<Curve25519Scalar>,
offset: usize,
) -> OwnedTable<Curve25519Scalar> {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let t = table_ref.parse().unwrap();
accessor.add_table(t, data, offset);
let and_expr = and(
equal(column(t, lhs.0, &accessor), const_scalar(lhs.1.into())),
equal(column(t, rhs.0, &accessor), const_scalar(rhs.1.into())),
);
let ast = FilterExpr::new(cols_result(t, results, &accessor), tab(t), and_expr);
let res = VerifiableQueryResult::new(&ast, &accessor, &());
exercise_verification(&res, &ast, &accessor, t);
res.verify(&ast, &accessor, &()).unwrap().table
}
/// This function filters the given data using polars with the given parameters.
fn filter_test_and_expr(
results: &[&str],
lhs: (&str, impl polars::prelude::Literal),
rhs: (&str, impl polars::prelude::Literal),
data: OwnedTable<Curve25519Scalar>,
) -> OwnedTable<Curve25519Scalar> {
let df_filter = polars::prelude::col(lhs.0)
.eq(lit(lhs.1))
.and(polars::prelude::col(rhs.0).eq(lit(rhs.1)));
data.apply_polars_filter(results, df_filter)
}

#[test]
fn we_can_prove_a_simple_and_query() {
Expand All @@ -67,7 +32,19 @@ fn we_can_prove_a_simple_and_query() {
varchar("d", ["ab", "t", "efg", "g"]),
bigint("c", [0, 2, 2, 0]),
]);
let res = create_and_verify_test_and_expr("sxt.t", &["a", "d"], ("b", 1), ("d", "t"), data, 0);
let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t, data, 0, ());
let ast = dense_filter(
cols_expr_plan(t, &["a", "d"], &accessor),
tab(t),
and(
equal(column(t, "b", &accessor), const_scalar(1)),
equal(column(t, "d", &accessor), const_scalar("t")),
),
);
let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &());
exercise_verification(&verifiable_res, &ast, &accessor, t);
let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table;
let expected_res = owned_table([bigint("a", [2]), varchar("d", ["t"])]);
assert_eq!(res, expected_res);
}
Expand All @@ -80,7 +57,19 @@ fn we_can_prove_a_simple_and_query_with_128_bits() {
varchar("d", ["ab", "t", "efg", "g"]),
int128("c", [0, 2, 2, 0]),
]);
let res = create_and_verify_test_and_expr("sxt.t", &["a", "d"], ("b", 1), ("d", "t"), data, 0);
let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t, data, 0, ());
let ast = dense_filter(
cols_expr_plan(t, &["a", "d"], &accessor),
tab(t),
and(
equal(column(t, "b", &accessor), const_scalar(1)),
equal(column(t, "d", &accessor), const_scalar("t")),
),
);
let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &());
exercise_verification(&verifiable_res, &ast, &accessor, t);
let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table;
let expected_res = owned_table([int128("a", [2]), varchar("d", ["t"])]);
assert_eq!(res, expected_res);
}
Expand All @@ -100,18 +89,50 @@ fn test_random_tables_with_given_offset(offset: usize) {
("d", ColumnType::VarChar),
];
for _ in 0..20 {
let data = make_random_test_accessor_data(&mut rng, &cols, &descr);
let data = OwnedTable::try_from(data).unwrap();
let data = make_random_test_accessor_owned_table(&mut rng, &cols, &descr);
let filter_val1 = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng);
let filter_val1 = format!("s{filter_val1}");
let filter_val2 = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng);
let results = &["a", "d"];
let lhs = ("b", filter_val1.as_str());
let rhs = ("c", filter_val2);
assert_eq!(
filter_test_and_expr(results, lhs, rhs, data.clone()),
create_and_verify_test_and_expr("sxt.t", results, lhs, rhs, data, offset)
)

let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(
t,
data.clone(),
offset,
(),
);
let ast = dense_filter(
cols_expr_plan(t, &["a", "d"], &accessor),
tab(t),
and(
equal(
column(t, "b", &accessor),
const_scalar(filter_val1.as_str()),
),
equal(column(t, "c", &accessor), const_scalar(filter_val2)),
),
);
let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &());
exercise_verification(&verifiable_res, &ast, &accessor, t);
let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table;

let (expected_a, expected_d): (Vec<_>, Vec<_>) = multizip((
data["a"].i64_iter(),
data["b"].string_iter(),
data["c"].i64_iter(),
data["d"].string_iter(),
))
.filter_map(|(a, b, c, d)| {
if b == &filter_val1 && c == &filter_val2 {
Some((*a, d.clone()))
} else {
None
}
})
.multiunzip();
let expected_result = owned_table([bigint("a", expected_a), varchar("d", expected_d)]);

assert_eq!(expected_result, res)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
};
use crate::{
base::{
database::{CommitmentAccessor, RecordBatchTestAccessor, TableRef, TestAccessor},
database::{CommitmentAccessor, OwnedTableTestAccessor, TableRef, TestAccessor},
scalar::{compute_commitment_for_testing, Curve25519Scalar},
},
sql::proof::Indexes,
Expand Down Expand Up @@ -90,7 +90,7 @@ fn tamper_no_result(
counts,
..Default::default()
};
let accessor_p = RecordBatchTestAccessor::new_empty();
let accessor_p = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let (proof, _result) = QueryProof::new(&expr_p, &accessor_p, &());
res_p.proof = Some(proof);
assert!(res_p.verify(expr, accessor, &()).is_err());
Expand Down

0 comments on commit d6bcc65

Please sign in to comment.