Skip to content

Commit

Permalink
feat: add ColumnarValue operations (#331)
Browse files Browse the repository at this point in the history
Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [ ] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [ ] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
To prepare to remove `table_length` from `ProofExpr::result_evaluate`
and return `ColumnarValue` we need to add operations on `ColumnarValue`.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- add comparison operations on `ColumnarValue`.
- add numerical operations on `ColumnarValue`.
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Will be tested in a future PR.
  • Loading branch information
iajoiner authored Nov 4, 2024
2 parents 7598b5c + 9b7b7d5 commit 36d9ae3
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 7 deletions.
120 changes: 115 additions & 5 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::{
base::{
database::Column,
database::{Column, ColumnarValue, LiteralValue},
if_rayon,
math::decimal::{DecimalError, Precision},
scalar::Scalar,
scalar::{Scalar, ScalarExt},
},
sql::parse::{type_check_binary_operation, ConversionError, ConversionResult},
};
use alloc::string::ToString;
use bumpalo::Bump;
use core::cmp;
use core::cmp::{max, Ordering};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
Expand All @@ -31,6 +31,70 @@ fn unchecked_subtract_impl<'a, S: Scalar>(
Ok(result)
}

/// Scale LHS and RHS to the same scale if at least one of them is decimal
/// and take the difference. This function is used for comparisons.
///
/// # Panics
/// This function will panic if `lhs` and `rhs` have [`ColumnType`]s that are not comparable
/// or if we have precision overflow issues.
#[allow(clippy::cast_sign_loss)]
pub fn scale_and_subtract_literal<S: Scalar>(
lhs: &LiteralValue<S>,
rhs: &LiteralValue<S>,
lhs_scale: i8,
rhs_scale: i8,
is_equal: bool,
) -> ConversionResult<S> {
let lhs_type = lhs.column_type();
let rhs_type = rhs.column_type();
let operator = if is_equal {
BinaryOperator::Equal
} else {
BinaryOperator::LessThanOrEqual
};
if !type_check_binary_operation(&lhs_type, &rhs_type, operator) {
return Err(ConversionError::DataTypeMismatch {
left_type: lhs_type.to_string(),
right_type: rhs_type.to_string(),
});
}
let max_scale = max(lhs_scale, rhs_scale);
let lhs_upscale = max_scale - lhs_scale;
let rhs_upscale = max_scale - rhs_scale;
// Only check precision overflow issues if at least one side is decimal
if max_scale != 0 {
let lhs_precision_value = lhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let rhs_precision_value = rhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let max_precision_value = max(
lhs_precision_value + (max_scale - lhs_scale) as u8,
rhs_precision_value + (max_scale - rhs_scale) as u8,
);
// Check if the precision is valid
let _max_precision = Precision::new(max_precision_value).map_err(|_| {
ConversionError::DecimalConversionError {
source: DecimalError::InvalidPrecision {
error: max_precision_value.to_string(),
},
}
})?;
}
match lhs_scale.cmp(&rhs_scale) {
Ordering::Less => {
let upscale_factor = S::pow10(rhs_upscale as u8);
Ok(lhs.to_scalar() * upscale_factor - rhs.to_scalar())
}
Ordering::Equal => Ok(lhs.to_scalar() - rhs.to_scalar()),
Ordering::Greater => {
let upscale_factor = S::pow10(lhs_upscale as u8);
Ok(lhs.to_scalar() - rhs.to_scalar() * upscale_factor)
}
}
}

#[allow(
clippy::missing_panics_doc,
reason = "precision and scale are validated prior to calling this function, ensuring no panic occurs"
Expand Down Expand Up @@ -67,7 +131,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
right_type: rhs_type.to_string(),
});
}
let max_scale = cmp::max(lhs_scale, rhs_scale);
let max_scale = max(lhs_scale, rhs_scale);
let lhs_upscale = max_scale - lhs_scale;
let rhs_upscale = max_scale - rhs_scale;
// Only check precision overflow issues if at least one side is decimal
Expand All @@ -78,7 +142,7 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
let rhs_precision_value = rhs_type
.precision_value()
.expect("If scale is set, precision must be set");
let max_precision_value = cmp::max(
let max_precision_value = max(
lhs_precision_value + (max_scale - lhs_scale) as u8,
rhs_precision_value + (max_scale - rhs_scale) as u8,
);
Expand All @@ -98,3 +162,49 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
lhs_len,
)
}

#[allow(clippy::cast_sign_loss)]
#[allow(dead_code)]
/// Scale LHS and RHS to the same scale if at least one of them is decimal
/// and take the difference. This function is used for comparisons.
pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>(
alloc: &'a Bump,
lhs: ColumnarValue<'a, S>,
rhs: ColumnarValue<'a, S>,
lhs_scale: i8,
rhs_scale: i8,
is_equal: bool,
) -> ConversionResult<ColumnarValue<'a, S>> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal,
)?)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc,
Column::from_literal_with_length(&lhs, rhs.len(), alloc),
rhs,
lhs_scale,
rhs_scale,
is_equal,
)?)))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract(
alloc,
lhs,
Column::from_literal_with_length(&rhs, lhs.len(), alloc),
lhs_scale,
rhs_scale,
is_equal,
)?)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
Ok(ColumnarValue::Literal(LiteralValue::Scalar(
scale_and_subtract_literal(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?,
)))
}
}
}
117 changes: 115 additions & 2 deletions crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
use crate::base::{
database::Column,
database::{Column, ColumnarValue, LiteralValue},
scalar::{Scalar, ScalarExt},
};
use bumpalo::Bump;
use core::cmp::Ordering;

#[allow(clippy::cast_sign_loss)]
/// Add or subtract two literals together.
pub(crate) fn add_subtract_literals<S: Scalar>(
lhs: &LiteralValue<S>,
rhs: &LiteralValue<S>,
lhs_scale: i8,
rhs_scale: i8,
is_subtract: bool,
) -> S {
let (lhs_scaled, rhs_scaled) = match lhs_scale.cmp(&rhs_scale) {
Ordering::Less => {
let scaling_factor = S::pow10((rhs_scale - lhs_scale) as u8);
(lhs.to_scalar() * scaling_factor, rhs.to_scalar())
}
Ordering::Equal => (lhs.to_scalar(), rhs.to_scalar()),
Ordering::Greater => {
let scaling_factor = S::pow10((lhs_scale - rhs_scale) as u8);
(lhs.to_scalar(), rhs.to_scalar() * scaling_factor)
}
};
if is_subtract {
lhs_scaled - rhs_scaled
} else {
lhs_scaled + rhs_scaled
}
}

#[allow(
clippy::missing_panics_doc,
Expand Down Expand Up @@ -36,9 +64,62 @@ pub(crate) fn add_subtract_columns<'a, S: Scalar>(
result
}

/// Add or subtract two [`ColumnarValues`] together.
#[allow(dead_code)]
pub(crate) fn add_subtract_columnar_values<'a, S: Scalar>(
lhs: ColumnarValue<'a, S>,
rhs: ColumnarValue<'a, S>,
lhs_scale: i8,
rhs_scale: i8,
alloc: &'a Bump,
is_subtract: bool,
) -> ColumnarValue<'a, S> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(Column::Scalar(add_subtract_columns(
lhs,
rhs,
lhs_scale,
rhs_scale,
alloc,
is_subtract,
)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(Column::Scalar(add_subtract_columns(
Column::from_literal_with_length(&lhs, rhs.len(), alloc),
rhs,
lhs_scale,
rhs_scale,
alloc,
is_subtract,
)))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
ColumnarValue::Column(Column::Scalar(add_subtract_columns(
lhs,
Column::from_literal_with_length(&rhs, lhs.len(), alloc),
lhs_scale,
rhs_scale,
alloc,
is_subtract,
)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
ColumnarValue::Literal(LiteralValue::Scalar(add_subtract_literals(
&lhs,
&rhs,
lhs_scale,
rhs_scale,
is_subtract,
)))
}
}
}

/// Multiply two columns together.
/// # Panics
/// Panics if: The lengths of `lhs` and `rhs` are not equal.`lhs.scalar_at(i)` or `rhs.scalar_at(i)` returns `None`, which occurs if the column does not have, a scalar at the given index `i`.
/// Panics if: `lhs` and `rhs` are not of the same length.
pub(crate) fn multiply_columns<'a, S: Scalar>(
lhs: &Column<'a, S>,
rhs: &Column<'a, S>,
Expand All @@ -55,6 +136,38 @@ pub(crate) fn multiply_columns<'a, S: Scalar>(
})
}

#[allow(dead_code)]
/// Multiply two [`ColumnarValues`] together.
/// # Panics
/// Panics if: `lhs` and `rhs` are not of the same length.
pub(crate) fn multiply_columnar_values<'a, S: Scalar>(
lhs: &ColumnarValue<'a, S>,
rhs: &ColumnarValue<'a, S>,
alloc: &'a Bump,
) -> ColumnarValue<'a, S> {
match (lhs, rhs) {
(ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => {
ColumnarValue::Column(Column::Scalar(multiply_columns(lhs, rhs, alloc)))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => {
let lhs_scalar = lhs.to_scalar();
let result =
alloc.alloc_slice_fill_with(rhs.len(), |i| lhs_scalar * rhs.scalar_at(i).unwrap());
ColumnarValue::Column(Column::Scalar(result))
}
(ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => {
let rhs_scalar = rhs.to_scalar();
let result =
alloc.alloc_slice_fill_with(lhs.len(), |i| lhs.scalar_at(i).unwrap() * rhs_scalar);
ColumnarValue::Column(Column::Scalar(result))
}
(ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => {
let result = lhs.to_scalar() * rhs.to_scalar();
ColumnarValue::Literal(LiteralValue::Scalar(result))
}
}
}

#[allow(
clippy::missing_panics_doc,
reason = "scaling factor is guaranteed to not be negative based on input validation prior to calling this function"
Expand Down

0 comments on commit 36d9ae3

Please sign in to comment.