Skip to content

Commit

Permalink
feat: add AddSubtractExpr (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner authored Jun 21, 2024
1 parent 0e675aa commit 37f93f6
Show file tree
Hide file tree
Showing 15 changed files with 921 additions and 94 deletions.
44 changes: 41 additions & 3 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ pub enum ColumnType {
/// Mapped to String
#[serde(alias = "VARCHAR", alias = "varchar")]
VarChar,
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
/// Mapped to i256
#[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")]
Decimal75(Precision, i8),
/// Mapped to i64
#[serde(alias = "TIMESTAMP", alias = "timestamp")]
TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone),
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
}

impl ColumnType {
Expand All @@ -240,6 +240,44 @@ impl ColumnType {
)
}

/// Returns the number of bits in the integer type if it is an integer type. Otherwise, return None.
fn to_integer_bits(self) -> Option<usize> {
match self {
ColumnType::SmallInt => Some(16),
ColumnType::Int => Some(32),
ColumnType::BigInt => Some(64),
ColumnType::Int128 => Some(128),
_ => None,
}
}

/// Returns the ColumnType of the integer type with the given number of bits if it is a valid integer type.
///
/// Otherwise, return None.
fn from_integer_bits(bits: usize) -> Option<Self> {
match bits {
16 => Some(ColumnType::SmallInt),
32 => Some(ColumnType::Int),
64 => Some(ColumnType::BigInt),
128 => Some(ColumnType::Int128),
_ => None,
}
}

/// Returns the larger integer type of two ColumnTypes if they are both integers.
///
/// If either of the columns is not an integer, return None.
pub fn max_integer_type(&self, other: &Self) -> Option<Self> {
// If either of the columns is not an integer, return None
if !self.is_integer() || !other.is_integer() {
return None;
}
Self::from_integer_bits(std::cmp::max(
self.to_integer_bits().unwrap(),
other.to_integer_bits().unwrap(),
))
}

/// Returns the precision of a ColumnType if it is converted to a decimal wrapped in Some(). If it can not be converted to a decimal, return None.
pub fn precision_value(&self) -> Option<u8> {
match self {
Expand Down
113 changes: 113 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use super::{
add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types, ProvableExpr,
ProvableExprPlan,
};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical `+` / `-` expression
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AddSubtractExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
}

impl<C: Commitment> AddSubtractExpr<C> {
/// Create numerical `+` / `-` expression
pub fn new(
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
) -> Self {
Self {
lhs,
rhs,
is_subtract,
}
}
}

impl<C: Commitment> ProvableExpr<C> for AddSubtractExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
Ok(())
}

fn data_type(&self) -> ColumnType {
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to add/subtract column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
Column::Scalar(add_subtract_columns(
lhs_column,
rhs_column,
self.lhs.data_type().scale().unwrap_or(0),
self.rhs.data_type().scale().unwrap_or(0),
alloc,
self.is_subtract,
))
}

#[tracing::instrument(
name = "proofs.sql.ast.not_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
Column::Scalar(add_subtract_columns(
lhs_column,
rhs_column,
self.lhs.data_type().scale().unwrap_or(0),
self.rhs.data_type().scale().unwrap_or(0),
alloc,
self.is_subtract,
))
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?;
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res =
scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract);
Ok(res)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
Loading

0 comments on commit 37f93f6

Please sign in to comment.