diff --git a/crates/proof-of-sql/src/base/math/permutation.rs b/crates/proof-of-sql/src/base/math/permutation.rs index f5466e422..ac52b0874 100644 --- a/crates/proof-of-sql/src/base/math/permutation.rs +++ b/crates/proof-of-sql/src/base/math/permutation.rs @@ -1,4 +1,9 @@ +use crate::base::if_rayon; use alloc::{format, string::String, vec::Vec}; +use core::cmp::Ordering; +use itertools::Itertools; +#[cfg(feature = "rayon")] +use rayon::prelude::ParallelSliceMut; use snafu::Snafu; /// An error that occurs when working with permutations @@ -23,12 +28,19 @@ pub struct Permutation { } impl Permutation { - /// Create a new permutation without checks - /// - /// Warning: This function does not check if the permutation is valid. - /// Only use this function if you are sure that the permutation is valid. - pub(crate) fn unchecked_new(permutation: Vec) -> Self { - Self { permutation } + /// Create a new permutation from a comparison function with the given length + pub(crate) fn unchecked_new_from_cmp(length: usize, cmp: F) -> Self + where + F: Fn(&usize, &usize) -> Ordering + Sync, + { + let mut indexes = (0..length).collect_vec(); + if_rayon!( + indexes.par_sort_unstable_by(cmp), + indexes.sort_unstable_by(cmp) + ); + Self { + permutation: indexes, + } } /// Create a new permutation. If the permutation is invalid, return an error. diff --git a/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs index 6bda616c3..4f8b3f5f5 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs @@ -3,14 +3,11 @@ use crate::base::{ database::{ order_by_util::compare_indexes_by_owned_columns_with_direction, OwnedColumn, OwnedTable, }, - if_rayon, math::permutation::Permutation, scalar::Scalar, }; use alloc::{string::ToString, vec::Vec}; use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection}; -#[cfg(feature = "rayon")] -use rayon::prelude::ParallelSliceMut; use serde::{Deserialize, Serialize}; /// A node representing a list of `OrderBy` expressions. @@ -30,7 +27,6 @@ impl OrderByPostprocessing { impl PostprocessingStep for OrderByPostprocessing { /// Apply the slice transformation to the given `OwnedTable`. fn apply(&self, owned_table: OwnedTable) -> PostprocessingResult> { - let mut indexes = (0..owned_table.num_rows()).collect::>(); // Evaluate the columns by which we order // Once we allow OrderBy for general aggregation-free expressions here we will need to call eval() let order_by_pairs: Vec<(OwnedColumn, OrderByDirection)> = self @@ -52,15 +48,9 @@ impl PostprocessingStep for OrderByPostprocessing { ) .collect::, OrderByDirection)>>>()?; // Define the ordering - if_rayon!( - indexes.par_sort_unstable_by(|&a, &b| { - compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b) - }), - indexes.sort_unstable_by(|&a, &b| { - compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b) - }) - ); - let permutation = Permutation::unchecked_new(indexes); + let permutation = Permutation::unchecked_new_from_cmp(owned_table.num_rows(), |&a, &b| { + compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b) + }); // Apply the ordering Ok( OwnedTable::::try_from_iter(owned_table.into_inner().into_iter().map( diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index 25495a4f2..92601f512 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -1,36 +1,17 @@ use crate::{ base::{ database::{Column, ColumnarValue, LiteralValue}, - if_rayon, math::decimal::{DecimalError, Precision}, scalar::{Scalar, ScalarExt}, + slice_ops, }, sql::parse::{type_check_binary_operation, ConversionError, ConversionResult}, }; use alloc::string::ToString; use bumpalo::Bump; use core::cmp::{max, Ordering}; -#[cfg(feature = "rayon")] -use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; use sqlparser::ast::BinaryOperator; -#[allow(clippy::unnecessary_wraps)] -fn unchecked_subtract_impl<'a, S: Scalar>( - alloc: &'a Bump, - lhs: &[S], - rhs: &[S], - table_length: usize, -) -> ConversionResult<&'a [S]> { - let result = alloc.alloc_slice_fill_default(table_length); - if_rayon!(result.par_iter_mut(), result.iter_mut()) - .zip(lhs) - .zip(rhs) - .for_each(|((a, l), r)| { - *a = *l - *r; - }); - Ok(result) -} - /// Scale LHS and RHS to the same scale if at least one of them is decimal /// and take the difference. This function is used for comparisons. /// @@ -155,12 +136,13 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( } })?; } - unchecked_subtract_impl( - alloc, + let result = alloc.alloc_slice_fill_default(lhs_len); + slice_ops::sub( + result, &lhs.to_scalar_with_scaling(lhs_upscale), &rhs.to_scalar_with_scaling(rhs_upscale), - lhs_len, - ) + ); + Ok(result) } #[allow(clippy::cast_sign_loss)]