Skip to content

Commit

Permalink
feat: add ColumnarValue operations
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Nov 1, 2024
1 parent 7598b5c commit d32a469
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 7 deletions.
120 changes: 115 additions & 5 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::{
base::{
database::Column,
database::{Column, ColumnarValue, LiteralValue},
if_rayon,
math::decimal::{DecimalError, Precision},
scalar::Scalar,
scalar::{Scalar, ScalarExt},
},
sql::parse::{type_check_binary_operation, ConversionError, ConversionResult},
};
use alloc::string::ToString;
use bumpalo::Bump;
use core::cmp;
use core::cmp::{max, Ordering};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
Expand All @@ -31,6 +31,70 @@ fn unchecked_subtract_impl<'a, S: Scalar>(
Ok(result)
}

#[allow(
clippy::missing_panics_doc,
reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs"
)]
/// Scale LHS and RHS to the same scale if at least one of them is decimal
/// and take the difference. This function is used for comparisons.
#[allow(clippy::cast_sign_loss)]
pub fn scale_and_subtract_literal<S: Scalar>(
lhs: &LiteralValue<S>,
rhs: &LiteralValue<S>,
lhs_scale: i8,
rhs_scale: i8,
is_equal: bool,
) -> ConversionResult<S> {
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
} else {
BinaryOperator::LessThanOrEqual
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
});
}
let max_scale = max(lhs_scale, rhs_scale);
let lhs_upscale = max_scale - lhs_scale;
let rhs_upscale = max_scale - rhs_scale;
// Only check precision overflow issues if at least one side is decimal
if max_scale != 0 {
let lhs_precision_value = lhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let rhs_precision_value = rhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let max_precision_value = max(
lhs_precision_value + (max_scale - lhs_scale) as u8,
rhs_precision_value + (max_scale - rhs_scale) as u8,
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value).map_err(|_| {
ConversionError::DecimalConversionError {
source: DecimalError::InvalidPrecision {
error: max_precision_value.to_string(),
},
}
})?;
}
match lhs_scale.cmp(&rhs_scale) {
Ordering::Less => {
let upscale_factor = S::pow10(rhs_upscale as u8);
Ok(lhs.to_scalar() * upscale_factor - rhs.to_scalar())
}
Ordering::Equal => Ok(lhs.to_scalar() - rhs.to_scalar()),
Ordering::Greater => {
let upscale_factor = S::pow10(lhs_upscale as u8);
Ok(lhs.to_scalar() - rhs.to_scalar() * upscale_factor)
}
}
}

#[allow(
clippy::missing_panics_doc,
reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs"
Expand Down Expand Up @@ -67,7 +131,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
right_type: rhs_type.to_string(),
});
}
let max_scale = cmp::max(lhs_scale, rhs_scale);
let max_scale = max(lhs_scale, rhs_scale);
let lhs_upscale = max_scale - lhs_scale;
let rhs_upscale = max_scale - rhs_scale;
// Only check precision overflow issues if at least one side is decimal
Expand All @@ -78,7 +142,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
let rhs_precision_value = rhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let max_precision_value = cmp::max(
let max_precision_value = max(
lhs_precision_value + (max_scale - lhs_scale) as u8,
rhs_precision_value + (max_scale - rhs_scale) as u8,
);
Expand All @@ -98,3 +162,49 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
lhs_len,
)
}

#[allow(clippy::cast_sign_loss)]
#[allow(dead_code)]
/// Scale LHS and RHS to the same scale if at least one of them is decimal
/// and take the difference. This function is used for comparisons.
pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>(
alloc: &'a Bump,
lhs: ColumnarValue<'a, S>,
rhs: ColumnarValue<'a, S>,
lhs_scale: i8,
rhs_scale: i8,
is_equal: bool,
) -> ConversionResult<ColumnarValue<'a, S>> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal,
)?)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc,
Column::from_literal_with_length(&lhs, rhs.len(), alloc),
rhs,
lhs_scale,
rhs_scale,
is_equal,
)?)))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc,
lhs,
Column::from_literal_with_length(&rhs, lhs.len(), alloc),
lhs_scale,
rhs_scale,
is_equal,
)?)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
Ok(ColumnarValue::Literal(LiteralValue::Scalar(
scale_and_subtract_literal(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?,
)))
}
}
}
117 changes: 115 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
use crate::base::{
database::Column,
database::{Column, ColumnarValue, LiteralValue},
scalar::{Scalar, ScalarExt},
};
use bumpalo::Bump;
use core::cmp::Ordering;

#[allow(clippy::cast_sign_loss)]
/// Add or subtract two literals together.
pub(crate) fn add_subtract_literals<S: Scalar>(
lhs: &LiteralValue<S>,
rhs: &LiteralValue<S>,
lhs_scale: i8,
rhs_scale: i8,
is_subtract: bool,
) -> S {
let (lhs_scaled, rhs_scaled) = match lhs_scale.cmp(&rhs_scale) {
Ordering::Less => {
let scaling_factor = S::pow10((rhs_scale - lhs_scale) as u8);
(lhs.to_scalar() * scaling_factor, rhs.to_scalar())
}
Ordering::Equal => (lhs.to_scalar(), rhs.to_scalar()),
Ordering::Greater => {
let scaling_factor = S::pow10((lhs_scale - rhs_scale) as u8);
(lhs.to_scalar(), rhs.to_scalar() * scaling_factor)
}
};
if is_subtract {
lhs_scaled - rhs_scaled
} else {
lhs_scaled + rhs_scaled
}
}

#[allow(
clippy::missing_panics_doc,
Expand Down Expand Up @@ -36,9 +64,62 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>(
result
}

/// Add or subtract two [`ColumnarValues`] together.
#[allow(dead_code)]
pub(crate) fn add_subtract_columnar_values<'a, S: Scalar>(
lhs: ColumnarValue<'a, S>,
rhs: ColumnarValue<'a, S>,
lhs_scale: i8,
rhs_scale: i8,
alloc: &'a Bump,
is_subtract: bool,
) -> ColumnarValue<'a, S> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(Column::Scalar(add_subtract_columns(
lhs,
rhs,
lhs_scale,
rhs_scale,
alloc,
is_subtract,
)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(Column::Scalar(add_subtract_columns(
Column::from_literal_with_length(&lhs, rhs.len(), alloc),
rhs,
lhs_scale,
rhs_scale,
alloc,
is_subtract,
)))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
ColumnarValue::Column(Column::Scalar(add_subtract_columns(
lhs,
Column::from_literal_with_length(&rhs, lhs.len(), alloc),
lhs_scale,
rhs_scale,
alloc,
is_subtract,
)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
ColumnarValue::Literal(LiteralValue::Scalar(add_subtract_literals(
&lhs,
&rhs,
lhs_scale,
rhs_scale,
is_subtract,
)))
}
}
}

/// Multiply two columns together.
/// # Panics
/// Panics if: The lengths of `lhs` and `rhs` are not equal.`lhs.scalar_at(i)` or `rhs.scalar_at(i)` returns `None`, which occurs if the column does not have, a scalar at the given index `i`.
/// Panics if: `lhs` and `rhs` are not of the same length.
pub(crate) fn multiply_columns<'a, S: Scalar>(
lhs: &Column<'a, S>,
rhs: &Column<'a, S>,
Expand All @@ -55,6 +136,38 @@ pub(crate) fn multiply_columns<'a, S: Scalar>(
})
}

#[allow(dead_code)]
/// Multiply two [`ColumnarValues`] together.
/// # Panics
/// Panics if: `lhs` and `rhs` are not of the same length.
pub(crate) fn multiply_columnar_values<'a, S: Scalar>(
lhs: &ColumnarValue<'a, S>,
rhs: &ColumnarValue<'a, S>,
alloc: &'a Bump,
) -> ColumnarValue<'a, S> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(Column::Scalar(multiply_columns(lhs, rhs, alloc)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
let lhs_scalar = lhs.to_scalar();
let result =
alloc.alloc_slice_fill_with(rhs.len(), |i| lhs_scalar * rhs.scalar_at(i).unwrap());
ColumnarValue::Column(Column::Scalar(result))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
let rhs_scalar = rhs.to_scalar();
let result =
alloc.alloc_slice_fill_with(lhs.len(), |i| lhs.scalar_at(i).unwrap() * rhs_scalar);
ColumnarValue::Column(Column::Scalar(result))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
let result = lhs.to_scalar() * rhs.to_scalar();
ColumnarValue::Literal(LiteralValue::Scalar(result))
}
}
}

#[allow(
clippy::missing_panics_doc,
reason = "scaling factor is guaranteed to not be negative based on input validation prior to calling this function"
Expand Down

0 comments on commit d32a469

Please sign in to comment.