Skip to content

Commit

Permalink
feat: add AddExpr & SubExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Jun 14, 2024
1 parent 1ed302b commit d439694
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 10 deletions.
8 changes: 4 additions & 4 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub const INT128_SCALE: usize = 0;
///
/// See `<https://ignite.apache.org/docs/latest/sql-reference/data-types>` for
/// a description of the native types used by Apache Ignite.
#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Deserialize, Copy)]
#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Hash, Serialize, Deserialize, Copy)]
pub enum ColumnType {
/// Mapped to bool
#[serde(alias = "BOOLEAN", alias = "boolean")]
Expand All @@ -188,12 +188,12 @@ pub enum ColumnType {
/// Mapped to String
#[serde(alias = "VARCHAR", alias = "varchar")]
VarChar,
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
/// Mapped to i256
#[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")]
Decimal75(Precision, i8),
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
}

impl ColumnType {
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/math/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
use proof_of_sql_parser::intermediate_decimal::IntermediateDecimal;
use serde::{Deserialize, Deserializer, Serialize};

#[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)]
#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Hash, Serialize, Copy)]
/// limit-enforced precision
pub struct Precision(u8);
pub(crate) const MAX_SUPPORTED_PRECISION: u8 = 75;
Expand Down
102 changes: 102 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use super::{
scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns,
ProvableExpr, ProvableExprPlan,
};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical + / - expression
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct AddSubtractExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
}

impl<C: Commitment> AddSubtractExpr<C> {
/// Create numerical + / - expression
pub fn new(
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
) -> Self {
Self {
lhs,
rhs,
is_subtract,
}
}
}

impl<C: Commitment> ProvableExpr<C> for AddSubtractExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
Ok(())
}

fn data_type(&self) -> ColumnType {
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to add/subtract column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract)
.expect("Failed to add/subtract columns")
}

#[tracing::instrument(
name = "proofs.sql.ast.not_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract)
.expect("Failed to add/subtract columns")
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?;
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res =
scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract)
.expect("Failed to scale and add/subtract");
Ok(res)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/ast/comparison_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value)
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value))?;
.map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?;
}
unchecked_subtract_impl(
alloc,
Expand Down
8 changes: 8 additions & 0 deletions crates/proof-of-sql/src/sql/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
mod filter_result_expr;
pub(crate) use filter_result_expr::FilterResultExpr;

mod add_subtract_expr;
pub(crate) use add_subtract_expr::AddSubtractExpr;

mod filter_expr;
pub(crate) use filter_expr::FilterExpr;
#[cfg(test)]
Expand Down Expand Up @@ -52,6 +55,11 @@ mod not_expr_test;
mod comparison_util;
pub(crate) use comparison_util::{scale_and_subtract, scale_and_subtract_eval};

mod numerical_util;
pub(crate) use numerical_util::{
scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns,
};

mod equals_expr;
use equals_expr::*;
#[cfg(all(test, feature = "blitzar"))]
Expand Down
97 changes: 97 additions & 0 deletions crates/proof-of-sql/src/sql/ast/numerical_util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use crate::{
base::{
database::{Column, ColumnType},
math::decimal::{scale_scalar, Precision},
scalar::Scalar,
},
sql::parse::{ConversionError, ConversionResult},
};
use bumpalo::Bump;

// For decimal type manipulation please refer to
// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16

/// Determine the output type of an add or subtract operation if it is possible
/// to add or subtract the two input types. If the types are not compatible, return
/// an error.
pub(crate) fn try_add_subtract_column_types(
lhs: ColumnType,
rhs: ColumnType,
) -> ConversionResult<ColumnType> {
if !lhs.is_numeric() || rhs.is_numeric() {
return Err(ConversionError::DataTypeMismatch(
lhs.to_string(),
rhs.to_string(),
));
}
if lhs.is_integer() && rhs.is_integer() {
return Ok(lhs.max(rhs));
}
if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar {
return Ok(ColumnType::Scalar);
} else {
let left_precision_value = lhs.precision_value().unwrap_or(0) as i16;
let right_precision_value = rhs.precision_value().unwrap_or(0) as i16;
let left_scale = lhs.scale().unwrap_or(0);
let right_scale = rhs.scale().unwrap_or(0);
let scale = left_scale.max(right_scale);
let precision_value: i16 = scale as i16
+ (left_precision_value - left_scale as i16)
.max(right_precision_value - right_scale as i16)
+ 1_i16;
let precision = u8::try_from(precision_value)
.map_err(|_| ConversionError::InvalidPrecision(precision_value))
.and_then(|p| {
Precision::new(p).map_err(|_| ConversionError::InvalidPrecision(p as i16))
})?;
Ok(ColumnType::Decimal75(precision, scale))
}
}

/// Add or subtract two columns together.
///
/// If the columns are not compatible for addition/subtraction, return an error.
pub(crate) fn try_add_subtract_columns<'a, S: Scalar>(
lhs: Column<'a, S>,
rhs: Column<'a, S>,
alloc: &'a Bump,
is_subtract: bool,
) -> ConversionResult<Column<'a, S>> {
let lhs_len = lhs.len();
let rhs_len = rhs.len();
if lhs_len != rhs_len {
return Err(ConversionError::DifferentColumnLength(lhs_len, rhs_len));
}
let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len);
let left_scale = lhs.column_type().scale().unwrap_or(0);
let right_scale = rhs.column_type().scale().unwrap_or(0);
let max_scale = left_scale.max(right_scale);
let lhs_scalar = lhs.to_scalar_with_scaling(max_scale - left_scale);
let rhs_scalar = rhs.to_scalar_with_scaling(max_scale - right_scale);
let res = alloc.alloc_slice_fill_with(lhs_len, |i| {
if is_subtract {
lhs_scalar[i] - rhs_scalar[i]
} else {
lhs_scalar[i] + rhs_scalar[i]
}
});
Ok(Column::Scalar(res))
}

/// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions.
pub(crate) fn scale_and_add_subtract_eval<S: Scalar>(
lhs_eval: S,
rhs_eval: S,
lhs_scale: i8,
rhs_scale: i8,
is_subtract: bool,
) -> ConversionResult<S> {
let max_scale = lhs_scale.max(rhs_scale);
let scaled_lhs_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)?;
let scaled_rhs_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)?;
if is_subtract {
Ok(scaled_lhs_eval - scaled_rhs_eval)
} else {
Ok(scaled_lhs_eval + scaled_rhs_eval)
}
}
71 changes: 70 additions & 1 deletion crates/proof-of-sql/src/sql/ast/provable_expr_plan.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{
AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr, ProvableExpr,
AddSubtractExpr, AndExpr, ColumnExpr, EqualsExpr, InequalityExpr, LiteralExpr, NotExpr, OrExpr,
ProvableExpr,
};
use crate::{
base::{
Expand Down Expand Up @@ -34,6 +35,10 @@ pub enum ProvableExprPlan<C: Commitment> {
Equals(EqualsExpr<C>),
/// Provable AST expression for an inequality expression
Inequality(InequalityExpr<C>),
/// Provable numeric + expression
Add(AddSubtractExpr<C>),
/// Provable numeric - expression
Subtract(AddSubtractExpr<C>),
}
impl<C: Commitment> ProvableExprPlan<C> {
/// Create column expression
Expand Down Expand Up @@ -109,6 +114,48 @@ impl<C: Commitment> ProvableExprPlan<C> {
}
}

/// Create a new add expression
pub fn try_new_add(
lhs: ProvableExprPlan<C>,
rhs: ProvableExprPlan<C>,
) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) {
Err(ConversionError::DataTypeMismatch(
lhs_datatype.to_string(),
rhs_datatype.to_string(),
))
} else {
Ok(Self::Add(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
false,
)))
}
}

/// Create a new subtract expression
pub fn try_new_subtract(
lhs: ProvableExprPlan<C>,
rhs: ProvableExprPlan<C>,
) -> ConversionResult<Self> {
let lhs_datatype = lhs.data_type();
let rhs_datatype = rhs.data_type();
if !type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) {
Err(ConversionError::DataTypeMismatch(
lhs_datatype.to_string(),
rhs_datatype.to_string(),
))
} else {
Ok(Self::Subtract(AddSubtractExpr::new(
Box::new(lhs),
Box::new(rhs),
true,
)))
}
}

/// Check that the plan has the correct data type
fn check_data_type(&self, data_type: ColumnType) -> ConversionResult<()> {
if self.data_type() == data_type {
Expand All @@ -132,12 +179,16 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Literal(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Equals(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Inequality(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Add(expr) => ProvableExpr::<C>::count(expr, builder),
ProvableExprPlan::Subtract(expr) => ProvableExpr::<C>::count(expr, builder),
}
}

fn data_type(&self) -> ColumnType {
match self {
ProvableExprPlan::Column(expr) => expr.data_type(),
ProvableExprPlan::Add(expr) => expr.data_type(),
ProvableExprPlan::Subtract(expr) => expr.data_type(),
ProvableExprPlan::Literal(expr) => ProvableExpr::<C>::data_type(expr),
ProvableExprPlan::And(_)
| ProvableExprPlan::Or(_)
Expand Down Expand Up @@ -175,6 +226,12 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
ProvableExprPlan::Add(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
ProvableExprPlan::Subtract(expr) => {
ProvableExpr::<C>::result_evaluate(expr, table_length, alloc, accessor)
}
}
}

Expand Down Expand Up @@ -206,6 +263,12 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
ProvableExprPlan::Add(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
ProvableExprPlan::Subtract(expr) => {
ProvableExpr::<C>::prover_evaluate(expr, builder, alloc, accessor)
}
}
}

Expand All @@ -224,6 +287,8 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Literal(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Equals(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Inequality(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Add(expr) => expr.verifier_evaluate(builder, accessor),
ProvableExprPlan::Subtract(expr) => expr.verifier_evaluate(builder, accessor),
}
}

Expand All @@ -244,6 +309,10 @@ impl<C: Commitment> ProvableExpr<C> for ProvableExprPlan<C> {
ProvableExprPlan::Inequality(expr) => {
ProvableExpr::<C>::get_column_references(expr, columns)
}
ProvableExprPlan::Add(expr) => ProvableExpr::<C>::get_column_references(expr, columns),
ProvableExprPlan::Subtract(expr) => {
ProvableExpr::<C>::get_column_references(expr, columns)
}
}
}
}
Loading

0 comments on commit d439694

Please sign in to comment.