-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
286 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
use super::{ | ||
scale_and_add_subtract_eval, try_add_subtract_column_types, try_add_subtract_columns, | ||
ProvableExpr, ProvableExprPlan, | ||
}; | ||
use crate::{ | ||
base::{ | ||
commitment::Commitment, | ||
database::{Column, ColumnRef, ColumnType, CommitmentAccessor, DataAccessor}, | ||
proof::ProofError, | ||
}, | ||
sql::proof::{CountBuilder, ProofBuilder, VerificationBuilder}, | ||
}; | ||
use bumpalo::Bump; | ||
use serde::{Deserialize, Serialize}; | ||
use std::collections::HashSet; | ||
|
||
/// Provable numerical + / - expression | ||
#[derive(Debug, PartialEq, Serialize, Deserialize)] | ||
pub struct AddSubtractExpr<C: Commitment> { | ||
lhs: Box<ProvableExprPlan<C>>, | ||
rhs: Box<ProvableExprPlan<C>>, | ||
is_subtract: bool, | ||
} | ||
|
||
impl<C: Commitment> AddSubtractExpr<C> { | ||
/// Create numerical + / - expression | ||
pub fn new( | ||
lhs: Box<ProvableExprPlan<C>>, | ||
rhs: Box<ProvableExprPlan<C>>, | ||
is_subtract: bool, | ||
) -> Self { | ||
Self { | ||
lhs, | ||
rhs, | ||
is_subtract, | ||
} | ||
} | ||
} | ||
|
||
impl<C: Commitment> ProvableExpr<C> for AddSubtractExpr<C> { | ||
fn count(&self, builder: &mut CountBuilder) -> Result<(), ProofError> { | ||
self.lhs.count(builder)?; | ||
self.rhs.count(builder)?; | ||
Ok(()) | ||
} | ||
|
||
fn data_type(&self) -> ColumnType { | ||
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type()) | ||
.expect("Failed to add/subtract column types") | ||
} | ||
|
||
fn result_evaluate<'a>( | ||
&self, | ||
table_length: usize, | ||
alloc: &'a Bump, | ||
accessor: &'a dyn DataAccessor<C::Scalar>, | ||
) -> Column<'a, C::Scalar> { | ||
let lhs_column: Column<'a, C::Scalar> = | ||
self.lhs.result_evaluate(table_length, alloc, accessor); | ||
let rhs_column: Column<'a, C::Scalar> = | ||
self.rhs.result_evaluate(table_length, alloc, accessor); | ||
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract) | ||
.expect("Failed to add/subtract columns") | ||
} | ||
|
||
#[tracing::instrument( | ||
name = "proofs.sql.ast.not_expr.prover_evaluate", | ||
level = "info", | ||
skip_all | ||
)] | ||
fn prover_evaluate<'a>( | ||
&self, | ||
builder: &mut ProofBuilder<'a, C::Scalar>, | ||
alloc: &'a Bump, | ||
accessor: &'a dyn DataAccessor<C::Scalar>, | ||
) -> Column<'a, C::Scalar> { | ||
let lhs_column: Column<'a, C::Scalar> = self.lhs.prover_evaluate(builder, alloc, accessor); | ||
let rhs_column: Column<'a, C::Scalar> = self.rhs.prover_evaluate(builder, alloc, accessor); | ||
try_add_subtract_columns(lhs_column, rhs_column, alloc, self.is_subtract) | ||
.expect("Failed to add/subtract columns") | ||
} | ||
|
||
fn verifier_evaluate( | ||
&self, | ||
builder: &mut VerificationBuilder<C>, | ||
accessor: &dyn CommitmentAccessor<C>, | ||
) -> Result<C::Scalar, ProofError> { | ||
let lhs_eval = self.lhs.verifier_evaluate(builder, accessor)?; | ||
let rhs_eval = self.rhs.verifier_evaluate(builder, accessor)?; | ||
let lhs_scale = self.lhs.data_type().scale().unwrap_or(0); | ||
let rhs_scale = self.rhs.data_type().scale().unwrap_or(0); | ||
let res = | ||
scale_and_add_subtract_eval(lhs_eval, rhs_eval, lhs_scale, rhs_scale, self.is_subtract) | ||
.expect("Failed to scale and add/subtract"); | ||
Ok(res) | ||
} | ||
|
||
fn get_column_references(&self, columns: &mut HashSet<ColumnRef>) { | ||
self.lhs.get_column_references(columns); | ||
self.rhs.get_column_references(columns); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
use crate::{ | ||
base::{ | ||
database::{Column, ColumnType}, | ||
math::decimal::{scale_scalar, Precision}, | ||
scalar::Scalar, | ||
}, | ||
sql::parse::{ConversionError, ConversionResult}, | ||
}; | ||
use bumpalo::Bump; | ||
|
||
// For decimal type manipulation please refer to | ||
// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 | ||
|
||
/// Determine the output type of an add or subtract operation if it is possible | ||
/// to add or subtract the two input types. If the types are not compatible, return | ||
/// an error. | ||
pub(crate) fn try_add_subtract_column_types( | ||
lhs: ColumnType, | ||
rhs: ColumnType, | ||
) -> ConversionResult<ColumnType> { | ||
if !lhs.is_numeric() || rhs.is_numeric() { | ||
return Err(ConversionError::DataTypeMismatch( | ||
lhs.to_string(), | ||
rhs.to_string(), | ||
)); | ||
} | ||
if lhs.is_integer() && rhs.is_integer() { | ||
return Ok(lhs.max(rhs)); | ||
} | ||
if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { | ||
return Ok(ColumnType::Scalar); | ||
} else { | ||
let left_precision_value = lhs.precision_value().unwrap_or(0) as i16; | ||
let right_precision_value = rhs.precision_value().unwrap_or(0) as i16; | ||
let left_scale = lhs.scale().unwrap_or(0); | ||
let right_scale = rhs.scale().unwrap_or(0); | ||
let scale = left_scale.max(right_scale); | ||
let precision_value: i16 = scale as i16 | ||
+ (left_precision_value - left_scale as i16) | ||
.max(right_precision_value - right_scale as i16) | ||
+ 1_i16; | ||
let precision = u8::try_from(precision_value) | ||
.map_err(|_| ConversionError::InvalidPrecision(precision_value)) | ||
.and_then(|p| { | ||
Precision::new(p).map_err(|_| ConversionError::InvalidPrecision(p as i16)) | ||
})?; | ||
Ok(ColumnType::Decimal75(precision, scale)) | ||
} | ||
} | ||
|
||
/// Add or subtract two columns together. | ||
/// | ||
/// If the columns are not compatible for addition/subtraction, return an error. | ||
pub(crate) fn try_add_subtract_columns<'a, S: Scalar>( | ||
lhs: Column<'a, S>, | ||
rhs: Column<'a, S>, | ||
alloc: &'a Bump, | ||
is_subtract: bool, | ||
) -> ConversionResult<Column<'a, S>> { | ||
let lhs_len = lhs.len(); | ||
let rhs_len = rhs.len(); | ||
if lhs_len != rhs_len { | ||
return Err(ConversionError::DifferentColumnLength(lhs_len, rhs_len)); | ||
} | ||
let _res: &mut [S] = alloc.alloc_slice_fill_default(lhs_len); | ||
let left_scale = lhs.column_type().scale().unwrap_or(0); | ||
let right_scale = rhs.column_type().scale().unwrap_or(0); | ||
let max_scale = left_scale.max(right_scale); | ||
let lhs_scalar = lhs.to_scalar_with_scaling(max_scale - left_scale); | ||
let rhs_scalar = rhs.to_scalar_with_scaling(max_scale - right_scale); | ||
let res = alloc.alloc_slice_fill_with(lhs_len, |i| { | ||
if is_subtract { | ||
lhs_scalar[i] - rhs_scalar[i] | ||
} else { | ||
lhs_scalar[i] + rhs_scalar[i] | ||
} | ||
}); | ||
Ok(Column::Scalar(res)) | ||
} | ||
|
||
/// The counterpart of `try_add_subtract_columns` for evaluating decimal expressions. | ||
pub(crate) fn scale_and_add_subtract_eval<S: Scalar>( | ||
lhs_eval: S, | ||
rhs_eval: S, | ||
lhs_scale: i8, | ||
rhs_scale: i8, | ||
is_subtract: bool, | ||
) -> ConversionResult<S> { | ||
let max_scale = lhs_scale.max(rhs_scale); | ||
let scaled_lhs_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)?; | ||
let scaled_rhs_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)?; | ||
if is_subtract { | ||
Ok(scaled_lhs_eval - scaled_rhs_eval) | ||
} else { | ||
Ok(scaled_lhs_eval + scaled_rhs_eval) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.