Skip to content

Commit

Permalink
feat!: add MultiplyExpr (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner authored Jun 26, 2024
1 parent 88f3d40 commit cb50067
Show file tree
Hide file tree
Showing 12 changed files with 666 additions and 37 deletions.
17 changes: 17 additions & 0 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,23 @@ impl<'a, S: Scalar> Column<'a, S> {
}
}

/// Returns element at index as scalar
///
/// Note that if index is out of bounds, this function will return None
pub(crate) fn scalar_at(&self, index: usize) -> Option<S> {
(index < self.len()).then_some(match self {
Self::Boolean(col) => S::from(col[index]),
Self::SmallInt(col) => S::from(col[index]),
Self::Int(col) => S::from(col[index]),
Self::BigInt(col) => S::from(col[index]),
Self::Int128(col) => S::from(col[index]),
Self::Scalar(col) => col[index],
Self::Decimal75(_, _, col) => col[index],
Self::VarChar((_, scals)) => scals[index],
Self::TimestampTZ(_, _, col) => S::from(col[index]),
})
}

/// Convert a column to a vector of Scalar values with scaling
pub(crate) fn to_scalar_with_scaling(&self, scale: i8) -> Vec<S> {
let scale_factor = scale_scalar(S::ONE, scale).expect("Invalid scale factor");
Expand Down
5 changes: 5 additions & 0 deletions crates/proof-of-sql/src/base/math/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ pub enum DecimalError {
/// or non-positive aka InvalidPrecision
InvalidPrecision(String),

#[error("Decimal scale is not valid: {0}")]
/// Decimal scale is not valid. Here we use i16 in order to include
/// invalid scale values
InvalidScale(i16),

#[error("Unsupported operation: cannot round decimal: {0}")]
/// This error occurs when attempting to scale a
/// decimal in such a way that a loss of precision occurs.
Expand Down
8 changes: 7 additions & 1 deletion crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ pub(crate) use add_subtract_expr::AddSubtractExpr;
#[cfg(all(test, feature = "blitzar"))]
mod add_subtract_expr_test;

mod multiply_expr;
use multiply_expr::MultiplyExpr;
#[cfg(all(test, feature = "blitzar"))]
mod multiply_expr_test;

mod filter_expr;
pub(crate) use filter_expr::FilterExpr;
#[cfg(test)]
Expand Down Expand Up @@ -59,7 +64,8 @@ pub(crate) use comparison_util::scale_and_subtract;

mod numerical_util;
pub(crate) use numerical_util::{
add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types,
add_subtract_columns, multiply_columns, scale_and_add_subtract_eval,
try_add_subtract_column_types, try_multiply_column_types,
};

mod equals_expr;
Expand Down
116 changes: 116 additions & 0 deletions crates/proof-of-sql/src/sql/ast/multiply_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use super::{ProvableExpr, ProvableExprPlan};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::{
ast::{multiply_columns, try_multiply_column_types},
proof::{CountBuilder, ProofBuilder, SumcheckSubpolynomialType, VerificationBuilder},
},
};
use bumpalo::Bump;
use num_traits::One;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical * expression
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MultiplyExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
}

impl<C: Commitment> MultiplyExpr<C> {
/// Create numerical `*` expression
pub fn new(lhs: Box<ProvableExprPlan<C>>, rhs: Box<ProvableExprPlan<C>>) -> Self {
Self { lhs, rhs }
}
}

impl<C: Commitment> ProvableExpr<C> for MultiplyExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
builder.count_subpolynomials(1);
builder.count_intermediate_mles(1);
builder.count_degree(3);
Ok(())
}

fn data_type(&self) -> ColumnType {
try_multiply_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to multiply column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
let scalars = multiply_columns(&lhs_column, &rhs_column, alloc);
Column::Scalar(scalars)
}

#[tracing::instrument(
name = "proofs.sql.ast.and_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);

// lhs_times_rhs
let lhs_times_rhs: &'a [C::Scalar] = multiply_columns(&lhs_column, &rhs_column, alloc);
builder.produce_intermediate_mle(lhs_times_rhs);

// subpolynomial: lhs_times_rhs - lhs * rhs
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::Identity,
vec![
(C::Scalar::one(), vec![Box::new(lhs_times_rhs)]),
(
-C::Scalar::one(),
vec![Box::new(lhs_column), Box::new(rhs_column)],
),
],
);
Column::Scalar(lhs_times_rhs)
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs = self.rhs.verifier_evaluate(builder, accessor)?;

// lhs_times_rhs
let lhs_times_rhs = builder.consume_intermediate_mle();

// subpolynomial: lhs_times_rhs - lhs * rhs
let eval = builder.mle_evaluations.random_evaluation * (lhs_times_rhs - lhs * rhs);
builder.produce_sumcheck_subpolynomial_evaluation(&eval);

// selection
Ok(lhs_times_rhs)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
Loading

0 comments on commit cb50067

Please sign in to comment.