diff --git a/.github/workflows/lint-and-test.yml b/.github/workflows/lint-and-test.yml index 628a8f59d..4d729345c 100644 --- a/.github/workflows/lint-and-test.yml +++ b/.github/workflows/lint-and-test.yml @@ -48,6 +48,14 @@ jobs: run: cargo check -p proof-of-sql --no-default-features --features="test" - name: Run cargo check (proof-of-sql) (just "blitzar" feature) run: cargo check -p proof-of-sql --no-default-features --features="blitzar" + - name: Run cargo check (proof-of-sql) (just "polars" feature) + run: cargo check -p proof-of-sql --no-default-features --features="polars" + - name: Run cargo check (proof-of-sql) (no "test" feature) + run: cargo check -p proof-of-sql --no-default-features --features="blitzar polars" + - name: Run cargo check (proof-of-sql) (no "blitzar" feature) + run: cargo check -p proof-of-sql --no-default-features --features="test polars" + - name: Run cargo check (proof-of-sql) (no "polars" feature) + run: cargo check -p proof-of-sql --no-default-features --features="blitzar test" test: name: Test Suite diff --git a/Cargo.toml b/Cargo.toml index 1913112a2..aa6f52a11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ bigdecimal = { version = "0.4.5", features = ["serde"] } blake3 = { version = "1.3.3" } blitzar = { version = "3.0.2" } bumpalo = { version = "3.11.0" } -bytemuck = {version = "1.14.2" } +bytemuck = {version = "1.14.2", features = ["derive"] } byte-slice-cast = { version = "1.2.1" } clap = { version = "4.5.4" } criterion = { version = "0.5.1" } @@ -35,7 +35,7 @@ derive_more = { version = "0.99" } dyn_partial_eq = { version = "0.1.2" } flexbuffers = { version = "2.0.0" } hashbrown = { version = "0.14.0" } -indexmap = { version = "2.1" } +indexmap = { version = "2.1", features = ["serde"] } itertools = { version = "0.13.0" } lalrpop-util = { version = "0.20.0" } lazy_static = { version = "1.4.0" } diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index d6a0736e9..c8157bf28 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -40,7 +40,7 @@ lazy_static = { workspace = true } merlin = { workspace = true } num-traits = { workspace = true } num-bigint = { workspace = true, default-features = false } -polars = { workspace = true, features = ["lazy", "bigidx", "dtype-decimal", "serde-lazy"] } +polars = { workspace = true, features = ["lazy", "bigidx", "dtype-decimal", "serde-lazy"], optional = true } postcard = { workspace = true, features = ["alloc"] } proof-of-sql-parser = { workspace = true } rand = { workspace = true, optional = true } @@ -59,7 +59,7 @@ clap = { workspace = true, features = ["derive"] } criterion = { workspace = true, features = ["html_reports"] } opentelemetry = { workspace = true } opentelemetry-jaeger = { workspace = true } -polars = { workspace = true, features = ["lazy"] } +polars = { workspace = true, features = ["lazy", "dtype-decimal"] } rand = { workspace = true } rand_core = { workspace = true } serde_json = { workspace = true } @@ -69,7 +69,7 @@ tracing-subscriber = { workspace = true } flexbuffers = { workspace = true } [features] -default = ["blitzar"] +default = ["blitzar", "polars"] test = ["dep:rand"] [lints] diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index fb5dd908c..2912fbf7c 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -193,12 +193,6 @@ impl<'a, S: Scalar> Column<'a, S> { } } -/// The precision for [ColumnType::INT128] values -pub const INT128_PRECISION: usize = 38; - -/// The scale for [ColumnType::INT128] values -pub const INT128_SCALE: usize = 0; - /// Represents the supported data types of a column in an in-memory, /// column-oriented database. /// diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index dd89afdb2..3d76a63e1 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -6,7 +6,6 @@ pub use accessor::{CommitmentAccessor, DataAccessor, MetadataAccessor, SchemaAcc mod column; pub use column::{Column, ColumnField, ColumnRef, ColumnType}; -pub(crate) use column::{INT128_PRECISION, INT128_SCALE}; mod literal_value; pub use literal_value::LiteralValue; @@ -17,7 +16,9 @@ pub use table_ref::TableRef; mod arrow_array_to_column_conversion; pub use arrow_array_to_column_conversion::{ArrayRefExt, ArrowArrayToColumnConversionError}; +#[cfg(any(test, feature = "polars"))] mod record_batch_dataframe_conversion; +#[cfg(any(test, feature = "polars"))] pub(crate) use record_batch_dataframe_conversion::{ dataframe_to_record_batch, record_batch_to_dataframe, }; @@ -59,6 +60,7 @@ mod test_accessor; #[cfg(any(test, feature = "test"))] pub use test_accessor::TestAccessor; #[cfg(test)] +#[allow(unused_imports)] pub(crate) use test_accessor::UnimplementedTestAccessor; #[cfg(any(test, feature = "test"))] diff --git a/crates/proof-of-sql/src/base/database/record_batch_test_accessor.rs b/crates/proof-of-sql/src/base/database/record_batch_test_accessor.rs index a67766ae4..cb58d2501 100644 --- a/crates/proof-of-sql/src/base/database/record_batch_test_accessor.rs +++ b/crates/proof-of-sql/src/base/database/record_batch_test_accessor.rs @@ -1,13 +1,15 @@ +#[cfg(any(test, feature = "polars"))] +use super::{dataframe_to_record_batch, record_batch_to_dataframe}; use super::{ - dataframe_to_record_batch, record_batch_to_dataframe, ArrayRefExt, Column, ColumnRef, - ColumnType, CommitmentAccessor, DataAccessor, MetadataAccessor, SchemaAccessor, TableRef, - TestAccessor, + ArrayRefExt, Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor, MetadataAccessor, + SchemaAccessor, TableRef, TestAccessor, }; use crate::base::scalar::{compute_commitment_for_testing, Curve25519Scalar}; use arrow::{array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; use bumpalo::Bump; use curve25519_dalek::ristretto::RistrettoPoint; use indexmap::IndexMap; +#[cfg(any(test, feature = "polars"))] use polars::prelude::DataFrame; use proof_of_sql_parser::Identifier; use std::collections::HashMap; @@ -114,6 +116,7 @@ impl TestAccessor for RecordBatchTestAccessor { impl RecordBatchTestAccessor { /// Apply a query function to table and then convert the result to a RecordBatch + #[cfg(any(test, feature = "polars"))] pub fn query_table( &self, table_ref: TableRef, diff --git a/crates/proof-of-sql/src/sql/ast/and_expr_test.rs b/crates/proof-of-sql/src/sql/ast/and_expr_test.rs index bf3f4c685..840b5c7c1 100644 --- a/crates/proof-of-sql/src/sql/ast/and_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/and_expr_test.rs @@ -104,9 +104,9 @@ fn test_random_tables_with_given_offset(offset: usize) { and( equal( column(t, "b", &accessor), - const_scalar(filter_val1.as_str()), + const_varchar(filter_val1.as_str()), ), - equal(column(t, "c", &accessor), const_scalar(filter_val2)), + equal(column(t, "c", &accessor), const_bigint(filter_val2)), ), ); let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); diff --git a/crates/proof-of-sql/src/sql/ast/column_expr_test.rs b/crates/proof-of-sql/src/sql/ast/column_expr_test.rs index 5f6b5e62d..ecfe83f79 100644 --- a/crates/proof-of-sql/src/sql/ast/column_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/column_expr_test.rs @@ -1,30 +1,23 @@ use crate::{ - base::database::{RecordBatchTestAccessor, TestAccessor}, - record_batch, - sql::ast::{test_expr::TestExprNode, test_utility::*}, + base::{ + commitment::InnerProductProof, + database::{owned_table_utility::*, OwnedTableTestAccessor}, + }, + sql::{ + ast::test_utility::*, + proof::{exercise_verification, VerifiableQueryResult}, + }, }; -use arrow::record_batch::RecordBatch; - -fn create_test_bool_col_expr( - table_ref: &str, - results: &[&str], - filter_col: &str, - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let table_ref = table_ref.parse().unwrap(); - accessor.add_table(table_ref, data, offset); - let col_expr = column(table_ref, filter_col, &accessor); - let df_filter = polars::prelude::col(filter_col); - TestExprNode::new(table_ref, results, col_expr, df_filter, accessor) -} #[test] fn we_can_prove_a_query_with_a_single_selected_row() { - let data = record_batch!("a" => [true, false]); - let test_expr = create_test_bool_col_expr("sxt.t", &["a"], "a", data.clone(), 0); - let res = test_expr.verify_expr(); - let expected = record_batch!("a" => [true]); - assert_eq!(res, expected); + let data = owned_table([boolean("a", [true, false])]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = projection(cols_expr_plan(t, &["a"], &accessor), tab(t)); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([boolean("a", [true, false])]); + assert_eq!(res, expected_res); } diff --git a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs index d8ad821f9..8f3460388 100644 --- a/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/dense_filter_expr_test.rs @@ -1,27 +1,15 @@ -use crate::{ - base::{database::owned_table_utility::*, math::decimal::Precision}, - sql::ast::{test_utility::*, ProvableExprPlan}, -}; use crate::{ base::{ database::{ - ColumnField, ColumnRef, ColumnType, LiteralValue, OwnedTable, OwnedTableTestAccessor, - RecordBatchTestAccessor, TableRef, TestAccessor, + owned_table_utility::*, ColumnField, ColumnRef, ColumnType, LiteralValue, OwnedTable, + OwnedTableTestAccessor, TableRef, TestAccessor, }, + math::decimal::Precision, scalar::Curve25519Scalar, }, - record_batch, sql::{ ast::{ - // Making this explicit to ensure that we don't accidentally use the - // sparse filter for these tests - test_utility::{ - col_expr_plan, cols_expr_plan, column, const_int128, dense_filter, equal, tab, - }, - ColumnExpr, - DenseFilterExpr, - LiteralExpr, - TableExpr, + test_utility::*, ColumnExpr, DenseFilterExpr, LiteralExpr, ProvableExprPlan, TableExpr, }, proof::{ exercise_verification, ProofExpr, ProverEvaluate, ResultBuilder, VerifiableQueryResult, @@ -177,24 +165,19 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { #[test] fn we_can_prove_and_get_the_correct_result_from_a_basic_dense_filter() { - let data = record_batch!( - "a" => [1_i64, 4_i64, 5_i64, 2_i64, 5_i64], - "b" => [1_i64, 2, 3, 4, 5], - ); + let data = owned_table([ + bigint("a", [1_i64, 4_i64, 5_i64, 2_i64, 5_i64]), + bigint("b", [1_i64, 2, 3, 4, 5]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let where_clause = equal(column(t, "a", &accessor), const_int128(5_i128)); - let expr = dense_filter(cols_expr_plan(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [3_i64, 5], - ); - assert_eq!(res, expected); + let ast = dense_filter(cols_expr_plan(t, &["b"], &accessor), tab(t), where_clause); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [3_i64, 5])]); + assert_eq!(res, expected_res); } #[test] diff --git a/crates/proof-of-sql/src/sql/ast/equals_expr_test.rs b/crates/proof-of-sql/src/sql/ast/equals_expr_test.rs index faaf6bd23..0ab428fe5 100644 --- a/crates/proof-of-sql/src/sql/ast/equals_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/equals_expr_test.rs @@ -1,89 +1,23 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{ - make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, OwnedTable, - OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor, - TestAccessor, - }, + database::{owned_table_utility::*, Column, OwnedTable, OwnedTableTestAccessor}, scalar::{Curve25519Scalar, Scalar}, }, - record_batch, - sql::ast::{test_expr::TestExprNode, test_utility::*, ProvableExpr, ProvableExprPlan}, + sql::{ + ast::{test_utility::*, ProvableExpr, ProvableExprPlan}, + proof::{exercise_verification, VerifiableQueryResult}, + }, }; -use arrow::record_batch::RecordBatch; use bumpalo::Bump; use curve25519_dalek::ristretto::RistrettoPoint; -use polars::prelude::*; +use itertools::{multizip, MultiUnzip}; use rand::{ distributions::{Distribution, Uniform}, rngs::StdRng, }; use rand_core::SeedableRng; -fn create_test_col_lit_equals_expr + Copy + Literal>( - table_ref: &str, - results: &[&str], - filter_col: &str, - filter_val: T, - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, offset); - let equals_expr = equal( - column(t, filter_col, &accessor), - const_scalar(filter_val.into()), - ); - let df_filter = polars::prelude::col(filter_col).eq(lit(filter_val)); - TestExprNode::new(t, results, equals_expr, df_filter, accessor) -} - -fn create_test_col_equals_expr( - table_ref: &str, - results: &[&str], - filter_col_lhs: &str, - filter_col_rhs: &str, - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, offset); - let equals_expr = equal( - column(t, filter_col_lhs, &accessor), - column(t, filter_col_rhs, &accessor), - ); - let df_filter = polars::prelude::col(filter_col_lhs).eq(col(filter_col_rhs)); - TestExprNode::new(t, results, equals_expr, df_filter, accessor) -} - -// col_bool = (col_lhs = col_rhs) -fn create_test_complex_col_equals_expr( - table_ref: &str, - results: &[&str], - filter_col_bool: &str, - filter_col_lhs: &str, - filter_col_rhs: &str, - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, offset); - let equals_expr = equal( - column(t, filter_col_bool, &accessor), - equal( - column(t, filter_col_lhs, &accessor), - column(t, filter_col_rhs, &accessor), - ), - ); - let df_filter = - polars::prelude::col(filter_col_bool).eq(col(filter_col_lhs).eq(col(filter_col_rhs))); - TestExprNode::new(t, results, equals_expr, df_filter, accessor) -} - #[test] fn we_can_prove_an_equality_query_with_no_rows() { let data: OwnedTable = owned_table([ @@ -92,20 +26,17 @@ fn we_can_prove_an_equality_query_with_no_rows() { varchar("d", [""; 0]), decimal75("e", 75, 0, [0; 0]), ]); - - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["a", "d"], - "b", - 0_i64, - data.try_into().unwrap(), - 0, - ); - let res = test_expr.verify_expr(); - let expected_res = record_batch!( - "a" => Vec::::new(), - "d" => Vec::::new(), + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + equal(column(t, "b", &accessor), const_bigint(0_i64)), ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [0; 0]), varchar("d", [""; 0])]); assert_eq!(res, expected_res); } @@ -117,14 +48,17 @@ fn we_can_prove_another_equality_query_with_no_rows() { varchar("d", [""; 0]), decimal75("e", 75, 0, [0; 0]), ]); - - let test_expr = - create_test_col_equals_expr("sxt.t", &["a", "d"], "a", "b", data.try_into().unwrap(), 0); - let res = test_expr.verify_expr(); - let expected_res = record_batch!( - "a" => Vec::::new(), - "d" => Vec::::new(), + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + equal(column(t, "a", &accessor), column(t, "b", &accessor)), ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [0; 0]), varchar("d", [""; 0])]); assert_eq!(res, expected_res); } @@ -137,25 +71,25 @@ fn we_can_prove_a_nested_equality_query_with_no_rows() { varchar("c", ["t"; 0]), decimal75("e", 75, 0, [0; 0]), ]); - - let test_expr = create_test_complex_col_equals_expr( - "sxt.t", - &["b", "c", "e"], - "bool", - "a", - "b", - data.try_into().unwrap(), - 0, + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b", "c", "e"], &accessor), + tab(t), + equal( + column(t, "bool", &accessor), + equal(column(t, "a", &accessor), column(t, "b", &accessor)), + ), ); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("b", [1; 0]), varchar("c", ["t"; 0]), decimal75("e", 75, 0, [0; 0]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -166,22 +100,17 @@ fn we_can_prove_an_equality_query_with_a_single_selected_row() { varchar("d", ["abc"]), decimal75("e", 75, 0, [0]), ]); - - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["d", "a"], - "b", - 0_i64, - data.try_into().unwrap(), - 0, - ); - let res = test_expr.verify_expr(); - - let expected_res = record_batch!( - "d" => ["abc"], - "a" => [123_i64], + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["d", "a"], &accessor), + tab(t), + equal(column(t, "b", &accessor), const_bigint(0_i64)), ); - + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([varchar("d", ["abc"]), bigint("a", [123_i64])]); assert_eq!(res, expected_res); } @@ -193,16 +122,17 @@ fn we_can_prove_another_equality_query_with_a_single_selected_row() { varchar("d", ["abc"]), decimal75("e", 75, 0, [0]), ]); - - let test_expr = - create_test_col_equals_expr("sxt.t", &["d", "a"], "a", "b", data.try_into().unwrap(), 0); - let res = test_expr.verify_expr(); - - let expected_res = record_batch!( - "d" => ["abc"], - "a" => [123_i64], + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["d", "a"], &accessor), + tab(t), + equal(column(t, "a", &accessor), column(t, "b", &accessor)), ); - + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([varchar("d", ["abc"]), bigint("a", [123_i64])]); assert_eq!(res, expected_res); } @@ -214,24 +144,22 @@ fn we_can_prove_an_equality_query_with_a_single_non_selected_row() { varchar("d", ["abc"]), decimal75("e", 75, 0, [Curve25519Scalar::MAX_SIGNED]), ]); - - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["a", "d", "e"], - "b", - 0_i64, - data.try_into().unwrap(), - 0, + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d", "e"], &accessor), + tab(t), + equal(column(t, "b", &accessor), const_bigint(0_i64)), ); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [0; 0]), varchar("d", [""; 0]), decimal75("e", 75, 0, [0; 0]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -252,24 +180,22 @@ fn we_can_prove_an_equality_query_with_multiple_rows() { ], ), ]); - - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["a", "c", "e"], - "b", - 0_i64, - data.try_into().unwrap(), - 0, + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "c", "e"], &accessor), + tab(t), + equal(column(t, "b", &accessor), const_bigint(0_i64)), ); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [1, 3]), varchar("c", ["t", "jj"]), decimal75("e", 75, 0, [0, 2]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -291,25 +217,25 @@ fn we_can_prove_a_nested_equality_query_with_multiple_rows() { ], ), ]); - - let test_expr = create_test_complex_col_equals_expr( - "sxt.t", - &["a", "c", "e"], - "bool", - "a", - "b", - data.try_into().unwrap(), - 0, + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "c", "e"], &accessor), + tab(t), + equal( + column(t, "bool", &accessor), + equal(column(t, "a", &accessor), column(t, "b", &accessor)), + ), ); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [1, 2]), varchar("c", ["t", "ghi"]), decimal75("e", 75, 0, [0, 1]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -331,24 +257,22 @@ fn we_can_prove_an_equality_query_with_a_nonzero_comparison() { ], ), ]); - - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["a", "c", "e"], - "b", - 123_u64, - data.try_into().unwrap(), - 0, + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "c", "e"], &accessor), + tab(t), + equal(column(t, "b", &accessor), const_bigint(123_i64)), ); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [1, 3]), varchar("c", ["t", "jj"]), decimal75("e", 42, 10, vec![0, 2]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -371,103 +295,95 @@ fn we_can_prove_an_equality_query_with_a_string_comparison() { ], ), ]); - - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["a", "b", "e"], - "c", - "ghi", - data.try_into().unwrap(), - 0, + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "b", "e"], &accessor), + tab(t), + equal(column(t, "c", &accessor), const_varchar("ghi")), ); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [2, 5]), bigint("b", [5, 0]), decimal75("e", 42, 10, [1, -1]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); -} - -#[test] -fn verify_fails_if_data_between_prover_and_verifier_differ() { - let data = record_batch!( - "a" => [1_i64, 2, 3, 4], - "c" => ["t", "ghi", "jj", "f"], - "b" => [0_i64, 5, 0, 5], - ); - let test_expr = create_test_col_lit_equals_expr("sxt.t", &["a", "c"], "b", 0_u64, data, 0); - - let data = record_batch!( - "a" => [1_i64, 2, 3, 4], - "c" => ["t", "ghi", "jj", "f"], - "b" => [0_i64, 2, 0, 5], - ); - let tampered_test_expr = - create_test_col_lit_equals_expr("sxt.t", &["a", "c"], "b", 0_u64, data, 0); - - let res = test_expr.create_verifiable_result(); - assert!(res - .verify(&test_expr.ast, &tampered_test_expr.accessor, &()) - .is_err()); + assert_eq!(res, expected_res); } -fn we_can_query_random_tables_with_multiple_selected_rows_and_given_offset(offset: usize) { - let descr = RandomTestAccessorDescriptor { - min_rows: 1, - max_rows: 20, - min_value: -3, - max_value: 3, - }; +fn test_random_tables_with_given_offset(offset: usize) { + let dist = Uniform::new(-3, 4); let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [ - ("aa", ColumnType::BigInt), - ("ab", ColumnType::VarChar), - ("b", ColumnType::BigInt), - ]; for _ in 0..20 { - // filtering by string value - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let filter_val = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["aa", "ab", "b"], - "ab", - ("s".to_owned() + &filter_val.to_string()[..]).as_str(), - data, + // Generate random table + let n = Uniform::new(1, 21).sample(&mut rng); + let data = owned_table([ + bigint("a", dist.sample_iter(&mut rng).take(n)), + varchar( + "b", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + bigint("c", dist.sample_iter(&mut rng).take(n)), + varchar( + "d", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + ]); + + // Generate random values to filter by + let filter_val = format!("s{}", dist.sample(&mut rng)); + + // Create and verify proof + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), offset, + (), ); - let res = test_expr.verify_expr(); - let expected_res = test_expr.query_table(); - assert_eq!(res, expected_res); - - // filtering by integer value - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let filter_val = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let test_expr = create_test_col_lit_equals_expr( - "sxt.t", - &["aa", "ab", "b"], - "b", - filter_val, - data, - offset, + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + equal( + column(t, "b", &accessor), + const_varchar(filter_val.as_str()), + ), ); - let res = test_expr.verify_expr(); - let expected_res = test_expr.query_table(); - assert_eq!(res, expected_res); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + // Calculate/compare expected result + let (expected_a, expected_d): (Vec<_>, Vec<_>) = multizip(( + data["a"].i64_iter(), + data["b"].string_iter(), + data["c"].i64_iter(), + data["d"].string_iter(), + )) + .filter_map(|(a, b, _c, d)| { + if b == &filter_val { + Some((*a, d.clone())) + } else { + None + } + }) + .multiunzip(); + let expected_result = owned_table([bigint("a", expected_a), varchar("d", expected_d)]); + + assert_eq!(expected_result, res) } } #[test] -fn we_can_query_random_tables_with_a_zero_offset() { - we_can_query_random_tables_with_multiple_selected_rows_and_given_offset(0); +fn we_can_query_random_tables_using_a_zero_offset() { + test_random_tables_with_given_offset(0); } #[test] -fn we_can_query_random_tables_with_a_non_zero_offset() { - we_can_query_random_tables_with_multiple_selected_rows_and_given_offset(121); +fn we_can_query_random_tables_using_a_non_zero_offset() { + test_random_tables_with_given_offset(121); } #[test] @@ -488,10 +404,8 @@ fn we_can_compute_the_correct_output_of_an_equals_expr_using_result_evaluate() { ], ), ]); - - let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); let t = "sxt.t".parse().unwrap(); - accessor.add_table(t, data, 0); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let equals_expr: ProvableExprPlan = equal( column(t, "e", &accessor), const_scalar(Curve25519Scalar::ZERO), diff --git a/crates/proof-of-sql/src/sql/ast/filter_expr_test.rs b/crates/proof-of-sql/src/sql/ast/filter_expr_test.rs index 424354db8..ca19eab7a 100644 --- a/crates/proof-of-sql/src/sql/ast/filter_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/filter_expr_test.rs @@ -2,19 +2,20 @@ use crate::{ base::{ database::{ owned_table_utility::*, ColumnField, ColumnRef, ColumnType, LiteralValue, OwnedTable, - OwnedTableTestAccessor, RecordBatchTestAccessor, TableRef, TestAccessor, + OwnedTableTestAccessor, TableRef, TestAccessor, }, math::decimal::Precision, scalar::Curve25519Scalar, }, proof_primitive::dory::DoryCommitment, - record_batch, sql::{ ast::{ test_utility::*, ColumnExpr, FilterExpr, FilterResultExpr, LiteralExpr, ProvableExprPlan, TableExpr, }, - proof::{ProofExpr, ProverEvaluate, ResultBuilder, VerifiableQueryResult}, + proof::{ + exercise_verification, ProofExpr, ProverEvaluate, ResultBuilder, VerifiableQueryResult, + }, }, }; use arrow::datatypes::{Field, Schema}; @@ -150,24 +151,19 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { #[test] fn we_can_prove_and_get_the_correct_result_from_a_basic_filter() { - let data = record_batch!( - "a" => [1_i64, 4_i64, 5_i64, 2_i64, 5_i64], - "b" => [1_i64, 2, 3, 4, 5], - ); + let data = owned_table([ + bigint("a", [1_i64, 4_i64, 5_i64, 2_i64, 5_i64]), + bigint("b", [1_i64, 2, 3, 4, 5]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let where_clause = equal(column(t, "a", &accessor), const_int128(5)); - let expr = filter(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [3_i64, 5], - ); - assert_eq!(res, expected); + let ast = filter(cols_result(t, &["b"], &accessor), tab(t), where_clause); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [3_i64, 5])]); + assert_eq!(res, expected_res); } #[test] diff --git a/crates/proof-of-sql/src/sql/ast/filter_expr_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/ast/filter_expr_test_dishonest_prover.rs index 86aacd406..b48112436 100644 --- a/crates/proof-of-sql/src/sql/ast/filter_expr_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/ast/filter_expr_test_dishonest_prover.rs @@ -1,11 +1,10 @@ use super::{OstensibleFilterExpr, ProvableExpr}; use crate::{ base::{ - database::{Column, DataAccessor, RecordBatchTestAccessor, TestAccessor}, + database::{owned_table_utility::*, Column, DataAccessor, OwnedTableTestAccessor}, proof::ProofError, scalar::Curve25519Scalar, }, - record_batch, sql::{ ast::test_utility::*, proof::{ @@ -81,18 +80,18 @@ impl ProverEvaluate for DishonestFilterExpr { #[test] fn we_fail_to_verify_a_basic_filter_with_a_dishonest_prover() { - let data = record_batch!( - "a" => [1_i64, 4_i64, 5_i64, 2_i64, 5_i64], - "b" => [1_i64, 2, 3, 4, 5], - ); + let data = owned_table([ + bigint("a", [1_i64, 4_i64, 5_i64, 2_i64, 5_i64]), + bigint("b", [1_i64, 2, 3, 4, 5]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let where_clause = equal(column(t, "a", &accessor), const_int128(5_i128)); - let expr = DishonestFilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); + let ast = DishonestFilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); + let verifiable_res: VerifiableQueryResult = + VerifiableQueryResult::new(&ast, &accessor, &()); assert!(matches!( - res.verify(&expr, &accessor, &()), + verifiable_res.verify(&ast, &accessor, &()), Err(QueryError::ProofError(ProofError::VerificationError(_))) )); } diff --git a/crates/proof-of-sql/src/sql/ast/inequality_expr_test.rs b/crates/proof-of-sql/src/sql/ast/inequality_expr_test.rs index 0d2eb25c2..5095d021a 100644 --- a/crates/proof-of-sql/src/sql/ast/inequality_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/inequality_expr_test.rs @@ -1,32 +1,30 @@ -use super::{prover_evaluate_equals_zero, prover_evaluate_or, FilterExpr, ProvableExpr}; use crate::{ base::{ bit::BitDistribution, commitment::InnerProductProof, database::{ - make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, OwnedTable, - OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor, - TestAccessor, + owned_table_utility::*, Column, OwnedTable, OwnedTableTestAccessor, TestAccessor, }, math::decimal::scale_scalar, proof::{MessageLabel, TranscriptProtocol}, scalar::{Curve25519Scalar, Scalar}, }, - record_batch, sql::{ - ast::{test_expr::TestExprNode, test_utility::*, ProvableExprPlan}, + ast::{ + prover_evaluate_equals_zero, prover_evaluate_or, test_utility::*, ProvableExpr, + ProvableExprPlan, + }, parse::ConversionError, proof::{ - make_transcript, Indexes, ProofBuilder, ProofExpr, QueryProof, ResultBuilder, - VerifiableQueryResult, + exercise_verification, make_transcript, Indexes, ProofBuilder, ProofExpr, QueryProof, + ResultBuilder, VerifiableQueryResult, }, }, }; -use arrow::record_batch::RecordBatch; use bumpalo::Bump; use curve25519_dalek::RistrettoPoint; +use itertools::{multizip, MultiUnzip}; use num_traits::Zero; -use polars::prelude::*; use rand::{ distributions::{Distribution, Uniform}, rngs::StdRng, @@ -35,87 +33,71 @@ use rand_core::SeedableRng; #[test] fn we_can_compare_a_constant_column() { - let data = record_batch!( - "a" => [123_i64, 123, 123], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [123_i64, 123, 123]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr: ProvableExprPlan = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => &[] as &[i64], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(5)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [0; 0])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_varying_column_with_constant_sign() { - let data = record_batch!( - "a" => [123_i64, 567, 8], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [123_i64, 567, 8]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr: ProvableExprPlan = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => &[] as &[i64], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(5)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [0; 0])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_columns_with_extreme_values() { - let data = record_batch!( - "bigint_a" => [i64::MAX, i64::MIN, i64::MAX], - "bigint_b" => [i64::MAX, i64::MAX, i64::MIN], - "int128_a" => [i128::MIN, i128::MAX, i128::MAX], - "int128_b" => [i128::MAX, i128::MIN, i128::MAX], - "boolean" => [true, false, true], - ); + let data = owned_table([ + bigint("bigint_a", [i64::MAX, i64::MIN, i64::MAX]), + bigint("bigint_b", [i64::MAX, i64::MAX, i64::MIN]), + int128("int128_a", [i128::MIN, i128::MAX, i128::MAX]), + int128("int128_b", [i128::MAX, i128::MIN, i128::MAX]), + boolean("boolean", [true, false, true]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let llhs_expr: ProvableExprPlan = column(t, "bigint_a", &accessor); - let lrhs_expr = column(t, "bigint_b", &accessor); - let rlhs_expr: ProvableExprPlan = column(t, "int128_a", &accessor); - let rrhs_expr = column(t, "int128_b", &accessor); - let bool_expr = column(t, "boolean", &accessor); - let where_clause = lte( - lte(lte(llhs_expr, lrhs_expr), gte(rlhs_expr, rrhs_expr)), - bool_expr, - ); - let expr = FilterExpr::new( - cols_result(t, &["bigint_b"], &accessor), + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["bigint_b"], &accessor), tab(t), - where_clause, - ); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "bigint_b" => [i64::MAX, i64::MIN], + lte( + lte( + lte( + column(t, "bigint_a", &accessor), + column(t, "bigint_b", &accessor), + ), + gte( + column(t, "int128_a", &accessor), + column(t, "int128_b", &accessor), + ), + ), + column(t, "boolean", &accessor), + ), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("bigint_b", [i64::MAX, i64::MIN])]); + assert_eq!(res, expected_res); } #[test] @@ -128,26 +110,22 @@ fn we_can_compare_columns_with_small_decimal_values_without_scale() { varchar("d", ["abc", "de"]), decimal75("e", 38, 0, [scalar_pos, scalar_neg]), ]); - - let mut accessor = RecordBatchTestAccessor::new_empty(); let t = "sxt.t".parse().unwrap(); - let batch = data.try_into().unwrap(); - accessor.add_table(t, batch, 0); - let lte_expr = lte( - column(t, "e", &accessor), - const_scalar::(0_i64), + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d", "e"], &accessor), + tab(t), + lte(column(t, "e", &accessor), const_bigint(0_i64)), ); - let df_filter = polars::prelude::col("e").lt_eq(lit(0)); - let test_expr = TestExprNode::new(t, &["a", "d", "e"], lte_expr, df_filter, accessor); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [25]), varchar("d", ["de"]), decimal75("e", 38, 0, [scalar_neg]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -161,27 +139,23 @@ fn we_can_compare_columns_with_small_decimal_values_with_scale() { decimal75("e", 38, 0, [scalar_pos, scalar_neg]), decimal75("f", 38, 38, [scalar_neg, scalar_pos]), ]); - - let mut accessor = RecordBatchTestAccessor::new_empty(); let t = "sxt.t".parse().unwrap(); - let batch = data.try_into().unwrap(); - accessor.add_table(t, batch, 0); - let lte_expr = lte( - column(t, "f", &accessor), - const_scalar::(0_i64), + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d", "e", "f"], &accessor), + tab(t), + lte(column(t, "f", &accessor), const_bigint(0_i64)), ); - let df_filter = polars::prelude::col("e").lt_eq(lit(0)); - let test_expr = TestExprNode::new(t, &["a", "d", "e", "f"], lte_expr, df_filter, accessor); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [123]), varchar("d", ["abc"]), decimal75("e", 38, 0, [scalar_pos]), decimal75("f", 38, 38, [scalar_neg]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -198,26 +172,22 @@ fn we_can_compare_columns_returning_extreme_decimal_values() { [Curve25519Scalar::MAX_SIGNED, scalar_min_signed], ), ]); - - let mut accessor = RecordBatchTestAccessor::new_empty(); let t = "sxt.t".parse().unwrap(); - let batch = data.try_into().unwrap(); - accessor.add_table(t, batch, 0); - let lte_expr = lte( - column(t, "b", &accessor), - const_scalar::(0_i64), + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d", "e"], &accessor), + tab(t), + lte(column(t, "b", &accessor), const_bigint(0_i64)), ); - let df_filter = polars::prelude::col("b").eq(lit(0_i64)); - let test_expr = TestExprNode::new(t, &["a", "d", "e"], lte_expr, df_filter, accessor); - let res = test_expr.verify_expr(); - - let expected_res: OwnedTable = owned_table([ + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ bigint("a", [25]), varchar("d", ["de"]), decimal75("e", 75, 0, [scalar_min_signed]), ]); - - assert_eq!(res, expected_res.try_into().unwrap()); + assert_eq!(res, expected_res); } #[test] @@ -234,11 +204,8 @@ fn we_cannot_compare_columns_filtering_on_extreme_decimal_values() { [Curve25519Scalar::MAX_SIGNED, scalar_min_signed], ), ]); - - let mut accessor = RecordBatchTestAccessor::new_empty(); let t = "sxt.t".parse().unwrap(); - let batch = data.try_into().unwrap(); - accessor.add_table(t, batch, 0); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); assert!(matches!( ProvableExprPlan::try_new_inequality( column(t, "e", &accessor), @@ -251,235 +218,185 @@ fn we_cannot_compare_columns_filtering_on_extreme_decimal_values() { #[test] fn we_can_compare_two_columns() { - let data = record_batch!( - "a" => [1_i64, 5, 8], - "b" => [1_i64, 7, 3], - ); + let data = owned_table([bigint("a", [1_i64, 5, 8]), bigint("b", [1_i64, 7, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let lhs_expr: ProvableExprPlan = column(t, "a", &accessor); - let rhs_expr = column(t, "b", &accessor); - let where_clause = lte(lhs_expr, rhs_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 7], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), column(t, "b", &accessor)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 7])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_varying_column_with_constant_absolute_value() { - let data = record_batch!( - "a" => [-123_i64, 123, -123], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([ + bigint("a", [-123_i64, 123, -123]), + bigint("b", [1_i64, 2, 3]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr: ProvableExprPlan = column(t, "a", &accessor); - let lit_expr = const_bigint(0); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 3], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(0)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 3])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_constant_column_of_negative_columns() { - let data = record_batch!( - "a" => [-123_i64, -123, -123], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([ + bigint("a", [-123_i64, -123, -123]), + bigint("b", [1_i64, 2, 3]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr: ProvableExprPlan = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 2, 3], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(5)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 2, 3])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_varying_column_with_negative_only_signs() { - let data = record_batch!( - "a" => [-123_i64, -133, -823], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([ + bigint("a", [-123_i64, -133, -823]), + bigint("b", [1_i64, 2, 3]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr: ProvableExprPlan = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 2, 3], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(5)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 2, 3])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_column_with_varying_absolute_values_and_signs() { - let data = record_batch!( - "a" => [-1_i64, 9, 0], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [-1_i64, 9, 0]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr: ProvableExprPlan = column(t, "a", &accessor); - let lit_expr = const_bigint(1); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 3], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(1)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 3])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_column_with_greater_than_or_equal() { - let data = record_batch!( - "a" => [-1_i64, 9, 0], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [-1_i64, 9, 0]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(1); - let where_clause = gte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [2_i64], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + gte(column(t, "a", &accessor), const_bigint(1)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [2_i64])]); + assert_eq!(res, expected_res); } #[test] fn we_can_run_nested_comparison() { - let data = record_batch!( - "a" => [0_i64, 2, 4], - "b" => [1_i64, 2, 3], - "boolean" => [false, false, true], - ); + let data = owned_table([ + bigint("a", [0_i64, 2, 4]), + bigint("b", [1_i64, 2, 3]), + boolean("boolean", [false, false, true]), + ]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let lhs_expr = column(t, "a", &accessor); - let rhs_expr = column(t, "b", &accessor); - let bool_expr = column(t, "boolean", &accessor); - let where_clause = equal(gte(lhs_expr, rhs_expr), bool_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 3_i64], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + equal( + gte(column(t, "a", &accessor), column(t, "b", &accessor)), + column(t, "boolean", &accessor), + ), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 3])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_column_with_varying_absolute_values_and_signs_and_a_constant_bit() { - let data = record_batch!( - "a" => [-2_i64, 3, 2], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [-2_i64, 3, 2]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(0); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(0)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compare_a_constant_column_of_zeros() { - let data = record_batch!( - "a" => [0_i64, 0, 0], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [0_i64, 0, 0]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(0); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - let res = res - .verify(&expr, &accessor, &()) - .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 2, 3], + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(0)), ); - assert_eq!(res, expected); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("b", [1_i64, 2, 3])]); + assert_eq!(res, expected_res); } #[test] fn the_sign_can_be_0_or_1_for_a_constant_column_of_zeros() { - let data = record_batch!( - "a" => [0_i64, 0, 0], - "b" => [1_i64, 2, 3], - ); + let data = owned_table([bigint("a", [0_i64, 0, 0]), bigint("b", [1_i64, 2, 3])]); let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(0); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = filter( + cols_result(t, &["b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(0)), + ); + let table_length = ast.get_length(&accessor); + let generator_offset = ast.get_offset(&accessor); let alloc = Bump::new(); let mut result_builder = ResultBuilder::new(3); @@ -488,10 +405,7 @@ fn the_sign_can_be_0_or_1_for_a_constant_column_of_zeros() { result_cols[0].result_evaluate(&mut result_builder, &accessor); let provable_result = result_builder.make_provable_query_result(); - let table_length = expr.get_length(&accessor); - let generator_offset = expr.get_offset(&accessor); - - let mut transcript = make_transcript(&expr, &provable_result, table_length, generator_offset); + let mut transcript = make_transcript(&ast, &provable_result, table_length, generator_offset); transcript.challenge_scalars::(&mut [], MessageLabel::PostResultChallenges); let mut builder = ProofBuilder::new(3, 2, Vec::new()); @@ -515,168 +429,72 @@ fn the_sign_can_be_0_or_1_for_a_constant_column_of_zeros() { let proof = QueryProof::::new_from_builder(builder, 0, transcript, &()); let res = proof - .verify(&expr, &accessor, &provable_result, &()) + .verify(&ast, &accessor, &provable_result, &()) .unwrap() - .into_record_batch(); - let expected = record_batch!( - "b" => [1_i64, 2, 3], - ); + .table; + let expected = owned_table([bigint("b", [1_i64, 2, 3])]); assert_eq!(res, expected); } -#[test] -fn verification_fails_if_commitments_dont_match_for_a_constant_column() { - let data = record_batch!( - "a" => [123_i64, 123, 123], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - - let data = record_batch!( - "a" => [321_i64, 321, 321], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - assert!(res.verify(&expr, &accessor, &()).is_err()); -} - -#[test] -fn verification_fails_if_commitments_dont_match_for_a_constant_absolute_column() { - let data = record_batch!( - "a" => [-123_i64, 123, -123], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(0); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - - let data = record_batch!( - "a" => [-321_i64, 321, -321], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(0); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - assert!(res.verify(&expr, &accessor, &()).is_err()); -} - -#[test] -fn verification_fails_if_commitments_dont_match_for_a_constant_sign_column() { - let data = record_batch!( - "a" => [193_i64, 323, 421], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - - let data = record_batch!( - "a" => [321_i64, 321, 321], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - assert!(res.verify(&expr, &accessor, &()).is_err()); +fn test_random_tables_with_given_offset(offset: usize) { + let dist = Uniform::new(-3, 4); + let mut rng = StdRng::from_seed([0u8; 32]); + for _ in 0..20 { + // Generate random table + let n = Uniform::new(1, 21).sample(&mut rng); + let data = owned_table([ + bigint("a", dist.sample_iter(&mut rng).take(n)), + varchar( + "b", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + ]); + + // Generate random values to filter by + let filter_val = dist.sample(&mut rng); + + // Create and verify proof + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), + offset, + (), + ); + let ast = dense_filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + lte(column(t, "a", &accessor), const_bigint(filter_val)), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + // Calculate/compare expected result + let (expected_a, expected_b): (Vec<_>, Vec<_>) = + multizip((data["a"].i64_iter(), data["b"].string_iter())) + .filter_map(|(a, b)| { + if a <= &filter_val { + Some((*a, b.clone())) + } else { + None + } + }) + .multiunzip(); + let expected_result = owned_table([bigint("a", expected_a), varchar("b", expected_b)]); + + assert_eq!(expected_result, res) + } } #[test] -fn verification_fails_if_commitments_dont_match() { - let data = record_batch!( - "a" => [-523_i64, 923, 823], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - let res = VerifiableQueryResult::::new(&expr, &accessor, &()); - - let data = record_batch!( - "a" => [-523_i64, 923, 83], - "b" => [1_i64, 2, 3], - ); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, data, 0); - let col_expr = column(t, "a", &accessor); - let lit_expr = const_bigint(5); - let where_clause = lte(col_expr, lit_expr); - let expr = FilterExpr::new(cols_result(t, &["b"], &accessor), tab(t), where_clause); - assert!(res.verify(&expr, &accessor, &()).is_err()); -} - -fn create_test_lte_expr + Copy + Literal>( - table_ref: &str, - result_col: &str, - filter_col: &str, - filter_val: T, - data: RecordBatch, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, 0); - let col_expr = column(t, filter_col, &accessor); - let lit_expr = const_scalar(filter_val.into()); - let where_clause = lte(col_expr, lit_expr); - - let df_filter = polars::prelude::col(filter_col).lt_eq(lit(filter_val)); - TestExprNode::new(t, &[result_col], where_clause, df_filter, accessor) +fn we_can_query_random_tables_using_a_zero_offset() { + test_random_tables_with_given_offset(0); } #[test] -fn we_can_query_random_data_of_varying_size() { - let descr = RandomTestAccessorDescriptor { - min_rows: 1, - max_rows: 20, - min_value: -3, - max_value: 3, - }; - let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [("a", ColumnType::BigInt), ("b", ColumnType::BigInt)]; - for _ in 0..10 { - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let filter_val = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let test_expr = create_test_lte_expr("sxt.t", "b", "a", filter_val, data); - let res = test_expr.verify_expr(); - let expected = test_expr.query_table(); - assert_eq!(res, expected); - } +fn we_can_query_random_tables_using_a_non_zero_offset() { + test_random_tables_with_given_offset(5121); } #[test] diff --git a/crates/proof-of-sql/src/sql/ast/literal_expr_test.rs b/crates/proof-of-sql/src/sql/ast/literal_expr_test.rs index 2fd613be4..6878fa00f 100644 --- a/crates/proof-of-sql/src/sql/ast/literal_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/literal_expr_test.rs @@ -1,82 +1,126 @@ use super::{ProvableExpr, ProvableExprPlan}; use crate::{ - base::database::{ - make_random_test_accessor_data, Column, ColumnType, RandomTestAccessorDescriptor, - RecordBatchTestAccessor, TestAccessor, UnimplementedTestAccessor, + base::{ + commitment::InnerProductProof, + database::{owned_table_utility::*, Column, OwnedTableTestAccessor}, + }, + sql::{ + ast::test_utility::*, + proof::{exercise_verification, VerifiableQueryResult}, }, - record_batch, - sql::ast::{test_expr::TestExprNode, test_utility::const_bool}, }; -use arrow::record_batch::RecordBatch; use bumpalo::Bump; use curve25519_dalek::ristretto::RistrettoPoint; -use polars::prelude::*; -use rand::rngs::StdRng; +use rand::{ + distributions::{Distribution, Uniform}, + rngs::StdRng, +}; use rand_core::SeedableRng; -fn create_test_const_bool_expr( - table_ref: &str, - results: &[&str], - filter_val: bool, - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let table_ref = table_ref.parse().unwrap(); - accessor.add_table(table_ref, data, offset); - let df_filter = lit(filter_val); - let const_expr = const_bool(filter_val); - TestExprNode::new(table_ref, results, const_expr, df_filter, accessor) -} - -fn test_random_tables_with_given_constant(value: bool) { - let descr = RandomTestAccessorDescriptor { - min_rows: 1, - max_rows: 20, - min_value: -3, - max_value: 3, - }; +fn test_random_tables_with_given_offset(offset: usize) { + let dist = Uniform::new(-3, 4); let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [("a", ColumnType::BigInt), ("b", ColumnType::VarChar)]; - for _ in 0..10 { - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let test_expr = create_test_const_bool_expr("sxt.t", &["a", "b"], value, data, 0); - let res = test_expr.verify_expr(); - let expected_res = test_expr.query_table(); - assert_eq!(res, expected_res); + for _ in 0..20 { + // Generate random table + let n = Uniform::new(1, 21).sample(&mut rng); + let data = owned_table([ + boolean("a", dist.sample_iter(&mut rng).take(n).map(|v| v < 0)), + varchar( + "b", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + bigint("c", dist.sample_iter(&mut rng).take(n)), + ]); + + // Generate random values to filter by + let lit = dist.sample(&mut rng) < 0; + + // Create and verify proof + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), + offset, + (), + ); + let ast = dense_filter( + cols_expr_plan(t, &["a", "b", "c"], &accessor), + tab(t), + const_bool(lit), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + // Calculate/compare expected result + let (expected_a, expected_b, expected_c): (Vec, Vec, Vec) = if lit { + ( + data["a"].bool_iter().cloned().collect(), + data["b"].string_iter().cloned().collect(), + data["c"].i64_iter().cloned().collect(), + ) + } else { + (vec![], vec![], vec![]) + }; + let expected_result = owned_table([ + boolean("a", expected_a), + varchar("b", expected_b), + bigint("c", expected_c), + ]); + + assert_eq!(expected_result, res) } } #[test] -fn we_can_prove_a_query_with_a_single_selected_row() { - let data = record_batch!("a" => [123_i64]); - let test_expr = create_test_const_bool_expr("sxt.t", &["a"], true, data.clone(), 0); - let res = test_expr.verify_expr(); - assert_eq!(res, data); +fn we_can_query_random_tables_using_a_zero_offset() { + test_random_tables_with_given_offset(0); } #[test] -fn we_can_prove_a_query_with_a_single_non_selected_row() { - let data = record_batch!("a" => [123_i64]); - let test_expr = create_test_const_bool_expr("sxt.t", &["a"], false, data, 0); - let res = test_expr.verify_expr(); - let expected_res = record_batch!("a" => Vec::::new()); - assert_eq!(res, expected_res); +fn we_can_query_random_tables_using_a_non_zero_offset() { + test_random_tables_with_given_offset(5121); } #[test] -fn we_can_select_from_tables_with_an_always_true_where_clause() { - test_random_tables_with_given_constant(true); +fn we_can_prove_a_query_with_a_single_selected_row() { + let data = owned_table([bigint("a", [123_i64])]); + let expected_res = data.clone(); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(true), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + assert_eq!(res, expected_res); } #[test] -fn we_can_select_from_tables_with_an_always_false_where_clause() { - test_random_tables_with_given_constant(false); +fn we_can_prove_a_query_with_a_single_non_selected_row() { + let data = owned_table([bigint("a", [123_i64])]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a"], &accessor), + tab(t), + const_bool(false), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [1_i64; 0])]); + assert_eq!(res, expected_res); } #[test] fn we_can_compute_the_correct_output_of_a_literal_expr_using_result_evaluate() { - let accessor = UnimplementedTestAccessor::new_empty(); + let data = owned_table([bigint("a", [123_i64, 456, 789, 1011])]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let literal_expr: ProvableExprPlan = const_bool(true); let alloc = Bump::new(); let res = literal_expr.result_evaluate(4, &alloc, &accessor); diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index f7aa68dda..e1010654d 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -92,9 +92,6 @@ mod sign_expr_test; mod table_expr; pub(crate) use table_expr::TableExpr; -#[cfg(all(test, feature = "blitzar"))] -mod test_expr; - #[cfg(test)] pub(crate) mod test_utility; diff --git a/crates/proof-of-sql/src/sql/ast/not_expr_test.rs b/crates/proof-of-sql/src/sql/ast/not_expr_test.rs index 335f84a56..dd2569e47 100644 --- a/crates/proof-of-sql/src/sql/ast/not_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/not_expr_test.rs @@ -1,102 +1,98 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{ - make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, - OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor, - TestAccessor, - }, - scalar::Curve25519Scalar, + database::{owned_table_utility::*, Column, OwnedTableTestAccessor, TestAccessor}, }, - record_batch, - sql::ast::{ - test_expr::TestExprNode, - test_utility::{column, const_int128, const_scalar, equal, not as unot}, - ProvableExpr, ProvableExprPlan, + sql::{ + ast::{test_utility::*, ProvableExpr, ProvableExprPlan}, + proof::{exercise_verification, VerifiableQueryResult}, }, }; -use arrow::record_batch::RecordBatch; use bumpalo::Bump; use curve25519_dalek::ristretto::RistrettoPoint; -use polars::prelude::*; +use itertools::{multizip, MultiUnzip}; use rand::{ distributions::{Distribution, Uniform}, rngs::StdRng, }; use rand_core::SeedableRng; -fn create_test_not_expr + Copy + Literal>( - table_ref: &str, - results: &[&str], - filter_col: &str, - filter_val: T, - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, offset); - let df_filter = polars::prelude::col(filter_col).neq(lit(filter_val)); - let not_expr = unot(equal( - column(t, filter_col, &accessor), - const_scalar(filter_val.into()), - )); - TestExprNode::new(t, results, not_expr, df_filter, accessor) -} - #[test] fn we_can_prove_a_not_equals_query_with_a_single_selected_row() { - let data = record_batch!( - "a" => [123_i64, 456], - "b" => [0_i64, 1], - "d" => ["alfa", "gama"] - ); - let test_expr = create_test_not_expr("sxt.t", &["a", "d"], "b", 1_i64, data, 0); - let res = test_expr.verify_expr(); - let expected_res = record_batch!( - "a" => [123_i64], - "d" => ["alfa"] + let data = owned_table([ + bigint("a", [123_i64, 456]), + bigint("b", [0_i64, 1]), + varchar("d", ["alfa", "gama"]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + not(equal(column(t, "b", &accessor), const_bigint(1))), ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [123]), varchar("d", ["alfa"])]); assert_eq!(res, expected_res); } fn test_random_tables_with_given_offset(offset: usize) { - let descr = RandomTestAccessorDescriptor { - min_rows: 1, - max_rows: 20, - min_value: -3, - max_value: 3, - }; + let dist = Uniform::new(-3, 4); let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [ - ("aa", ColumnType::BigInt), - ("ab", ColumnType::VarChar), - ("b", ColumnType::BigInt), - ]; for _ in 0..20 { - // filtering by string value - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let filter_val = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let test_expr = create_test_not_expr( - "sxt.t", - &["aa", "ab", "b"], - "ab", - ("s".to_owned() + &filter_val.to_string()[..]).as_str(), - data, + // Generate random table + let n = Uniform::new(1, 21).sample(&mut rng); + let data = owned_table([ + bigint("a", dist.sample_iter(&mut rng).take(n)), + varchar( + "b", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + ]); + + // Generate random values to filter by + let filter_val_a = dist.sample(&mut rng); + let filter_val_b = format!("s{}", dist.sample(&mut rng)); + + // Create and verify proof + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), offset, + (), ); - let res = test_expr.verify_expr(); - let expected_res = test_expr.query_table(); - assert_eq!(res, expected_res); + let ast = dense_filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + not(and( + equal(column(t, "a", &accessor), const_bigint(filter_val_a)), + equal( + column(t, "b", &accessor), + const_scalar(filter_val_b.as_str()), + ), + )), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + // Calculate/compare expected result + let (expected_a, expected_b): (Vec<_>, Vec<_>) = + multizip((data["a"].i64_iter(), data["b"].string_iter())) + .filter_map(|(a, b)| { + if a != &filter_val_a || b != &filter_val_b { + Some((*a, b.clone())) + } else { + None + } + }) + .multiunzip(); + let expected_result = owned_table([bigint("a", expected_a), varchar("b", expected_b)]); - // filtering by integer value - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let filter_val = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let test_expr = - create_test_not_expr("sxt.t", &["aa", "ab", "b"], "b", filter_val, data, offset); - let res = test_expr.verify_expr(); - let expected_res = test_expr.query_table(); - assert_eq!(res, expected_res); + assert_eq!(expected_result, res) } } @@ -121,7 +117,7 @@ fn we_can_compute_the_correct_output_of_a_not_expr_using_result_evaluate() { let t = "sxt.t".parse().unwrap(); accessor.add_table(t, data, 0); let not_expr: ProvableExprPlan = - unot(equal(column(t, "b", &accessor), const_int128(1))); + not(equal(column(t, "b", &accessor), const_int128(1))); let alloc = Bump::new(); let res = not_expr.result_evaluate(2, &alloc, &accessor); let expected_res = Column::Boolean(&[true, false]); diff --git a/crates/proof-of-sql/src/sql/ast/or_expr_test.rs b/crates/proof-of-sql/src/sql/ast/or_expr_test.rs index 5bdb59cde..869f2a667 100644 --- a/crates/proof-of-sql/src/sql/ast/or_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/or_expr_test.rs @@ -1,128 +1,159 @@ use crate::{ base::{ commitment::InnerProductProof, - database::{ - make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, - OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor, - TestAccessor, - }, - scalar::Curve25519Scalar, + database::{owned_table_utility::*, Column, OwnedTableTestAccessor, TestAccessor}, + }, + sql::{ + ast::{test_utility::*, ProvableExpr, ProvableExprPlan}, + proof::{exercise_verification, VerifiableQueryResult}, }, - record_batch, - sql::ast::{test_expr::TestExprNode, test_utility::*, ProvableExpr, ProvableExprPlan}, }; -use arrow::record_batch::RecordBatch; use bumpalo::Bump; use curve25519_dalek::ristretto::RistrettoPoint; -use polars::prelude::*; +use itertools::{multizip, MultiUnzip}; use rand::{ distributions::{Distribution, Uniform}, rngs::StdRng, }; use rand_core::SeedableRng; -fn create_test_or_expr< - T1: Into + Copy + Literal, - T2: Into + Copy + Literal, ->( - table_ref: &str, - results: &[&str], - lhs: (&str, T1), - rhs: (&str, T2), - data: RecordBatch, - offset: usize, -) -> TestExprNode { - let mut accessor = RecordBatchTestAccessor::new_empty(); - let t = table_ref.parse().unwrap(); - accessor.add_table(t, data, offset); - let or_expr = or( - equal(column(t, lhs.0, &accessor), const_scalar(lhs.1.into())), - equal(column(t, rhs.0, &accessor), const_scalar(rhs.1.into())), - ); - let df_filter = polars::prelude::col(lhs.0) - .eq(lit(lhs.1)) - .or(polars::prelude::col(rhs.0).eq(lit(rhs.1))); - TestExprNode::new(t, results, or_expr, df_filter, accessor) -} - #[test] fn we_can_prove_a_simple_or_query() { - let data = record_batch!( - "a" => [1_i64, 2, 3, 4], - "d" => ["ab", "t", "g", "efg"], - "b" => [0_i64, 1, 0, 2], - ); - let test_expr = create_test_or_expr("sxt.t", &["a"], ("b", 1), ("d", "efgh"), data, 0); - let res = test_expr.verify_expr(); - let expected_res = record_batch!( - "a" => [2_i64], + let data = owned_table([ + bigint("a", [1_i64, 2, 3, 4]), + varchar("d", ["ab", "t", "g", "efg"]), + bigint("b", [0_i64, 1, 0, 2]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + or( + equal(column(t, "b", &accessor), const_bigint(1)), + equal(column(t, "d", &accessor), const_varchar("g")), + ), ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [2_i64, 3]), varchar("d", ["t", "g"])]); assert_eq!(res, expected_res); } #[test] -fn we_can_prove_a_simple_or_query_with_i128_data() { - let data = record_batch!( - "a" => [1_i128, 2, 3, 4], - "d" => ["ab", "t", "g", "efg"], - "b" => [0_i128, 1, 0, 2], - ); - let test_expr = create_test_or_expr("sxt.t", &["a"], ("b", 1), ("d", "efgh"), data, 0); - let res = test_expr.verify_expr(); - let expected_res = record_batch!( - "a" => [2_i128], +fn we_can_prove_a_simple_or_query_with_variable_integer_types() { + let data = owned_table([ + int128("a", [1_i128, 2, 3, 4]), + varchar("d", ["ab", "t", "g", "efg"]), + smallint("b", [0_i16, 1, 0, 2]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + or( + equal(column(t, "b", &accessor), const_bigint(1)), + equal(column(t, "d", &accessor), const_varchar("g")), + ), ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([int128("a", [2_i64, 3]), varchar("d", ["t", "g"])]); assert_eq!(res, expected_res); } #[test] fn we_can_prove_an_or_query_where_both_lhs_and_rhs_are_true() { - let data = record_batch!( - "a" => [1_i64, 2, 3, 4], - "b" => [0_i64, 1, 0, 1], - "c" => [0_i64, 2, 2, 0], - "d" => ["ab", "t", "g", "efg"], - ); - let test_expr = create_test_or_expr("sxt.t", &["d"], ("b", 1), ("d", "g"), data, 0); - let res = test_expr.verify_expr(); - let expected_res = record_batch!( - "d" => ["t", "g", "efg"], + let data = owned_table([ + bigint("a", [1_i64, 2, 3, 4]), + int128("b", [0_i128, 1, 1, 1]), + int("c", [0_i32, 2, 2, 0]), + varchar("d", ["ab", "t", "g", "efg"]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + or( + equal(column(t, "b", &accessor), const_bigint(1)), + equal(column(t, "d", &accessor), const_varchar("g")), + ), ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [2_i64, 3, 4]), varchar("d", ["t", "g", "efg"])]); assert_eq!(res, expected_res); } fn test_random_tables_with_given_offset(offset: usize) { - let descr = RandomTestAccessorDescriptor { - min_rows: 1, - max_rows: 20, - min_value: -3, - max_value: 3, - }; + let dist = Uniform::new(-3, 4); let mut rng = StdRng::from_seed([0u8; 32]); - let cols = [ - ("a", ColumnType::BigInt), - ("b", ColumnType::VarChar), - ("c", ColumnType::BigInt), - ("d", ColumnType::VarChar), - ]; for _ in 0..20 { - let data = make_random_test_accessor_data(&mut rng, &cols, &descr); - let filter_val1 = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let filter_val2 = Uniform::new(descr.min_value, descr.max_value + 1).sample(&mut rng); - let test_expr = create_test_or_expr( - "sxt.t", - &["a", "d"], - ( + // Generate random table + let n = Uniform::new(1, 21).sample(&mut rng); + let data = owned_table([ + bigint("a", dist.sample_iter(&mut rng).take(n)), + varchar( "b", - ("s".to_owned() + &filter_val1.to_string()[..]).as_str(), + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + bigint("c", dist.sample_iter(&mut rng).take(n)), + varchar( + "d", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), ), - ("c", filter_val2), - data, + ]); + + // Generate random values to filter by + let filter_val1 = format!("s{}", dist.sample(&mut rng)); + let filter_val2 = dist.sample(&mut rng); + + // Create and verify proof + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), offset, + (), + ); + let ast = dense_filter( + cols_expr_plan(t, &["a", "d"], &accessor), + tab(t), + or( + equal( + column(t, "b", &accessor), + const_varchar(filter_val1.as_str()), + ), + equal(column(t, "c", &accessor), const_bigint(filter_val2)), + ), ); - let res = test_expr.verify_expr(); - let expected_res = test_expr.query_table(); - assert_eq!(res, expected_res); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + // Calculate/compare expected result + let (expected_a, expected_d): (Vec<_>, Vec<_>) = multizip(( + data["a"].i64_iter(), + data["b"].string_iter(), + data["c"].i64_iter(), + data["d"].string_iter(), + )) + .filter_map(|(a, b, c, d)| { + if b == &filter_val1 || c == &filter_val2 { + Some((*a, d.clone())) + } else { + None + } + }) + .multiunzip(); + let expected_result = owned_table([bigint("a", expected_a), varchar("d", expected_d)]); + + assert_eq!(expected_result, res) } } diff --git a/crates/proof-of-sql/src/sql/ast/sign_expr_test.rs b/crates/proof-of-sql/src/sql/ast/sign_expr_test.rs index 3cb6c4e6a..ae1f3b3b1 100644 --- a/crates/proof-of-sql/src/sql/ast/sign_expr_test.rs +++ b/crates/proof-of-sql/src/sql/ast/sign_expr_test.rs @@ -1,12 +1,6 @@ use super::{count_sign, prover_evaluate_sign, verifier_evaluate_sign}; use crate::{ - base::{ - bit::BitDistribution, - database::{RecordBatchTestAccessor, TestAccessor}, - polynomial::MultilinearExtension, - scalar::Curve25519Scalar, - }, - record_batch, + base::{bit::BitDistribution, polynomial::MultilinearExtension, scalar::Curve25519Scalar}, sql::{ ast::result_evaluate_sign, proof::{ @@ -23,10 +17,6 @@ use num_traits::Zero; fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() { let data = [123_i64, 123, 123]; let dist = BitDistribution::new::(&data); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, record_batch!("a" => data), 0); - let alloc = Bump::new(); let data: Vec = data.into_iter().map(Curve25519Scalar::from).collect(); let mut builder = ProofBuilder::new(3, 2, Vec::new()); @@ -39,10 +29,6 @@ fn prover_evaluation_generates_the_bit_distribution_of_a_constant_column() { fn prover_evaluation_generates_the_bit_distribution_of_a_negative_constant_column() { let data = [-123_i64, -123, -123]; let dist = BitDistribution::new::(&data); - let t = "sxt.t".parse().unwrap(); - let mut accessor = RecordBatchTestAccessor::new_empty(); - accessor.add_table(t, record_batch!("a" => data), 0); - let alloc = Bump::new(); let data: Vec = data.into_iter().map(Curve25519Scalar::from).collect(); let mut builder = ProofBuilder::new(3, 2, Vec::new()); diff --git a/crates/proof-of-sql/src/sql/ast/test_expr.rs b/crates/proof-of-sql/src/sql/ast/test_expr.rs deleted file mode 100644 index 3a8334566..000000000 --- a/crates/proof-of-sql/src/sql/ast/test_expr.rs +++ /dev/null @@ -1,71 +0,0 @@ -use super::{FilterExpr, ProvableExprPlan}; -use crate::{ - base::database::{RecordBatchTestAccessor, TableRef}, - sql::{ - ast::test_utility::{cols_result, tab}, - proof::{exercise_verification, VerifiableQueryResult}, - }, -}; -use arrow::record_batch::RecordBatch; -use blitzar::proof::InnerProductProof; -use curve25519_dalek::RistrettoPoint; -use polars::prelude::{Expr, *}; - -pub struct TestExprNode { - pub table_ref: TableRef, - pub results: Vec, - pub ast: FilterExpr, - pub accessor: RecordBatchTestAccessor, - pub df_filter: Expr, -} - -impl TestExprNode { - pub fn new( - table_ref: TableRef, - results: &[&str], - filter_expr: ProvableExprPlan, - df_filter: Expr, - accessor: RecordBatchTestAccessor, - ) -> Self { - let polar_results = results - .iter() - .map(|v| polars::prelude::col(v)) - .collect::>(); - let ast = FilterExpr::new( - cols_result(table_ref, results, &accessor), - tab(table_ref), - filter_expr, - ); - - Self { - table_ref, - df_filter, - results: polar_results, - ast, - accessor, - } - } - - pub fn create_verifiable_result(&self) -> VerifiableQueryResult { - VerifiableQueryResult::new(&self.ast, &self.accessor, &()) - } - - pub fn verify_expr(&self) -> RecordBatch { - let res = VerifiableQueryResult::new(&self.ast, &self.accessor, &()); - exercise_verification(&res, &self.ast, &self.accessor, self.table_ref); - res.verify(&self.ast, &self.accessor, &()) - .unwrap() - .into_record_batch() - } - - pub fn query_table(&self) -> RecordBatch { - self.accessor.query_table(self.table_ref, |df| { - df.clone() - .lazy() - .filter(self.df_filter.clone()) - .select(&self.results[..]) - .collect() - .unwrap() - }) - } -} diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index d160a6ea9..cafe6642b 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -112,6 +112,8 @@ pub fn const_varchar(val: &str) -> ProvableExprPlan { ))) } +/// Create a constant scalar value. Used if we don't want to specify column types. +#[allow(dead_code)] pub fn const_scalar>(val: T) -> ProvableExprPlan { ProvableExprPlan::new_literal(LiteralValue::Scalar(val.into())) } diff --git a/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs index 0dda9bb26..89b4755f9 100644 --- a/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/result_expr_builder.rs @@ -36,13 +36,10 @@ impl ResultExprBuilder { .iter() .map(|aliased_expr| Expression::Column(aliased_expr.alias)) .collect(); - self.composition - .add(Box::new(SelectExpr::new_from_expressions(&exprs))); + self.composition.add(Box::new(SelectExpr::new(&exprs))); } else { self.composition - .add(Box::new(SelectExpr::new_from_aliased_result_exprs( - aliased_exprs, - ))); + .add(Box::new(SelectExpr::new(aliased_exprs))); } self } diff --git a/crates/proof-of-sql/src/sql/transform/group_by_expr.rs b/crates/proof-of-sql/src/sql/transform/group_by_expr.rs index 99ee5ad1a..87a541f70 100644 --- a/crates/proof-of-sql/src/sql/transform/group_by_expr.rs +++ b/crates/proof-of-sql/src/sql/transform/group_by_expr.rs @@ -1,8 +1,10 @@ #[allow(deprecated)] +#[cfg(feature = "polars")] use super::DataFrameExpr; -use super::ToPolarsExpr; -use crate::base::database::{INT128_PRECISION, INT128_SCALE}; +#[cfg(feature = "polars")] +use super::{ToPolarsExpr, INT128_PRECISION, INT128_SCALE}; use dyn_partial_eq::DynPartialEq; +#[cfg(feature = "polars")] use polars::prelude::{col, DataType, Expr, GetOutput, LazyFrame, NamedFrom, Series}; use proof_of_sql_parser::{intermediate_ast::AliasedResultExpr, Identifier}; use serde::{Deserialize, Serialize}; @@ -11,32 +13,43 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] pub struct GroupByExpr { /// A list of aggregation column expressions - agg_exprs: Vec, + aliased_exprs: Vec, /// A list of group by column expressions - by_exprs: Vec, + by_ids: Vec, } impl GroupByExpr { /// Create a new group by expression containing the group by and aggregation expressions pub fn new(by_ids: &[Identifier], aliased_exprs: &[AliasedResultExpr]) -> Self { - let by_exprs = Vec::from_iter(by_ids.iter().map(|id| col(id.as_str()))); - let agg_exprs = Vec::from_iter(aliased_exprs.iter().map(ToPolarsExpr::to_polars_expr)); - assert!(!agg_exprs.is_empty(), "Agg expressions must not be empty"); assert!( - !by_exprs.is_empty(), - "Group by expressions must not be empty" + !aliased_exprs.is_empty(), + "Agg expressions must not be empty" ); + assert!(!by_ids.is_empty(), "Group by expressions must not be empty"); Self { - by_exprs, - agg_exprs, + by_ids: by_ids.to_vec(), + aliased_exprs: aliased_exprs.to_vec(), } } + + #[cfg(feature = "polars")] + fn agg_exprs(&self) -> Vec { + self.aliased_exprs + .iter() + .map(ToPolarsExpr::to_polars_expr) + .collect() + } } +#[cfg(not(feature = "polars"))] +#[typetag::serde] +impl super::RecordBatchExpr for GroupByExpr {} +#[cfg(feature = "polars")] super::impl_record_batch_expr_for_data_frame_expr!(GroupByExpr); #[allow(deprecated)] +#[cfg(feature = "polars")] impl DataFrameExpr for GroupByExpr { fn lazy_transformation(&self, lazy_frame: LazyFrame, num_input_rows: usize) -> LazyFrame { // TODO: polars currently lacks support for min/max aggregation in data frames @@ -44,23 +57,23 @@ impl DataFrameExpr for GroupByExpr { // We remove the group by clause to temporarily work around this limitation. // Issue created to track progress: https://github.com/pola-rs/polars/issues/11232 if num_input_rows == 0 { - return lazy_frame.select(&self.agg_exprs).limit(0); + return lazy_frame.select(&self.agg_exprs()).limit(0); } if num_input_rows == 1 { - return lazy_frame.select(&self.agg_exprs); + return lazy_frame.select(&self.agg_exprs()); } // Add invalid column aliases to group by expressions so that we can // exclude them from the final result. - let by_expr_aliases = (0..self.by_exprs.len()) + let by_expr_aliases = (0..self.by_ids.len()) .map(|pos| "#$".to_owned() + pos.to_string().as_str()) .collect::>(); let by_exprs: Vec<_> = self - .by_exprs - .clone() - .into_iter() + .by_ids + .iter() + .map(|id| col(id.as_str())) .zip(by_expr_aliases.iter()) .map(|(expr, alias)| expr.alias(alias.as_str())) // TODO: remove this mapping once Polars supports decimal columns inside group by @@ -72,11 +85,12 @@ impl DataFrameExpr for GroupByExpr { // to avoid non-deterministic results with our tests. lazy_frame .group_by_stable(&by_exprs) - .agg(&self.agg_exprs) + .agg(&self.agg_exprs()) .select(&[col("*").exclude(by_expr_aliases)]) } } +#[cfg(any(test, feature = "polars"))] pub(crate) fn group_by_map_i128_to_utf8(v: i128) -> String { // use big end to allow // skipping leading zeros @@ -100,6 +114,7 @@ pub(crate) fn group_by_map_i128_to_utf8(v: i128) -> String { // Polars doesn't support Decimal columns inside group by. // So we need to remap them to the supported UTF8 type. +#[cfg(feature = "polars")] fn group_by_map_to_utf8_if_decimal(expr: Expr) -> Expr { expr.map( |series| match series.dtype().clone() { diff --git a/crates/proof-of-sql/src/sql/transform/mod.rs b/crates/proof-of-sql/src/sql/transform/mod.rs index c7bd2001b..99b3236d4 100644 --- a/crates/proof-of-sql/src/sql/transform/mod.rs +++ b/crates/proof-of-sql/src/sql/transform/mod.rs @@ -1,4 +1,12 @@ //! This module contains postprocessing for non-provable components. +/// The precision for [ColumnType::INT128] values +#[cfg(feature = "polars")] +pub const INT128_PRECISION: usize = 38; + +/// The scale for [ColumnType::INT128] values +#[cfg(feature = "polars")] +pub const INT128_SCALE: usize = 0; + mod result_expr; pub use result_expr::ResultExpr; @@ -11,10 +19,13 @@ pub use composition_expr::CompositionExpr; #[cfg(test)] pub mod composition_expr_test; +#[cfg(feature = "polars")] mod data_frame_expr; #[allow(deprecated)] +#[cfg(feature = "polars")] pub(crate) use data_frame_expr::DataFrameExpr; mod record_batch_expr; +#[cfg(feature = "polars")] pub(crate) use record_batch_expr::impl_record_batch_expr_for_data_frame_expr; pub use record_batch_expr::RecordBatchExpr; @@ -47,10 +58,16 @@ pub use group_by_expr::GroupByExpr; #[cfg(test)] mod group_by_expr_test; +#[cfg(feature = "polars")] mod polars_conversions; +#[cfg(feature = "polars")] pub use polars_conversions::LiteralConversion; +#[cfg(feature = "polars")] mod polars_arithmetic; +#[cfg(feature = "polars")] pub use polars_arithmetic::SafeDivision; +#[cfg(feature = "polars")] mod to_polars_expr; +#[cfg(feature = "polars")] pub(crate) use to_polars_expr::ToPolarsExpr; diff --git a/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs b/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs index 26dce954b..3ad484d8a 100644 --- a/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs +++ b/crates/proof-of-sql/src/sql/transform/order_by_exprs.rs @@ -1,10 +1,16 @@ #[allow(deprecated)] +#[cfg(feature = "polars")] use super::DataFrameExpr; -use crate::base::database::{INT128_PRECISION, INT128_SCALE}; +#[cfg(feature = "polars")] +use super::{INT128_PRECISION, INT128_SCALE}; +#[cfg(any(test, feature = "polars"))] use arrow::datatypes::ArrowNativeTypeOp; use dyn_partial_eq::DynPartialEq; +#[cfg(feature = "polars")] use polars::prelude::{col, DataType, Expr, GetOutput, LazyFrame, NamedFrom, Series}; -use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection}; +use proof_of_sql_parser::intermediate_ast::OrderBy; +#[cfg(feature = "polars")] +use proof_of_sql_parser::intermediate_ast::OrderByDirection; use serde::{Deserialize, Serialize}; /// A node representing a list of `OrderBy` expressions. @@ -20,8 +26,13 @@ impl OrderByExprs { } } +#[cfg(not(feature = "polars"))] +#[typetag::serde] +impl super::RecordBatchExpr for OrderByExprs {} +#[cfg(feature = "polars")] super::impl_record_batch_expr_for_data_frame_expr!(OrderByExprs); #[allow(deprecated)] +#[cfg(feature = "polars")] impl DataFrameExpr for OrderByExprs { /// Sort the `LazyFrame` by the `OrderBy` expressions. fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame { @@ -51,6 +62,7 @@ impl DataFrameExpr for OrderByExprs { /// * `a < b` if and only if `map_i128_to_utf8(a) < map_i128_to_utf8(b)`. /// * `a == b` if and only if `map_i128_to_utf8(a) == map_i128_to_utf8(b)`. /// * `a > b` if and only if `map_i128_to_utf8(a) > map_i128_to_utf8(b)`. +#[cfg(any(test, feature = "polars"))] pub(crate) fn order_by_map_i128_to_utf8(v: i128) -> String { let is_neg = v.is_negative() as u8; v.abs() @@ -78,6 +90,7 @@ pub(crate) fn order_by_map_i128_to_utf8(v: i128) -> String { // Polars doesn't support Decimal columns inside order by. // So we need to remap them to the supported UTF8 type. +#[cfg(feature = "polars")] fn order_by_map_to_utf8_if_decimal(expr: Expr) -> Expr { expr.map( |series| match series.dtype().clone() { diff --git a/crates/proof-of-sql/src/sql/transform/polars_conversions.rs b/crates/proof-of-sql/src/sql/transform/polars_conversions.rs index acb6ec060..5c7db486c 100644 --- a/crates/proof-of-sql/src/sql/transform/polars_conversions.rs +++ b/crates/proof-of-sql/src/sql/transform/polars_conversions.rs @@ -1,4 +1,4 @@ -use crate::base::database::{INT128_PRECISION, INT128_SCALE}; +use super::{INT128_PRECISION, INT128_SCALE}; use polars::prelude::{DataType, Expr, Literal, LiteralValue, Series}; /// Convert a Rust type to a Polars `Expr` type. diff --git a/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs b/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs index 0e59e2aea..fbdc9ff7d 100644 --- a/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs +++ b/crates/proof-of-sql/src/sql/transform/record_batch_expr.rs @@ -7,9 +7,13 @@ use std::fmt::Debug; #[dyn_partial_eq] pub trait RecordBatchExpr: Debug + Send + Sync { /// Apply the transformation to the `RecordBatch` and return the result. - fn apply_transformation(&self, record_batch: RecordBatch) -> Option; + #[allow(unused_variables)] + fn apply_transformation(&self, record_batch: RecordBatch) -> Option { + None + } } +#[cfg(feature = "polars")] macro_rules! impl_record_batch_expr_for_data_frame_expr { ($t:ty) => { #[typetag::serde] @@ -29,4 +33,5 @@ macro_rules! impl_record_batch_expr_for_data_frame_expr { }; } +#[cfg(feature = "polars")] pub(crate) use impl_record_batch_expr_for_data_frame_expr; diff --git a/crates/proof-of-sql/src/sql/transform/result_expr.rs b/crates/proof-of-sql/src/sql/transform/result_expr.rs index 244a0afec..5c8d38481 100644 --- a/crates/proof-of-sql/src/sql/transform/result_expr.rs +++ b/crates/proof-of-sql/src/sql/transform/result_expr.rs @@ -1,9 +1,9 @@ -use crate::{ - base::database::{dataframe_to_record_batch, record_batch_to_dataframe}, - sql::transform::RecordBatchExpr, -}; +#[cfg(feature = "polars")] +use crate::base::database::{dataframe_to_record_batch, record_batch_to_dataframe}; +use crate::sql::transform::RecordBatchExpr; use arrow::record_batch::RecordBatch; use dyn_partial_eq::DynPartialEq; +#[cfg(feature = "polars")] use polars::prelude::{IntoLazy, LazyFrame}; use serde::{Deserialize, Serialize}; @@ -23,11 +23,13 @@ impl ResultExpr { } } +#[cfg(feature = "polars")] pub(super) fn record_batch_to_lazy_frame(result_batch: RecordBatch) -> Option<(LazyFrame, usize)> { let num_input_rows = result_batch.num_rows(); let df = record_batch_to_dataframe(result_batch)?; Some((df.lazy(), num_input_rows)) } +#[cfg(feature = "polars")] pub(super) fn lazy_frame_to_record_batch(lazy_frame: LazyFrame) -> Option { // We're currently excluding NULLs in post-processing due to a lack of // prover support, aiming to avoid future complexities. diff --git a/crates/proof-of-sql/src/sql/transform/select_expr.rs b/crates/proof-of-sql/src/sql/transform/select_expr.rs index 3c407fd40..a2300c560 100644 --- a/crates/proof-of-sql/src/sql/transform/select_expr.rs +++ b/crates/proof-of-sql/src/sql/transform/select_expr.rs @@ -1,48 +1,86 @@ #[allow(deprecated)] +#[cfg(feature = "polars")] use super::DataFrameExpr; +use super::RecordBatchExpr; +#[cfg(feature = "polars")] use super::{ - record_batch_expr::RecordBatchExpr, result_expr::{lazy_frame_to_record_batch, record_batch_to_lazy_frame}, ToPolarsExpr, }; use arrow::record_batch::RecordBatch; use dyn_partial_eq::DynPartialEq; +#[cfg(feature = "polars")] use polars::prelude::{Expr, LazyFrame}; use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, Expression}; use serde::{Deserialize, Serialize}; +#[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] +pub enum SelectTerm { + #[cfg(test)] + #[cfg(feature = "polars")] + Polars(Expr), + AliasedResult(AliasedResultExpr), + Result(Expression), +} +#[cfg(feature = "polars")] +impl ToPolarsExpr for SelectTerm { + fn to_polars_expr(&self) -> Expr { + match self { + #[cfg(test)] + Self::Polars(s) => s.to_polars_expr(), + Self::AliasedResult(s) => s.to_polars_expr(), + Self::Result(s) => s.to_polars_expr(), + } + } +} +#[cfg(test)] +#[cfg(feature = "polars")] +impl From<&Expr> for SelectTerm { + fn from(value: &Expr) -> Self { + Self::Polars(value.clone()) + } +} +impl From<&AliasedResultExpr> for SelectTerm { + fn from(value: &AliasedResultExpr) -> Self { + Self::AliasedResult(value.clone()) + } +} +impl From<&Expression> for SelectTerm { + fn from(value: &Expression) -> Self { + Self::Result(value.clone()) + } +} +#[cfg(test)] +impl From<&Box> for SelectTerm { + fn from(value: &Box) -> Self { + value.as_ref().into() + } +} + /// The select expression used to select, reorder, and apply alias transformations #[derive(Debug, DynPartialEq, PartialEq, Serialize, Deserialize)] pub struct SelectExpr { /// The schema of the resulting lazy frame - result_schema: Vec, + result_schema: Vec, } impl SelectExpr { - #[cfg(test)] - pub(crate) fn new(exprs: &[impl ToPolarsExpr]) -> Self { - Self::new_from_to_polars(exprs) - } - fn new_from_to_polars(exprs: &[impl ToPolarsExpr]) -> Self { - let result_schema = Vec::from_iter(exprs.iter().map(ToPolarsExpr::to_polars_expr)); + /// Create a new select expression from a slice that implements `Into` + pub fn new(exprs: impl IntoIterator>) -> Self { + let result_schema = Vec::from_iter(exprs.into_iter().map(Into::into)); assert!(!result_schema.is_empty()); Self { result_schema } } - /// Create a new select expression from a slice of AliasedResultExpr - pub fn new_from_aliased_result_exprs(aliased_exprs: &[AliasedResultExpr]) -> Self { - Self::new_from_to_polars(aliased_exprs) - } - /// Create a new select expression from a slice of Expressions - pub fn new_from_expressions(exprs: &[Expression]) -> Self { - Self::new_from_to_polars(exprs) - } } #[allow(deprecated)] +#[cfg(feature = "polars")] impl DataFrameExpr for SelectExpr { /// Apply the select transformation to the lazy frame fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame { - lazy_frame.select(&self.result_schema) + lazy_frame.select(&Vec::from_iter( + self.result_schema.iter().map(ToPolarsExpr::to_polars_expr), + )) } } @@ -52,11 +90,10 @@ impl RecordBatchExpr for SelectExpr { let easy_result: Option> = self .result_schema .iter() - .cloned() .map(|expr| match expr { - Expr::Alias(a, b) => match *a { - Expr::Column(c) if c == b => { - Some((b.to_owned(), record_batch.column_by_name(&b)?.to_owned())) + SelectTerm::AliasedResult(AliasedResultExpr { expr, alias }) => match **expr { + Expression::Column(c) if &c == alias => { + Some((c, record_batch.column_by_name(c.as_str())?.to_owned())) } _ => None, }, @@ -67,8 +104,15 @@ impl RecordBatchExpr for SelectExpr { if let Some(Ok(result)) = easy_result.map(RecordBatch::try_from_iter) { return Some(result); } - let (lazy_frame, num_input_rows) = record_batch_to_lazy_frame(record_batch)?; - #[allow(deprecated)] - lazy_frame_to_record_batch(self.lazy_transformation(lazy_frame, num_input_rows)) + #[cfg(feature = "polars")] + { + let (lazy_frame, num_input_rows) = record_batch_to_lazy_frame(record_batch)?; + #[allow(deprecated)] + lazy_frame_to_record_batch(self.lazy_transformation(lazy_frame, num_input_rows)) + } + #[cfg(not(feature = "polars"))] + { + None + } } } diff --git a/crates/proof-of-sql/src/sql/transform/slice_expr.rs b/crates/proof-of-sql/src/sql/transform/slice_expr.rs index 22e88af66..78d6f8b02 100644 --- a/crates/proof-of-sql/src/sql/transform/slice_expr.rs +++ b/crates/proof-of-sql/src/sql/transform/slice_expr.rs @@ -1,6 +1,8 @@ #[allow(deprecated)] +#[cfg(feature = "polars")] use super::DataFrameExpr; use dyn_partial_eq::DynPartialEq; +#[cfg(feature = "polars")] use polars::prelude::LazyFrame; use serde::{Deserialize, Serialize}; @@ -30,8 +32,13 @@ impl SliceExpr { } } +#[cfg(not(feature = "polars"))] +#[typetag::serde] +impl super::RecordBatchExpr for SliceExpr {} +#[cfg(feature = "polars")] super::record_batch_expr::impl_record_batch_expr_for_data_frame_expr!(SliceExpr); #[allow(deprecated)] +#[cfg(feature = "polars")] impl DataFrameExpr for SliceExpr { /// Apply the slice transformation to the given `LazyFrame`. fn lazy_transformation(&self, lazy_frame: LazyFrame, _: usize) -> LazyFrame { diff --git a/crates/proof-of-sql/src/sql/transform/test_utility.rs b/crates/proof-of-sql/src/sql/transform/test_utility.rs index d5f1c5b30..c9ee94646 100644 --- a/crates/proof-of-sql/src/sql/transform/test_utility.rs +++ b/crates/proof-of-sql/src/sql/transform/test_utility.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{select_expr::SelectTerm, *}; use proof_of_sql_parser::intermediate_ast::*; pub fn lit_i64(literal: i64) -> Box { @@ -20,8 +20,9 @@ pub fn col(name: &str) -> Box { Box::new(Expression::Column(name.parse().unwrap())) } -pub(crate) fn select(result_schema: &[impl ToPolarsExpr]) -> Box { - #[allow(deprecated)] +pub(crate) fn select( + result_schema: impl IntoIterator>, +) -> Box { Box::new(SelectExpr::new(result_schema)) } @@ -34,9 +35,7 @@ pub fn schema(columns: &[(&str, &str)]) -> Vec { pub fn result(columns: &[(&str, &str)]) -> ResultExpr { let mut composition = CompositionExpr::default(); - composition.add(Box::new(SelectExpr::new_from_aliased_result_exprs( - &schema(columns), - ))); + composition.add(Box::new(SelectExpr::new(&schema(columns)))); ResultExpr::new(Box::new(composition)) }