diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index c25445340..985f6bc17 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -1,15 +1,15 @@ use crate::{ base::{ - database::Column, + database::{Column, ColumnarValue, LiteralValue}, if_rayon, math::decimal::{DecimalError, Precision}, - scalar::Scalar, + scalar::{Scalar, ScalarExt}, }, sql::parse::{type_check_binary_operation, ConversionError, ConversionResult}, }; use alloc::string::ToString; use bumpalo::Bump; -use core::cmp; +use core::cmp::{max, Ordering}; use proof_of_sql_parser::intermediate_ast::BinaryOperator; #[cfg(feature = "rayon")] use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; @@ -31,6 +31,70 @@ fn unchecked_subtract_impl<'a, S: Scalar>( Ok(result) } +#[allow( + clippy::missing_panics_doc, + reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs" +)] +/// Scale LHS and RHS to the same scale if at least one of them is decimal +/// and take the difference. This function is used for comparisons. +#[allow(clippy::cast_sign_loss)] +pub fn scale_and_subtract_literal( + lhs: LiteralValue, + rhs: LiteralValue, + lhs_scale: i8, + rhs_scale: i8, + is_equal: bool, +) -> ConversionResult { + let lhs_type = lhs.column_type(); + let rhs_type = rhs.column_type(); + let operator = if is_equal { + BinaryOperator::Equal + } else { + BinaryOperator::LessThanOrEqual + }; + if !type_check_binary_operation(&lhs_type, &rhs_type, operator) { + return Err(ConversionError::DataTypeMismatch { + left_type: lhs_type.to_string(), + right_type: rhs_type.to_string(), + }); + } + let max_scale = max(lhs_scale, rhs_scale); + let lhs_upscale = max_scale - lhs_scale; + let rhs_upscale = max_scale - rhs_scale; + // Only check precision overflow issues if at least one side is decimal + if max_scale != 0 { + let lhs_precision_value = lhs_type + .precision_value() + .expect("If scale is set, precision must be set"); + let rhs_precision_value = rhs_type + .precision_value() + .expect("If scale is set, precision must be set"); + let max_precision_value = max( + lhs_precision_value + (max_scale - lhs_scale) as u8, + rhs_precision_value + (max_scale - rhs_scale) as u8, + ); + // Check if the precision is valid + let _max_precision = Precision::new(max_precision_value).map_err(|_| { + ConversionError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: max_precision_value.to_string(), + }, + } + })?; + } + match lhs_scale.cmp(&rhs_scale) { + Ordering::Less => { + let upscale_factor = S::pow10(rhs_upscale as u8); + Ok(lhs.to_scalar() * upscale_factor - rhs.to_scalar()) + } + Ordering::Equal => Ok(lhs.to_scalar() - rhs.to_scalar()), + Ordering::Greater => { + let upscale_factor = S::pow10(lhs_upscale as u8); + Ok(lhs.to_scalar() - rhs.to_scalar() * upscale_factor) + } + } +} + #[allow( clippy::missing_panics_doc, reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs" @@ -67,7 +131,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( right_type: rhs_type.to_string(), }); } - let max_scale = cmp::max(lhs_scale, rhs_scale); + let max_scale = max(lhs_scale, rhs_scale); let lhs_upscale = max_scale - lhs_scale; let rhs_upscale = max_scale - rhs_scale; // Only check precision overflow issues if at least one side is decimal @@ -78,7 +142,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( let rhs_precision_value = rhs_type .precision_value() .expect("If scale is set, precision must be set"); - let max_precision_value = cmp::max( + let max_precision_value = max( lhs_precision_value + (max_scale - lhs_scale) as u8, rhs_precision_value + (max_scale - rhs_scale) as u8, ); @@ -98,3 +162,49 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( lhs_len, ) } + +#[allow(clippy::cast_sign_loss)] +#[allow(dead_code)] +/// Scale LHS and RHS to the same scale if at least one of them is decimal +/// and take the difference. This function is used for comparisons. +pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>( + alloc: &'a Bump, + lhs: ColumnarValue<'a, S>, + rhs: ColumnarValue<'a, S>, + lhs_scale: i8, + rhs_scale: i8, + is_equal: bool, +) -> ConversionResult> { + match (lhs, rhs) { + (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { + Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( + alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal, + )?))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { + Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( + alloc, + Column::from_literal_with_length(&lhs, rhs.len(), alloc), + rhs, + lhs_scale, + rhs_scale, + is_equal, + )?))) + } + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { + Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( + alloc, + lhs, + Column::from_literal_with_length(&rhs, lhs.len(), alloc), + lhs_scale, + rhs_scale, + is_equal, + )?))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { + Ok(ColumnarValue::Literal(LiteralValue::Scalar( + scale_and_subtract_literal(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?, + ))) + } + } +} diff --git a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs index eacc03142..8d584d0a1 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs @@ -1,8 +1,36 @@ use crate::base::{ - database::Column, + database::{Column, ColumnarValue, LiteralValue}, scalar::{Scalar, ScalarExt}, }; use bumpalo::Bump; +use core::cmp::Ordering; + +#[allow(clippy::cast_sign_loss)] +/// Add or subtract two literals together. +pub(crate) fn add_subtract_literals( + lhs: &LiteralValue, + rhs: &LiteralValue, + lhs_scale: i8, + rhs_scale: i8, + is_subtract: bool, +) -> S { + let (lhs_scaled, rhs_scaled) = match lhs_scale.cmp(&rhs_scale) { + Ordering::Less => { + let scaling_factor = S::pow10((rhs_scale - lhs_scale) as u8); + (lhs.to_scalar() * scaling_factor, rhs.to_scalar()) + } + Ordering::Equal => (lhs.to_scalar(), rhs.to_scalar()), + Ordering::Greater => { + let scaling_factor = S::pow10((lhs_scale - rhs_scale) as u8); + (lhs.to_scalar(), rhs.to_scalar() * scaling_factor) + } + }; + if is_subtract { + lhs_scaled - rhs_scaled + } else { + lhs_scaled + rhs_scaled + } +} #[allow( clippy::missing_panics_doc, @@ -36,9 +64,62 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>( result } +/// Add or subtract two [`ColumnarValues`] together. +#[allow(dead_code)] +pub(crate) fn add_subtract_columnar_values<'a, S: Scalar>( + lhs: ColumnarValue<'a, S>, + rhs: ColumnarValue<'a, S>, + lhs_scale: i8, + rhs_scale: i8, + alloc: &'a Bump, + is_subtract: bool, +) -> ColumnarValue<'a, S> { + match (lhs, rhs) { + (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { + ColumnarValue::Column(Column::Scalar(add_subtract_columns( + lhs, + rhs, + lhs_scale, + rhs_scale, + alloc, + is_subtract, + ))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { + ColumnarValue::Column(Column::Scalar(add_subtract_columns( + Column::from_literal_with_length(&lhs, rhs.len(), alloc), + rhs, + lhs_scale, + rhs_scale, + alloc, + is_subtract, + ))) + } + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { + ColumnarValue::Column(Column::Scalar(add_subtract_columns( + lhs, + Column::from_literal_with_length(&rhs, lhs.len(), alloc), + lhs_scale, + rhs_scale, + alloc, + is_subtract, + ))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { + ColumnarValue::Literal(LiteralValue::Scalar(add_subtract_literals( + &lhs, + &rhs, + lhs_scale, + rhs_scale, + is_subtract, + ))) + } + } +} + /// Multiply two columns together. /// # Panics -/// Panics if: The lengths of `lhs` and `rhs` are not equal.`lhs.scalar_at(i)` or `rhs.scalar_at(i)` returns `None`, which occurs if the column does not have, a scalar at the given index `i`. +/// Panics if: `lhs` and `rhs` are not of the same length. pub(crate) fn multiply_columns<'a, S: Scalar>( lhs: &Column<'a, S>, rhs: &Column<'a, S>, @@ -55,6 +136,38 @@ pub(crate) fn multiply_columns<'a, S: Scalar>( }) } +#[allow(dead_code)] +/// Multiply two [`ColumnarValues`] together. +/// # Panics +/// Panics if: `lhs` and `rhs` are not of the same length. +pub(crate) fn multiply_columnar_values<'a, S: Scalar>( + lhs: &ColumnarValue<'a, S>, + rhs: &ColumnarValue<'a, S>, + alloc: &'a Bump, +) -> ColumnarValue<'a, S> { + match (lhs, rhs) { + (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { + ColumnarValue::Column(Column::Scalar(multiply_columns(lhs, rhs, alloc))) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { + let lhs_scalar = lhs.to_scalar(); + let result = + alloc.alloc_slice_fill_with(rhs.len(), |i| lhs_scalar * rhs.scalar_at(i).unwrap()); + ColumnarValue::Column(Column::Scalar(result)) + } + (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { + let rhs_scalar = rhs.to_scalar(); + let result = + alloc.alloc_slice_fill_with(lhs.len(), |i| lhs.scalar_at(i).unwrap() * rhs_scalar); + ColumnarValue::Column(Column::Scalar(result)) + } + (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { + let result = lhs.to_scalar() * rhs.to_scalar(); + ColumnarValue::Literal(LiteralValue::Scalar(result)) + } + } +} + #[allow( clippy::missing_panics_doc, reason = "scaling factor is guaranteed to not be negative based on input validation prior to calling this function"