diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index dd89afdb2..79884e4ea 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -38,7 +38,10 @@ mod record_batch_test_accessor_test; #[cfg(any(test, feature = "test"))] mod test_accessor_utility; #[cfg(any(test, feature = "test"))] -pub use test_accessor_utility::{make_random_test_accessor_data, RandomTestAccessorDescriptor}; +pub use test_accessor_utility::{ + make_random_test_accessor_data, make_random_test_accessor_owned_table, + RandomTestAccessorDescriptor, +}; mod owned_column; pub use owned_column::OwnedColumn; diff --git a/crates/proof-of-sql/src/base/database/owned_column.rs b/crates/proof-of-sql/src/base/database/owned_column.rs index f6b9fb6a0..466c40871 100644 --- a/crates/proof-of-sql/src/base/database/owned_column.rs +++ b/crates/proof-of-sql/src/base/database/owned_column.rs @@ -77,4 +77,70 @@ impl OwnedColumn { OwnedColumn::TimestampTZ(tu, tz, _) => ColumnType::TimestampTZ(*tu, *tz), } } + + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [i16], panicking if it is not. + pub fn i16_iter(&self) -> impl Iterator { + match self { + OwnedColumn::SmallInt(col) => col.iter(), + _ => panic!("Expected SmallInt column"), + } + } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [i32], panicking if it is not. + pub fn i32_iter(&self) -> impl Iterator { + match self { + OwnedColumn::Int(col) => col.iter(), + _ => panic!("Expected Int column"), + } + } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [i64], panicking if it is not. + pub fn i64_iter(&self) -> impl Iterator { + match self { + OwnedColumn::BigInt(col) => col.iter(), + OwnedColumn::TimestampTZ(_, _, col) => col.iter(), + _ => panic!("Expected TimestampTZ or BigInt column"), + } + } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [i128], panicking if it is not. + pub fn i128_iter(&self) -> impl Iterator { + match self { + OwnedColumn::Int128(col) => col.iter(), + _ => panic!("Expected Int128 column"), + } + } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [bool], panicking if it is not. + pub fn bool_iter(&self) -> impl Iterator { + match self { + OwnedColumn::Boolean(col) => col.iter(), + _ => panic!("Expected Boolean column"), + } + } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is a [Scalar], panicking if it is not. + pub fn scalar_iter(&self) -> impl Iterator { + match self { + OwnedColumn::Scalar(col) => col.iter(), + OwnedColumn::Decimal75(_, _, col) => col.iter(), + _ => panic!("Expected Scalar or Decimal75 column"), + } + } + #[cfg(test)] + /// Returns an iterator over the raw data of the column + /// assuming the underlying type is [String], panicking if it is not. + pub fn string_iter(&self) -> impl Iterator { + match self { + OwnedColumn::VarChar(col) => col.iter(), + _ => panic!("Expected VarChar column"), + } + } } diff --git a/crates/proof-of-sql/src/base/database/owned_table.rs b/crates/proof-of-sql/src/base/database/owned_table.rs index a26817394..b4c315a0b 100644 --- a/crates/proof-of-sql/src/base/database/owned_table.rs +++ b/crates/proof-of-sql/src/base/database/owned_table.rs @@ -112,3 +112,13 @@ impl PartialEq for OwnedTable { .all(|(a, b)| a == b) } } + +#[cfg(test)] +impl core::ops::Index<&str> for OwnedTable { + type Output = OwnedColumn; + fn index(&self, index: &str) -> &Self::Output { + self.table + .get(&index.parse::().unwrap()) + .unwrap() + } +} diff --git a/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs b/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs index 8b88519d2..8b897f27c 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs @@ -153,4 +153,16 @@ where res.setup = Some(setup); res } + + /// Create a new test accessor containing the provided table. + pub fn new_from_table( + table_ref: TableRef, + owned_table: OwnedTable, + offset: usize, + setup: CP::ProverPublicSetup, + ) -> Self { + let mut res = Self::new_empty_with_setup(setup); + res.add_table(table_ref, owned_table, offset); + res + } } diff --git a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs index f14397c4d..cfa9ad51a 100644 --- a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs +++ b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs @@ -1,4 +1,5 @@ -use crate::base::{database::ColumnType, time::timestamp::PoSQLTimeUnit}; +use super::{OwnedColumn, OwnedTable}; +use crate::base::{database::ColumnType, scalar::Scalar, time::timestamp::PoSQLTimeUnit}; use arrow::{ array::{ Array, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, Int64Array, @@ -144,6 +145,57 @@ pub fn make_random_test_accessor_data( RecordBatch::try_new(schema, columns).unwrap() } +/// Generate a OwnedTable with random data +/// +/// Currently, this mirrors [make_random_test_accessor_data] and is intended to replace it. +pub fn make_random_test_accessor_owned_table( + rng: &mut StdRng, + cols: &[(&str, ColumnType)], + descriptor: &RandomTestAccessorDescriptor, +) -> OwnedTable { + let n = Uniform::new(descriptor.min_rows, descriptor.max_rows + 1).sample(rng); + let dist = Uniform::new(descriptor.min_value, descriptor.max_value + 1); + + OwnedTable::try_from_iter(cols.iter().map(|(col_name, col_type)| { + let values = dist.sample_iter(&mut *rng).take(n); + ( + col_name.parse().unwrap(), + match col_type { + ColumnType::Boolean => OwnedColumn::Boolean(values.map(|x| x % 2 != 0).collect()), + ColumnType::SmallInt => { + OwnedColumn::SmallInt( + values + .map(|x| ((x >> 48) as i16)) // Shift right to align the lower 16 bits + .collect(), + ) + } + ColumnType::Int => { + OwnedColumn::Int( + values + .map(|x| ((x >> 32) as i32)) // Shift right to align the lower 32 bits + .collect(), + ) + } + ColumnType::BigInt => OwnedColumn::BigInt(values.collect()), + ColumnType::Int128 => OwnedColumn::Int128(values.map(|x| x as i128).collect()), + ColumnType::Decimal75(precision, scale) => { + OwnedColumn::Decimal75(*precision, *scale, values.map(Into::into).collect()) + } + ColumnType::VarChar => OwnedColumn::VarChar( + values + .map(|v| "s".to_owned() + &v.to_string()[..]) + .collect(), + ), + ColumnType::Scalar => OwnedColumn::Scalar(values.map(Into::into).collect()), + ColumnType::TimestampTZ(tu, tz) => { + OwnedColumn::TimestampTZ(*tu, *tz, values.collect()) + } + }, + ) + })) + .unwrap() +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/proof-of-sql/src/sql/ast/and_expr_test.rs b/crates/proof-of-sql/src/sql/ast/and_expr_test.rs index 8f0d2e64c..cb270c180 100644 --- a/crates/proof-of-sql/src/sql/ast/and_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/and_expr_test.rs @@ -1,12 +1,11 @@ -use super::{test_utility::*, FilterExpr, ProvableExpr}; +use super::{test_utility::*, ProvableExpr}; use crate::{ base::{ commitment::InnerProductProof, database::{ - make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, OwnedTable, + make_random_test_accessor_owned_table, owned_table_utility::*, Column, ColumnType, OwnedTableTestAccessor, RandomTestAccessorDescriptor, TestAccessor, }, - scalar::Curve25519Scalar, }, sql::{ ast::{ @@ -18,46 +17,12 @@ use crate::{ }; use bumpalo::Bump; use curve25519_dalek::ristretto::RistrettoPoint; -use polars::prelude::*; +use itertools::{multizip, MultiUnzip}; use rand::{ distributions::{Distribution, Uniform}, rngs::StdRng, }; use rand_core::SeedableRng; -/// This function creates a TestAccessor, adds a table, and then creates a FilterExpr with the given parameters. -/// It then executes the query, verifies the result, and returns the table. -fn create_and_verify_test_and_expr( - table_ref: &str, - results: &[&str], - lhs: (&str, impl Into), - rhs: (&str, impl Into), - data: OwnedTable, - offset: usize, -) -> OwnedTable { - let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, offset); - let and_expr = and( - equal(column(t, lhs.0, &accessor), const_scalar(lhs.1.into())), - equal(column(t, rhs.0, &accessor), const_scalar(rhs.1.into())), - ); - let ast = FilterExpr::new(cols_result(t, results, &accessor), tab(t), and_expr); - let res = VerifiableQueryResult::new(&ast, &accessor, &()); - exercise_verification(&res, &ast, &accessor, t); - res.verify(&ast, &accessor, &()).unwrap().table -} -/// This function filters the given data using polars with the given parameters. -fn filter_test_and_expr( - results: &[&str], - lhs: (&str, impl polars::prelude::Literal), - rhs: (&str, impl polars::prelude::Literal), - data: OwnedTable, -) -> OwnedTable { - let df_filter = polars::prelude::col(lhs.0) - .eq(lit(lhs.1)) - .and(polars::prelude::col(rhs.0).eq(lit(rhs.1))); - data.apply_polars_filter(results, df_filter) -} #[test] fn we_can_prove_a_simple_and_query() { @@ -67,7 +32,19 @@ fn we_can_prove_a_simple_and_query() { varchar("d", ["ab", "t", "efg", "g"]), bigint("c", [0, 2, 2, 0]), ]); - let res = create_and_verify_test_and_expr("sxt.t", &["a", "d"], ("b", 1), ("d", "t"), data, 0); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + and( + equal(column(t, "b", &accessor), const_scalar(1)), + equal(column(t, "d", &accessor), const_scalar("t")), + ), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; let expected_res = owned_table([bigint("a", [2]), varchar("d", ["t"])]); assert_eq!(res, expected_res); } @@ -80,7 +57,19 @@ fn we_can_prove_a_simple_and_query_with_128_bits() { varchar("d", ["ab", "t", "efg", "g"]), int128("c", [0, 2, 2, 0]), ]); - let res = create_and_verify_test_and_expr("sxt.t", &["a", "d"], ("b", 1), ("d", "t"), data, 0); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + and( + equal(column(t, "b", &accessor), const_scalar(1)), + equal(column(t, "d", &accessor), const_scalar("t")), + ), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; let expected_res = owned_table([int128("a", [2]), varchar("d", ["t"])]); assert_eq!(res, expected_res); } @@ -100,18 +89,50 @@ fn test_random_tables_with_given_offset(offset: usize) { ("d", ColumnType::VarChar), ]; for _ in 0..20 { - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let data = OwnedTable::try_from(data).unwrap(); + let data = make_random_test_accessor_owned_table(&mut rng, &cols, &descr); let filter_val1 = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); let filter_val1 = format!("s{filter_val1}"); let filter_val2 = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let results = &["a", "d"]; - let lhs = ("b", filter_val1.as_str()); - let rhs = ("c", filter_val2); - assert_eq!( - filter_test_and_expr(results, lhs, rhs, data.clone()), - create_and_verify_test_and_expr("sxt.t", results, lhs, rhs, data, offset) - ) + + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), + offset, + (), + ); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + and( + equal( + column(t, "b", &accessor), + const_scalar(filter_val1.as_str()), + ), + equal(column(t, "c", &accessor), const_scalar(filter_val2)), + ), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + let (expected_a, expected_d): (Vec<_>, Vec<_>) = multizip(( + data["a"].i64_iter(), + data["b"].string_iter(), + data["c"].i64_iter(), + data["d"].string_iter(), + )) + .filter_map(|(a, b, c, d)| { + if b == &filter_val1 && c == &filter_val2 { + Some((*a, d.clone())) + } else { + None + } + }) + .multiunzip(); + let expected_result = owned_table([bigint("a", expected_a), varchar("d", expected_d)]); + + assert_eq!(expected_result, res) } } diff --git a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs index 019f5d81c..f23925e6a 100644 --- a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{ base::{ - database::{CommitmentAccessor, RecordBatchTestAccessor, TableRef, TestAccessor}, + database::{CommitmentAccessor, OwnedTableTestAccessor, TableRef, TestAccessor}, scalar::{compute_commitment_for_testing, Curve25519Scalar}, }, sql::proof::Indexes, @@ -90,7 +90,7 @@ fn tamper_no_result( counts, ..Default::default() }; - let accessor_p = RecordBatchTestAccessor::new_empty(); + let accessor_p = OwnedTableTestAccessor::::new_empty_with_setup(()); let (proof, _result) = QueryProof::new(&expr_p, &accessor_p, &()); res_p.proof = Some(proof); assert!(res_p.verify(expr, accessor, &()).is_err());