From e7a77bac9200b3313efeb552afeb60dd74336a8e Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Thu, 7 Nov 2024 12:19:21 -0500 Subject: [PATCH 1/4] refactor: simplify arithmetic in `owned_column_operation.rs` --- .../database/column_arithmetic_operation.rs | 291 +++++++ crates/proof-of-sql/src/base/database/mod.rs | 3 + .../base/database/owned_column_operation.rs | 812 +----------------- .../src/base/database/slice_operation.rs | 8 +- 4 files changed, 318 insertions(+), 796 deletions(-) create mode 100644 crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs diff --git a/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs b/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs new file mode 100644 index 000000000..0c5e6fc13 --- /dev/null +++ b/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs @@ -0,0 +1,291 @@ +use super::{ColumnOperationError, ColumnOperationResult}; +use crate::base::{ + database::{ + slice_decimal_operation::{ + try_add_decimal_columns, try_divide_decimal_columns, try_multiply_decimal_columns, + try_subtract_decimal_columns, + }, + slice_operation::{ + try_add, try_div, try_mul, try_slice_binary_op, try_slice_binary_op_left_upcast, + try_slice_binary_op_right_upcast, try_sub, + }, + ColumnType, OwnedColumn, + }, + math::decimal::Precision, + scalar::Scalar, +}; +use alloc::vec::Vec; +use core::fmt::Debug; +use num_bigint::BigInt; +use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}; +use proof_of_sql_parser::intermediate_ast::BinaryOperator; + +pub trait ArithmeticOp { + fn op(l: &T, r: &T) -> ColumnOperationResult + where + T: Debug + CheckedDiv + CheckedMul + CheckedAdd + CheckedSub; + fn decimal_op( + lhs: &[T0], + rhs: &[T1], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> ColumnOperationResult<(Precision, i8, Vec)> + where + S: Scalar + From + From, + T0: Copy + Debug + Into, + T1: Copy + Debug + Into; + + #[allow(clippy::too_many_lines)] + fn owned_column_element_wise_arithmetic( + lhs: &OwnedColumn, + rhs: &OwnedColumn, + ) -> ColumnOperationResult> { + if lhs.len() != rhs.len() { + return Err(ColumnOperationError::DifferentColumnLength { + len_a: lhs.len(), + len_b: rhs.len(), + }); + } + match (&lhs, &rhs) { + (OwnedColumn::TinyInt(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::TinyInt( + try_slice_binary_op(lhs, rhs, Self::op)?, + )), + (OwnedColumn::TinyInt(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::SmallInt( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::TinyInt(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::Int( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::TinyInt(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::TinyInt(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::TinyInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + + (OwnedColumn::SmallInt(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::SmallInt( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::SmallInt(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::SmallInt( + try_slice_binary_op(lhs, rhs, Self::op)?, + )), + (OwnedColumn::SmallInt(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::Int( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::SmallInt(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::SmallInt(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::SmallInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + + (OwnedColumn::Int(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::Int( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::Int( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int(lhs), OwnedColumn::Int(rhs)) => { + Ok(OwnedColumn::Int(try_slice_binary_op(lhs, rhs, Self::op)?)) + } + (OwnedColumn::Int(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + + (OwnedColumn::BigInt(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::BigInt(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::BigInt(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::BigInt(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt( + try_slice_binary_op(lhs, rhs, Self::op)?, + )), + (OwnedColumn::BigInt(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::BigInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + + (OwnedColumn::Int128(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int128(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int128(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int128(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int128(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128( + try_slice_binary_op(lhs, rhs, Self::op)?, + )), + (OwnedColumn::Int128(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::TinyInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::SmallInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::Int(rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::BigInt(rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::Int128(rhs_values)) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + ( + OwnedColumn::Decimal75(_, _, lhs_values), + OwnedColumn::Decimal75(_, _, rhs_values), + ) => { + let (new_precision, new_scale, new_values) = + Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?; + Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) + } + _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Add, + left_type: lhs.column_type(), + right_type: rhs.column_type(), + }), + } + } +} + +pub struct AddOp {} +impl ArithmeticOp for AddOp { + fn op(l: &T, r: &T) -> ColumnOperationResult + where + T: CheckedAdd + Debug, + { + try_add(l, r) + } + + fn decimal_op( + lhs: &[T0], + rhs: &[T1], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> ColumnOperationResult<(Precision, i8, Vec)> + where + S: Scalar + From + From, + T0: Copy, + T1: Copy, + { + try_add_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } +} + +pub struct SubOp {} +impl ArithmeticOp for SubOp { + fn op(l: &T, r: &T) -> ColumnOperationResult + where + T: CheckedSub + Debug, + { + try_sub(l, r) + } + + fn decimal_op( + lhs: &[T0], + rhs: &[T1], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> ColumnOperationResult<(Precision, i8, Vec)> + where + S: Scalar + From + From, + T0: Copy, + T1: Copy, + { + try_subtract_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } +} + +pub struct MulOp {} +impl ArithmeticOp for MulOp { + fn op(l: &T, r: &T) -> ColumnOperationResult + where + T: CheckedMul + Debug, + { + try_mul(l, r) + } + + fn decimal_op( + lhs: &[T0], + rhs: &[T1], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> ColumnOperationResult<(Precision, i8, Vec)> + where + S: Scalar + From + From, + T0: Copy, + T1: Copy, + { + try_multiply_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } +} + +pub struct DivOp {} +impl ArithmeticOp for DivOp { + fn op(l: &T, r: &T) -> ColumnOperationResult + where + T: CheckedDiv + Debug, + { + try_div(l, r) + } + + fn decimal_op( + lhs: &[T0], + rhs: &[T1], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> ColumnOperationResult<(Precision, i8, Vec)> + where + S: Scalar + From + From, + T0: Copy + Debug + Into, + T1: Copy + Debug + Into, + { + try_divide_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } +} diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 5516fb61b..c38e4b010 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -17,6 +17,9 @@ pub use column_type_operation::{ try_add_subtract_column_types, try_divide_column_types, try_multiply_column_types, }; +mod column_arithmetic_operation; +pub(super) use column_arithmetic_operation::{AddOp, ArithmeticOp, DivOp, MulOp, SubOp}; + mod column_operation_error; pub use column_operation_error::{ColumnOperationError, ColumnOperationResult}; diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index a8a495115..6eebbe207 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -1,16 +1,12 @@ -use super::{ColumnOperationError, ColumnOperationResult}; +use super::{ + AddOp, ArithmeticOp, ColumnOperationError, ColumnOperationResult, DivOp, MulOp, SubOp, +}; use crate::base::{ database::{ - slice_decimal_operation::{ - eq_decimal_columns, ge_decimal_columns, le_decimal_columns, try_add_decimal_columns, - try_divide_decimal_columns, try_multiply_decimal_columns, try_subtract_decimal_columns, - }, + slice_decimal_operation::{eq_decimal_columns, ge_decimal_columns, le_decimal_columns}, slice_operation::{ slice_and, slice_eq, slice_eq_with_casting, slice_ge, slice_ge_with_casting, slice_le, - slice_le_with_casting, slice_not, slice_or, try_add_slices, - try_add_slices_with_casting, try_divide_slices, try_divide_slices_left_upcast, - try_divide_slices_right_upcast, try_multiply_slices, try_multiply_slices_with_casting, - try_subtract_slices, try_subtract_slices_left_upcast, try_subtract_slices_right_upcast, + slice_le_with_casting, slice_not, slice_or, }, OwnedColumn, }, @@ -610,793 +606,25 @@ impl OwnedColumn { }), } } -} - -impl Add for OwnedColumn { - type Output = ColumnOperationResult; - - #[allow(clippy::too_many_lines)] - fn add(self, rhs: Self) -> Self::Output { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (&self, &rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::TinyInt(try_add_slices(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::SmallInt(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_add_slices(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Int(try_add_slices(lhs, rhs)?)), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::BigInt(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::BigInt(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::BigInt(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => Ok(Self::BigInt(try_add_slices(lhs, rhs)?)), - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(lhs, rhs)?)) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Int128(try_add_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::Int128(rhs)) => Ok(Self::Int128(try_add_slices(lhs, rhs)?)), - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_add_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "+".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } + /// Element-wise addition for two columns + pub fn element_wise_add(&self, rhs: &OwnedColumn) -> ColumnOperationResult> { + AddOp::owned_column_element_wise_arithmetic(self, rhs) } -} - -impl Sub for OwnedColumn { - type Output = ColumnOperationResult; - - #[allow(clippy::too_many_lines)] - fn sub(self, rhs: Self) -> Self::Output { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (&self, &rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::TinyInt(try_subtract_slices(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::SmallInt(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_subtract_slices(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Int(try_subtract_slices(lhs, rhs)?)), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::BigInt(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::BigInt(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::BigInt(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_subtract_slices(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_subtract_slices_left_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } + /// Element-wise subtraction for two columns + pub fn element_wise_sub(&self, rhs: &OwnedColumn) -> ColumnOperationResult> { + SubOp::owned_column_element_wise_arithmetic(self, rhs) + } - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int128(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int128(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Int128(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Int128(try_subtract_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_subtract_slices(lhs, rhs)?)) - } - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } + /// Element-wise multiplication for two columns + pub fn element_wise_mul(&self, rhs: &OwnedColumn) -> ColumnOperationResult> { + MulOp::owned_column_element_wise_arithmetic(self, rhs) + } - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_subtract_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "-".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } - } -} - -impl Mul for OwnedColumn { - type Output = ColumnOperationResult; - - #[allow(clippy::too_many_lines)] - fn mul(self, rhs: Self) -> Self::Output { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (&self, &rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::TinyInt(try_multiply_slices(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::SmallInt(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_multiply_slices(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Int(try_multiply_slices(lhs, rhs)?)), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::BigInt(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::BigInt(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::BigInt(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_multiply_slices(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(lhs, rhs)?)) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Int128(try_multiply_slices_with_casting(rhs, lhs)?)) - } - (Self::Int128(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_multiply_slices(lhs, rhs)?)) - } - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_multiply_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "*".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } - } -} - -impl Div for OwnedColumn { - type Output = ColumnOperationResult; - - #[allow(clippy::too_many_lines)] - fn div(self, rhs: Self) -> Self::Output { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (&self, &rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::TinyInt(try_divide_slices(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::SmallInt(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::SmallInt(try_divide_slices(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Int(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Int(try_divide_slices(lhs, rhs)?)), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::BigInt(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::BigInt(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::BigInt(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::BigInt(try_divide_slices(lhs, rhs)?)) - } - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_divide_slices_left_upcast(lhs, rhs)?)) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Int128(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Int128(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Int128(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Int128(try_divide_slices_right_upcast(lhs, rhs)?)) - } - (Self::Int128(lhs), Self::Int128(rhs)) => { - Ok(Self::Int128(try_divide_slices(lhs, rhs)?)) - } - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - let (new_precision, new_scale, new_values) = try_divide_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - )?; - Ok(Self::Decimal75(new_precision, new_scale, new_values)) - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "/".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } + /// Element-wise division for two columns + pub fn element_wise_div(&self, rhs: &OwnedColumn) -> ColumnOperationResult> { + DivOp::owned_column_element_wise_arithmetic(self, rhs) } } @@ -1897,7 +1125,7 @@ mod test { } #[test] - fn we_can_try_add_decimal_columns() { + fn we_can_decimal_op() { // lhs and rhs have the same precision and scale let lhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); diff --git a/crates/proof-of-sql/src/base/database/slice_operation.rs b/crates/proof-of-sql/src/base/database/slice_operation.rs index 689178b9d..dd83edb90 100644 --- a/crates/proof-of-sql/src/base/database/slice_operation.rs +++ b/crates/proof-of-sql/src/base/database/slice_operation.rs @@ -4,7 +4,7 @@ use core::fmt::Debug; use num_traits::ops::checked::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}; /// Function for checked addition with overflow error handling -fn try_add(l: &T, r: &T) -> ColumnOperationResult +pub(crate) fn try_add(l: &T, r: &T) -> ColumnOperationResult where T: CheckedAdd + Debug, { @@ -15,7 +15,7 @@ where } /// Function for checked subtraction with overflow error handling -fn try_sub(l: &T, r: &T) -> ColumnOperationResult +pub(crate) fn try_sub(l: &T, r: &T) -> ColumnOperationResult where T: CheckedSub + Debug, { @@ -26,7 +26,7 @@ where } /// Function for checked multiplication with overflow error handling -fn try_mul(l: &T, r: &T) -> ColumnOperationResult +pub(crate) fn try_mul(l: &T, r: &T) -> ColumnOperationResult where T: CheckedMul + Debug, { @@ -37,7 +37,7 @@ where } /// Function for checked division with division by zero error handling -fn try_div(l: &T, r: &T) -> ColumnOperationResult +pub(crate) fn try_div(l: &T, r: &T) -> ColumnOperationResult where T: CheckedDiv + Debug, { From ffdcf8f864cf0daca95d5862e385d2287aacb0a5 Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Sun, 10 Nov 2024 22:12:39 -0500 Subject: [PATCH 2/4] refactor: simplify comparison in `owned_column_operation.rs` --- .../database/column_comparison_operation.rs | 363 ++++++++++++ crates/proof-of-sql/src/base/database/mod.rs | 5 + .../base/database/owned_column_operation.rs | 548 +----------------- 3 files changed, 376 insertions(+), 540 deletions(-) create mode 100644 crates/proof-of-sql/src/base/database/column_comparison_operation.rs diff --git a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs new file mode 100644 index 000000000..8329ca277 --- /dev/null +++ b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs @@ -0,0 +1,363 @@ +use super::{ColumnOperationError, ColumnOperationResult}; +use crate::base::{ + database::{ + slice_decimal_operation::{eq_decimal_columns, ge_decimal_columns, le_decimal_columns}, + slice_operation::{ + slice_binary_op, slice_binary_op_left_upcast, slice_binary_op_right_upcast, + }, + ColumnType, OwnedColumn, + }, + scalar::Scalar, +}; +use core::{cmp::Ord, fmt::Debug}; +use num_traits::Zero; +use proof_of_sql_parser::intermediate_ast::BinaryOperator; + +pub trait ComparisonOp { + fn op(l: &T, r: &T) -> bool + where + T: Debug + Ord; + + fn decimal_op_left_upcast( + lhs: &[T], + rhs: &[S], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + Ord + Zero + Into; + + fn decimal_op_right_upcast( + lhs: &[S], + rhs: &[T], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + Ord + Zero + Into; + + /// Return an error if op is not implemented for string + fn string_op(lhs: &[String], rhs: &[String]) -> ColumnOperationResult>; + + #[allow(clippy::too_many_lines)] + fn owned_column_element_wise_comparison( + lhs: &OwnedColumn, + rhs: &OwnedColumn, + ) -> ColumnOperationResult> { + if lhs.len() != rhs.len() { + return Err(ColumnOperationError::DifferentColumnLength { + len_a: lhs.len(), + len_b: rhs.len(), + }); + } + let result = match (&lhs, &rhs) { + (OwnedColumn::TinyInt(lhs), OwnedColumn::TinyInt(rhs)) => { + Ok(slice_binary_op(lhs, rhs, Self::op)) + } + (OwnedColumn::TinyInt(lhs), OwnedColumn::SmallInt(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::TinyInt(lhs), OwnedColumn::Int(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::TinyInt(lhs), OwnedColumn::BigInt(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::TinyInt(lhs), OwnedColumn::Int128(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::TinyInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + Ok(Self::decimal_op_left_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + + (OwnedColumn::SmallInt(lhs), OwnedColumn::TinyInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::SmallInt(lhs), OwnedColumn::SmallInt(rhs)) => { + Ok(slice_binary_op(lhs, rhs, Self::op)) + } + (OwnedColumn::SmallInt(lhs), OwnedColumn::Int(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::SmallInt(lhs), OwnedColumn::BigInt(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::SmallInt(lhs), OwnedColumn::Int128(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::SmallInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + Ok(Self::decimal_op_left_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + + (OwnedColumn::Int(lhs), OwnedColumn::TinyInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int(lhs), OwnedColumn::SmallInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int(lhs), OwnedColumn::Int(rhs)) => { + Ok(slice_binary_op(lhs, rhs, Self::op)) + } + (OwnedColumn::Int(lhs), OwnedColumn::BigInt(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int(lhs), OwnedColumn::Int128(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + Ok(Self::decimal_op_left_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + + (OwnedColumn::BigInt(lhs), OwnedColumn::TinyInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::BigInt(lhs), OwnedColumn::SmallInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::BigInt(lhs), OwnedColumn::Int(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::BigInt(lhs), OwnedColumn::BigInt(rhs)) => { + Ok(slice_binary_op(lhs, rhs, Self::op)) + } + (OwnedColumn::BigInt(lhs), OwnedColumn::Int128(rhs)) => { + Ok(slice_binary_op_left_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::BigInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + Ok(Self::decimal_op_left_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + + (OwnedColumn::Int128(lhs), OwnedColumn::TinyInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int128(lhs), OwnedColumn::SmallInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int128(lhs), OwnedColumn::Int(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int128(lhs), OwnedColumn::BigInt(rhs)) => { + Ok(slice_binary_op_right_upcast(lhs, rhs, Self::op)) + } + (OwnedColumn::Int128(lhs), OwnedColumn::Int128(rhs)) => { + Ok(slice_binary_op(lhs, rhs, Self::op)) + } + (OwnedColumn::Int128(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => { + Ok(Self::decimal_op_left_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::TinyInt(rhs_values)) => { + Ok(Self::decimal_op_right_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::SmallInt(rhs_values)) => { + Ok(Self::decimal_op_right_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::Int(rhs_values)) => { + Ok(Self::decimal_op_right_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::BigInt(rhs_values)) => { + Ok(Self::decimal_op_right_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + (OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::Int128(rhs_values)) => { + Ok(Self::decimal_op_right_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )) + } + ( + OwnedColumn::Decimal75(_, _, lhs_values), + OwnedColumn::Decimal75(_, _, rhs_values), + ) => Ok(Self::decimal_op_left_upcast( + lhs_values, + rhs_values, + lhs.column_type(), + rhs.column_type(), + )), + + (OwnedColumn::VarChar(lhs), OwnedColumn::VarChar(rhs)) => Self::string_op(lhs, rhs), + _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Add, + left_type: lhs.column_type(), + right_type: rhs.column_type(), + }), + }?; + Ok(OwnedColumn::Boolean(result)) + } +} + +pub struct EqualOp {} +impl ComparisonOp for EqualOp { + fn op(l: &T, r: &T) -> bool + where + T: Debug + PartialEq, + { + l == r + } + + fn decimal_op_left_upcast( + lhs: &[T], + rhs: &[S], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + PartialEq + Zero + Into, + { + eq_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } + + fn decimal_op_right_upcast( + lhs: &[S], + rhs: &[T], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + PartialEq + Zero + Into, + { + eq_decimal_columns(rhs, lhs, right_column_type, left_column_type) + } + + fn string_op(lhs: &[String], rhs: &[String]) -> ColumnOperationResult> { + Ok(lhs.iter().zip(rhs.iter()).map(|(l, r)| l == r).collect()) + } +} + +pub struct GreaterThanOrEqualOp {} +impl ComparisonOp for GreaterThanOrEqualOp { + fn op(l: &T, r: &T) -> bool + where + T: Debug + Ord, + { + l >= r + } + + fn decimal_op_left_upcast( + lhs: &[T], + rhs: &[S], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + Ord + Zero + Into, + { + ge_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } + + fn decimal_op_right_upcast( + lhs: &[S], + rhs: &[T], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + Ord + Zero + Into, + { + le_decimal_columns(rhs, lhs, right_column_type, left_column_type) + } + + fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { + Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Add, + left_type: ColumnType::VarChar, + right_type: ColumnType::VarChar, + }) + } +} + +pub struct LessThanOrEqualOp {} +impl ComparisonOp for LessThanOrEqualOp { + fn op(l: &T, r: &T) -> bool + where + T: Debug + Ord, + { + l <= r + } + + fn decimal_op_left_upcast( + lhs: &[T], + rhs: &[S], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + Ord + Zero + Into, + { + le_decimal_columns(lhs, rhs, left_column_type, right_column_type) + } + + fn decimal_op_right_upcast( + lhs: &[S], + rhs: &[T], + left_column_type: ColumnType, + right_column_type: ColumnType, + ) -> Vec + where + S: Scalar, + T: Copy + Debug + Ord + Zero + Into, + { + ge_decimal_columns(rhs, lhs, right_column_type, left_column_type) + } + + fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { + Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Add, + left_type: ColumnType::VarChar, + right_type: ColumnType::VarChar, + }) + } +} diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index c38e4b010..b45f3ea6a 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -20,6 +20,11 @@ pub use column_type_operation::{ mod column_arithmetic_operation; pub(super) use column_arithmetic_operation::{AddOp, ArithmeticOp, DivOp, MulOp, SubOp}; +mod column_comparison_operation; +pub(super) use column_comparison_operation::{ + ComparisonOp, EqualOp, GreaterThanOrEqualOp, LessThanOrEqualOp, +}; + mod column_operation_error; pub use column_operation_error::{ColumnOperationError, ColumnOperationResult}; diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index 6eebbe207..43a33460b 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -1,13 +1,10 @@ use super::{ - AddOp, ArithmeticOp, ColumnOperationError, ColumnOperationResult, DivOp, MulOp, SubOp, + AddOp, ArithmeticOp, ColumnOperationError, ColumnOperationResult, ComparisonOp, DivOp, EqualOp, + GreaterThanOrEqualOp, LessThanOrEqualOp, MulOp, SubOp, }; use crate::base::{ database::{ - slice_decimal_operation::{eq_decimal_columns, ge_decimal_columns, le_decimal_columns}, - slice_operation::{ - slice_and, slice_eq, slice_eq_with_casting, slice_ge, slice_ge_with_casting, slice_le, - slice_le_with_casting, slice_not, slice_or, - }, + slice_operation::{slice_and, slice_not, slice_or}, OwnedColumn, }, scalar::Scalar, @@ -64,547 +61,18 @@ impl OwnedColumn { } /// Element-wise equality check for two columns - #[allow(clippy::too_many_lines)] pub fn element_wise_eq(&self, rhs: &Self) -> ColumnOperationResult { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (self, rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(lhs, rhs))) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_eq_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::Int128(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(eq_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - (Self::Boolean(lhs), Self::Boolean(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::Scalar(lhs), Self::Scalar(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::VarChar(lhs), Self::VarChar(rhs)) => Ok(Self::Boolean(slice_eq(lhs, rhs))), - (Self::TimestampTZ(_, _, _), Self::TimestampTZ(_, _, _)) => { - todo!("Implement equality check for TimeStampTZ") - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "=".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } + EqualOp::owned_column_element_wise_comparison(self, rhs) } - /// Element-wise <= check for two columns - #[allow(clippy::too_many_lines)] + /// Element-wise less than or equal to check for two columns pub fn element_wise_le(&self, rhs: &Self) -> ColumnOperationResult { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (self, rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(lhs, rhs))) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::Int128(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - (Self::Boolean(lhs), Self::Boolean(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::Scalar(lhs), Self::Scalar(rhs)) => Ok(Self::Boolean(slice_le(lhs, rhs))), - (Self::TimestampTZ(_, _, _), Self::TimestampTZ(_, _, _)) => { - todo!("Implement inequality check for TimeStampTZ") - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: "<=".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } + LessThanOrEqualOp::owned_column_element_wise_comparison(self, rhs) } - /// Element-wise >= check for two columns - #[allow(clippy::too_many_lines)] + /// Element-wise greater than or equal to check for two columns pub fn element_wise_ge(&self, rhs: &Self) -> ColumnOperationResult { - if self.len() != rhs.len() { - return Err(ColumnOperationError::DifferentColumnLength { - len_a: self.len(), - len_b: rhs.len(), - }); - } - match (self, rhs) { - (Self::TinyInt(lhs), Self::TinyInt(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::TinyInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::TinyInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::SmallInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::SmallInt(lhs), Self::SmallInt(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::SmallInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::SmallInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Int(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::Int(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::Int(lhs), Self::Int(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::Int(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::Int(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::Int(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::BigInt(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::BigInt(lhs), Self::BigInt(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::BigInt(lhs), Self::Int128(rhs)) => { - Ok(Self::Boolean(slice_ge_with_casting(lhs, rhs))) - } - (Self::BigInt(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Int128(lhs), Self::TinyInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::SmallInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::Int(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::BigInt(rhs)) => { - Ok(Self::Boolean(slice_le_with_casting(rhs, lhs))) - } - (Self::Int128(lhs), Self::Int128(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::Int128(lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - - (Self::Decimal75(_, _, lhs_values), Self::TinyInt(rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::SmallInt(rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Int(rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::BigInt(rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Int128(rhs_values)) => { - Ok(Self::Boolean(le_decimal_columns( - rhs_values, - lhs_values, - rhs.column_type(), - self.column_type(), - ))) - } - (Self::Decimal75(_, _, lhs_values), Self::Decimal75(_, _, rhs_values)) => { - Ok(Self::Boolean(ge_decimal_columns( - lhs_values, - rhs_values, - self.column_type(), - rhs.column_type(), - ))) - } - (Self::Boolean(lhs), Self::Boolean(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::Scalar(lhs), Self::Scalar(rhs)) => Ok(Self::Boolean(slice_ge(lhs, rhs))), - (Self::TimestampTZ(_, _, _), Self::TimestampTZ(_, _, _)) => { - todo!("Implement inequality check for TimeStampTZ") - } - _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: ">=".to_string(), - left_type: self.column_type(), - right_type: rhs.column_type(), - }), - } + GreaterThanOrEqualOp::owned_column_element_wise_comparison(self, rhs) } /// Element-wise addition for two columns From 0dad376159bfe9f2a35c82bdd3547cbf09eaab8e Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Sun, 10 Nov 2024 22:38:56 -0500 Subject: [PATCH 3/4] refactor!: remove unused functions --- .../database/column_arithmetic_operation.rs | 5 +- .../database/column_comparison_operation.rs | 14 +- .../base/database/expression_evaluation.rs | 8 +- .../base/database/owned_column_operation.rs | 139 +++++----- .../src/base/database/slice_operation.rs | 247 ++---------------- 5 files changed, 106 insertions(+), 307 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs b/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs index 0c5e6fc13..728badffe 100644 --- a/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs @@ -14,11 +14,10 @@ use crate::base::{ math::decimal::Precision, scalar::Scalar, }; -use alloc::vec::Vec; +use alloc::{string::ToString, vec::Vec}; use core::fmt::Debug; use num_bigint::BigInt; use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; pub trait ArithmeticOp { fn op(l: &T, r: &T) -> ColumnOperationResult @@ -186,7 +185,7 @@ pub trait ArithmeticOp { Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values)) } _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Add, + operator: "ArithmeticOp".to_string(), left_type: lhs.column_type(), right_type: rhs.column_type(), }), diff --git a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs index 8329ca277..1f5c32833 100644 --- a/crates/proof-of-sql/src/base/database/column_comparison_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_comparison_operation.rs @@ -9,9 +9,12 @@ use crate::base::{ }, scalar::Scalar, }; +use alloc::{ + string::{String, ToString}, + vec::Vec, +}; use core::{cmp::Ord, fmt::Debug}; use num_traits::Zero; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; pub trait ComparisonOp { fn op(l: &T, r: &T) -> bool @@ -223,9 +226,12 @@ pub trait ComparisonOp { rhs.column_type(), )), + (OwnedColumn::Boolean(lhs), OwnedColumn::Boolean(rhs)) => { + Ok(slice_binary_op(lhs, rhs, Self::op)) + } (OwnedColumn::VarChar(lhs), OwnedColumn::VarChar(rhs)) => Self::string_op(lhs, rhs), _ => Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Add, + operator: "ComparisonOp".to_string(), left_type: lhs.column_type(), right_type: rhs.column_type(), }), @@ -311,7 +317,7 @@ impl ComparisonOp for GreaterThanOrEqualOp { fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Add, + operator: ">=".to_string(), left_type: ColumnType::VarChar, right_type: ColumnType::VarChar, }) @@ -355,7 +361,7 @@ impl ComparisonOp for LessThanOrEqualOp { fn string_op(_lhs: &[String], _rhs: &[String]) -> ColumnOperationResult> { Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Add, + operator: "<=".to_string(), left_type: ColumnType::VarChar, right_type: ColumnType::VarChar, }) diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index f4e81424c..9b4da0b90 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -91,10 +91,10 @@ impl OwnedTable { BinaryOperator::Equal => Ok(left.element_wise_eq(&right)?), BinaryOperator::GreaterThanOrEqual => Ok(left.element_wise_ge(&right)?), BinaryOperator::LessThanOrEqual => Ok(left.element_wise_le(&right)?), - BinaryOperator::Add => Ok((left + right)?), - BinaryOperator::Subtract => Ok((left - right)?), - BinaryOperator::Multiply => Ok((left * right)?), - BinaryOperator::Division => Ok((left / right)?), + BinaryOperator::Add => Ok(left.element_wise_add(&right)?), + BinaryOperator::Subtract => Ok(left.element_wise_sub(&right)?), + BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?), + BinaryOperator::Division => Ok(left.element_wise_div(&right)?), } } } diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index 43a33460b..48eca7027 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -10,7 +10,6 @@ use crate::base::{ scalar::Scalar, }; use alloc::string::ToString; -use core::ops::{Add, Div, Mul, Sub}; impl OwnedColumn { /// Element-wise NOT operation for a column @@ -100,6 +99,7 @@ impl OwnedColumn { mod test { use super::*; use crate::base::{math::decimal::Precision, scalar::test_scalar::TestScalar}; + use alloc::vec; #[test] fn we_cannot_do_binary_operation_on_columns_with_different_lengths() { @@ -132,7 +132,7 @@ mod test { let lhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); let rhs = OwnedColumn::::TinyInt(vec![1, 2]); - let result = lhs.clone() + rhs.clone(); + let result = lhs.element_wise_add(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) @@ -140,25 +140,25 @@ mod test { let lhs = OwnedColumn::::SmallInt(vec![1, 2, 3]); let rhs = OwnedColumn::::SmallInt(vec![1, 2]); - let result = lhs.clone() + rhs.clone(); + let result = lhs.element_wise_add(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs.clone() - rhs.clone(); + let result = lhs.element_wise_sub(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs.clone() * rhs.clone(); + let result = lhs.element_wise_mul(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) )); - let result = lhs / rhs; + let result = lhs.element_wise_div(&rhs); assert!(matches!( result, Err(ColumnOperationError::DifferentColumnLength { .. }) @@ -530,25 +530,25 @@ mod test { TestScalar::from(2), TestScalar::from(3), ]); - let result = lhs.clone() + rhs.clone(); + let result = lhs.element_wise_add(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.clone() - rhs.clone(); + let result = lhs.element_wise_sub(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs.clone() * rhs.clone(); + let result = lhs.element_wise_mul(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); - let result = lhs / rhs; + let result = lhs.element_wise_div(&rhs); assert!(matches!( result, Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) @@ -560,46 +560,40 @@ mod test { // lhs and rhs have the same precision let lhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); - let result = lhs + rhs; - assert_eq!( - result, - Ok(OwnedColumn::::TinyInt(vec![2_i8, 4, 6])) - ); + let result = lhs.element_wise_add(&rhs).unwrap(); + assert_eq!(result, OwnedColumn::::TinyInt(vec![2_i8, 4, 6])); let lhs = OwnedColumn::::SmallInt(vec![1_i16, 2, 3]); let rhs = OwnedColumn::::SmallInt(vec![1_i16, 2, 3]); - let result = lhs + rhs; + let result = lhs.element_wise_add(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::SmallInt(vec![2_i16, 4, 6])) + OwnedColumn::::SmallInt(vec![2_i16, 4, 6]) ); // lhs and rhs have different precisions let lhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); let rhs = OwnedColumn::::Int(vec![1_i32, 2, 3]); - let result = lhs + rhs; - assert_eq!( - result, - Ok(OwnedColumn::::Int(vec![2_i32, 4, 6])) - ); + let result = lhs.element_wise_add(&rhs).unwrap(); + assert_eq!(result, OwnedColumn::::Int(vec![2_i32, 4, 6])); let lhs = OwnedColumn::::Int128(vec![1_i128, 2, 3]); let rhs = OwnedColumn::::Int(vec![1_i32, 2, 3]); - let result = lhs + rhs; + let result = lhs.element_wise_add(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::Int128(vec![2_i128, 4, 6])) + OwnedColumn::::Int128(vec![2_i128, 4, 6]) ); } #[test] - fn we_can_decimal_op() { + fn we_can_add_decimal_columns() { // lhs and rhs have the same precision and scale let lhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs + rhs).unwrap(); + let result = lhs.element_wise_add(&rhs).unwrap(); let expected_scalars = [2, 4, 6].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -611,7 +605,7 @@ mod test { let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(51).unwrap(), 3, rhs_scalars); - let result = (lhs + rhs).unwrap(); + let result = lhs.element_wise_add(&rhs).unwrap(); let expected_scalars = [11, 22, 33].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -622,7 +616,7 @@ mod test { let lhs = OwnedColumn::::TinyInt(vec![1, 2, 3]); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs + rhs).unwrap(); + let result = lhs.element_wise_add(&rhs).unwrap(); let expected_scalars = [101, 202, 303].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -632,7 +626,7 @@ mod test { let lhs = OwnedColumn::::Int(vec![1, 2, 3]); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs + rhs).unwrap(); + let result = lhs.element_wise_add(&rhs).unwrap(); let expected_scalars = [101, 202, 303].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -641,50 +635,47 @@ mod test { } #[test] - fn we_can_try_subtract_integer_columns() { + fn we_can_subtract_integer_columns() { // lhs and rhs have the same precision let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, 2]); let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); - let result = lhs - rhs; + let result = lhs.element_wise_sub(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::TinyInt(vec![3_i8, 3, -1])) + OwnedColumn::::TinyInt(vec![3_i8, 3, -1]) ); let lhs = OwnedColumn::::Int(vec![4_i32, 5, 2]); let rhs = OwnedColumn::::Int(vec![1_i32, 2, 3]); - let result = lhs - rhs; - assert_eq!( - result, - Ok(OwnedColumn::::Int(vec![3_i32, 3, -1])) - ); + let result = lhs.element_wise_sub(&rhs).unwrap(); + assert_eq!(result, OwnedColumn::::Int(vec![3_i32, 3, -1])); // lhs and rhs have different precisions let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, 2]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 5]); - let result = lhs - rhs; + let result = lhs.element_wise_sub(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::BigInt(vec![3_i64, 3, -3])) + OwnedColumn::::BigInt(vec![3_i64, 3, -3]) ); let lhs = OwnedColumn::::Int(vec![3_i32, 2, 3]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 5]); - let result = lhs - rhs; + let result = lhs.element_wise_sub(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::BigInt(vec![2_i64, 0, -2])) + OwnedColumn::::BigInt(vec![2_i64, 0, -2]) ); } #[test] - fn we_can_try_subtract_decimal_columns() { + fn we_can_subtract_decimal_columns() { // lhs and rhs have the same precision and scale let lhs_scalars = [4, 5, 2].iter().map(TestScalar::from).collect(); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs - rhs).unwrap(); + let result = lhs.element_wise_sub(&rhs).unwrap(); let expected_scalars = [3, 3, -1].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -696,7 +687,7 @@ mod test { let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(25).unwrap(), 2, lhs_scalars); let rhs = OwnedColumn::::Decimal75(Precision::new(51).unwrap(), 3, rhs_scalars); - let result = (lhs - rhs).unwrap(); + let result = lhs.element_wise_sub(&rhs).unwrap(); let expected_scalars = [39, 48, 17].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -707,7 +698,7 @@ mod test { let lhs = OwnedColumn::::TinyInt(vec![4, 5, 2]); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs - rhs).unwrap(); + let result = lhs.element_wise_sub(&rhs).unwrap(); let expected_scalars = [399, 498, 197].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -717,7 +708,7 @@ mod test { let lhs = OwnedColumn::::Int(vec![4, 5, 2]); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs - rhs).unwrap(); + let result = lhs.element_wise_sub(&rhs).unwrap(); let expected_scalars = [399, 498, 197].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -726,50 +717,50 @@ mod test { } #[test] - fn we_can_try_multiply_integer_columns() { + fn we_can_multiply_integer_columns() { // lhs and rhs have the same precision let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, -2]); let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); - let result = lhs * rhs; + let result = lhs.element_wise_mul(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::TinyInt(vec![4_i8, 10, -6])) + OwnedColumn::::TinyInt(vec![4_i8, 10, -6]) ); let lhs = OwnedColumn::::BigInt(vec![4_i64, 5, -2]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 3]); - let result = lhs * rhs; + let result = lhs.element_wise_mul(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::BigInt(vec![4_i64, 10, -6])) + OwnedColumn::::BigInt(vec![4_i64, 10, -6]) ); // lhs and rhs have different precisions let lhs = OwnedColumn::::TinyInt(vec![3_i8, 2, 3]); let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); - let result = lhs * rhs; + let result = lhs.element_wise_mul(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::Int128(vec![3_i128, 4, 15])) + OwnedColumn::::Int128(vec![3_i128, 4, 15]) ); let lhs = OwnedColumn::::Int(vec![3_i32, 2, 3]); let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); - let result = lhs * rhs; + let result = lhs.element_wise_mul(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::Int128(vec![3_i128, 4, 15])) + OwnedColumn::::Int128(vec![3_i128, 4, 15]) ); } #[test] - fn we_can_try_multiply_decimal_columns() { + fn we_can_multiply_decimal_columns() { // lhs and rhs are both decimals let lhs_scalars = [4, 5, 2].iter().map(TestScalar::from).collect(); let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, lhs_scalars); let rhs_scalars = [-1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs * rhs).unwrap(); + let result = lhs.element_wise_mul(&rhs).unwrap(); let expected_scalars = [-4, 10, 6].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -780,7 +771,7 @@ mod test { let lhs = OwnedColumn::::TinyInt(vec![4, 5, 2]); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs * rhs).unwrap(); + let result = lhs.element_wise_mul(&rhs).unwrap(); let expected_scalars = [4, 10, 6].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -790,7 +781,7 @@ mod test { let lhs = OwnedColumn::::Int(vec![4, 5, 2]); let rhs_scalars = [1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs * rhs).unwrap(); + let result = lhs.element_wise_mul(&rhs).unwrap(); let expected_scalars = [4, 10, 6].iter().map(TestScalar::from).collect(); assert_eq!( result, @@ -799,39 +790,33 @@ mod test { } #[test] - fn we_can_try_divide_integer_columns() { + fn we_can_divide_integer_columns() { // lhs and rhs have the same precision let lhs = OwnedColumn::::TinyInt(vec![4_i8, 5, -2]); let rhs = OwnedColumn::::TinyInt(vec![1_i8, 2, 3]); - let result = lhs / rhs; - assert_eq!( - result, - Ok(OwnedColumn::::TinyInt(vec![4_i8, 2, 0])) - ); + let result = lhs.element_wise_div(&rhs).unwrap(); + assert_eq!(result, OwnedColumn::::TinyInt(vec![4_i8, 2, 0])); let lhs = OwnedColumn::::BigInt(vec![4_i64, 5, -2]); let rhs = OwnedColumn::::BigInt(vec![1_i64, 2, 3]); - let result = lhs / rhs; - assert_eq!( - result, - Ok(OwnedColumn::::BigInt(vec![4_i64, 2, 0])) - ); + let result = lhs.element_wise_div(&rhs).unwrap(); + assert_eq!(result, OwnedColumn::::BigInt(vec![4_i64, 2, 0])); // lhs and rhs have different precisions let lhs = OwnedColumn::::TinyInt(vec![3_i8, 2, 3]); let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); - let result = lhs / rhs; + let result = lhs.element_wise_div(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::Int128(vec![3_i128, 1, 0])) + OwnedColumn::::Int128(vec![3_i128, 1, 0]) ); let lhs = OwnedColumn::::Int(vec![3_i32, 2, 3]); let rhs = OwnedColumn::::Int128(vec![1_i128, 2, 5]); - let result = lhs / rhs; + let result = lhs.element_wise_div(&rhs).unwrap(); assert_eq!( result, - Ok(OwnedColumn::::Int128(vec![3_i128, 1, 0])) + OwnedColumn::::Int128(vec![3_i128, 1, 0]) ); } @@ -842,7 +827,7 @@ mod test { let lhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, lhs_scalars); let rhs_scalars = [-1, 2, 4].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(5).unwrap(), 2, rhs_scalars); - let result = (lhs / rhs).unwrap(); + let result = lhs.element_wise_div(&rhs).unwrap(); let expected_scalars = [-400_000_000_i128, 250_000_000, 75_000_000] .iter() .map(TestScalar::from) @@ -856,7 +841,7 @@ mod test { let lhs = OwnedColumn::::TinyInt(vec![4, 5, 3]); let rhs_scalars = [-1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(3).unwrap(), 2, rhs_scalars); - let result = (lhs / rhs).unwrap(); + let result = lhs.element_wise_div(&rhs).unwrap(); let expected_scalars = [-400_000_000, 250_000_000, 100_000_000] .iter() .map(TestScalar::from) @@ -869,7 +854,7 @@ mod test { let lhs = OwnedColumn::::SmallInt(vec![4, 5, 3]); let rhs_scalars = [-1, 2, 3].iter().map(TestScalar::from).collect(); let rhs = OwnedColumn::::Decimal75(Precision::new(3).unwrap(), 2, rhs_scalars); - let result = (lhs / rhs).unwrap(); + let result = lhs.element_wise_div(&rhs).unwrap(); let expected_scalars = [-400_000_000, 250_000_000, 100_000_000] .iter() .map(TestScalar::from) diff --git a/crates/proof-of-sql/src/base/database/slice_operation.rs b/crates/proof-of-sql/src/base/database/slice_operation.rs index dd83edb90..b008fa56d 100644 --- a/crates/proof-of-sql/src/base/database/slice_operation.rs +++ b/crates/proof-of-sql/src/base/database/slice_operation.rs @@ -142,201 +142,10 @@ pub(super) fn slice_or(lhs: &[bool], rhs: &[bool]) -> Vec { slice_binary_op(lhs, rhs, |l, r| -> bool { *l || *r }) } -/// Try to check whether two slices of the same length are equal element-wise. -/// -/// We do not check for length equality here. -pub(super) fn slice_eq(lhs: &[T], rhs: &[T]) -> Vec -where - T: PartialEq + Debug, -{ - slice_binary_op(lhs, rhs, PartialEq::eq) -} - -/// Try to check whether a slice is less than or equal to another element-wise. -/// -/// We do not check for length equality here. -pub(super) fn slice_le(lhs: &[T], rhs: &[T]) -> Vec -where - T: PartialOrd + Debug, -{ - slice_binary_op(lhs, rhs, PartialOrd::le) -} - -/// Try to check whether a slice is greater than or equal to another element-wise. -/// -/// We do not check for length equality here. -pub(super) fn slice_ge(lhs: &[T], rhs: &[T]) -> Vec -where - T: PartialOrd + Debug, -{ - slice_binary_op(lhs, rhs, PartialOrd::ge) -} - -/// Try to add two slices of the same length. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_add_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedAdd + Debug, -{ - try_slice_binary_op(lhs, rhs, try_add) -} - -/// Subtract one slice from another of the same length. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_subtract_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedSub + Debug, -{ - try_slice_binary_op(lhs, rhs, try_sub) -} - -/// Multiply two slices of the same length. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_multiply_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedMul + Debug, -{ - try_slice_binary_op(lhs, rhs, try_mul) -} - -/// Divide one slice by another of the same length. -/// -/// We do not check for length equality here. However, we do check for division by 0. -pub(super) fn try_divide_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedDiv + Debug, -{ - try_slice_binary_op(lhs, rhs, try_div) -} - -// Casting required for binary operations on different types - -/// Check whether two slices of the same length are equal element-wise. -/// -/// Note that we cast elements of the left slice to the type of the right slice. -/// Also note that we do not check for length equality here. -pub(super) fn slice_eq_with_casting(lhs: &[S], rhs: &[T]) -> Vec -where - S: Copy + Debug + Into, - T: PartialEq + Copy + Debug, -{ - slice_binary_op_left_upcast(lhs, rhs, PartialEq::eq) -} - -/// Check whether a slice is less than or equal to another element-wise. -/// -/// Note that we cast elements of the left slice to the type of the right slice. -/// Also note that we do not check for length equality here. -pub(super) fn slice_le_with_casting(lhs: &[S], rhs: &[T]) -> Vec -where - S: Copy + Debug + Into, - T: PartialOrd + Copy + Debug, -{ - slice_binary_op_left_upcast(lhs, rhs, PartialOrd::le) -} - -/// Check whether a slice is greater than or equal to another element-wise. -/// -/// Note that we cast elements of the left slice to the type of the right slice. -/// Also note that we do not check for length equality here. -pub(super) fn slice_ge_with_casting(lhs: &[S], rhs: &[T]) -> Vec -where - S: Copy + Debug + Into, - T: PartialOrd + Copy + Debug, -{ - slice_binary_op_left_upcast(lhs, rhs, PartialOrd::ge) -} - -/// Add two slices of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_add_slices_with_casting( - lhs: &[S], - rhs: &[T], -) -> ColumnOperationResult> -where - S: Copy + Debug + Into, - T: CheckedAdd + Copy + Debug, -{ - try_slice_binary_op_left_upcast(lhs, rhs, try_add) -} - -/// Subtract one slice from another of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here -pub(super) fn try_subtract_slices_left_upcast( - lhs: &[S], - rhs: &[T], -) -> ColumnOperationResult> -where - S: Copy + Debug + Into, - T: CheckedSub + Copy + Debug, -{ - try_slice_binary_op_left_upcast(lhs, rhs, try_sub) -} - -/// Subtract one slice from another of the same length, casting the right slice to the type of the left slice. -/// -/// We do not check for length equality here -pub(super) fn try_subtract_slices_right_upcast( - lhs: &[T], - rhs: &[S], -) -> ColumnOperationResult> -where - S: Copy + Debug + Into, - T: CheckedSub + Copy + Debug, -{ - try_slice_binary_op_right_upcast(lhs, rhs, try_sub) -} - -/// Multiply two slices of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_multiply_slices_with_casting( - lhs: &[S], - rhs: &[T], -) -> ColumnOperationResult> -where - S: Copy + Debug + Into, - T: CheckedMul + Copy + Debug, -{ - try_slice_binary_op_left_upcast(lhs, rhs, try_mul) -} - -/// Divide one slice by another of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here -pub(super) fn try_divide_slices_left_upcast( - lhs: &[S], - rhs: &[T], -) -> ColumnOperationResult> -where - S: Copy + Debug + Into, - T: CheckedDiv + Copy + Debug, -{ - try_slice_binary_op_left_upcast(lhs, rhs, try_div) -} - -/// Divide one slice by another of the same length, casting the right slice to the type of the left slice. -/// -/// We do not check for length equality here -pub(super) fn try_divide_slices_right_upcast( - lhs: &[T], - rhs: &[S], -) -> ColumnOperationResult> -where - S: Copy + Debug + Into, - T: CheckedDiv + Copy + Debug, -{ - try_slice_binary_op_right_upcast(lhs, rhs, try_div) -} - #[cfg(test)] mod test { use super::*; + use core::cmp::{PartialEq, PartialOrd}; // NOT #[test] @@ -372,14 +181,14 @@ mod test { fn we_can_eq_slices() { let lhs = [1_i16, 2, 3]; let rhs = [1_i16, 3, 3]; - let actual = slice_eq(&lhs, &rhs); + let actual = slice_binary_op(&lhs, &rhs, PartialEq::eq); let expected = vec![true, false, true]; assert_eq!(expected, actual); // Try strings let lhs = ["Chloe".to_string(), "Margaret".to_string()]; let rhs = ["Chloe".to_string(), "Chloe".to_string()]; - let actual = slice_eq(&lhs, &rhs); + let actual = slice_binary_op(&lhs, &rhs, PartialEq::eq); let expected = vec![true, false]; assert_eq!(expected, actual); } @@ -388,7 +197,7 @@ mod test { fn we_can_eq_slices_with_cast() { let lhs = [1_i16, 2, 3]; let rhs = [1_i32, 3, 3]; - let actual = slice_eq_with_casting(&lhs, &rhs); + let actual = slice_binary_op_left_upcast(&lhs, &rhs, PartialEq::eq); let expected = vec![true, false, true]; assert_eq!(expected, actual); } @@ -398,7 +207,7 @@ mod test { fn we_can_le_slices() { let lhs = [1_i32, 2, 3]; let rhs = [1_i32, 3, 2]; - let actual = slice_le(&lhs, &rhs); + let actual = slice_binary_op(&lhs, &rhs, PartialOrd::le); let expected = vec![true, true, false]; assert_eq!(expected, actual); } @@ -407,7 +216,7 @@ mod test { fn we_can_le_slices_with_cast() { let lhs = [1_i16, 2, 3]; let rhs = [1_i64, 3, 2]; - let actual = slice_le_with_casting(&lhs, &rhs); + let actual = slice_binary_op_left_upcast(&lhs, &rhs, PartialOrd::le); let expected = vec![true, true, false]; assert_eq!(expected, actual); } @@ -417,7 +226,7 @@ mod test { fn we_can_ge_slices() { let lhs = [1_i128, 2, 3]; let rhs = [1_i128, 3, 2]; - let actual = slice_ge(&lhs, &rhs); + let actual = slice_binary_op(&lhs, &rhs, PartialOrd::ge); let expected = vec![true, false, true]; assert_eq!(expected, actual); } @@ -426,7 +235,7 @@ mod test { fn we_can_ge_slices_with_cast() { let lhs = [1_i16, 2, 3]; let rhs = [1_i64, 3, 2]; - let actual = slice_ge_with_casting(&lhs, &rhs); + let actual = slice_binary_op_left_upcast(&lhs, &rhs, PartialOrd::ge); let expected = vec![true, false, true]; assert_eq!(expected, actual); } @@ -436,7 +245,7 @@ mod test { fn we_can_try_add_slices() { let lhs = [1_i16, 2, 3]; let rhs = [4_i16, -5, 6]; - let actual = try_add_slices(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op(&lhs, &rhs, try_add).unwrap(); let expected = vec![5_i16, -3, 9]; assert_eq!(expected, actual); } @@ -446,7 +255,7 @@ mod test { let lhs = [i16::MAX, 1]; let rhs = [1_i16, 1]; assert!(matches!( - try_add_slices(&lhs, &rhs), + try_slice_binary_op(&lhs, &rhs, try_add), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -455,7 +264,7 @@ mod test { fn we_can_try_add_slices_with_cast() { let lhs = [1_i16, 2, 3]; let rhs = [4_i32, -5, 6]; - let actual = try_add_slices_with_casting(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op_left_upcast(&lhs, &rhs, try_add).unwrap(); let expected = vec![5_i32, -3, 9]; assert_eq!(expected, actual); } @@ -465,7 +274,7 @@ mod test { let lhs = [-1_i16, 1]; let rhs = [i32::MIN, 1]; assert!(matches!( - try_add_slices_with_casting(&lhs, &rhs), + try_slice_binary_op_left_upcast(&lhs, &rhs, try_add), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -475,7 +284,7 @@ mod test { fn we_can_try_subtract_slices() { let lhs = [1_i16, 2, 3]; let rhs = [4_i16, -5, 6]; - let actual = try_subtract_slices(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op(&lhs, &rhs, try_sub).unwrap(); let expected = vec![-3_i16, 7, -3]; assert_eq!(expected, actual); } @@ -485,7 +294,7 @@ mod test { let lhs = [i128::MIN, 1]; let rhs = [1_i128, 1]; assert!(matches!( - try_subtract_slices(&lhs, &rhs), + try_slice_binary_op(&lhs, &rhs, try_sub), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -494,7 +303,7 @@ mod test { fn we_can_try_subtract_slices_left_upcast() { let lhs = [1_i16, 2, 3]; let rhs = [4_i32, -5, 6]; - let actual = try_subtract_slices_left_upcast(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op_left_upcast(&lhs, &rhs, try_sub).unwrap(); let expected = vec![-3_i32, 7, -3]; assert_eq!(expected, actual); } @@ -504,7 +313,7 @@ mod test { let lhs = [0_i16, 1]; let rhs = [i32::MIN, 1]; assert!(matches!( - try_subtract_slices_left_upcast(&lhs, &rhs), + try_slice_binary_op_left_upcast(&lhs, &rhs, try_sub), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -513,7 +322,7 @@ mod test { fn we_can_try_subtract_slices_right_upcast() { let lhs = [1_i32, 2, 3]; let rhs = [4_i16, -5, 6]; - let actual = try_subtract_slices_right_upcast(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op_right_upcast(&lhs, &rhs, try_sub).unwrap(); let expected = vec![-3_i32, 7, -3]; assert_eq!(expected, actual); } @@ -523,7 +332,7 @@ mod test { let lhs = [i32::MIN, 1]; let rhs = [1_i16, 1]; assert!(matches!( - try_subtract_slices_right_upcast(&lhs, &rhs), + try_slice_binary_op_right_upcast(&lhs, &rhs, try_sub), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -533,7 +342,7 @@ mod test { fn we_can_try_multiply_slices() { let lhs = [1_i16, 2, 3]; let rhs = [4_i16, -5, 6]; - let actual = try_multiply_slices(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op(&lhs, &rhs, try_mul).unwrap(); let expected = vec![4_i16, -10, 18]; assert_eq!(expected, actual); } @@ -543,7 +352,7 @@ mod test { let lhs = [i32::MAX, 2]; let rhs = [2, 2]; assert!(matches!( - try_multiply_slices(&lhs, &rhs), + try_slice_binary_op(&lhs, &rhs, try_mul), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -552,7 +361,7 @@ mod test { fn we_can_try_multiply_slices_with_cast() { let lhs = [1_i16, 2, 3]; let rhs = [4_i32, -5, 6]; - let actual = try_multiply_slices_with_casting(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op_left_upcast(&lhs, &rhs, try_mul).unwrap(); let expected = vec![4_i32, -10, 18]; assert_eq!(expected, actual); } @@ -562,7 +371,7 @@ mod test { let lhs = [2_i16, 2]; let rhs = [i32::MAX, 2]; assert!(matches!( - try_multiply_slices_with_casting(&lhs, &rhs), + try_slice_binary_op_left_upcast(&lhs, &rhs, try_mul), Err(ColumnOperationError::IntegerOverflow { .. }) )); } @@ -572,7 +381,7 @@ mod test { fn we_can_try_divide_slices() { let lhs = [5_i16, -5, -7, 9]; let rhs = [-3_i16, 3, -4, 5]; - let actual = try_divide_slices(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op(&lhs, &rhs, try_div).unwrap(); let expected = vec![-1_i16, -1, 1, 1]; assert_eq!(expected, actual); } @@ -582,7 +391,7 @@ mod test { let lhs = [1_i32, 2, 3]; let rhs = [0_i32, -5, 6]; assert!(matches!( - try_divide_slices(&lhs, &rhs), + try_slice_binary_op(&lhs, &rhs, try_div), Err(ColumnOperationError::DivisionByZero) )); } @@ -591,7 +400,7 @@ mod test { fn we_can_try_divide_slices_left_upcast() { let lhs = [5_i16, -4, -9, 9]; let rhs = [-3_i32, 3, -4, 5]; - let actual = try_divide_slices_left_upcast(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op_left_upcast(&lhs, &rhs, try_div).unwrap(); let expected = vec![-1_i32, -1, 2, 1]; assert_eq!(expected, actual); } @@ -601,7 +410,7 @@ mod test { let lhs = [1_i16, 2]; let rhs = [0_i32, 2]; assert!(matches!( - try_divide_slices_left_upcast(&lhs, &rhs), + try_slice_binary_op_left_upcast(&lhs, &rhs, try_div), Err(ColumnOperationError::DivisionByZero) )); } @@ -610,7 +419,7 @@ mod test { fn we_can_try_divide_slices_right_upcast() { let lhs = [15_i128, -82, -7, 9]; let rhs = [-3_i32, 3, -4, 5]; - let actual = try_divide_slices_right_upcast(&lhs, &rhs).unwrap(); + let actual = try_slice_binary_op_right_upcast(&lhs, &rhs, try_div).unwrap(); let expected = vec![-5_i128, -27, 1, 1]; assert_eq!(expected, actual); } @@ -620,7 +429,7 @@ mod test { let lhs = [1_i32, 2]; let rhs = [0_i16, 2]; assert!(matches!( - try_divide_slices_right_upcast(&lhs, &rhs), + try_slice_binary_op_right_upcast(&lhs, &rhs, try_div), Err(ColumnOperationError::DivisionByZero) )); } From 11596efbe44b30b8238ccfd4ce03c9dea33077d5 Mon Sep 17 00:00:00 2001 From: Ian Joiner <14581281+iajoiner@users.noreply.github.com> Date: Mon, 11 Nov 2024 09:16:29 -0500 Subject: [PATCH 4/4] fix: restrict `try_add` etc to super --- crates/proof-of-sql/src/base/database/slice_operation.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/slice_operation.rs b/crates/proof-of-sql/src/base/database/slice_operation.rs index b008fa56d..a64776a7b 100644 --- a/crates/proof-of-sql/src/base/database/slice_operation.rs +++ b/crates/proof-of-sql/src/base/database/slice_operation.rs @@ -4,7 +4,7 @@ use core::fmt::Debug; use num_traits::ops::checked::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}; /// Function for checked addition with overflow error handling -pub(crate) fn try_add(l: &T, r: &T) -> ColumnOperationResult +pub(super) fn try_add(l: &T, r: &T) -> ColumnOperationResult where T: CheckedAdd + Debug, { @@ -15,7 +15,7 @@ where } /// Function for checked subtraction with overflow error handling -pub(crate) fn try_sub(l: &T, r: &T) -> ColumnOperationResult +pub(super) fn try_sub(l: &T, r: &T) -> ColumnOperationResult where T: CheckedSub + Debug, { @@ -26,7 +26,7 @@ where } /// Function for checked multiplication with overflow error handling -pub(crate) fn try_mul(l: &T, r: &T) -> ColumnOperationResult +pub(super) fn try_mul(l: &T, r: &T) -> ColumnOperationResult where T: CheckedMul + Debug, { @@ -37,7 +37,7 @@ where } /// Function for checked division with division by zero error handling -pub(crate) fn try_div(l: &T, r: &T) -> ColumnOperationResult +pub(super) fn try_div(l: &T, r: &T) -> ColumnOperationResult where T: CheckedDiv + Debug, {