Skip to content

Commit

Permalink
feat: add MultiplyExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Jun 20, 2024
1 parent 9c23c26 commit 1832530
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 4 deletions.
8 changes: 7 additions & 1 deletion crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ pub(crate) use add_subtract_expr::AddSubtractExpr;
#[cfg(all(test, feature = "blitzar"))]
mod add_subtract_expr_test;

mod multiply_expr;
use multiply_expr::MultiplyExpr;
#[cfg(all(test, feature = "blitzar"))]
mod multiply_expr_test;

mod filter_expr;
pub(crate) use filter_expr::FilterExpr;
#[cfg(test)]
Expand Down Expand Up @@ -59,7 +64,8 @@ pub(crate) use comparison_util::scale_and_subtract;

mod numerical_util;
pub(crate) use numerical_util::{
add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types,
scale_and_add_subtract_eval, try_add_subtract_column_types, add_subtract_columns,
try_multiply_column_types, try_multiply_columns,
};

mod equals_expr;
Expand Down
123 changes: 123 additions & 0 deletions crates/proof-of-sql/src/sql/ast/multiply_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use super::{ProvableExpr, ProvableExprPlan};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::{
ast::{try_multiply_column_types, try_multiply_columns},
parse::ConversionError,
proof::{CountBuilder, ProofBuilder, SumcheckSubpolynomialType, VerificationBuilder},
},
};
use bumpalo::Bump;
use num_traits::One;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical * expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct MultiplyExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
}

impl<C: Commitment> MultiplyExpr<C> {
/// Create numerical * expression
pub fn new(lhs: Box<ProvableExprPlan<C>>, rhs: Box<ProvableExprPlan<C>>) -> Self {
Self { lhs, rhs }
}
}

impl<C: Commitment> ProvableExpr<C> for MultiplyExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
builder.count_subpolynomials(1);
builder.count_intermediate_mles(1);
builder.count_degree(3);
Ok(())
}

fn data_type(&self) -> ColumnType {
try_multiply_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to multiply column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
let scalars = try_multiply_columns(lhs_column, rhs_column, alloc)
.expect("Failed to multiply columns");
Column::Scalar(scalars)
}

#[tracing::instrument(
name = "proofs.sql.ast.and_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
let lhs_scalars = lhs_column.to_scalar_with_scaling(0);
let rhs_scalars = rhs_column.to_scalar_with_scaling(0);
let lhs_bump: &'a [C::Scalar] = alloc.alloc_slice_copy(&lhs_scalars);
let rhs_bump: &'a [C::Scalar] = alloc.alloc_slice_copy(&rhs_scalars);

// lhs_times_rhs
let lhs_times_rhs: &'a [C::Scalar] = try_multiply_columns(lhs_column, rhs_column, alloc)
.expect("Failed to multiply columns");
builder.produce_intermediate_mle(lhs_times_rhs);

// subpolynomial: lhs_times_rhs - lhs * rhs
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(C::Scalar::one(), vec![Box::new(lhs_times_rhs)]),
(
-C::Scalar::one(),
vec![Box::new(lhs_bump), Box::new(rhs_bump)],
),
],
);
Column::Scalar(lhs_times_rhs)
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs = self.rhs.verifier_evaluate(builder, accessor)?;

// lhs_times_rhs
let lhs_times_rhs = builder.consume_intermediate_mle();

// subpolynomial: lhs_times_rhs - lhs * rhs
let eval = builder.mle_evaluations.random_evaluation * (lhs_times_rhs - lhs * rhs);
builder.produce_sumcheck_subpolynomial_evaluation(&eval);

// selection
Ok(lhs_times_rhs)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
62 changes: 62 additions & 0 deletions crates/proof-of-sql/src/sql/ast/multiply_expr_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use super::{test_utility::*, FilterExpr, ProvableExpr};
use crate::{
base::{
commitment::InnerProductProof,
database::{
make_random_test_accessor_data, Column, ColumnType, OwnedTable, OwnedTableTestAccessor,
RandomTestAccessorDescriptor, TestAccessor,
},
scalar::Curve25519Scalar,
},
owned_table,
sql::{
ast::{test_utility::*, ProvableExprPlan},
proof::{exercise_verification, VerifiableQueryResult},
},
};
/// This function creates a TestAccessor, adds a table, and then creates a FilterExpr with the given parameters.
/// It then executes the query, verifies the result, and returns the table.
/// The query is `select r_0,\cdots r_k from table where res = lhs * rhs`
fn create_and_verify_test_col_mul_expr(
table_ref: &str,
results: &[&str],
lhs: &str,
rhs: &str,
res: &str,
data: OwnedTable<Curve25519Scalar>,
offset: usize,
) -> OwnedTable<Curve25519Scalar> {
let mut accessor = OwnedTableTestAccessor::<InnerProductProof>::new_empty_with_setup(());
let t = table_ref.parse().unwrap();
accessor.add_table(t, data, offset);
let mul_expr = mul(column(t, lhs, &accessor), column(t, rhs, &accessor));
let eq_expr = equal(column(t, res, &accessor), mul_expr);
let ast = FilterExpr::new(cols_result(t, results, &accessor), tab(t), eq_expr);
let res = VerifiableQueryResult::new(&ast, &accessor, &());
exercise_verification(&res, &ast, &accessor, t);
res.verify(&ast, &accessor, &()).unwrap().table
}

#[test]
fn we_can_prove_a_simple_mul_query() {
let data = owned_table!(
"product" => [1_i64, -4, 0, 4],
"c0" => [0_i32, -2, 0, -1],
"d" => ["ab", "t", "efg", "g"],
"c1" => [0_i16, 2, -2, 0],
);
let res = create_and_verify_test_col_mul_expr(
"sxt.t",
&["product", "c0"],
"c0",
"c1",
"product",
data,
0,
);
let expected_res = owned_table!(
"product" => [-4_i64, 0],
"c0" => [-2_i32, 0]
);
assert_eq!(res, expected_res);
}
62 changes: 62 additions & 0 deletions crates/proof-of-sql/src/sql/ast/numerical_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ pub(crate) fn try_add_subtract_column_types(
}
}

/// Determine the output type of a multiplication operation if it is possible
/// to multiply the two input types. If the types are not compatible, return
/// an error.
pub(crate) fn try_multiply_column_types(
lhs: ColumnType,
rhs: ColumnType,
) -> ConversionResult<ColumnType> {
if !lhs.is_numeric() || rhs.is_numeric() {
return Err(ConversionError::DataTypeMismatch(
lhs.to_string(),
rhs.to_string(),
));
}
if lhs.is_integer() && rhs.is_integer() {
// We can unwrap here because we know that both types are integers
return Ok(lhs.max_integer_type(&rhs).unwrap());
}
if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar {
return Ok(ColumnType::Scalar);
} else {
let left_precision_value = lhs.precision_value().unwrap_or(0);
let right_precision_value = rhs.precision_value().unwrap_or(0);
let precision_value = left_precision_value + right_precision_value + 1;
let precision = Precision::new(precision_value).map_err(|_| {
ConversionError::DecimalRoundingError(format!(
"Required precision {} is beyond what we can support",
precision_value
))
})?;
let left_scale = lhs.scale().unwrap_or(0);
let right_scale = rhs.scale().unwrap_or(0);
let scale =
left_scale
.checked_add(right_scale)
.ok_or(ConversionError::DecimalScaleError(format!(
"Required scale {} is beyond what we can support",
left_scale as i16 + right_scale as i16
)))?;
Ok(ColumnType::Decimal75(precision, scale))
}
}

/// Add or subtract two columns together.
///
/// If the columns are not compatible for addition/subtraction, return an error.
Expand Down Expand Up @@ -80,6 +122,26 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>(
res
}

/// Multiply two columns together.
///
/// If the columns are not compatible for multiplication, return an error.
pub(crate) fn try_multiply_columns<'a, S: Scalar>(
lhs: Column<'a, S>,
rhs: Column<'a, S>,
alloc: &'a Bump,
) -> ConversionResult<&'a [S]> {
let lhs_len = lhs.len();
let rhs_len = rhs.len();
if lhs_len != rhs_len {
return Err(ConversionError::DifferentColumnLength(lhs_len, rhs_len));
}
let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len);
let lhs_scalar = lhs.to_scalar_with_scaling(0);
let rhs_scalar = rhs.to_scalar_with_scaling(0);
let res = alloc.alloc_slice_fill_with(lhs_len, |i| lhs_scalar[i] * rhs_scalar[i]);
Ok(res)
}

/// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions.
pub(crate) fn scale_and_add_subtract_eval<S: Scalar>(
lhs_eval: S,
Expand Down
38 changes: 36 additions & 2 deletions crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr,
ProvableExpr,
AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, MultiplyExpr,
NotExpr, OrExpr, ProvableExpr,
};
use crate::{
base::{
Expand Down Expand Up @@ -39,6 +39,8 @@ pub enum ProvableExprPlan<C: Commitment> {
Add(AddSubtractExpr<C>),
/// Provable numeric `-` expression
Subtract(AddSubtractExpr<C>),
/// Provable numeric `*` expression
Multiply(MultiplyExpr<C>),
}
impl<C: Commitment> ProvableExprPlan<C> {
/// Create column expression
Expand Down Expand Up @@ -156,6 +158,26 @@ impl<C: Commitment> ProvableExprPlan<C> {
}
}

/// Create a new multiply expression
pub fn try_new_multiply(
lhs: ProvableExprPlan<C>,
rhs: ProvableExprPlan<C>,
) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) {
Err(ConversionError::DataTypeMismatch(
lhs_datatype.to_string(),
rhs_datatype.to_string(),
))
} else {
Ok(Self::Multiply(MultiplyExpr::new(
Box::new(lhs),
Box::new(rhs),
)))
}
}

/// Check that the plan has the correct data type
fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> {
if self.data_type() == data_type {
Expand All @@ -181,6 +203,7 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Add(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Subtract(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Multiply(expr) => ProvableExpr::<C>::count(expr, builder),
}
}

Expand All @@ -189,6 +212,7 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Column(expr) => expr.data_type(),
ProvableExprPlan::Add(expr) => expr.data_type(),
ProvableExprPlan::Subtract(expr) => expr.data_type(),
ProvableExprPlan::Multiply(expr) => expr.data_type(),
ProvableExprPlan::Literal(expr) => ProvableExpr::<C>::data_type(expr),
ProvableExprPlan::And(_)
| ProvableExprPlan::Or(_)
Expand Down Expand Up @@ -232,6 +256,9 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Subtract(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
ProvableExprPlan::Multiply(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
}
}

Expand Down Expand Up @@ -269,6 +296,9 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Subtract(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
ProvableExprPlan::Multiply(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
}
}

Expand All @@ -289,6 +319,7 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Add(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Subtract(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Multiply(expr) => expr.verifier_evaluate(builder, accessor),
}
}

Expand All @@ -313,6 +344,9 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Subtract(expr) => {
ProvableExpr::<C>::get_column_references(expr, columns)
}
ProvableExprPlan::Multiply(expr) => {
ProvableExpr::<C>::get_column_references(expr, columns)
}
}
}
}
7 changes: 7 additions & 0 deletions crates/proof-of-sql/src/sql/ast/test_utility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ pub fn subtract<C: Commitment>(
ProvableExprPlan::try_new_subtract(left, right).unwrap()
}

pub fn multiply<C: Commitment>(
left: ProvableExprPlan<C>,
right: ProvableExprPlan<C>,
) -> ProvableExprPlan<C> {
ProvableExprPlan::try_new_multiply(left, right).unwrap()
}

pub fn const_bool<C: Commitment>(val: bool) -> ProvableExprPlan<C> {
ProvableExprPlan::new_literal(LiteralValue::Boolean(val))
}
Expand Down
Loading

0 comments on commit 1832530

Please sign in to comment.