diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 4bba7b38c..22890aaf3 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -131,6 +131,23 @@ impl<'a, S: Scalar> Column<'a, S> { } } + /// Returns element at index as scalar + /// + /// Note that if index is out of bounds, this function will return None + pub(crate) fn scalar_at(&self, index: usize) -> Option { + (index < self.len()).then_some(match self { + Self::Boolean(col) => S::from(col[index]), + Self::SmallInt(col) => S::from(col[index]), + Self::Int(col) => S::from(col[index]), + Self::BigInt(col) => S::from(col[index]), + Self::Int128(col) => S::from(col[index]), + Self::Scalar(col) => col[index], + Self::Decimal75(_, _, col) => col[index], + Self::VarChar((_, scals)) => scals[index], + Self::TimestampTZ(_, _, col) => S::from(col[index]), + }) + } + /// Convert a column to a vector of Scalar values with scaling pub(crate) fn to_scalar_with_scaling(&self, scale: i8) -> Vec { let scale_factor = scale_scalar(S::ONE, scale).expect("Invalid scale factor"); diff --git a/crates/proof-of-sql/src/base/math/decimal.rs b/crates/proof-of-sql/src/base/math/decimal.rs index 363cef87d..d614f2450 100644 --- a/crates/proof-of-sql/src/base/math/decimal.rs +++ b/crates/proof-of-sql/src/base/math/decimal.rs @@ -30,6 +30,11 @@ pub enum DecimalError { /// or non-positive aka InvalidPrecision InvalidPrecision(String), + #[error("Decimal scale is not valid: {0}")] + /// Decimal scale is not valid. Here we use i16 in order to include + /// invalid scale values + InvalidScale(i16), + #[error("Unsupported operation: cannot round decimal: {0}")] /// This error occurs when attempting to scale a /// decimal in such a way that a loss of precision occurs. diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index 9dbf0167a..1b174cb9a 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -7,6 +7,11 @@ pub(crate) use add_subtract_expr::AddSubtractExpr; #[cfg(all(test, feature = "blitzar"))] mod add_subtract_expr_test; +mod multiply_expr; +use multiply_expr::MultiplyExpr; +#[cfg(all(test, feature = "blitzar"))] +mod multiply_expr_test; + mod filter_expr; pub(crate) use filter_expr::FilterExpr; #[cfg(test)] @@ -59,7 +64,8 @@ pub(crate) use comparison_util::scale_and_subtract; mod numerical_util; pub(crate) use numerical_util::{ - add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types, + add_subtract_columns, multiply_columns, scale_and_add_subtract_eval, + try_add_subtract_column_types, try_multiply_column_types, }; mod equals_expr; diff --git a/crates/proof-of-sql/src/sql/ast/multiply_expr.rs b/crates/proof-of-sql/src/sql/ast/multiply_expr.rs new file mode 100644 index 000000000..9dc4ae3e5 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/multiply_expr.rs @@ -0,0 +1,116 @@ +use super::{ProvableExpr, ProvableExprPlan}; +use crate::{ + base::{ + commitment::Commitment, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + proof::ProofError, + }, + sql::{ + ast::{multiply_columns, try_multiply_column_types}, + proof::{CountBuilder, ProofBuilder, SumcheckSubpolynomialType, VerificationBuilder}, + }, +}; +use bumpalo::Bump; +use num_traits::One; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Provable numerical * expression +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MultiplyExpr { + lhs: Box>, + rhs: Box>, +} + +impl MultiplyExpr { + /// Create numerical `*` expression + pub fn new(lhs: Box>, rhs: Box>) -> Self { + Self { lhs, rhs } + } +} + +impl ProvableExpr for MultiplyExpr { + fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { + self.lhs.count(builder)?; + self.rhs.count(builder)?; + builder.count_subpolynomials(1); + builder.count_intermediate_mles(1); + builder.count_degree(3); + Ok(()) + } + + fn data_type(&self) -> ColumnType { + try_multiply_column_types(self.lhs.data_type(), self.rhs.data_type()) + .expect("Failed to multiply column types") + } + + fn result_evaluate<'a>( + &self, + table_length: usize, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = + self.lhs.result_evaluate(table_length, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = + self.rhs.result_evaluate(table_length, alloc, accessor); + let scalars = multiply_columns(&lhs_column, &rhs_column, alloc); + Column::Scalar(scalars) + } + + #[tracing::instrument( + name = "proofs.sql.ast.and_expr.prover_evaluate", + level = "info", + skip_all + )] + fn prover_evaluate<'a>( + &self, + builder: &mut ProofBuilder<'a, C::Scalar>, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor); + + // lhs_times_rhs + let lhs_times_rhs: &'a [C::Scalar] = multiply_columns(&lhs_column, &rhs_column, alloc); + builder.produce_intermediate_mle(lhs_times_rhs); + + // subpolynomial: lhs_times_rhs - lhs * rhs + builder.produce_sumcheck_subpolynomial( + SumcheckSubpolynomialType::Identity, + vec![ + (C::Scalar::one(), vec![Box::new(lhs_times_rhs)]), + ( + -C::Scalar::one(), + vec![Box::new(lhs_column), Box::new(rhs_column)], + ), + ], + ); + Column::Scalar(lhs_times_rhs) + } + + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &dyn CommitmentAccessor, + ) -> Result { + let lhs = self.lhs.verifier_evaluate(builder, accessor)?; + let rhs = self.rhs.verifier_evaluate(builder, accessor)?; + + // lhs_times_rhs + let lhs_times_rhs = builder.consume_intermediate_mle(); + + // subpolynomial: lhs_times_rhs - lhs * rhs + let eval = builder.mle_evaluations.random_evaluation * (lhs_times_rhs - lhs * rhs); + builder.produce_sumcheck_subpolynomial_evaluation(&eval); + + // selection + Ok(lhs_times_rhs) + } + + fn get_column_references(&self, columns: &mut HashSet) { + self.lhs.get_column_references(columns); + self.rhs.get_column_references(columns); + } +} diff --git a/crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs b/crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs new file mode 100644 index 000000000..92ea2a5a9 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs @@ -0,0 +1,346 @@ +use crate::{ + base::{ + commitment::InnerProductProof, + database::{owned_table_utility::*, Column, OwnedTableTestAccessor}, + scalar::Curve25519Scalar, + }, + sql::{ + ast::{test_utility::*, ProofPlan, ProvableExpr, ProvableExprPlan}, + parse::ConversionError, + proof::{exercise_verification, QueryError, VerifiableQueryResult}, + }, +}; +use bumpalo::Bump; +use curve25519_dalek::ristretto::RistrettoPoint; +use itertools::{multizip, MultiUnzip}; +use rand::{ + distributions::{Distribution, Uniform}, + rngs::StdRng, +}; +use rand_core::SeedableRng; + +// select a * 2 as a, c, b * 4.5 as b, d * 3 + 4.7 as d, e from sxt.t where d * 3.9 = 8.19 +#[test] +fn we_can_prove_a_typical_multiply_query() { + let data = owned_table([ + smallint("a", [1_i16, 2, 3, 4]), + int("b", [0_i32, 1, 2, 1]), + varchar("e", ["ab", "t", "efg", "g"]), + bigint("c", [0_i64, 2, 2, 0]), + decimal75("d", 2, 1, [21_i64, 4, 21, -7]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast = dense_filter( + vec![ + ( + multiply(column(t, "a", &accessor), const_int(2)), + "a".parse().unwrap(), + ), + col_expr_plan(t, "c", &accessor), + ( + multiply(column(t, "b", &accessor), const_decimal75(2, 1, 45)), + "b".parse().unwrap(), + ), + ( + add( + multiply(column(t, "d", &accessor), const_smallint(3)), + const_decimal75(2, 1, 47), + ), + "d".parse().unwrap(), + ), + col_expr_plan(t, "e", &accessor), + ], + tab(t), + equal( + multiply(column(t, "d", &accessor), const_decimal75(2, 1, 39)), + const_decimal75(3, 2, 819), + ), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ + int("a", [2_i32, 6]), + bigint("c", [0_i64, 2]), + decimal75("b", 13, 1, [0_i64, 90]), + decimal75("d", 9, 1, [110_i64, 110]), + varchar("e", ["ab", "efg"]), + ]); + assert_eq!(res, expected_res); +} + +// Column type issue tests +#[test] +fn decimal_column_type_issues_error_out_when_producing_provable_ast() { + let data = owned_table([decimal75("a", 57, 2, [1_i16, 2, 3, 4])]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + assert!(matches!( + ProvableExprPlan::try_new_multiply( + column(t, "a", &accessor), + const_bigint::(1) + ), + Err(ConversionError::DataTypeMismatch(..)) + )); +} + +// Overflow tests +// select a * b as c from sxt.t where b = 2 +#[test] +fn result_expr_can_overflow() { + let data = owned_table([ + smallint("a", [i16::MAX, i16::MIN]), + smallint("b", [2_i16, 0]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast: ProofPlan = dense_filter( + vec![( + multiply(column(t, "a", &accessor), column(t, "b", &accessor)), + "c".parse().unwrap(), + )], + tab(t), + equal(column(t, "b", &accessor), const_bigint(2)), + ); + let verifiable_res: VerifiableQueryResult = + VerifiableQueryResult::new(&ast, &accessor, &()); + assert!(matches!( + verifiable_res.verify(&ast, &accessor, &()), + Err(QueryError::Overflow) + )); +} + +// select a * b as c from sxt.t where b == 0 +#[test] +fn overflow_in_nonselected_rows_doesnt_error_out() { + let data = owned_table([ + smallint("a", [i16::MAX, i16::MIN + 1]), + smallint("b", [2_i16, 0]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast: ProofPlan = dense_filter( + vec![( + multiply(column(t, "a", &accessor), column(t, "b", &accessor)), + "c".parse().unwrap(), + )], + tab(t), + equal(column(t, "b", &accessor), const_bigint(0)), + ); + let verifiable_res: VerifiableQueryResult = + VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([smallint("c", [0_i16])]); + assert_eq!(res, expected_res); +} + +// select a, b from sxt.t where a * b >= 0 +#[test] +fn overflow_in_where_clause_doesnt_error_out() { + let data = owned_table([ + bigint("a", [i64::MAX, i64::MIN + 1]), + smallint("b", [2_i16, 1]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast: ProofPlan = dense_filter( + cols_expr_plan(t, &["a", "b"], &accessor), + tab(t), + gte( + multiply(column(t, "a", &accessor), column(t, "b", &accessor)), + const_bigint(0), + ), + ); + let verifiable_res: VerifiableQueryResult = + VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([bigint("a", [i64::MAX]), smallint("b", [2_i16])]); + assert_eq!(res, expected_res); +} + +// select a * b as c from sxt.t +#[test] +fn result_expr_can_overflow_more() { + let data = owned_table([ + bigint("a", [i64::MAX, i64::MIN, i64::MAX, i64::MIN]), + bigint("b", [i64::MAX, i64::MAX, i64::MIN, i64::MIN]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast: ProofPlan = dense_filter( + vec![( + multiply(column(t, "a", &accessor), column(t, "b", &accessor)), + "c".parse().unwrap(), + )], + tab(t), + const_bool(true), + ); + let verifiable_res: VerifiableQueryResult = + VerifiableQueryResult::new(&ast, &accessor, &()); + assert!(matches!( + verifiable_res.verify(&ast, &accessor, &()), + Err(QueryError::Overflow) + )); +} + +// select * from sxt.t where a * b * c * d * e = res +// Only the last row is a valid result +// The other two are due to the fact that scalars are elements of finite fields +// and that hence scalar multiplication inherently wraps around +#[test] +fn where_clause_can_wrap_around() { + let data = owned_table([ + bigint("a", [2357878470324616199_i64, 2657439699204141, 884]), + bigint("b", [31194601778911687_i64, 1644425323726039, 884]), + bigint("c", [500213946116239_i64, 1570568673569987, 884]), + bigint("d", [211980999383887_i64, 1056107792886999, 884]), + bigint("e", [927908842441_i64, 998426626609497, 884]), + bigint("res", [-20_i64, 50, 539835356263424]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let ast: ProofPlan = dense_filter( + cols_expr_plan(t, &["a", "b", "c", "d", "e", "res"], &accessor), + tab(t), + equal( + multiply( + multiply( + multiply( + multiply(column(t, "a", &accessor), column(t, "b", &accessor)), + column(t, "c", &accessor), + ), + column(t, "d", &accessor), + ), + column(t, "e", &accessor), + ), + column(t, "res", &accessor), + ), + ); + let verifiable_res: VerifiableQueryResult = + VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + let expected_res = owned_table([ + bigint("a", [2357878470324616199_i64, 2657439699204141, 884]), + bigint("b", [31194601778911687_i64, 1644425323726039, 884]), + bigint("c", [500213946116239_i64, 1570568673569987, 884]), + bigint("d", [211980999383887_i64, 1056107792886999, 884]), + bigint("e", [927908842441_i64, 998426626609497, 884]), + bigint("res", [-20_i64, 50, 539835356263424]), + ]); + assert_eq!(res, expected_res); +} + +fn test_random_tables_with_given_offset(offset: usize) { + let dist = Uniform::new(-3, 4); + let mut rng = StdRng::from_seed([0u8; 32]); + for _ in 0..20 { + // Generate random table + let n = Uniform::new(1, 21).sample(&mut rng); + let data = owned_table([ + bigint("a", dist.sample_iter(&mut rng).take(n)), + varchar( + "b", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + bigint("c", dist.sample_iter(&mut rng).take(n)), + varchar( + "d", + dist.sample_iter(&mut rng).take(n).map(|v| format!("s{v}")), + ), + ]); + + // Generate random values to filter by + let filter_val1 = format!("s{}", dist.sample(&mut rng)); + let filter_val2 = dist.sample(&mut rng); + + // Create and verify proof + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table( + t, + data.clone(), + offset, + (), + ); + let ast = dense_filter( + vec![ + col_expr_plan(t, "d", &accessor), + ( + add( + multiply(column(t, "a", &accessor), column(t, "c", &accessor)), + const_int128(4), + ), + "f".parse().unwrap(), + ), + ], + tab(t), + and( + equal( + column(t, "b", &accessor), + const_scalar(filter_val1.as_str()), + ), + equal(column(t, "c", &accessor), const_scalar(filter_val2)), + ), + ); + let verifiable_res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&verifiable_res, &ast, &accessor, t); + let res = verifiable_res.verify(&ast, &accessor, &()).unwrap().table; + + // Calculate/compare expected result + let (expected_f, expected_d): (Vec<_>, Vec<_>) = multizip(( + data["a"].i64_iter(), + data["b"].string_iter(), + data["c"].i64_iter(), + data["d"].string_iter(), + )) + .filter_map(|(a, b, c, d)| { + if b == &filter_val1 && c == &filter_val2 { + Some(((*a * *c + 4) as i128, d.clone())) + } else { + None + } + }) + .multiunzip(); + let expected_result = owned_table([varchar("d", expected_d), int128("f", expected_f)]); + + assert_eq!(expected_result, res) + } +} + +#[test] +fn we_can_query_random_tables_using_a_zero_offset() { + test_random_tables_with_given_offset(0); +} + +#[test] +fn we_can_query_random_tables_using_a_non_zero_offset() { + test_random_tables_with_given_offset(23); +} + +// b * (a - 1.5) +#[test] +fn we_can_compute_the_correct_output_of_a_multiply_expr_using_result_evaluate() { + let data = owned_table([ + smallint("a", [1_i16, 2, 3, 4]), + int("b", [0_i32, 1, 5, 1]), + varchar("d", ["ab", "t", "efg", "g"]), + bigint("c", [0_i64, 2, 2, 0]), + ]); + let t = "sxt.t".parse().unwrap(); + let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); + let add_subtract_expr: ProvableExprPlan = multiply( + column(t, "b", &accessor), + subtract(column(t, "a", &accessor), const_decimal75(2, 1, 15)), + ); + let alloc = Bump::new(); + let res = add_subtract_expr.result_evaluate(4, &alloc, &accessor); + let expected_res_scalar = [0, 5, 75, 25] + .iter() + .map(|v| Curve25519Scalar::from(*v)) + .collect::>(); + let expected_res = Column::Scalar(&expected_res_scalar); + assert_eq!(res, expected_res); +} diff --git a/crates/proof-of-sql/src/sql/ast/numerical_util.rs b/crates/proof-of-sql/src/sql/ast/numerical_util.rs index 42fa8eeec..caf899859 100644 --- a/crates/proof-of-sql/src/sql/ast/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/ast/numerical_util.rs @@ -4,10 +4,7 @@ use crate::{ math::decimal::{scale_scalar, DecimalError, Precision}, scalar::Scalar, }, - sql::{ - ast::numerical_util::DecimalError::InvalidPrecision, - parse::{ConversionError, ConversionError::DecimalConversionError, ConversionResult}, - }, + sql::parse::{ConversionError, ConversionError::DecimalConversionError, ConversionResult}, }; use bumpalo::Bump; @@ -44,15 +41,58 @@ pub(crate) fn try_add_subtract_column_types( .max(right_precision_value - right_scale as i16) + 1_i16; let precision = u8::try_from(precision_value) - .map_err(|_| DecimalConversionError(InvalidPrecision(precision_value.to_string()))) + .map_err(|_| { + DecimalConversionError(DecimalError::InvalidPrecision(precision_value.to_string())) + }) .and_then(|p| { - Precision::new(p) - .map_err(|_| DecimalConversionError(InvalidPrecision(p.to_string()))) + Precision::new(p).map_err(|_| { + DecimalConversionError(DecimalError::InvalidPrecision(p.to_string())) + }) })?; Ok(ColumnType::Decimal75(precision, scale)) } } +/// Determine the output type of a multiplication operation if it is possible +/// to multiply the two input types. If the types are not compatible, return +/// an error. +pub(crate) fn try_multiply_column_types( + lhs: ColumnType, + rhs: ColumnType, +) -> ConversionResult { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ConversionError::DataTypeMismatch( + lhs.to_string(), + rhs.to_string(), + )); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = lhs.precision_value().unwrap_or(0); + let right_precision_value = rhs.precision_value().unwrap_or(0); + let precision_value = left_precision_value + right_precision_value + 1; + let precision = Precision::new(precision_value).map_err(|_| { + DecimalConversionError(DecimalError::InvalidPrecision(format!( + "Required precision {} is beyond what we can support", + precision_value + ))) + })?; + let left_scale = lhs.scale().unwrap_or(0); + let right_scale = rhs.scale().unwrap_or(0); + let scale = left_scale + .checked_add(right_scale) + .ok_or(DecimalConversionError(DecimalError::InvalidScale( + left_scale as i16 + right_scale as i16, + )))?; + Ok(ColumnType::Decimal75(precision, scale)) + } +} + /// Add or subtract two columns together. pub(crate) fn add_subtract_columns<'a, S: Scalar>( lhs: Column<'a, S>, @@ -68,7 +108,6 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>( lhs_len == rhs_len, "lhs and rhs should have the same length" ); - let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len); let max_scale = lhs_scale.max(rhs_scale); let lhs_scalar = lhs.to_scalar_with_scaling(max_scale - lhs_scale); let rhs_scalar = rhs.to_scalar_with_scaling(max_scale - rhs_scale); @@ -82,6 +121,23 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>( res } +/// Multiply two columns together. +pub(crate) fn multiply_columns<'a, S: Scalar>( + lhs: &Column<'a, S>, + rhs: &Column<'a, S>, + alloc: &'a Bump, +) -> &'a [S] { + let lhs_len = lhs.len(); + let rhs_len = rhs.len(); + assert!( + lhs_len == rhs_len, + "lhs and rhs should have the same length" + ); + alloc.alloc_slice_fill_with(lhs_len, |i| { + lhs.scalar_at(i).unwrap() * rhs.scalar_at(i).unwrap() + }) +} + /// The counterpart of `add_subtract_columns` for evaluating decimal expressions. pub(crate) fn scale_and_add_subtract_eval( lhs_eval: S, diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 69335f5c8..5f83c4c67 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -1,6 +1,6 @@ use super::{ - AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, - ProvableExpr, + AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr, + NotExpr, OrExpr, ProvableExpr, }; use crate::{ base::{ @@ -37,6 +37,8 @@ pub enum ProvableExprPlan { Inequality(InequalityExpr), /// Provable numeric `+` / `-` expression AddSubtract(AddSubtractExpr), + /// Provable numeric `*` expression + Multiply(MultiplyExpr), } impl ProvableExprPlan { /// Create column expression @@ -154,6 +156,26 @@ impl ProvableExprPlan { } } + /// Create a new multiply expression + pub fn try_new_multiply( + lhs: ProvableExprPlan, + rhs: ProvableExprPlan, + ) -> ConversionResult { + let lhs_datatype = lhs.data_type(); + let rhs_datatype = rhs.data_type(); + if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) { + Err(ConversionError::DataTypeMismatch( + lhs_datatype.to_string(), + rhs_datatype.to_string(), + )) + } else { + Ok(Self::Multiply(MultiplyExpr::new( + Box::new(lhs), + Box::new(rhs), + ))) + } + } + /// Check that the plan has the correct data type fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> { if self.data_type() == data_type { @@ -178,6 +200,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Equals(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Inequality(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::AddSubtract(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Multiply(expr) => ProvableExpr::::count(expr, builder), } } @@ -185,6 +208,7 @@ impl ProvableExpr for ProvableExprPlan { match self { ProvableExprPlan::Column(expr) => expr.data_type(), ProvableExprPlan::AddSubtract(expr) => expr.data_type(), + ProvableExprPlan::Multiply(expr) => expr.data_type(), ProvableExprPlan::Literal(expr) => ProvableExpr::::data_type(expr), ProvableExprPlan::And(_) | ProvableExprPlan::Or(_) @@ -225,6 +249,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::AddSubtract(expr) => { ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) } + ProvableExprPlan::Multiply(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } } } @@ -259,6 +286,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::AddSubtract(expr) => { ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) } + ProvableExprPlan::Multiply(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } } } @@ -278,6 +308,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Equals(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::AddSubtract(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Multiply(expr) => expr.verifier_evaluate(builder, accessor), } } @@ -301,6 +332,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::AddSubtract(expr) => { ProvableExpr::::get_column_references(expr, columns) } + ProvableExprPlan::Multiply(expr) => { + ProvableExpr::::get_column_references(expr, columns) + } } } } diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index d8bc6ddf7..e9f7a90b2 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -78,10 +78,25 @@ pub fn subtract( ProvableExprPlan::try_new_subtract(left, right).unwrap() } +pub fn multiply( + left: ProvableExprPlan, + right: ProvableExprPlan, +) -> ProvableExprPlan { + ProvableExprPlan::try_new_multiply(left, right).unwrap() +} + pub fn const_bool(val: bool) -> ProvableExprPlan { ProvableExprPlan::new_literal(LiteralValue::Boolean(val)) } +pub fn const_smallint(val: i16) -> ProvableExprPlan { + ProvableExprPlan::new_literal(LiteralValue::SmallInt(val)) +} + +pub fn const_int(val: i32) -> ProvableExprPlan { + ProvableExprPlan::new_literal(LiteralValue::Int(val)) +} + pub fn const_bigint(val: i64) -> ProvableExprPlan { ProvableExprPlan::new_literal(LiteralValue::BigInt(val)) } diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index d18069ecd..271594615 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -144,9 +144,15 @@ impl ProvableExprPlanBuilder<'_> { let right = self.visit_expr(right); ProvableExprPlan::try_new_subtract(left?, right?) } - BinaryOperator::Multiply | BinaryOperator::Division => Err( - ConversionError::Unprovable(format!("Binary operator {:?} is not supported", op)), - ), + BinaryOperator::Multiply => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_multiply(left?, right?) + } + BinaryOperator::Division => Err(ConversionError::Unprovable(format!( + "Binary operator {:?} is not supported at this location", + op + ))), } } } diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 0b870b1cc..1f7d4e234 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -4,7 +4,7 @@ use crate::{ database::{ColumnRef, ColumnType, SchemaAccessor, TableRef}, math::decimal::Precision, }, - sql::ast::try_add_subtract_column_types, + sql::ast::{try_add_subtract_column_types, try_multiply_column_types}, }; use proof_of_sql_parser::{ intermediate_ast::{ @@ -313,9 +313,8 @@ pub(crate) fn type_check_binary_operation( BinaryOperator::Add | BinaryOperator::Subtract => { try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok() } - BinaryOperator::Multiply | BinaryOperator::Division => { - left_dtype.is_numeric() && right_dtype.is_numeric() - } + BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(), + BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(), } } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index 770fe1b6c..193d40a4c 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -447,7 +447,7 @@ fn we_can_convert_an_ast_with_cond_or() { ); let ast = query_to_provable_ast( t, - "select a from sxt_tab where (b = 3) or (c = -2)", + "select a from sxt_tab where (b * 3 = 3) or (c = -2)", &accessor, ); let expected_ast = QueryExpr::new( @@ -455,7 +455,10 @@ fn we_can_convert_an_ast_with_cond_or() { cols_expr_plan(t, &["a"], &accessor), tab(t), or( - equal(column(t, "b", &accessor), const_bigint(3)), + equal( + multiply(column(t, "b", &accessor), const_bigint(3)), + const_bigint(3), + ), equal(column(t, "c", &accessor), const_bigint(-2)), ), ), @@ -1626,6 +1629,10 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex add(column(t, "a", &accessor), column(t, "b", &accessor)), "__expr__".parse().unwrap(), ), + ( + multiply(const_bigint(2), column(t, "f", &accessor)), + "f2".parse().unwrap(), + ), ( subtract(const_bigint(-77), column(t, "h", &accessor)), "col".parse().unwrap(), @@ -1634,17 +1641,13 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex add(column(t, "a", &accessor), column(t, "f", &accessor)), "af".parse().unwrap(), ), - col_expr_plan(t, "a", &accessor), - col_expr_plan(t, "b", &accessor), - col_expr_plan(t, "f", &accessor), - col_expr_plan(t, "h", &accessor), ], tab(t), const_bool(true), ), composite_result(vec![select(&[ pc("__expr__").alias("__expr__"), - (lit_i64(2) * pc("f")).alias("f2"), + pc("f2").alias("f2"), pc("col").alias("col"), pc("af").alias("af"), // TODO: add `a / b as a_div_b` result expr once polars properly @@ -1676,14 +1679,40 @@ fn we_can_parse_multiple_arithmetic_expression_where_multiplication_has_preceden ); let expected_ast = QueryExpr::new( dense_filter( - cols_expr_plan(t, &["c", "f", "g", "h"], &accessor), + vec![ + ( + multiply( + add(const_bigint(2), column(t, "f", &accessor)), + add( + add(column(t, "c", &accessor), column(t, "g", &accessor)), + multiply(const_bigint(2), column(t, "h", &accessor)), + ), + ), + "__expr__".parse().unwrap(), + ), + ( + multiply( + add( + add( + multiply( + subtract(column(t, "h", &accessor), column(t, "g", &accessor)), + const_bigint(2), + ), + column(t, "c", &accessor), + ), + column(t, "g", &accessor), + ), + add(column(t, "f", &accessor), const_bigint(2)), + ), + "d".parse().unwrap(), + ), + ], tab(t), const_bool(true), ), composite_result(vec![select(&[ - ((lit_i64(2) + pc("f")) * (pc("c") + pc("g") + lit_i64(2) * pc("h"))).alias("__expr__"), - (((pc("h") - pc("g")) * lit_i64(2) + pc("c") + pc("g")) * (pc("f") + lit_i64(2))) - .alias("d"), + pc("__expr__").alias("__expr__"), + pc("d").alias("d"), ])]), ); assert_eq!(ast, expected_ast); diff --git a/crates/proof-of-sql/tests/integration_tests.rs b/crates/proof-of-sql/tests/integration_tests.rs index 221b8882d..099c21f08 100644 --- a/crates/proof-of-sql/tests/integration_tests.rs +++ b/crates/proof-of-sql/tests/integration_tests.rs @@ -402,7 +402,7 @@ fn we_can_prove_a_complex_query_with_curve25519() { "sxt.table".parse().unwrap(), owned_table([ smallint("a", [1_i16, 2, 3]), - int("b", [1_i32, 0, 1]), + int("b", [1_i32, 4, 3]), bigint("c", [3_i64, 3, -3]), bigint("d", [1_i64, 2, 3]), varchar("e", ["d", "e", "f"]), @@ -413,7 +413,7 @@ fn we_can_prove_a_complex_query_with_curve25519() { 0, ); let query = QueryExpr::try_new( - "SELECT a + b + c + 1 as t, 45.7 as g, (a = b) or f as h, d0 + d1 + 1.4 as dr FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f" + "SELECT a + (b * c) + 1 as t, 45.7 as g, (a = b) or f as h, d0 * d1 + 1.4 as dr FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f" .parse() .unwrap(), "sxt".parse().unwrap(), @@ -427,10 +427,10 @@ fn we_can_prove_a_complex_query_with_curve25519() { .unwrap() .table; let expected_result = owned_table([ - bigint("t", [2]), + bigint("t", [-5]), decimal75("g", 3, 1, [457]), - boolean("h", [false]), - decimal75("dr", 16, 4, [14203]), + boolean("h", [true]), + decimal75("dr", 26, 6, [1400006]), ]); assert_eq!(owned_table_result, expected_result); } @@ -452,13 +452,13 @@ fn we_can_prove_a_complex_query_with_dory() { bigint("d", [1, 2, 3]), varchar("e", ["d", "e", "f"]), boolean("f", [true, false, true]), - decimal75("d0", 12, 4, [1, 2, 3]), + decimal75("d0", 12, 4, [1, 4, 3]), decimal75("d1", 12, 2, [3, 4, 2]), ]), 0, ); let query = QueryExpr::try_new( - "SELECT 0.5 + a - b + c - d as res, 32 as g, (c >= d) and f as h, a + b + 1 + c + d + d0 - d1 + 0.5 as res2 FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f and d1 - d0 > 0.01" + "SELECT 0.5 + a * b * c - d as res, 32 as g, (c >= d) and f as h, (a + 1) * (b + 1 + c + d + d0 - d1 + 0.5) as res2 FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f and 100000 * d1 * d0 + a = 1.3" .parse() .unwrap(), "sxt".parse().unwrap(), @@ -480,7 +480,7 @@ fn we_can_prove_a_complex_query_with_dory() { decimal75("res", 22, 1, [25]), bigint("g", [32]), boolean("h", [true]), - decimal75("res2", 26, 4, [74701]), + decimal75("res2", 46, 4, [129402]), ]); assert_eq!(owned_table_result, expected_result); }