diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index b78ec8873..9e5510640 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -168,7 +168,7 @@ pub const INT128_SCALE: usize = 0; /// /// See `` for /// a description of the native types used by Apache Ignite. -#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Deserialize, Copy)] +#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Hash, Serialize, Deserialize, Copy)] pub enum ColumnType { /// Mapped to bool #[serde(alias = "BOOLEAN", alias = "boolean")] @@ -188,12 +188,12 @@ pub enum ColumnType { /// Mapped to String #[serde(alias = "VARCHAR", alias = "varchar")] VarChar, - /// Mapped to Curve25519Scalar - #[serde(alias = "SCALAR", alias = "scalar")] - Scalar, /// Mapped to i256 #[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")] Decimal75(Precision, i8), + /// Mapped to Curve25519Scalar + #[serde(alias = "SCALAR", alias = "scalar")] + Scalar, } impl ColumnType { diff --git a/crates/proof-of-sql/src/base/math/decimal.rs b/crates/proof-of-sql/src/base/math/decimal.rs index 38f19013b..9a75448ab 100644 --- a/crates/proof-of-sql/src/base/math/decimal.rs +++ b/crates/proof-of-sql/src/base/math/decimal.rs @@ -6,7 +6,7 @@ use crate::{ use proof_of_sql_parser::intermediate_decimal::IntermediateDecimal; use serde::{Deserialize, Deserializer, Serialize}; -#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)] +#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Hash, Serialize, Copy)] /// limit-enforced precision pub struct Precision(u8); pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75; diff --git a/crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs b/crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs new file mode 100644 index 000000000..a8ea95f50 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs @@ -0,0 +1,102 @@ +use super::{ + scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns, + ProvableExpr, ProvableExprPlan, +}; +use crate::{ + base::{ + commitment::Commitment, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + proof::ProofError, + }, + sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder}, +}; +use bumpalo::Bump; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Provable numerical + / - expression +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct AddSubtractExpr { + lhs: Box>, + rhs: Box>, + is_subtract: bool, +} + +impl AddSubtractExpr { + /// Create numerical + / - expression + pub fn new( + lhs: Box>, + rhs: Box>, + is_subtract: bool, + ) -> Self { + Self { + lhs, + rhs, + is_subtract, + } + } +} + +impl ProvableExpr for AddSubtractExpr { + fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { + self.lhs.count(builder)?; + self.rhs.count(builder)?; + Ok(()) + } + + fn data_type(&self) -> ColumnType { + try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type()) + .expect("Failed to add/subtract column types") + } + + fn result_evaluate<'a>( + &self, + table_length: usize, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = + self.lhs.result_evaluate(table_length, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = + self.rhs.result_evaluate(table_length, alloc, accessor); + try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract) + .expect("Failed to add/subtract columns") + } + + #[tracing::instrument( + name = "proofs.sql.ast.not_expr.prover_evaluate", + level = "info", + skip_all + )] + fn prover_evaluate<'a>( + &self, + builder: &mut ProofBuilder<'a, C::Scalar>, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor); + try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract) + .expect("Failed to add/subtract columns") + } + + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &dyn CommitmentAccessor, + ) -> Result { + let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?; + let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?; + let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); + let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); + let res = + scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract) + .expect("Failed to scale and add/subtract"); + Ok(res) + } + + fn get_column_references(&self, columns: &mut HashSet) { + self.lhs.get_column_references(columns); + self.rhs.get_column_references(columns); + } +} diff --git a/crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs b/crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs new file mode 100644 index 000000000..810b2fd51 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/add_subtract_expr_test.rs @@ -0,0 +1,181 @@ +use crate::{ + base::{ + commitment::InnerProductProof, + database::{ + make_random_test_accessor_data, owned_table_utility::*, Column, ColumnType, + OwnedTableTestAccessor, RandomTestAccessorDescriptor, RecordBatchTestAccessor, + TestAccessor, + }, + }, + record_batch, + sql::ast::{ + test_expr::TestExprNode, + test_utility::{add, column, equal, subtract}, + ProvableExpr, ProvableExprPlan, + }, +}; +use arrow::record_batch::RecordBatch; +use bumpalo::Bump; +use curve25519_dalek::ristretto::RistrettoPoint; +use polars::prelude::*; +use rand::{rngs::StdRng, Rng}; +use rand_core::SeedableRng; + +// select results from table_ref where filter_col_l = filter_col_r0 + / - filter_col_r1 +#[allow(clippy::too_many_arguments)] +fn create_test_add_subtract_expr( + table_ref: &str, + results: &[&str], + filter_col_l: &str, + filter_col_r0: &str, + filter_col_r1: &str, + data: RecordBatch, + offset: usize, + is_subtract: bool, +) -> TestExprNode { + let mut accessor = RecordBatchTestAccessor::new_empty(); + let t = table_ref.parse().unwrap(); + accessor.add_table(t, data, offset); + let df_filter = if is_subtract { + polars::prelude::col(filter_col_l).eq(col(filter_col_r0) - col(filter_col_r1)) + } else { + polars::prelude::col(filter_col_l).eq(col(filter_col_r0) + col(filter_col_r1)) + }; + let filter_expr = equal( + column(t, filter_col_l, &accessor), + if is_subtract { + subtract( + column(t, filter_col_r0, &accessor), + column(t, filter_col_r1, &accessor), + ) + } else { + add( + column(t, filter_col_r0, &accessor), + column(t, filter_col_r1, &accessor), + ) + }, + ); + TestExprNode::new(t, results, filter_expr, df_filter, accessor) +} + +#[test] +fn we_can_prove_a_equals_add_query_with_a_single_selected_row() { + let data = record_batch!( + "a" => [123_i64, 456], + "b" => [4_i64, 1], + "c" => [123_i64, 457], + "d" => ["alfa", "gama"] + ); + let test_expr = + create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, false); + let res = test_expr.verify_expr(); + let expected_res = record_batch!( + "a" => [456_i64], + "d" => ["gama"] + ); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_prove_a_equals_subtract_query_with_a_single_selected_row() { + let data = record_batch!( + "a" => [127_i64, 458], + "b" => [4_i64, 1], + "c" => [123_i64, 457], + "d" => ["alfa", "gama"] + ); + let test_expr = + create_test_add_subtract_expr("sxt.t", &["a", "d"], "c", "a", "b", data, 0, true); + let res = test_expr.verify_expr(); + let expected_res = record_batch!( + "a" => [127_i64, 458], + "d" => ["alfa", "gama"] + ); + assert_eq!(res, expected_res); +} + +fn test_random_tables_with_given_offset(offset: usize) { + let descr = RandomTestAccessorDescriptor { + min_rows: 1, + max_rows: 20, + min_value: -3, + max_value: 3, + }; + let mut rng = StdRng::from_seed([0u8; 32]); + let cols = [ + ("l", ColumnType::BigInt), + ("r0", ColumnType::BigInt), + ("r1", ColumnType::BigInt), + ("varchar", ColumnType::VarChar), + ("integer", ColumnType::BigInt), + ]; + for _ in 0..20 { + let data = make_random_test_accessor_data(&mut rng, &cols, &descr); + let is_subtract = rng.gen::(); + let test_expr = create_test_add_subtract_expr( + "sxt.t", + &["l", "varchar", "integer"], + "l", + "r0", + "r1", + data, + offset, + is_subtract, + ); + let res = test_expr.verify_expr(); + let expected_res = test_expr.query_table(); + assert_eq!(res, expected_res); + } +} + +#[test] +fn we_can_query_random_tables_with_a_zero_offset() { + test_random_tables_with_given_offset(0); +} + +#[test] +fn we_can_query_random_tables_with_a_non_zero_offset() { + test_random_tables_with_given_offset(75); +} + +#[test] +fn we_can_compute_the_correct_output_of_an_add_expr_using_result_evaluate() { + let data = owned_table([ + bigint("a", [123, 456]), + bigint("b", [3, 1]), + bigint("c", [126, 453]), + varchar("d", ["alfa", "gama"]), + ]); + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + let t = "sxt.t".parse().unwrap(); + accessor.add_table(t, data, 0); + let eq_expr: ProvableExprPlan = equal( + column(t, "c", &accessor), + add(column(t, "a", &accessor), column(t, "b", &accessor)), + ); + let alloc = Bump::new(); + let res = eq_expr.result_evaluate(2, &alloc, &accessor); + let expected_res = Column::Boolean(&[true, false]); + assert_eq!(res, expected_res); +} + +#[test] +fn we_can_compute_the_correct_output_of_a_subtract_expr_using_result_evaluate() { + let data = owned_table([ + bigint("a", [123, 456]), + bigint("b", [3, 1]), + bigint("c", [126, 455]), + varchar("d", ["alfa", "gama"]), + ]); + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + let t = "sxt.t".parse().unwrap(); + accessor.add_table(t, data, 0); + let eq_expr: ProvableExprPlan = equal( + column(t, "c", &accessor), + subtract(column(t, "a", &accessor), column(t, "b", &accessor)), + ); + let alloc = Bump::new(); + let res = eq_expr.result_evaluate(2, &alloc, &accessor); + let expected_res = Column::Boolean(&[false, true]); + assert_eq!(res, expected_res); +} diff --git a/crates/proof-of-sql/src/sql/ast/comparison_util.rs b/crates/proof-of-sql/src/sql/ast/comparison_util.rs index 382bd2c70..53689d5d6 100644 --- a/crates/proof-of-sql/src/sql/ast/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/ast/comparison_util.rs @@ -72,7 +72,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( ); // Check if the precision is valid let _max_precision = Precision::new(max_precision_value) - .map_err(|_| ConversionError::InvalidPrecision(max_precision_value))?; + .map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?; } unchecked_subtract_impl( alloc, diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index af0bfe89b..cce66e2d7 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -2,6 +2,11 @@ mod filter_result_expr; pub(crate) use filter_result_expr::FilterResultExpr; +mod add_subtract_expr; +pub(crate) use add_subtract_expr::AddSubtractExpr; +#[cfg(all(test, feature = "blitzar"))] +mod add_subtract_expr_test; + mod filter_expr; pub(crate) use filter_expr::FilterExpr; #[cfg(test)] @@ -52,6 +57,11 @@ mod not_expr_test; mod comparison_util; pub(crate) use comparison_util::{scale_and_subtract, scale_and_subtract_eval}; +mod numerical_util; +pub(crate) use numerical_util::{ + scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns, +}; + mod equals_expr; use equals_expr::*; #[cfg(all(test, feature = "blitzar"))] diff --git a/crates/proof-of-sql/src/sql/ast/numerical_util.rs b/crates/proof-of-sql/src/sql/ast/numerical_util.rs new file mode 100644 index 000000000..eaffebc74 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/numerical_util.rs @@ -0,0 +1,97 @@ +use crate::{ + base::{ + database::{Column, ColumnType}, + math::decimal::{scale_scalar, Precision}, + scalar::Scalar, + }, + sql::parse::{ConversionError, ConversionResult}, +}; +use bumpalo::Bump; + +// For decimal type manipulation please refer to +// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 + +/// Determine the output type of an add or subtract operation if it is possible +/// to add or subtract the two input types. If the types are not compatible, return +/// an error. +pub(crate) fn try_add_subtract_column_types( + lhs: ColumnType, + rhs: ColumnType, +) -> ConversionResult { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ConversionError::DataTypeMismatch( + lhs.to_string(), + rhs.to_string(), + )); + } + if lhs.is_integer() && rhs.is_integer() { + return Ok(lhs.max(rhs)); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = lhs.precision_value().unwrap_or(0) as i16; + let right_precision_value = rhs.precision_value().unwrap_or(0) as i16; + let left_scale = lhs.scale().unwrap_or(0); + let right_scale = rhs.scale().unwrap_or(0); + let scale = left_scale.max(right_scale); + let precision_value: i16 = scale as i16 + + (left_precision_value - left_scale as i16) + .max(right_precision_value - right_scale as i16) + + 1_i16; + let precision = u8::try_from(precision_value) + .map_err(|_| ConversionError::InvalidPrecision(precision_value)) + .and_then(|p| { + Precision::new(p).map_err(|_| ConversionError::InvalidPrecision(p as i16)) + })?; + Ok(ColumnType::Decimal75(precision, scale)) + } +} + +/// Add or subtract two columns together. +/// +/// If the columns are not compatible for addition/subtraction, return an error. +pub(crate) fn try_add_subtract_columns<'a, S: Scalar>( + lhs: Column<'a, S>, + rhs: Column<'a, S>, + alloc: &'a Bump, + is_subtract: bool, +) -> ConversionResult> { + let lhs_len = lhs.len(); + let rhs_len = rhs.len(); + if lhs_len != rhs_len { + return Err(ConversionError::DifferentColumnLength(lhs_len, rhs_len)); + } + let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len); + let left_scale = lhs.column_type().scale().unwrap_or(0); + let right_scale = rhs.column_type().scale().unwrap_or(0); + let max_scale = left_scale.max(right_scale); + let lhs_scalar = lhs.to_scalar_with_scaling(max_scale - left_scale); + let rhs_scalar = rhs.to_scalar_with_scaling(max_scale - right_scale); + let res = alloc.alloc_slice_fill_with(lhs_len, |i| { + if is_subtract { + lhs_scalar[i] - rhs_scalar[i] + } else { + lhs_scalar[i] + rhs_scalar[i] + } + }); + Ok(Column::Scalar(res)) +} + +/// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions. +pub(crate) fn scale_and_add_subtract_eval( + lhs_eval: S, + rhs_eval: S, + lhs_scale: i8, + rhs_scale: i8, + is_subtract: bool, +) -> ConversionResult { + let max_scale = lhs_scale.max(rhs_scale); + let scaled_lhs_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)?; + let scaled_rhs_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)?; + if is_subtract { + Ok(scaled_lhs_eval - scaled_rhs_eval) + } else { + Ok(scaled_lhs_eval + scaled_rhs_eval) + } +} diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 5ffaa6162..dcec2c374 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -1,5 +1,6 @@ use super::{ - AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, ProvableExpr, + AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, + ProvableExpr, }; use crate::{ base::{ @@ -34,6 +35,10 @@ pub enum ProvableExprPlan { Equals(EqualsExpr), /// Provable AST expression for an inequality expression Inequality(InequalityExpr), + /// Provable numeric + expression + Add(AddSubtractExpr), + /// Provable numeric - expression + Subtract(AddSubtractExpr), } impl ProvableExprPlan { /// Create column expression @@ -109,6 +114,48 @@ impl ProvableExprPlan { } } + /// Create a new add expression + pub fn try_new_add( + lhs: ProvableExprPlan, + rhs: ProvableExprPlan, + ) -> ConversionResult { + let lhs_datatype = lhs.data_type(); + let rhs_datatype = rhs.data_type(); + if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) { + Err(ConversionError::DataTypeMismatch( + lhs_datatype.to_string(), + rhs_datatype.to_string(), + )) + } else { + Ok(Self::Add(AddSubtractExpr::new( + Box::new(lhs), + Box::new(rhs), + false, + ))) + } + } + + /// Create a new subtract expression + pub fn try_new_subtract( + lhs: ProvableExprPlan, + rhs: ProvableExprPlan, + ) -> ConversionResult { + let lhs_datatype = lhs.data_type(); + let rhs_datatype = rhs.data_type(); + if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) { + Err(ConversionError::DataTypeMismatch( + lhs_datatype.to_string(), + rhs_datatype.to_string(), + )) + } else { + Ok(Self::Subtract(AddSubtractExpr::new( + Box::new(lhs), + Box::new(rhs), + true, + ))) + } + } + /// Check that the plan has the correct data type fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> { if self.data_type() == data_type { @@ -132,12 +179,16 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Literal(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Equals(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Inequality(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Add(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Subtract(expr) => ProvableExpr::::count(expr, builder), } } fn data_type(&self) -> ColumnType { match self { ProvableExprPlan::Column(expr) => expr.data_type(), + ProvableExprPlan::Add(expr) => expr.data_type(), + ProvableExprPlan::Subtract(expr) => expr.data_type(), ProvableExprPlan::Literal(expr) => ProvableExpr::::data_type(expr), ProvableExprPlan::And(_) | ProvableExprPlan::Or(_) @@ -175,6 +226,12 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => { ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) } + ProvableExprPlan::Add(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } + ProvableExprPlan::Subtract(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } } } @@ -206,6 +263,12 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => { ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) } + ProvableExprPlan::Add(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } + ProvableExprPlan::Subtract(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } } } @@ -224,6 +287,8 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Literal(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Equals(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Add(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Subtract(expr) => expr.verifier_evaluate(builder, accessor), } } @@ -244,6 +309,10 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => { ProvableExpr::::get_column_references(expr, columns) } + ProvableExprPlan::Add(expr) => ProvableExpr::::get_column_references(expr, columns), + ProvableExprPlan::Subtract(expr) => { + ProvableExpr::::get_column_references(expr, columns) + } } } } diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index 4ed303123..3bc668638 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -63,6 +63,20 @@ pub fn or( ProvableExprPlan::try_new_or(left, right).unwrap() } +pub fn add( + left: ProvableExprPlan, + right: ProvableExprPlan, +) -> ProvableExprPlan { + ProvableExprPlan::try_new_add(left, right).unwrap() +} + +pub fn subtract( + left: ProvableExprPlan, + right: ProvableExprPlan, +) -> ProvableExprPlan { + ProvableExprPlan::try_new_subtract(left, right).unwrap() +} + pub fn const_bool(val: bool) -> ProvableExprPlan { ProvableExprPlan::new_literal(LiteralValue::Boolean(val)) } diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index 1ffc98fbf..255522b3e 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -59,8 +59,8 @@ pub enum ConversionError { PrecisionParseError(String), #[error("Decimal precision is not valid: {0}")] - /// Decimal precision exceeds the allowed limit - InvalidPrecision(u8), + /// Decimal precision is an integer but exceeds the allowed limit. We use i16 here to include all kinds of invalid precision values. + InvalidPrecision(i16), #[error("Encountered parsing error: {0}")] /// General parsing error diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index 0b08ad800..d33637533 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -73,7 +73,7 @@ impl ProvableExprPlanBuilder<'_> { Literal::Decimal(d) => { let scale = d.scale(); let precision = Precision::new(d.precision()) - .map_err(|_| ConversionError::InvalidPrecision(d.precision()))?; + .map_err(|_| ConversionError::InvalidPrecision(d.precision() as i16))?; Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( precision, scale, @@ -130,13 +130,19 @@ impl ProvableExprPlanBuilder<'_> { let right = self.visit_expr(right); ProvableExprPlan::try_new_inequality(left?, right?, true) } - BinaryOperator::Add - | BinaryOperator::Subtract - | BinaryOperator::Multiply - | BinaryOperator::Division => Err(ConversionError::Unprovable(format!( - "Binary operator {:?} is not supported yet", - op - ))), + BinaryOperator::Add => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_add(left?, right?) + } + BinaryOperator::Subtract => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_subtract(left?, right?) + } + BinaryOperator::Multiply | BinaryOperator::Division => Err( + ConversionError::Unprovable(format!("Binary operator {:?} is not supported", op)), + ), } } } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index e83615da4..770fe1b6c 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -276,6 +276,40 @@ fn we_can_convert_an_ast_with_two_columns() { assert_eq!(ast, expected_ast); } +#[test] +fn we_can_convert_an_ast_with_two_columns_and_arithmetic() { + let t = "sxt.sxt_tab".parse().unwrap(); + let accessor = record_batch_to_accessor( + t, + record_batch!( + "a" => Vec::::new(), + "b" => Vec::::new(), + "c" => Vec::::new(), + ), + 0_usize, + ); + let ast = query_to_provable_ast( + t, + "select a, b from sxt_tab where c = a + b - 1", + &accessor, + ); + let expected_ast = QueryExpr::new( + dense_filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + equal( + column(t, "c", &accessor), + subtract( + add(column(t, "a", &accessor), column(t, "b", &accessor)), + const_bigint(1), + ), + ), + ), + result(&[("a", "a"), ("b", "b")]), + ); + assert_eq!(ast, expected_ast); +} + #[test] fn we_can_parse_all_result_columns_with_select_star() { let t = "sxt.sxt_tab".parse().unwrap(); @@ -717,14 +751,17 @@ fn we_can_parse_order_by_with_multiple_columns() { ); let ast = query_to_provable_ast( t, - "select a, b from sxt_tab where a = 3 order by b desc, a asc", + "select a, b from sxt_tab where a = b + 3 order by b desc, a asc", &accessor, ); let expected_ast = QueryExpr::new( dense_filter( cols_expr_plan(t, &["a", "b"], &accessor), tab(t), - equal(column(t, "a", &accessor), const_bigint(3)), + equal( + column(t, "a", &accessor), + add(column(t, "b", &accessor), const_bigint(3)), + ), ), composite_result(vec![ select(&[pc("a").alias("a"), pc("b").alias("b")]), @@ -1584,15 +1621,32 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr_plan(t, &["a", "b", "f", "h"], &accessor), + vec![ + ( + add(column(t, "a", &accessor), column(t, "b", &accessor)), + "__expr__".parse().unwrap(), + ), + ( + subtract(const_bigint(-77), column(t, "h", &accessor)), + "col".parse().unwrap(), + ), + ( + add(column(t, "a", &accessor), column(t, "f", &accessor)), + "af".parse().unwrap(), + ), + col_expr_plan(t, "a", &accessor), + col_expr_plan(t, "b", &accessor), + col_expr_plan(t, "f", &accessor), + col_expr_plan(t, "h", &accessor), + ], tab(t), const_bool(true), ), composite_result(vec![select(&[ - (pc("a") + pc("b")).alias("__expr__"), + pc("__expr__").alias("__expr__"), (lit_i64(2) * pc("f")).alias("f2"), - ((-77_i64).to_lit() - pc("h")).alias("col"), - (pc("a") + pc("f")).alias("af"), + pc("col").alias("col"), + pc("af").alias("af"), // TODO: add `a / b as a_div_b` result expr once polars properly // supports decimal division without panicking in production // (pc("a") / pc("b")).alias("a_div_b"), diff --git a/crates/proof-of-sql/tests/integration_tests.rs b/crates/proof-of-sql/tests/integration_tests.rs index 1d56894b3..2d764414f 100644 --- a/crates/proof-of-sql/tests/integration_tests.rs +++ b/crates/proof-of-sql/tests/integration_tests.rs @@ -11,10 +11,7 @@ use proof_of_sql::{ }, proof_primitive::dory::{DoryCommitment, DoryEvaluationProof, DoryProverPublicSetup}, record_batch, - sql::{ - parse::{ConversionError, QueryExpr}, - proof::QueryProof, - }, + sql::{parse::QueryExpr, proof::QueryProof}, }; #[test] @@ -160,41 +157,68 @@ fn we_can_prove_a_basic_inequality_query_with_curve25519() { assert_eq!(owned_table_result, expected_result); } -//TODO: Once arithmetic is supported, this test should be updated to use arithmetic. #[test] #[cfg(feature = "blitzar")] -fn we_cannot_prove_a_query_with_arithmetic_in_where_clause_with_curve25519() { +fn we_can_prove_a_query_with_arithmetic_in_where_clause_with_curve25519() { let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor.add_table( "sxt.table".parse().unwrap(), - owned_table([bigint("a", [1, 2, 3]), bigint("b", [1, 0, 2])]), + owned_table([bigint("a", [1, 2, 3]), bigint("b", [4, 1, 2])]), 0, ); - let res_query = QueryExpr::::try_new( + let query = QueryExpr::::try_new( "SELECT * FROM table WHERE b >= a + 1".parse().unwrap(), "sxt".parse().unwrap(), &accessor, - ); - assert!(matches!(res_query, Err(ConversionError::Unprovable(_)))); + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &()); + let owned_table_result = proof + .verify(query.proof_expr(), &accessor, &serialized_result, &()) + .unwrap() + .table; + let owned_table_result: OwnedTable = query + .result() + .transform_results(owned_table_result.try_into().unwrap()) + .unwrap() + .try_into() + .unwrap(); + let expected_result = owned_table([bigint("a", [1]), bigint("b", [4])]); + assert_eq!(owned_table_result, expected_result); } #[test] -fn we_cannot_prove_a_query_with_arithmetic_in_where_clause_with_dory() { +fn we_can_prove_a_query_with_arithmetic_in_where_clause_with_dory() { let dory_prover_setup = DoryProverPublicSetup::rand(4, 3, &mut test_rng()); + let dory_verifier_setup = (&dory_prover_setup).into(); let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup( dory_prover_setup.clone(), ); accessor.add_table( "sxt.table".parse().unwrap(), - owned_table([bigint("a", [1, 2, 3]), bigint("b", [1, 0, 2])]), + owned_table([bigint("a", [1, -1, 3]), bigint("b", [0, 0, 2])]), 0, ); - let res_query = QueryExpr::::try_new( - "SELECT * FROM table WHERE b >= -(a)".parse().unwrap(), + let query = QueryExpr::::try_new( + "SELECT * FROM table WHERE b > 1 - a".parse().unwrap(), "sxt".parse().unwrap(), &accessor, - ); - assert!(matches!(res_query, Err(ConversionError::Unprovable(_)))); + ) + .unwrap(); + let (proof, serialized_result) = + QueryProof::::new(query.proof_expr(), &accessor, &dory_prover_setup); + let owned_table_result = proof + .verify( + query.proof_expr(), + &accessor, + &serialized_result, + &dory_verifier_setup, + ) + .unwrap() + .table; + let expected_result = owned_table([bigint("a", [3]), bigint("b", [2])]); + assert_eq!(owned_table_result, expected_result); } #[test]