-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
343 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
use super::{ProvableExpr, ProvableExprPlan}; | ||
use crate::{ | ||
base::{ | ||
commitment::Commitment, | ||
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, | ||
proof::ProofError, | ||
}, | ||
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder}, | ||
}; | ||
use bumpalo::Bump; | ||
use serde::{Deserialize, Serialize}; | ||
use std::collections::HashSet; | ||
|
||
/// Provable numerical + expression | ||
#[derive(Debug, PartialEq, Serialize, Deserialize)] | ||
pub struct AddExpr<C: Commitment> { | ||
lhs: Box<ProvableExprPlan<C>>, | ||
rhs: Box<ProvableExprPlan<C>>, | ||
} | ||
|
||
impl<C: Commitment> AddExpr<C> { | ||
/// Create numerical + expression | ||
pub fn new(lhs: Box<ProvableExprPlan<C>>, rhs: Box<ProvableExprPlan<C>>) -> Self { | ||
Self { lhs, rhs } | ||
} | ||
} | ||
|
||
impl<C: Commitment> ProvableExpr<C> for AddExpr<C> { | ||
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { | ||
self.left.count(builder)?; | ||
self.right.count(builder)?; | ||
Ok(()) | ||
} | ||
|
||
fn data_type(&self) -> ColumnType { | ||
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type()) | ||
.expect("Failed to add column types") | ||
} | ||
|
||
fn result_evaluate<'a>( | ||
&self, | ||
table_length: usize, | ||
alloc: &'a Bump, | ||
accessor: &'a dyn DataAccessor<C::Scalar>, | ||
) -> Column<'a, C::Scalar> { | ||
let lhs_column: Column<'a, C::Scalar> = | ||
self.lhs.result_evaluate(table_length, alloc, accessor); | ||
let rhs_column: Column<'a, C::Scalar> = | ||
self.rhs.result_evaluate(table_length, alloc, accessor); | ||
let scalars = try_add_subtract_columns(lhs_column, rhs_column, alloc, false) | ||
.expect("Failed to add columns"); | ||
Column::Scalar(scalars) | ||
} | ||
|
||
#[tracing::instrument( | ||
name = "proofs.sql.ast.not_expr.prover_evaluate", | ||
level = "info", | ||
skip_all | ||
)] | ||
fn prover_evaluate<'a>( | ||
&self, | ||
builder: &mut ProofBuilder<'a, C::Scalar>, | ||
alloc: &'a Bump, | ||
accessor: &'a dyn DataAccessor<C::Scalar>, | ||
) -> Column<'a, C::Scalar> { | ||
let lhs_column: Column<'a, C::Scalar> = | ||
self.lhs.result_evaluate(table_length, alloc, accessor); | ||
let rhs_column: Column<'a, C::Scalar> = | ||
self.rhs.result_evaluate(table_length, alloc, accessor); | ||
let scalars = try_add_subtract_columns(lhs_column, rhs_column, alloc, false) | ||
.expect("Failed to add columns"); | ||
Column::Scalar(scalars) | ||
} | ||
|
||
fn verifier_evaluate( | ||
&self, | ||
builder: &mut VerificationBuilder<C>, | ||
accessor: &dyn CommitmentAccessor<C>, | ||
) -> Result<C::Scalar, ProofError> { | ||
let lhs_column_eval = self.lhs.verifier_evaluate(builder, accessor)?; | ||
let rhs_column_eval = self.rhs.verifier_evaluate(builder, accessor)?; | ||
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); | ||
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); | ||
let res = scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, false) | ||
.expect("Failed to scale and add/subtract"); | ||
Ok(res) | ||
} | ||
|
||
fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) { | ||
self.lhs.get_column_references(columns); | ||
self.rhs.get_column_references(columns); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
use crate::{ | ||
base::{ | ||
database::{Column, ColumnType}, | ||
math::decimal::Precision, | ||
scalar::Scalar, | ||
}, | ||
sql::parse::{ConversionError, ConversionResult}, | ||
}; | ||
use bumpalo::Bump; | ||
use rayon::iter::{ | ||
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, | ||
}; | ||
// For decimal type manipulation please refer to | ||
// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 | ||
|
||
/// Determine the output type of an add or subtract operation if it is possible | ||
/// to add or subtract the two input types. If the types are not compatible, return | ||
/// an error. | ||
pub(crate) fn try_add_subtract_column_types( | ||
lhs: ColumnType, | ||
rhs: ColumnType, | ||
) -> ConversionResult<ColumnType> { | ||
if !lhs.is_numeric() || rhs.is_numeric() { | ||
return Err(ConversionError::DataTypeMismatch( | ||
lhs.to_string(), | ||
rhs.to_string(), | ||
)); | ||
} | ||
if lhs.is_integer() && rhs.is_integer() { | ||
return Ok(lhs.max(rhs)); | ||
} | ||
if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { | ||
return Ok(ColumnType::Scalar); | ||
} else { | ||
let left_precision_value = lhs.precision_value().unwrap_or(0); | ||
let right_precision_value = rhs.precision_value().unwrap_or(0); | ||
let left_scale = lhs.scale().unwrap_or(0); | ||
let right_scale = rhs.scale().unwrap_or(0); | ||
let scale = left_scale.max(right_scale); | ||
let precision_value = scale + (left_precision_value - left_scale).max(right_precision_value - right_scale) + 1; | ||
let precision = Precision::new(precision_value).map_err(|_| ConversionError::InvalidPrecision(precision_value))?; | ||
Ok(ColumnType::Decimal75(precision, scale)) | ||
} | ||
} | ||
|
||
/// Add or subtract two columns together. | ||
/// | ||
/// If the columns are not compatible for addition/subtraction, return an error. | ||
pub(crate) fn try_add_subtract_columns<'a, S: Scalar>( | ||
lhs: Column<'a, S>, | ||
rhs: Column<'a, S>, | ||
alloc: &'a Bump, | ||
is_subtract: bool, | ||
) -> ConversionResult<Column<'a, S>> { | ||
let lhs_len = lhs.len(); | ||
let rhs_len = rhs.len(); | ||
if lhs_len != rhs_len { | ||
return Err(ConversionError::DifferentColumnLength(lhs_len, rhs_len)); | ||
} | ||
let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len); | ||
let left_scale = lhs.scale().unwrap_or(0); | ||
let right_scale = rhs.scale().unwrap_or(0); | ||
let max_scale = left_scale.max(right_scale); | ||
let lhs_scalar = lhs.to_scalar_with_scaling(max_scale - left_scale); | ||
let rhs_scalar = rhs.to_scalar_with_scaling(max_scale - right_scale); | ||
let res = alloc.alloc_slice_fill_with(lhs_len, |i| { | ||
if is_subtract { | ||
lhs_scalar[i] - rhs_scalar[i] | ||
} else { | ||
lhs_scalar[i] + rhs_scalar[i] | ||
} | ||
}); | ||
Ok(res) | ||
} | ||
|
||
/// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions. | ||
pub(crate) fn scale_and_add_subtract_eval<S: Scalar>( | ||
lhs_eval: S, | ||
rhs_eval: S, | ||
lhs_scale: i8, | ||
rhs_scale: i8, | ||
is_subtract: bool, | ||
) -> ConversionResult<S> { | ||
let max_scale = lhs_scale.max(rhs_scale); | ||
let scaled_lhs_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)?; | ||
let scaled_rhs_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)?; | ||
if is_subtract { | ||
Ok(scaled_lhs_eval - scaled_rhs_eval) | ||
} else { | ||
Ok(scaled_lhs_eval + scaled_rhs_eval) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.