Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add AddSubtractExpr #14

Merged
merged 3 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ pub enum ColumnType {
/// Mapped to String
#[serde(alias = "VARCHAR", alias = "varchar")]
VarChar,
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
/// Mapped to i256
#[serde(rename = "Decimal75", alias = "DECIMAL75", alias = "decimal75")]
Decimal75(Precision, i8),
/// Mapped to i64
#[serde(alias = "TIMESTAMP", alias = "timestamp")]
TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone),
/// Mapped to Curve25519Scalar
#[serde(alias = "SCALAR", alias = "scalar")]
Scalar,
}

impl ColumnType {
Expand All @@ -240,6 +240,44 @@ impl ColumnType {
)
}

/// Returns the number of bits in the integer type if it is an integer type. Otherwise, return None.
fn to_integer_bits(self) -> Option<usize> {
match self {
ColumnType::SmallInt => Some(16),
ColumnType::Int => Some(32),
ColumnType::BigInt => Some(64),
ColumnType::Int128 => Some(128),
_ => None,
}
}

/// Returns the ColumnType of the integer type with the given number of bits if it is a valid integer type.
///
/// Otherwise, return None.
fn from_integer_bits(bits: usize) -> Option<Self> {
match bits {
16 => Some(ColumnType::SmallInt),
32 => Some(ColumnType::Int),
64 => Some(ColumnType::BigInt),
128 => Some(ColumnType::Int128),
_ => None,
}
}

/// Returns the larger integer type of two ColumnTypes if they are both integers.
///
/// If either of the columns is not an integer, return None.
pub fn max_integer_type(&self, other: &Self) -> Option<Self> {
// If either of the columns is not an integer, return None
if !self.is_integer() || !other.is_integer() {
return None;
}
Self::from_integer_bits(std::cmp::max(
self.to_integer_bits().unwrap(),
other.to_integer_bits().unwrap(),
))
}

/// Returns the precision of a ColumnType if it is converted to a decimal wrapped in Some(). If it can not be converted to a decimal, return None.
pub fn precision_value(&self) -> Option<u8> {
match self {
Expand Down
113 changes: 113 additions & 0 deletions crates/proof-of-sql/src/sql/ast/add_subtract_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
use super::{
add_subtract_columns, scale_and_add_subtract_eval, try_add_subtract_column_types, ProvableExpr,
ProvableExprPlan,
};
use crate::{
base::{
commitment::Commitment,
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor},
proof::ProofError,
},
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder},
};
use bumpalo::Bump;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

/// Provable numerical `+` / `-` expression
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AddSubtractExpr<C: Commitment> {
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
}

impl<C: Commitment> AddSubtractExpr<C> {
/// Create numerical `+` / `-` expression
pub fn new(
lhs: Box<ProvableExprPlan<C>>,
rhs: Box<ProvableExprPlan<C>>,
is_subtract: bool,
) -> Self {
Self {
lhs,
rhs,
is_subtract,
}
}
}

impl<C: Commitment> ProvableExpr<C> for AddSubtractExpr<C> {
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> {
self.lhs.count(builder)?;
self.rhs.count(builder)?;
Ok(())
}

fn data_type(&self) -> ColumnType {
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to add/subtract column types")
}

fn result_evaluate<'a>(
&self,
table_length: usize,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> =
self.lhs.result_evaluate(table_length, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> =
self.rhs.result_evaluate(table_length, alloc, accessor);
Column::Scalar(add_subtract_columns(
lhs_column,
rhs_column,
self.lhs.data_type().scale().unwrap_or(0),
self.rhs.data_type().scale().unwrap_or(0),
alloc,
self.is_subtract,
))
}

#[tracing::instrument(
name = "proofs.sql.ast.not_expr.prover_evaluate",
level = "info",
skip_all
)]
fn prover_evaluate<'a>(
&self,
builder: &mut ProofBuilder<'a, C::Scalar>,
alloc: &'a Bump,
accessor: &'a dyn DataAccessor<C::Scalar>,
) -> Column<'a, C::Scalar> {
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor);
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor);
Column::Scalar(add_subtract_columns(
lhs_column,
rhs_column,
self.lhs.data_type().scale().unwrap_or(0),
self.rhs.data_type().scale().unwrap_or(0),
alloc,
self.is_subtract,
))
}

fn verifier_evaluate(
&self,
builder: &mut VerificationBuilder<C>,
accessor: &dyn CommitmentAccessor<C>,
) -> Result<C::Scalar, ProofError> {
let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?;
let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?;
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0);
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0);
let res =
scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract);
Ok(res)
}

fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) {
self.lhs.get_column_references(columns);
self.rhs.get_column_references(columns);
}
}
Loading
Loading