diff --git a/crates/proof-of-sql/src/sql/ast/mod.rs b/crates/proof-of-sql/src/sql/ast/mod.rs index 9dbf0167a..1eae02757 100644 --- a/crates/proof-of-sql/src/sql/ast/mod.rs +++ b/crates/proof-of-sql/src/sql/ast/mod.rs @@ -7,6 +7,11 @@ pub(crate) use add_subtract_expr::AddSubtractExpr; #[cfg(all(test, feature = "blitzar"))] mod add_subtract_expr_test; +mod multiply_expr; +use multiply_expr::MultiplyExpr; +#[cfg(all(test, feature = "blitzar"))] +mod multiply_expr_test; + mod filter_expr; pub(crate) use filter_expr::FilterExpr; #[cfg(test)] @@ -59,7 +64,8 @@ pub(crate) use comparison_util::scale_and_subtract; mod numerical_util; pub(crate) use numerical_util::{ - add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types, + scale_and_add_subtract_eval, try_add_subtract_column_types, add_subtract_columns, + try_multiply_column_types, try_multiply_columns, }; mod equals_expr; diff --git a/crates/proof-of-sql/src/sql/ast/multiply_expr.rs b/crates/proof-of-sql/src/sql/ast/multiply_expr.rs new file mode 100644 index 000000000..46c01b371 --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/multiply_expr.rs @@ -0,0 +1,123 @@ +use super::{ProvableExpr, ProvableExprPlan}; +use crate::{ + base::{ + commitment::Commitment, + database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, + proof::ProofError, + }, + sql::{ + ast::{try_multiply_column_types, try_multiply_columns}, + parse::ConversionError, + proof::{CountBuilder, ProofBuilder, SumcheckSubpolynomialType, VerificationBuilder}, + }, +}; +use bumpalo::Bump; +use num_traits::One; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +/// Provable numerical * expression +#[derive(Debug, PartialEq, Serialize, Deserialize)] +pub struct MultiplyExpr { + lhs: Box>, + rhs: Box>, +} + +impl MultiplyExpr { + /// Create numerical * expression + pub fn new(lhs: Box>, rhs: Box>) -> Self { + Self { lhs, rhs } + } +} + +impl ProvableExpr for MultiplyExpr { + fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { + self.lhs.count(builder)?; + self.rhs.count(builder)?; + builder.count_subpolynomials(1); + builder.count_intermediate_mles(1); + builder.count_degree(3); + Ok(()) + } + + fn data_type(&self) -> ColumnType { + try_multiply_column_types(self.lhs.data_type(), self.rhs.data_type()) + .expect("Failed to multiply column types") + } + + fn result_evaluate<'a>( + &self, + table_length: usize, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = + self.lhs.result_evaluate(table_length, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = + self.rhs.result_evaluate(table_length, alloc, accessor); + let scalars = try_multiply_columns(lhs_column, rhs_column, alloc) + .expect("Failed to multiply columns"); + Column::Scalar(scalars) + } + + #[tracing::instrument( + name = "proofs.sql.ast.and_expr.prover_evaluate", + level = "info", + skip_all + )] + fn prover_evaluate<'a>( + &self, + builder: &mut ProofBuilder<'a, C::Scalar>, + alloc: &'a Bump, + accessor: &'a dyn DataAccessor, + ) -> Column<'a, C::Scalar> { + let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor); + let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor); + let lhs_scalars = lhs_column.to_scalar_with_scaling(0); + let rhs_scalars = rhs_column.to_scalar_with_scaling(0); + let lhs_bump: &'a [C::Scalar] = alloc.alloc_slice_copy(&lhs_scalars); + let rhs_bump: &'a [C::Scalar] = alloc.alloc_slice_copy(&rhs_scalars); + + // lhs_times_rhs + let lhs_times_rhs: &'a [C::Scalar] = try_multiply_columns(lhs_column, rhs_column, alloc) + .expect("Failed to multiply columns"); + builder.produce_intermediate_mle(lhs_times_rhs); + + // subpolynomial: lhs_times_rhs - lhs * rhs + builder.produce_sumcheck_subpolynomial( + SumcheckSubpolynomialType::Identity, + vec![ + (C::Scalar::one(), vec![Box::new(lhs_times_rhs)]), + ( + -C::Scalar::one(), + vec![Box::new(lhs_bump), Box::new(rhs_bump)], + ), + ], + ); + Column::Scalar(lhs_times_rhs) + } + + fn verifier_evaluate( + &self, + builder: &mut VerificationBuilder, + accessor: &dyn CommitmentAccessor, + ) -> Result { + let lhs = self.lhs.verifier_evaluate(builder, accessor)?; + let rhs = self.rhs.verifier_evaluate(builder, accessor)?; + + // lhs_times_rhs + let lhs_times_rhs = builder.consume_intermediate_mle(); + + // subpolynomial: lhs_times_rhs - lhs * rhs + let eval = builder.mle_evaluations.random_evaluation * (lhs_times_rhs - lhs * rhs); + builder.produce_sumcheck_subpolynomial_evaluation(&eval); + + // selection + Ok(lhs_times_rhs) + } + + fn get_column_references(&self, columns: &mut HashSet) { + self.lhs.get_column_references(columns); + self.rhs.get_column_references(columns); + } +} diff --git a/crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs b/crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs new file mode 100644 index 000000000..c3696b80b --- /dev/null +++ b/crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs @@ -0,0 +1,62 @@ +use super::{test_utility::*, FilterExpr, ProvableExpr}; +use crate::{ + base::{ + commitment::InnerProductProof, + database::{ + make_random_test_accessor_data, Column, ColumnType, OwnedTable, OwnedTableTestAccessor, + RandomTestAccessorDescriptor, TestAccessor, + }, + scalar::Curve25519Scalar, + }, + owned_table, + sql::{ + ast::{test_utility::*, ProvableExprPlan}, + proof::{exercise_verification, VerifiableQueryResult}, + }, +}; +/// This function creates a TestAccessor, adds a table, and then creates a FilterExpr with the given parameters. +/// It then executes the query, verifies the result, and returns the table. +/// The query is `select r_0,\cdots r_k from table where res = lhs * rhs` +fn create_and_verify_test_col_mul_expr( + table_ref: &str, + results: &[&str], + lhs: &str, + rhs: &str, + res: &str, + data: OwnedTable, + offset: usize, +) -> OwnedTable { + let mut accessor = OwnedTableTestAccessor::::new_empty_with_setup(()); + let t = table_ref.parse().unwrap(); + accessor.add_table(t, data, offset); + let mul_expr = mul(column(t, lhs, &accessor), column(t, rhs, &accessor)); + let eq_expr = equal(column(t, res, &accessor), mul_expr); + let ast = FilterExpr::new(cols_result(t, results, &accessor), tab(t), eq_expr); + let res = VerifiableQueryResult::new(&ast, &accessor, &()); + exercise_verification(&res, &ast, &accessor, t); + res.verify(&ast, &accessor, &()).unwrap().table +} + +#[test] +fn we_can_prove_a_simple_mul_query() { + let data = owned_table!( + "product" => [1_i64, -4, 0, 4], + "c0" => [0_i32, -2, 0, -1], + "d" => ["ab", "t", "efg", "g"], + "c1" => [0_i16, 2, -2, 0], + ); + let res = create_and_verify_test_col_mul_expr( + "sxt.t", + &["product", "c0"], + "c0", + "c1", + "product", + data, + 0, + ); + let expected_res = owned_table!( + "product" => [-4_i64, 0], + "c0" => [-2_i32, 0] + ); + assert_eq!(res, expected_res); +} diff --git a/crates/proof-of-sql/src/sql/ast/numerical_util.rs b/crates/proof-of-sql/src/sql/ast/numerical_util.rs index 3b2daf582..91b5ece02 100644 --- a/crates/proof-of-sql/src/sql/ast/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/ast/numerical_util.rs @@ -49,6 +49,48 @@ pub(crate) fn try_add_subtract_column_types( } } +/// Determine the output type of a multiplication operation if it is possible +/// to multiply the two input types. If the types are not compatible, return +/// an error. +pub(crate) fn try_multiply_column_types( + lhs: ColumnType, + rhs: ColumnType, +) -> ConversionResult { + if !lhs.is_numeric() || rhs.is_numeric() { + return Err(ConversionError::DataTypeMismatch( + lhs.to_string(), + rhs.to_string(), + )); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + return Ok(ColumnType::Scalar); + } else { + let left_precision_value = lhs.precision_value().unwrap_or(0); + let right_precision_value = rhs.precision_value().unwrap_or(0); + let precision_value = left_precision_value + right_precision_value + 1; + let precision = Precision::new(precision_value).map_err(|_| { + ConversionError::DecimalRoundingError(format!( + "Required precision {} is beyond what we can support", + precision_value + )) + })?; + let left_scale = lhs.scale().unwrap_or(0); + let right_scale = rhs.scale().unwrap_or(0); + let scale = + left_scale + .checked_add(right_scale) + .ok_or(ConversionError::DecimalScaleError(format!( + "Required scale {} is beyond what we can support", + left_scale as i16 + right_scale as i16 + )))?; + Ok(ColumnType::Decimal75(precision, scale)) + } +} + /// Add or subtract two columns together. pub(crate) fn add_subtract_columns<'a, S: Scalar>( lhs: Column<'a, S>, @@ -78,6 +120,25 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>( res } +/// Multiply two columns together. +pub(crate) fn multiply_columns<'a, S: Scalar>( + lhs: Column<'a, S>, + rhs: Column<'a, S>, + alloc: &'a Bump, +) -> &'a [S] { + let lhs_len = lhs.len(); + let rhs_len = rhs.len(); + assert!( + lhs_len == rhs_len, + "lhs and rhs should have the same length" + ); + let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len); + let lhs_scalar = lhs.to_scalar_with_scaling(0); + let rhs_scalar = rhs.to_scalar_with_scaling(0); + let res = alloc.alloc_slice_fill_with(lhs_len, |i| lhs_scalar[i] * rhs_scalar[i]); + Ok(res) +} + /// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions. pub(crate) fn scale_and_add_subtract_eval( lhs_eval: S, diff --git a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs index 0d388beb2..7545336c6 100644 --- a/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs +++ b/crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs @@ -1,6 +1,6 @@ use super::{ - AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, - ProvableExpr, + AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr, + NotExpr, OrExpr, ProvableExpr, }; use crate::{ base::{ @@ -39,6 +39,8 @@ pub enum ProvableExprPlan { Add(AddSubtractExpr), /// Provable numeric `-` expression Subtract(AddSubtractExpr), + /// Provable numeric `*` expression + Multiply(MultiplyExpr), } impl ProvableExprPlan { /// Create column expression @@ -156,6 +158,26 @@ impl ProvableExprPlan { } } + /// Create a new multiply expression + pub fn try_new_multiply( + lhs: ProvableExprPlan, + rhs: ProvableExprPlan, + ) -> ConversionResult { + let lhs_datatype = lhs.data_type(); + let rhs_datatype = rhs.data_type(); + if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) { + Err(ConversionError::DataTypeMismatch( + lhs_datatype.to_string(), + rhs_datatype.to_string(), + )) + } else { + Ok(Self::Multiply(MultiplyExpr::new( + Box::new(lhs), + Box::new(rhs), + ))) + } + } + /// Check that the plan has the correct data type fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> { if self.data_type() == data_type { @@ -181,6 +203,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Add(expr) => ProvableExpr::::count(expr, builder), ProvableExprPlan::Subtract(expr) => ProvableExpr::::count(expr, builder), + ProvableExprPlan::Multiply(expr) => ProvableExpr::::count(expr, builder), } } @@ -189,6 +212,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Column(expr) => expr.data_type(), ProvableExprPlan::Add(expr) => expr.data_type(), ProvableExprPlan::Subtract(expr) => expr.data_type(), + ProvableExprPlan::Multiply(expr) => expr.data_type(), ProvableExprPlan::Literal(expr) => ProvableExpr::::data_type(expr), ProvableExprPlan::And(_) | ProvableExprPlan::Or(_) @@ -232,6 +256,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Subtract(expr) => { ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) } + ProvableExprPlan::Multiply(expr) => { + ProvableExpr::::result_evaluate(expr, table_length, alloc, accessor) + } } } @@ -269,6 +296,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Subtract(expr) => { ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) } + ProvableExprPlan::Multiply(expr) => { + ProvableExpr::::prover_evaluate(expr, builder, alloc, accessor) + } } } @@ -289,6 +319,7 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Add(expr) => expr.verifier_evaluate(builder, accessor), ProvableExprPlan::Subtract(expr) => expr.verifier_evaluate(builder, accessor), + ProvableExprPlan::Multiply(expr) => expr.verifier_evaluate(builder, accessor), } } @@ -313,6 +344,9 @@ impl ProvableExpr for ProvableExprPlan { ProvableExprPlan::Subtract(expr) => { ProvableExpr::::get_column_references(expr, columns) } + ProvableExprPlan::Multiply(expr) => { + ProvableExpr::::get_column_references(expr, columns) + } } } } diff --git a/crates/proof-of-sql/src/sql/ast/test_utility.rs b/crates/proof-of-sql/src/sql/ast/test_utility.rs index 3bc668638..5b45b6b49 100644 --- a/crates/proof-of-sql/src/sql/ast/test_utility.rs +++ b/crates/proof-of-sql/src/sql/ast/test_utility.rs @@ -77,6 +77,13 @@ pub fn subtract( ProvableExprPlan::try_new_subtract(left, right).unwrap() } +pub fn multiply( + left: ProvableExprPlan, + right: ProvableExprPlan, +) -> ProvableExprPlan { + ProvableExprPlan::try_new_multiply(left, right).unwrap() +} + pub fn const_bool(val: bool) -> ProvableExprPlan { ProvableExprPlan::new_literal(LiteralValue::Boolean(val)) } diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index d33637533..ecdbcbbf1 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -140,7 +140,12 @@ impl ProvableExprPlanBuilder<'_> { let right = self.visit_expr(right); ProvableExprPlan::try_new_subtract(left?, right?) } - BinaryOperator::Multiply | BinaryOperator::Division => Err( + BinaryOperator::Multiply => { + let left = self.visit_expr(left); + let right = self.visit_expr(right); + ProvableExprPlan::try_new_multiply(left?, right?) + } + BinaryOperator::Division => Err( ConversionError::Unprovable(format!("Binary operator {:?} is not supported", op)), ), }