Skip to content

Commit

Permalink
feat: add ColumnarValue operations
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Oct 31, 2024
1 parent 130d56b commit 966efb9
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 35 deletions.
108 changes: 108 additions & 0 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,73 @@ fn unchecked_subtract_impl<'a, S: Scalar>(
Ok(result)
}

#[allow(
clippy::missing_panics_doc,
reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs"
)]
/// Scale LHS and RHS to the same scale if at least one of them is decimal
/// and take the difference. This function is used for comparisons.
#[allow(clippy::cast_sign_loss)]
pub fn scale_and_subtract_literal<'a, S: Scalar>(
alloc: &'a Bump,
lhs: LiteralValue<S>,
rhs: LiteralValue<S>,
lhs_scale: i8,
rhs_scale: i8,
is_equal: bool,
) -> ConversionResult<S> {
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
} else {
BinaryOperator::LessThanOrEqual
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
});
}
let max_scale = cmp::max(lhs_scale, rhs_scale);
let lhs_upscale = max_scale - lhs_scale;
let rhs_upscale = max_scale - rhs_scale;
// Only check precision overflow issues if at least one side is decimal
if max_scale != 0 {
let lhs_precision_value = lhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let rhs_precision_value = rhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let max_precision_value = cmp::max(
lhs_precision_value + (max_scale - lhs_scale) as u8,
rhs_precision_value + (max_scale - rhs_scale) as u8,
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value).map_err(|_| {
ConversionError::DecimalConversionError {
source: DecimalError::InvalidPrecision {
error: max_precision_value.to_string(),
},
}
})?;
}
let lhs_scalar = lhs.to_scalar();
let rhs_scalar = rhs.to_scalar();
if left_upscale > 0 {
let upscale_factor =
S::pow10(u8::try_from(lhs_upscale).expect("Upscale factor is nonnegative"));
Ok(lhs_scalar * upscale_factor - rhs_scalar)
} else if right_upscale > 0 {
let upscale_factor =
S::pow10(u8::try_from(rhs_upscale).expect("Upscale factor is nonnegative"));
Ok(lhs_scalar - rhs_scalar * upscale_factor)
} else {
Ok(lhs_scalar - rhs_scalar)
}
}

#[allow(
clippy::missing_panics_doc,
reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs"
Expand Down Expand Up @@ -98,3 +165,44 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
lhs_len,
)
}

#[allow(clippy::cast_sign_loss)]
pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>(
alloc: &'a Bump,
lhs: ColumnarValue<'a, S>,
rhs: ColumnarValue<'a, S>,
lhs_scale: i8,
rhs_scale: i8,
is_equal: bool,
) -> ConversionResult<ColumnarValue<'a, S>> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal,
)?)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc,
Column::from_literal_with_length(lhs, rhs.len(), alloc),
rhs,
lhs_scale,
rhs_scale,
is_equal,
)?)))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc,
lhs,
Column::from_literal_with_length(rhs, lhs.len(), alloc),
lhs_scale,
rhs_scale,
is_equal,
)?)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => Ok(ColumnarValue::Scalar(
scale_and_subtract_literal(alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal)?,
)),
}
}
112 changes: 77 additions & 35 deletions crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,35 @@
use crate::base::{
database::{Column, ColumnarValue},
database::{Column, ColumnarValue, LiteralValue},
scalar::{Scalar, ScalarExt},
};
use bumpalo::Bump;
use core::cmp::Ordering;

/// Add or subtract two literals together.
pub(crate) fn add_subtract_literals<S: Scalar>(
lhs: LiteralValue<S>,
rhs: LiteralValue<S>,
lhs_scale: i8,
rhs_scale: i8,
is_subtract: bool,
) -> S {
let (lhs_scaled, rhs_scaled) = match lhs_scale.cmp(rhs_scale) {
Ordering::Less => {
let scaling_factor = S::pow10(rhs_scale - lhs_scale);
(lhs.to_scalar() * scaling_factor, rhs.to_scalar())
}
Ordering::Equal => (lhs.to_scalar(), rhs.to_scalar()),
Ordering::Greater => {
let scaling_factor = S::pow10(lhs_scale - rhs_scale);
(lhs.to_scalar(), rhs.to_scalar() * scaling_factor)
}
};
if is_subtract {
lhs_scaled - rhs_scaled
} else {
lhs_scaled + rhs_scaled
}
}

#[allow(
clippy::missing_panics_doc,
Expand Down Expand Up @@ -46,42 +73,31 @@ pub(crate) fn add_subtract_columnar_values<'a, S: Scalar>(
is_subtract: bool,
) -> ColumnarValue<'a, S> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(add_subtract_columns(lhs, rhs, lhs_scale, rhs_scale, alloc, is_subtract))
}
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => ColumnarValue::Column(
add_subtract_columns(lhs, rhs, lhs_scale, rhs_scale, alloc, is_subtract),
),
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
let lhs_len = rhs.len();
let lhs_scalar = lhs.to_scalar_with_scaling(lhs_scale);
let result = alloc.alloc_slice_fill_with(lhs_len, |i| {
if is_subtract {
lhs_scalar - rhs.scalar_at(i).unwrap()
} else {
lhs_scalar + rhs.scalar_at(i).unwrap()
}
});
ColumnarValue::Column(result)
ColumnarValue::Column(add_subtract_columns(
Column::from_literal_with_length(lhs, rhs.len(), alloc),
rhs,
lhs_scale,
rhs_scale,
alloc,
is_subtract,
))
}
(ColumnarValue::Column(lhs), ColumnarValue::Scalar(rhs)) => {
let rhs_len = lhs.len();
let rhs_scalar = rhs.to_scalar_with_scaling(rhs_scale);
let result = alloc.alloc_slice_fill_with(rhs_len, |i| {
if is_subtract {
lhs.scalar_at(i).unwrap() - rhs_scalar
} else {
lhs.scalar_at(i).unwrap() + rhs_scalar
}
});
ColumnarValue::Column(result)
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
ColumnarValue::Column(add_subtract_columns(
lhs,
Column::from_literal_with_length(rhs, lhs.len(), alloc),
lhs_scale,
rhs_scale,
alloc,
is_subtract,
))
}
(ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => {
let lhs_scalar = lhs.to_scalar_with_scaling(lhs_scale);
let rhs_scalar = rhs.to_scalar_with_scaling(rhs_scale);
let result = if is_subtract {
lhs_scalar - rhs_scalar
} else {
lhs_scalar + rhs_scalar
};
ColumnarValue::Scalar(result)
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
ColumnarValue::Literal(add_subtract_literals(lhs, rhs, lhs_scale, rhs_scale, is_subtract))
}
}
}
Expand All @@ -105,7 +121,33 @@ pub(crate) fn multiply_columns<'a, S: Scalar>(
})
}


pub(crate) fn multiply_columnar_values<'a, S: Scalar>(
lhs: &ColumnarValue<'a, S>,
rhs: &ColumnarValue<'a, S>,
alloc: &'a Bump,
) -> ColumnarValue<'a, S> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(multiply_columns(lhs, rhs, alloc))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
let lhs_scalar = lhs.to_scalar();
let result =
alloc.alloc_slice_fill_with(rhs.len(), |i| lhs_scalar * rhs.scalar_at(i).unwrap());
ColumnarValue::Column(result)
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
let rhs_scalar = rhs.to_scalar();
let result =
alloc.alloc_slice_fill_with(lhs.len(), |i| lhs.scalar_at(i).unwrap() * rhs_scalar);
ColumnarValue::Column(result)
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
let result = lhs.to_scalar() * rhs.to_scalar();
ColumnarValue::Scalar(result)
}
}
}

#[allow(
clippy::missing_panics_doc,
Expand Down

0 comments on commit 966efb9

Please sign in to comment.