diff --git a/crates/proof-of-sql/src/base/database/column_type_operation.rs b/crates/proof-of-sql/src/base/database/column_type_operation.rs new file mode 100644 index 000000000..21300dfd5 --- /dev/null +++ b/crates/proof-of-sql/src/base/database/column_type_operation.rs @@ -0,0 +1,755 @@ +use super::{ColumnOperationError, ColumnOperationResult}; +use crate::base::{ + database::ColumnType, + math::decimal::{DecimalError, Precision}, +}; +use alloc::{format, string::ToString}; +use proof_of_sql_parser::intermediate_ast::BinaryOperator; +// For decimal type manipulation please refer to +// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 + +/// Determine the output type of an add or subtract operation if it is possible +/// to add or subtract the two input types. If the types are not compatible, return +/// an error. +/// +/// # Panics +/// +/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. +/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. +pub fn try_add_subtract_column_types( + lhs: ColumnType, + rhs: ColumnType, + operator: BinaryOperator, +) -> ColumnOperationResult { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator, + left_type: lhs, + right_type: rhs, + }); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = + i16::from(lhs.precision_value().expect("Numeric types have precision")); + let right_precision_value = + i16::from(rhs.precision_value().expect("Numeric types have precision")); + let left_scale = lhs.scale().expect("Numeric types have scale"); + let right_scale = rhs.scale().expect("Numeric types have scale"); + let scale = left_scale.max(right_scale); + let precision_value: i16 = i16::from(scale) + + (left_precision_value - i16::from(left_scale)) + .max(right_precision_value - i16::from(right_scale)) + + 1_i16; + let precision = u8::try_from(precision_value) + .map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: precision_value.to_string(), + }, + }) + .and_then(|p| { + Precision::new(p).map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: p.to_string(), + }, + }) + })?; + Ok(ColumnType::Decimal75(precision, scale)) + } +} + +/// Determine the output type of a multiplication operation if it is possible +/// to multiply the two input types. If the types are not compatible, return +/// an error. +/// +/// # Panics +/// +/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. +/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. +pub fn try_multiply_column_types( + lhs: ColumnType, + rhs: ColumnType, +) -> ColumnOperationResult { + if !lhs.is_numeric() || !rhs.is_numeric() { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Multiply, + left_type: lhs, + right_type: rhs, + }); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { + Ok(ColumnType::Scalar) + } else { + let left_precision_value = lhs.precision_value().expect("Numeric types have precision"); + let right_precision_value = rhs.precision_value().expect("Numeric types have precision"); + let precision_value = left_precision_value + right_precision_value + 1; + let precision = Precision::new(precision_value).map_err(|_| { + ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: format!( + "Required precision {precision_value} is beyond what we can support" + ), + }, + } + })?; + let left_scale = lhs.scale().expect("Numeric types have scale"); + let right_scale = rhs.scale().expect("Numeric types have scale"); + let scale = left_scale.checked_add(right_scale).ok_or( + ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { + scale: (i16::from(left_scale) + i16::from(right_scale)).to_string(), + }, + }, + )?; + Ok(ColumnType::Decimal75(precision, scale)) + } +} + +/// Determine the output type of a division operation if it is possible +/// to multiply the two input types. If the types are not compatible, return +/// an error. +/// +/// # Panics +/// +/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. +/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. +pub fn try_divide_column_types( + lhs: ColumnType, + rhs: ColumnType, +) -> ColumnOperationResult { + if !lhs.is_numeric() + || !rhs.is_numeric() + || lhs == ColumnType::Scalar + || rhs == ColumnType::Scalar + { + return Err(ColumnOperationError::BinaryOperationInvalidColumnType { + operator: BinaryOperator::Division, + left_type: lhs, + right_type: rhs, + }); + } + if lhs.is_integer() && rhs.is_integer() { + // We can unwrap here because we know that both types are integers + return Ok(lhs.max_integer_type(&rhs).unwrap()); + } + let left_precision_value = + i16::from(lhs.precision_value().expect("Numeric types have precision")); + let right_precision_value = + i16::from(rhs.precision_value().expect("Numeric types have precision")); + let left_scale = i16::from(lhs.scale().expect("Numeric types have scale")); + let right_scale = i16::from(rhs.scale().expect("Numeric types have scale")); + let raw_scale = (left_scale + right_precision_value + 1_i16).max(6_i16); + let precision_value: i16 = left_precision_value - left_scale + right_scale + raw_scale; + let scale = + i8::try_from(raw_scale).map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { + scale: raw_scale.to_string(), + }, + })?; + let precision = u8::try_from(precision_value) + .map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: precision_value.to_string(), + }, + }) + .and_then(|p| { + Precision::new(p).map_err(|_| ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { + error: p.to_string(), + }, + }) + })?; + Ok(ColumnType::Decimal75(precision, scale)) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn we_can_add_numeric_types() { + // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Int; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Int; + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a scalar + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Scalar; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Scalar; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + + // lhs is a decimal with nonnegative scale and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); + assert_eq!(expected, actual); + + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals with nonnegative scale + let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals one of which has negative scale + let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals both with negative scale + // and with result having maximum precision + let lhs = ColumnType::Decimal75(Precision::new(74).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), -14); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_add_non_numeric_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::VarChar; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + } + + #[test] + fn we_cannot_add_some_numeric_types_due_to_decimal_issues() { + let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 4); + let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 4); + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + + let lhs = ColumnType::Int; + let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 10); + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + } + + #[test] + fn we_can_subtract_numeric_types() { + // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Int; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Int; + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a scalar + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Scalar; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Scalar; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + + // lhs is a decimal and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); + assert_eq!(expected, actual); + + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::SmallInt; + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals with nonnegative scale + let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals one of which has negative scale + let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals both with negative scale + // and with result having maximum precision + let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), -14); + let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_subtract_non_numeric_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::VarChar; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + } + + #[test] + fn we_cannot_subtract_some_numeric_types_due_to_decimal_issues() { + let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 0); + let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 1); + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + + let lhs = ColumnType::Int128; + let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 12); + assert!(matches!( + try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + } + + #[test] + fn we_can_multiply_numeric_types() { + // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::SmallInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Int; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Int; + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a scalar + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Scalar; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Scalar; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Scalar; + assert_eq!(expected, actual); + + // lhs is a decimal and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 2); + assert_eq!(expected, actual); + + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::SmallInt; + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 2); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals with nonnegative scale + let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(31).unwrap(), 5); + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), -2); + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), -2); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals one of which has negative scale + let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(56).unwrap(), -8); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals both with negative scale + // and with result having maximum precision + let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); + let actual = try_multiply_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -27); + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_multiply_non_numeric_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::VarChar; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + } + + #[test] + fn we_cannot_multiply_some_numeric_types_due_to_decimal_issues() { + // Invalid precision + let lhs = ColumnType::Decimal75(Precision::new(38).unwrap(), 4); + let rhs = ColumnType::Decimal75(Precision::new(37).unwrap(), 4); + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + + let lhs = ColumnType::Int; + let rhs = ColumnType::Decimal75(Precision::new(65).unwrap(), 0); + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + + // Invalid scale + let lhs = ColumnType::Decimal75(Precision::new(5).unwrap(), -64_i8); + let rhs = ColumnType::Decimal75(Precision::new(5).unwrap(), -65_i8); + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { .. } + }) + )); + + let lhs = ColumnType::Decimal75(Precision::new(5).unwrap(), 64_i8); + let rhs = ColumnType::Decimal75(Precision::new(5).unwrap(), 64_i8); + assert!(matches!( + try_multiply_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { .. } + }) + )); + } + + #[test] + fn we_can_divide_numeric_types() { + // lhs and rhs are integers with the same precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::TinyInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::TinyInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::SmallInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + // lhs and rhs are integers with different precision + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::SmallInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::SmallInt; + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Int; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Int; + assert_eq!(expected, actual); + + // lhs is a decimal with nonnegative scale and rhs is an integer + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::TinyInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 6); + assert_eq!(expected, actual); + + let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let rhs = ColumnType::SmallInt; + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 8); + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a decimal with nonnegative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 11); + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(18).unwrap(), 11); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals with nonnegative scale + let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(33).unwrap(), 14); + assert_eq!(expected, actual); + + // lhs is an integer and rhs is a decimal with negative scale + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(12).unwrap(), 11); + assert_eq!(expected, actual); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 11); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals one of which has negative scale + let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(64).unwrap(), 6); + assert_eq!(expected, actual); + + // lhs and rhs are both decimals both with negative scale + // and with result having maximum precision + let lhs = ColumnType::Decimal75(Precision::new(70).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); + let actual = try_divide_column_types(lhs, rhs).unwrap(); + let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), 6); + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_divide_non_numeric_or_scalar_types() { + let lhs = ColumnType::TinyInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::SmallInt; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::VarChar; + let rhs = ColumnType::VarChar; + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + + let lhs = ColumnType::Scalar; + let rhs = ColumnType::Scalar; + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) + )); + } + + #[test] + fn we_cannot_divide_some_numeric_types_due_to_decimal_issues() { + // Invalid precision + let lhs = ColumnType::Decimal75(Precision::new(71).unwrap(), -13); + let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + + let lhs = ColumnType::Int; + let rhs = ColumnType::Decimal75(Precision::new(68).unwrap(), 67); + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidPrecision { .. } + }) + )); + + // Invalid scale + let lhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 53_i8); + let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 40_i8); + assert!(matches!( + try_divide_column_types(lhs, rhs), + Err(ColumnOperationError::DecimalConversionError { + source: DecimalError::InvalidScale { .. } + }) + )); + } +} diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 822b798ee..756e5301f 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -7,8 +7,12 @@ pub use accessor::{CommitmentAccessor, DataAccessor, MetadataAccessor, SchemaAcc mod column; pub use column::{Column, ColumnField, ColumnRef, ColumnType}; -mod column_operation; -pub use column_operation::{ +mod slice_operation; + +mod slice_decimal_operation; + +mod column_type_operation; +pub use column_type_operation::{ try_add_subtract_column_types, try_divide_column_types, try_multiply_column_types, }; diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index f0aa1e2ce..7b021356c 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -1,15 +1,16 @@ use super::{ColumnOperationError, ColumnOperationResult}; use crate::base::{ database::{ - column_operation::{ - eq_decimal_columns, ge_decimal_columns, le_decimal_columns, slice_and, slice_eq, - slice_eq_with_casting, slice_ge, slice_ge_with_casting, slice_le, - slice_le_with_casting, slice_not, slice_or, try_add_decimal_columns, try_add_slices, - try_add_slices_with_casting, try_divide_decimal_columns, try_divide_slices, - try_divide_slices_left_upcast, try_divide_slices_right_upcast, - try_multiply_decimal_columns, try_multiply_slices, try_multiply_slices_with_casting, - try_subtract_decimal_columns, try_subtract_slices, try_subtract_slices_left_upcast, - try_subtract_slices_right_upcast, + slice_decimal_operation::{ + eq_decimal_columns, ge_decimal_columns, le_decimal_columns, try_add_decimal_columns, + try_divide_decimal_columns, try_multiply_decimal_columns, try_subtract_decimal_columns, + }, + slice_operation::{ + slice_and, slice_eq, slice_eq_with_casting, slice_ge, slice_ge_with_casting, slice_le, + slice_le_with_casting, slice_not, slice_or, try_add_slices, + try_add_slices_with_casting, try_divide_slices, try_divide_slices_left_upcast, + try_divide_slices_right_upcast, try_multiply_slices, try_multiply_slices_with_casting, + try_subtract_slices, try_subtract_slices_left_upcast, try_subtract_slices_right_upcast, }, OwnedColumn, }, diff --git a/crates/proof-of-sql/src/base/database/column_operation.rs b/crates/proof-of-sql/src/base/database/slice_decimal_operation.rs similarity index 51% rename from crates/proof-of-sql/src/base/database/column_operation.rs rename to crates/proof-of-sql/src/base/database/slice_decimal_operation.rs index f5893be27..d18c92c23 100644 --- a/crates/proof-of-sql/src/base/database/column_operation.rs +++ b/crates/proof-of-sql/src/base/database/slice_decimal_operation.rs @@ -1,518 +1,19 @@ -#![allow(dead_code)] use super::{ColumnOperationError, ColumnOperationResult}; use crate::base::{ - database::ColumnType, - math::decimal::{DecimalError, Precision}, + database::{ + column_type_operation::{ + try_add_subtract_column_types, try_divide_column_types, try_multiply_column_types, + }, + ColumnType, + }, + math::decimal::Precision, scalar::{Scalar, ScalarExt}, }; -use alloc::{format, string::ToString, vec::Vec}; +use alloc::vec::Vec; use core::{cmp::Ordering, fmt::Debug}; use num_bigint::BigInt; -use num_traits::{ - ops::checked::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}, - Zero, -}; +use num_traits::Zero; use proof_of_sql_parser::intermediate_ast::BinaryOperator; - -// For decimal type manipulation please refer to -// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 - -/// Determine the output type of an add or subtract operation if it is possible -/// to add or subtract the two input types. If the types are not compatible, return -/// an error. -/// -/// # Panics -/// -/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. -/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. -pub fn try_add_subtract_column_types( - lhs: ColumnType, - rhs: ColumnType, - operator: BinaryOperator, -) -> ColumnOperationResult { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator, - left_type: lhs, - right_type: rhs, - }); - } - if lhs.is_integer() && rhs.is_integer() { - // We can unwrap here because we know that both types are integers - return Ok(lhs.max_integer_type(&rhs).unwrap()); - } - if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { - Ok(ColumnType::Scalar) - } else { - let left_precision_value = - i16::from(lhs.precision_value().expect("Numeric types have precision")); - let right_precision_value = - i16::from(rhs.precision_value().expect("Numeric types have precision")); - let left_scale = lhs.scale().expect("Numeric types have scale"); - let right_scale = rhs.scale().expect("Numeric types have scale"); - let scale = left_scale.max(right_scale); - let precision_value: i16 = i16::from(scale) - + (left_precision_value - i16::from(left_scale)) - .max(right_precision_value - i16::from(right_scale)) - + 1_i16; - let precision = u8::try_from(precision_value) - .map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: precision_value.to_string(), - }, - }) - .and_then(|p| { - Precision::new(p).map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: p.to_string(), - }, - }) - })?; - Ok(ColumnType::Decimal75(precision, scale)) - } -} - -/// Determine the output type of a multiplication operation if it is possible -/// to multiply the two input types. If the types are not compatible, return -/// an error. -/// -/// # Panics -/// -/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. -/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. -pub fn try_multiply_column_types( - lhs: ColumnType, - rhs: ColumnType, -) -> ColumnOperationResult { - if !lhs.is_numeric() || !rhs.is_numeric() { - return Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Multiply, - left_type: lhs, - right_type: rhs, - }); - } - if lhs.is_integer() && rhs.is_integer() { - // We can unwrap here because we know that both types are integers - return Ok(lhs.max_integer_type(&rhs).unwrap()); - } - if lhs == ColumnType::Scalar || rhs == ColumnType::Scalar { - Ok(ColumnType::Scalar) - } else { - let left_precision_value = lhs.precision_value().expect("Numeric types have precision"); - let right_precision_value = rhs.precision_value().expect("Numeric types have precision"); - let precision_value = left_precision_value + right_precision_value + 1; - let precision = Precision::new(precision_value).map_err(|_| { - ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: format!( - "Required precision {precision_value} is beyond what we can support" - ), - }, - } - })?; - let left_scale = lhs.scale().expect("Numeric types have scale"); - let right_scale = rhs.scale().expect("Numeric types have scale"); - let scale = left_scale.checked_add(right_scale).ok_or( - ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { - scale: (i16::from(left_scale) + i16::from(right_scale)).to_string(), - }, - }, - )?; - Ok(ColumnType::Decimal75(precision, scale)) - } -} - -/// Determine the output type of a division operation if it is possible -/// to multiply the two input types. If the types are not compatible, return -/// an error. -/// -/// # Panics -/// -/// - Panics if `lhs` or `rhs` does not have a precision or scale when they are expected to be numeric types. -/// - Panics if `lhs` or `rhs` is an integer, and `lhs.max_integer_type(&rhs)` returns `None`. -pub fn try_divide_column_types( - lhs: ColumnType, - rhs: ColumnType, -) -> ColumnOperationResult { - if !lhs.is_numeric() - || !rhs.is_numeric() - || lhs == ColumnType::Scalar - || rhs == ColumnType::Scalar - { - return Err(ColumnOperationError::BinaryOperationInvalidColumnType { - operator: BinaryOperator::Division, - left_type: lhs, - right_type: rhs, - }); - } - if lhs.is_integer() && rhs.is_integer() { - // We can unwrap here because we know that both types are integers - return Ok(lhs.max_integer_type(&rhs).unwrap()); - } - let left_precision_value = - i16::from(lhs.precision_value().expect("Numeric types have precision")); - let right_precision_value = - i16::from(rhs.precision_value().expect("Numeric types have precision")); - let left_scale = i16::from(lhs.scale().expect("Numeric types have scale")); - let right_scale = i16::from(rhs.scale().expect("Numeric types have scale")); - let raw_scale = (left_scale + right_precision_value + 1_i16).max(6_i16); - let precision_value: i16 = left_precision_value - left_scale + right_scale + raw_scale; - let scale = - i8::try_from(raw_scale).map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { - scale: raw_scale.to_string(), - }, - })?; - let precision = u8::try_from(precision_value) - .map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: precision_value.to_string(), - }, - }) - .and_then(|p| { - Precision::new(p).map_err(|_| ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { - error: p.to_string(), - }, - }) - })?; - Ok(ColumnType::Decimal75(precision, scale)) -} - -// Unary operations - -/// Negate a slice of boolean values. -pub(super) fn slice_not(input: &[bool]) -> Vec { - input.iter().map(|l| -> bool { !*l }).collect::>() -} - -// Binary operations on slices of the same type - -/// Element-wise AND on two boolean slices of the same length. -/// -/// We do not check for length equality here. -pub(super) fn slice_and(lhs: &[bool], rhs: &[bool]) -> Vec { - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> bool { *l && *r }) - .collect::>() -} - -/// Element-wise OR on two boolean slices of the same length. -/// -/// We do not check for length equality here. -pub(super) fn slice_or(lhs: &[bool], rhs: &[bool]) -> Vec { - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> bool { *l || *r }) - .collect::>() -} - -/// Try to check whether two slices of the same length are equal element-wise. -/// -/// We do not check for length equality here. -pub(super) fn slice_eq(lhs: &[T], rhs: &[T]) -> Vec -where - T: PartialEq + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> bool { *l == *r }) - .collect::>() -} - -/// Try to check whether a slice is less than or equal to another element-wise. -/// -/// We do not check for length equality here. -pub(super) fn slice_le(lhs: &[T], rhs: &[T]) -> Vec -where - T: PartialOrd + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> bool { *l <= *r }) - .collect::>() -} - -/// Try to check whether a slice is greater than or equal to another element-wise. -/// -/// We do not check for length equality here. -pub(super) fn slice_ge(lhs: &[T], rhs: &[T]) -> Vec -where - T: PartialOrd + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> bool { *l >= *r }) - .collect::>() -} - -/// Try to add two slices of the same length. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_add_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedAdd + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - l.checked_add(r) - .ok_or(ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer addition {l:?} + {r:?}"), - }) - }) - .collect::>>() -} - -/// Subtract one slice from another of the same length. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_subtract_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedSub + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - l.checked_sub(r) - .ok_or(ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer subtraction {l:?} - {r:?}"), - }) - }) - .collect::>>() -} - -/// Multiply two slices of the same length. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_multiply_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedMul + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - l.checked_mul(r) - .ok_or(ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer multiplication {l:?} * {r:?}"), - }) - }) - .collect::>>() -} - -/// Divide one slice by another of the same length. -/// -/// We do not check for length equality here. However, we do check for division by 0. -pub(super) fn try_divide_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> -where - T: CheckedDiv + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - l.checked_div(r).ok_or(ColumnOperationError::DivisionByZero) - }) - .collect::>>() -} - -// Casting required for binary operations on different types - -/// Check whether two slices of the same length are equal element-wise. -/// -/// Note that we cast elements of the left slice to the type of the right slice. -/// Also note that we do not check for length equality here. -pub(super) fn slice_eq_with_casting( - numbers_of_smaller_type: &[SmallerType], - numbers_of_larger_type: &[LargerType], -) -> Vec -where - SmallerType: Copy + Debug + Into, - LargerType: PartialEq + Copy + Debug, -{ - numbers_of_smaller_type - .iter() - .zip(numbers_of_larger_type.iter()) - .map(|(l, r)| -> bool { Into::::into(*l) == *r }) - .collect::>() -} - -/// Check whether a slice is less than or equal to another element-wise. -/// -/// Note that we cast elements of the left slice to the type of the right slice. -/// Also note that we do not check for length equality here. -pub(super) fn slice_le_with_casting( - numbers_of_smaller_type: &[SmallerType], - numbers_of_larger_type: &[LargerType], -) -> Vec -where - SmallerType: Copy + Debug + Into, - LargerType: PartialOrd + Copy + Debug, -{ - numbers_of_smaller_type - .iter() - .zip(numbers_of_larger_type.iter()) - .map(|(l, r)| -> bool { Into::::into(*l) <= *r }) - .collect::>() -} - -/// Check whether a slice is greater than or equal to another element-wise. -/// -/// Note that we cast elements of the left slice to the type of the right slice. -/// Also note that we do not check for length equality here. -pub(super) fn slice_ge_with_casting( - numbers_of_smaller_type: &[SmallerType], - numbers_of_larger_type: &[LargerType], -) -> Vec -where - SmallerType: Copy + Debug + Into, - LargerType: PartialOrd + Copy + Debug, -{ - numbers_of_smaller_type - .iter() - .zip(numbers_of_larger_type.iter()) - .map(|(l, r)| -> bool { Into::::into(*l) >= *r }) - .collect::>() -} - -/// Add two slices of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_add_slices_with_casting( - numbers_of_smaller_type: &[SmallerType], - numbers_of_larger_type: &[LargerType], -) -> ColumnOperationResult> -where - SmallerType: Copy + Debug + Into, - LargerType: CheckedAdd + Copy + Debug, -{ - numbers_of_smaller_type - .iter() - .zip(numbers_of_larger_type.iter()) - .map(|(l, r)| -> ColumnOperationResult { - Into::::into(*l).checked_add(r).ok_or( - ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer addition {l:?} + {r:?}"), - }, - ) - }) - .collect() -} - -/// Subtract one slice from another of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here -pub(super) fn try_subtract_slices_left_upcast( - lhs: &[SmallerType], - rhs: &[LargerType], -) -> ColumnOperationResult> -where - SmallerType: Copy + Debug + Into, - LargerType: CheckedSub + Copy + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - Into::::into(*l).checked_sub(r).ok_or( - ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer subtraction {l:?} - {r:?}"), - }, - ) - }) - .collect() -} - -/// Subtract one slice from another of the same length, casting the right slice to the type of the left slice. -/// -/// We do not check for length equality here -pub(super) fn try_subtract_slices_right_upcast( - lhs: &[LargerType], - rhs: &[SmallerType], -) -> ColumnOperationResult> -where - SmallerType: Copy + Debug + Into, - LargerType: CheckedSub + Copy + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - l.checked_sub(&Into::::into(*r)).ok_or( - ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer subtraction {l:?} - {r:?}"), - }, - ) - }) - .collect() -} - -/// Multiply two slices of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here. However, we do check for integer overflow. -pub(super) fn try_multiply_slices_with_casting( - numbers_of_smaller_type: &[SmallerType], - numbers_of_larger_type: &[LargerType], -) -> ColumnOperationResult> -where - SmallerType: Copy + Debug + Into, - LargerType: CheckedMul + Copy + Debug, -{ - numbers_of_smaller_type - .iter() - .zip(numbers_of_larger_type.iter()) - .map(|(l, r)| -> ColumnOperationResult { - Into::::into(*l).checked_mul(r).ok_or( - ColumnOperationError::IntegerOverflow { - error: format!("Overflow in integer multiplication {l:?} * {r:?}"), - }, - ) - }) - .collect() -} - -/// Divide one slice by another of the same length, casting the left slice to the type of the right slice. -/// -/// We do not check for length equality here -pub(super) fn try_divide_slices_left_upcast( - lhs: &[SmallerType], - rhs: &[LargerType], -) -> ColumnOperationResult> -where - SmallerType: Copy + Debug + Into, - LargerType: CheckedDiv + Copy + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - Into::::into(*l) - .checked_div(r) - .ok_or(ColumnOperationError::DivisionByZero) - }) - .collect() -} - -/// Divide one slice by another of the same length, casting the right slice to the type of the left slice. -/// -/// We do not check for length equality here -pub(super) fn try_divide_slices_right_upcast( - lhs: &[LargerType], - rhs: &[SmallerType], -) -> ColumnOperationResult> -where - SmallerType: Copy + Debug + Into, - LargerType: CheckedDiv + Copy + Debug, -{ - lhs.iter() - .zip(rhs.iter()) - .map(|(l, r)| -> ColumnOperationResult { - l.checked_div(&Into::::into(*r)) - .ok_or(ColumnOperationError::DivisionByZero) - }) - .collect() -} - -// Decimal operations - /// Check whether a numerical slice is equal to a decimal one. /// /// Note that we do not check for length equality here. @@ -973,639 +474,6 @@ mod test { use super::*; use crate::base::scalar::test_scalar::TestScalar; - #[test] - fn we_can_add_numeric_types() { - // lhs and rhs are integers with the same precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::TinyInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - // lhs and rhs are integers with different precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Int; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Int; - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a scalar - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Scalar; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Scalar; - assert_eq!(expected, actual); - - // lhs is a decimal with nonnegative scale and rhs is an integer - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); - assert_eq!(expected, actual); - - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals with nonnegative scale - let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a decimal with negative scale - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals one of which has negative scale - let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals both with negative scale - // and with result having maximum precision - let lhs = ColumnType::Decimal75(Precision::new(74).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), -14); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_add_non_numeric_types() { - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::VarChar; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - } - - #[test] - fn we_cannot_add_some_numeric_types_due_to_decimal_issues() { - let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 4); - let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 4); - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - - let lhs = ColumnType::Int; - let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 10); - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - } - - #[test] - fn we_can_subtract_numeric_types() { - // lhs and rhs are integers with the same precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::TinyInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - // lhs and rhs are integers with different precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Int; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Int; - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a scalar - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Scalar; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Scalar; - assert_eq!(expected, actual); - - // lhs is a decimal and rhs is an integer - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); - assert_eq!(expected, actual); - - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals with nonnegative scale - let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a decimal with negative scale - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals one of which has negative scale - let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals both with negative scale - // and with result having maximum precision - let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), -14); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_subtract_non_numeric_types() { - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::VarChar; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - } - - #[test] - fn we_cannot_subtract_some_numeric_types_due_to_decimal_issues() { - let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 0); - let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 1); - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - - let lhs = ColumnType::Int128; - let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 12); - assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - } - - #[test] - fn we_can_multiply_numeric_types() { - // lhs and rhs are integers with the same precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::TinyInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::TinyInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::SmallInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - // lhs and rhs are integers with different precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::SmallInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Int; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Int; - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a scalar - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Scalar; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Scalar; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Scalar; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Scalar; - assert_eq!(expected, actual); - - // lhs is a decimal and rhs is an integer - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::TinyInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 2); - assert_eq!(expected, actual); - - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::SmallInt; - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 2); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals with nonnegative scale - let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(31).unwrap(), 5); - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a decimal with negative scale - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), -2); - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), -2); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals one of which has negative scale - let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(56).unwrap(), -8); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals both with negative scale - // and with result having maximum precision - let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); - let actual = try_multiply_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -27); - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_multiply_non_numeric_types() { - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::VarChar; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - } - - #[test] - fn we_cannot_multiply_some_numeric_types_due_to_decimal_issues() { - // Invalid precision - let lhs = ColumnType::Decimal75(Precision::new(38).unwrap(), 4); - let rhs = ColumnType::Decimal75(Precision::new(37).unwrap(), 4); - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - - let lhs = ColumnType::Int; - let rhs = ColumnType::Decimal75(Precision::new(65).unwrap(), 0); - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - - // Invalid scale - let lhs = ColumnType::Decimal75(Precision::new(5).unwrap(), -64_i8); - let rhs = ColumnType::Decimal75(Precision::new(5).unwrap(), -65_i8); - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { .. } - }) - )); - - let lhs = ColumnType::Decimal75(Precision::new(5).unwrap(), 64_i8); - let rhs = ColumnType::Decimal75(Precision::new(5).unwrap(), 64_i8); - assert!(matches!( - try_multiply_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { .. } - }) - )); - } - - #[test] - fn we_can_divide_numeric_types() { - // lhs and rhs are integers with the same precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::TinyInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::TinyInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::SmallInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - // lhs and rhs are integers with different precision - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::SmallInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::SmallInt; - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Int; - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Int; - assert_eq!(expected, actual); - - // lhs is a decimal with nonnegative scale and rhs is an integer - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::TinyInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 6); - assert_eq!(expected, actual); - - let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let rhs = ColumnType::SmallInt; - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 8); - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a decimal with nonnegative scale - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(16).unwrap(), 11); - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(18).unwrap(), 11); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals with nonnegative scale - let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(33).unwrap(), 14); - assert_eq!(expected, actual); - - // lhs is an integer and rhs is a decimal with negative scale - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(12).unwrap(), 11); - assert_eq!(expected, actual); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(14).unwrap(), 11); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals one of which has negative scale - let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(64).unwrap(), 6); - assert_eq!(expected, actual); - - // lhs and rhs are both decimals both with negative scale - // and with result having maximum precision - let lhs = ColumnType::Decimal75(Precision::new(70).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); - let actual = try_divide_column_types(lhs, rhs).unwrap(); - let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), 6); - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_divide_non_numeric_or_scalar_types() { - let lhs = ColumnType::TinyInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::SmallInt; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::VarChar; - let rhs = ColumnType::VarChar; - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - - let lhs = ColumnType::Scalar; - let rhs = ColumnType::Scalar; - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) - )); - } - - #[test] - fn we_cannot_divide_some_numeric_types_due_to_decimal_issues() { - // Invalid precision - let lhs = ColumnType::Decimal75(Precision::new(71).unwrap(), -13); - let rhs = ColumnType::Decimal75(Precision::new(13).unwrap(), -14); - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - - let lhs = ColumnType::Int; - let rhs = ColumnType::Decimal75(Precision::new(68).unwrap(), 67); - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidPrecision { .. } - }) - )); - - // Invalid scale - let lhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 53_i8); - let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 40_i8); - assert!(matches!( - try_divide_column_types(lhs, rhs), - Err(ColumnOperationError::DecimalConversionError { - source: DecimalError::InvalidScale { .. } - }) - )); - } - - // NOT - #[test] - fn we_can_negate_boolean_slices() { - let input = [true, false, true]; - let actual = slice_not(&input); - let expected = vec![false, true, false]; - assert_eq!(expected, actual); - } - - // AND - #[test] - fn we_can_and_boolean_slices() { - let lhs = [true, false, true, false]; - let rhs = [true, true, false, false]; - let actual = slice_and(&lhs, &rhs); - let expected = vec![true, false, false, false]; - assert_eq!(expected, actual); - } - - // OR - #[test] - fn we_can_or_boolean_slices() { - let lhs = [true, false, true, false]; - let rhs = [true, true, false, false]; - let actual = slice_or(&lhs, &rhs); - let expected = vec![true, true, true, false]; - assert_eq!(expected, actual); - } - - // = - #[test] - fn we_can_eq_slices() { - let lhs = [1_i16, 2, 3]; - let rhs = [1_i16, 3, 3]; - let actual = slice_eq(&lhs, &rhs); - let expected = vec![true, false, true]; - assert_eq!(expected, actual); - - // Try strings - let lhs = ["Chloe".to_string(), "Margaret".to_string()]; - let rhs = ["Chloe".to_string(), "Chloe".to_string()]; - let actual = slice_eq(&lhs, &rhs); - let expected = vec![true, false]; - assert_eq!(expected, actual); - } - - #[test] - fn we_can_eq_slices_with_cast() { - let lhs = [1_i16, 2, 3]; - let rhs = [1_i32, 3, 3]; - let actual = slice_eq_with_casting(&lhs, &rhs); - let expected = vec![true, false, true]; - assert_eq!(expected, actual); - } - #[test] fn we_can_eq_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale @@ -1719,25 +587,6 @@ mod test { assert_eq!(expected, actual); } - // <= - #[test] - fn we_can_le_slices() { - let lhs = [1_i32, 2, 3]; - let rhs = [1_i32, 3, 2]; - let actual = slice_le(&lhs, &rhs); - let expected = vec![true, true, false]; - assert_eq!(expected, actual); - } - - #[test] - fn we_can_le_slices_with_cast() { - let lhs = [1_i16, 2, 3]; - let rhs = [1_i64, 3, 2]; - let actual = slice_le_with_casting(&lhs, &rhs); - let expected = vec![true, true, false]; - assert_eq!(expected, actual); - } - #[test] fn we_can_le_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale @@ -1851,25 +700,6 @@ mod test { assert_eq!(expected, actual); } - // >= - #[test] - fn we_can_ge_slices() { - let lhs = [1_i128, 2, 3]; - let rhs = [1_i128, 3, 2]; - let actual = slice_ge(&lhs, &rhs); - let expected = vec![true, false, true]; - assert_eq!(expected, actual); - } - - #[test] - fn we_can_ge_slices_with_cast() { - let lhs = [1_i16, 2, 3]; - let rhs = [1_i64, 3, 2]; - let actual = slice_ge_with_casting(&lhs, &rhs); - let expected = vec![true, false, true]; - assert_eq!(expected, actual); - } - #[test] fn we_can_ge_decimal_columns() { // lhs is integer and rhs is decimal with nonnegative scale @@ -1983,45 +813,6 @@ mod test { assert_eq!(expected, actual); } - // + - #[test] - fn we_can_try_add_slices() { - let lhs = [1_i16, 2, 3]; - let rhs = [4_i16, -5, 6]; - let actual = try_add_slices(&lhs, &rhs).unwrap(); - let expected = vec![5_i16, -3, 9]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_add_slices_if_overflow() { - let lhs = [i16::MAX, 1]; - let rhs = [1_i16, 1]; - assert!(matches!( - try_add_slices(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - - #[test] - fn we_can_try_add_slices_with_cast() { - let lhs = [1_i16, 2, 3]; - let rhs = [4_i32, -5, 6]; - let actual = try_add_slices_with_casting(&lhs, &rhs).unwrap(); - let expected = vec![5_i32, -3, 9]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_add_slices_with_cast_if_overflow() { - let lhs = [-1_i16, 1]; - let rhs = [i32::MIN, 1]; - assert!(matches!( - try_add_slices_with_casting(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - #[allow(clippy::too_many_lines)] #[test] fn we_can_try_add_decimal_columns() { @@ -2144,64 +935,6 @@ mod test { assert_eq!(expected, actual); } - // - - #[test] - fn we_can_try_subtract_slices() { - let lhs = [1_i16, 2, 3]; - let rhs = [4_i16, -5, 6]; - let actual = try_subtract_slices(&lhs, &rhs).unwrap(); - let expected = vec![-3_i16, 7, -3]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_subtract_slices_if_overflow() { - let lhs = [i128::MIN, 1]; - let rhs = [1_i128, 1]; - assert!(matches!( - try_subtract_slices(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - - #[test] - fn we_can_try_subtract_slices_left_upcast() { - let lhs = [1_i16, 2, 3]; - let rhs = [4_i32, -5, 6]; - let actual = try_subtract_slices_left_upcast(&lhs, &rhs).unwrap(); - let expected = vec![-3_i32, 7, -3]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_subtract_slices_left_upcast_if_overflow() { - let lhs = [0_i16, 1]; - let rhs = [i32::MIN, 1]; - assert!(matches!( - try_subtract_slices_left_upcast(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - - #[test] - fn we_can_try_subtract_slices_right_upcast() { - let lhs = [1_i32, 2, 3]; - let rhs = [4_i16, -5, 6]; - let actual = try_subtract_slices_right_upcast(&lhs, &rhs).unwrap(); - let expected = vec![-3_i32, 7, -3]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_subtract_slices_right_upcast_if_overflow() { - let lhs = [i32::MIN, 1]; - let rhs = [1_i16, 1]; - assert!(matches!( - try_subtract_slices_right_upcast(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - #[allow(clippy::too_many_lines)] #[test] fn we_can_try_subtract_decimal_columns() { @@ -2324,45 +1057,6 @@ mod test { assert_eq!(expected, actual); } - // * - #[test] - fn we_can_try_multiply_slices() { - let lhs = [1_i16, 2, 3]; - let rhs = [4_i16, -5, 6]; - let actual = try_multiply_slices(&lhs, &rhs).unwrap(); - let expected = vec![4_i16, -10, 18]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_multiply_slices_if_overflow() { - let lhs = [i32::MAX, 2]; - let rhs = [2, 2]; - assert!(matches!( - try_multiply_slices(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - - #[test] - fn we_can_try_multiply_slices_with_cast() { - let lhs = [1_i16, 2, 3]; - let rhs = [4_i32, -5, 6]; - let actual = try_multiply_slices_with_casting(&lhs, &rhs).unwrap(); - let expected = vec![4_i32, -10, 18]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_multiply_slices_with_cast_if_overflow() { - let lhs = [2_i16, 2]; - let rhs = [i32::MAX, 2]; - assert!(matches!( - try_multiply_slices_with_casting(&lhs, &rhs), - Err(ColumnOperationError::IntegerOverflow { .. }) - )); - } - #[allow(clippy::too_many_lines)] #[test] fn we_can_try_multiply_decimal_columns() { @@ -2486,64 +1180,6 @@ mod test { assert_eq!(expected, actual); } - // / - #[test] - fn we_can_try_divide_slices() { - let lhs = [5_i16, -5, -7, 9]; - let rhs = [-3_i16, 3, -4, 5]; - let actual = try_divide_slices(&lhs, &rhs).unwrap(); - let expected = vec![-1_i16, -1, 1, 1]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_divide_slices_if_divide_by_zero() { - let lhs = [1_i32, 2, 3]; - let rhs = [0_i32, -5, 6]; - assert!(matches!( - try_divide_slices(&lhs, &rhs), - Err(ColumnOperationError::DivisionByZero) - )); - } - - #[test] - fn we_can_try_divide_slices_left_upcast() { - let lhs = [5_i16, -4, -9, 9]; - let rhs = [-3_i32, 3, -4, 5]; - let actual = try_divide_slices_left_upcast(&lhs, &rhs).unwrap(); - let expected = vec![-1_i32, -1, 2, 1]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_divide_slices_left_upcast_if_divide_by_zero() { - let lhs = [1_i16, 2]; - let rhs = [0_i32, 2]; - assert!(matches!( - try_divide_slices_left_upcast(&lhs, &rhs), - Err(ColumnOperationError::DivisionByZero) - )); - } - - #[test] - fn we_can_try_divide_slices_right_upcast() { - let lhs = [15_i128, -82, -7, 9]; - let rhs = [-3_i32, 3, -4, 5]; - let actual = try_divide_slices_right_upcast(&lhs, &rhs).unwrap(); - let expected = vec![-5_i128, -27, 1, 1]; - assert_eq!(expected, actual); - } - - #[test] - fn we_cannot_try_divide_slices_right_upcast_if_divide_by_zero() { - let lhs = [1_i32, 2]; - let rhs = [0_i16, 2]; - assert!(matches!( - try_divide_slices_right_upcast(&lhs, &rhs), - Err(ColumnOperationError::DivisionByZero) - )); - } - #[allow(clippy::too_many_lines)] #[test] fn we_can_try_divide_decimal_columns() { diff --git a/crates/proof-of-sql/src/base/database/slice_operation.rs b/crates/proof-of-sql/src/base/database/slice_operation.rs new file mode 100644 index 000000000..2c5b1ef73 --- /dev/null +++ b/crates/proof-of-sql/src/base/database/slice_operation.rs @@ -0,0 +1,627 @@ +use super::{ColumnOperationError, ColumnOperationResult}; +use alloc::{format, vec::Vec}; +use core::fmt::Debug; +use num_traits::ops::checked::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}; + +// Unary operations + +/// Negate a slice of boolean values. +pub(super) fn slice_not(input: &[bool]) -> Vec { + input.iter().map(|l| -> bool { !*l }).collect::>() +} + +// Binary operations on slices of the same type + +/// Element-wise AND on two boolean slices of the same length. +/// +/// We do not check for length equality here. +pub(super) fn slice_and(lhs: &[bool], rhs: &[bool]) -> Vec { + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> bool { *l && *r }) + .collect::>() +} + +/// Element-wise OR on two boolean slices of the same length. +/// +/// We do not check for length equality here. +pub(super) fn slice_or(lhs: &[bool], rhs: &[bool]) -> Vec { + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> bool { *l || *r }) + .collect::>() +} + +/// Try to check whether two slices of the same length are equal element-wise. +/// +/// We do not check for length equality here. +pub(super) fn slice_eq(lhs: &[T], rhs: &[T]) -> Vec +where + T: PartialEq + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> bool { *l == *r }) + .collect::>() +} + +/// Try to check whether a slice is less than or equal to another element-wise. +/// +/// We do not check for length equality here. +pub(super) fn slice_le(lhs: &[T], rhs: &[T]) -> Vec +where + T: PartialOrd + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> bool { *l <= *r }) + .collect::>() +} + +/// Try to check whether a slice is greater than or equal to another element-wise. +/// +/// We do not check for length equality here. +pub(super) fn slice_ge(lhs: &[T], rhs: &[T]) -> Vec +where + T: PartialOrd + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> bool { *l >= *r }) + .collect::>() +} + +/// Try to add two slices of the same length. +/// +/// We do not check for length equality here. However, we do check for integer overflow. +pub(super) fn try_add_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> +where + T: CheckedAdd + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + l.checked_add(r) + .ok_or(ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer addition {l:?} + {r:?}"), + }) + }) + .collect::>>() +} + +/// Subtract one slice from another of the same length. +/// +/// We do not check for length equality here. However, we do check for integer overflow. +pub(super) fn try_subtract_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> +where + T: CheckedSub + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + l.checked_sub(r) + .ok_or(ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer subtraction {l:?} - {r:?}"), + }) + }) + .collect::>>() +} + +/// Multiply two slices of the same length. +/// +/// We do not check for length equality here. However, we do check for integer overflow. +pub(super) fn try_multiply_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> +where + T: CheckedMul + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + l.checked_mul(r) + .ok_or(ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer multiplication {l:?} * {r:?}"), + }) + }) + .collect::>>() +} + +/// Divide one slice by another of the same length. +/// +/// We do not check for length equality here. However, we do check for division by 0. +pub(super) fn try_divide_slices(lhs: &[T], rhs: &[T]) -> ColumnOperationResult> +where + T: CheckedDiv + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + l.checked_div(r).ok_or(ColumnOperationError::DivisionByZero) + }) + .collect::>>() +} + +// Casting required for binary operations on different types + +/// Check whether two slices of the same length are equal element-wise. +/// +/// Note that we cast elements of the left slice to the type of the right slice. +/// Also note that we do not check for length equality here. +pub(super) fn slice_eq_with_casting( + numbers_of_smaller_type: &[SmallerType], + numbers_of_larger_type: &[LargerType], +) -> Vec +where + SmallerType: Copy + Debug + Into, + LargerType: PartialEq + Copy + Debug, +{ + numbers_of_smaller_type + .iter() + .zip(numbers_of_larger_type.iter()) + .map(|(l, r)| -> bool { Into::::into(*l) == *r }) + .collect::>() +} + +/// Check whether a slice is less than or equal to another element-wise. +/// +/// Note that we cast elements of the left slice to the type of the right slice. +/// Also note that we do not check for length equality here. +pub(super) fn slice_le_with_casting( + numbers_of_smaller_type: &[SmallerType], + numbers_of_larger_type: &[LargerType], +) -> Vec +where + SmallerType: Copy + Debug + Into, + LargerType: PartialOrd + Copy + Debug, +{ + numbers_of_smaller_type + .iter() + .zip(numbers_of_larger_type.iter()) + .map(|(l, r)| -> bool { Into::::into(*l) <= *r }) + .collect::>() +} + +/// Check whether a slice is greater than or equal to another element-wise. +/// +/// Note that we cast elements of the left slice to the type of the right slice. +/// Also note that we do not check for length equality here. +pub(super) fn slice_ge_with_casting( + numbers_of_smaller_type: &[SmallerType], + numbers_of_larger_type: &[LargerType], +) -> Vec +where + SmallerType: Copy + Debug + Into, + LargerType: PartialOrd + Copy + Debug, +{ + numbers_of_smaller_type + .iter() + .zip(numbers_of_larger_type.iter()) + .map(|(l, r)| -> bool { Into::::into(*l) >= *r }) + .collect::>() +} + +/// Add two slices of the same length, casting the left slice to the type of the right slice. +/// +/// We do not check for length equality here. However, we do check for integer overflow. +pub(super) fn try_add_slices_with_casting( + numbers_of_smaller_type: &[SmallerType], + numbers_of_larger_type: &[LargerType], +) -> ColumnOperationResult> +where + SmallerType: Copy + Debug + Into, + LargerType: CheckedAdd + Copy + Debug, +{ + numbers_of_smaller_type + .iter() + .zip(numbers_of_larger_type.iter()) + .map(|(l, r)| -> ColumnOperationResult { + Into::::into(*l).checked_add(r).ok_or( + ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer addition {l:?} + {r:?}"), + }, + ) + }) + .collect() +} + +/// Subtract one slice from another of the same length, casting the left slice to the type of the right slice. +/// +/// We do not check for length equality here +pub(super) fn try_subtract_slices_left_upcast( + lhs: &[SmallerType], + rhs: &[LargerType], +) -> ColumnOperationResult> +where + SmallerType: Copy + Debug + Into, + LargerType: CheckedSub + Copy + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + Into::::into(*l).checked_sub(r).ok_or( + ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer subtraction {l:?} - {r:?}"), + }, + ) + }) + .collect() +} + +/// Subtract one slice from another of the same length, casting the right slice to the type of the left slice. +/// +/// We do not check for length equality here +pub(super) fn try_subtract_slices_right_upcast( + lhs: &[LargerType], + rhs: &[SmallerType], +) -> ColumnOperationResult> +where + SmallerType: Copy + Debug + Into, + LargerType: CheckedSub + Copy + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + l.checked_sub(&Into::::into(*r)).ok_or( + ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer subtraction {l:?} - {r:?}"), + }, + ) + }) + .collect() +} + +/// Multiply two slices of the same length, casting the left slice to the type of the right slice. +/// +/// We do not check for length equality here. However, we do check for integer overflow. +pub(super) fn try_multiply_slices_with_casting( + numbers_of_smaller_type: &[SmallerType], + numbers_of_larger_type: &[LargerType], +) -> ColumnOperationResult> +where + SmallerType: Copy + Debug + Into, + LargerType: CheckedMul + Copy + Debug, +{ + numbers_of_smaller_type + .iter() + .zip(numbers_of_larger_type.iter()) + .map(|(l, r)| -> ColumnOperationResult { + Into::::into(*l).checked_mul(r).ok_or( + ColumnOperationError::IntegerOverflow { + error: format!("Overflow in integer multiplication {l:?} * {r:?}"), + }, + ) + }) + .collect() +} + +/// Divide one slice by another of the same length, casting the left slice to the type of the right slice. +/// +/// We do not check for length equality here +pub(super) fn try_divide_slices_left_upcast( + lhs: &[SmallerType], + rhs: &[LargerType], +) -> ColumnOperationResult> +where + SmallerType: Copy + Debug + Into, + LargerType: CheckedDiv + Copy + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + Into::::into(*l) + .checked_div(r) + .ok_or(ColumnOperationError::DivisionByZero) + }) + .collect() +} + +/// Divide one slice by another of the same length, casting the right slice to the type of the left slice. +/// +/// We do not check for length equality here +pub(super) fn try_divide_slices_right_upcast( + lhs: &[LargerType], + rhs: &[SmallerType], +) -> ColumnOperationResult> +where + SmallerType: Copy + Debug + Into, + LargerType: CheckedDiv + Copy + Debug, +{ + lhs.iter() + .zip(rhs.iter()) + .map(|(l, r)| -> ColumnOperationResult { + l.checked_div(&Into::::into(*r)) + .ok_or(ColumnOperationError::DivisionByZero) + }) + .collect() +} + +#[cfg(test)] +mod test { + use super::*; + + // NOT + #[test] + fn we_can_negate_boolean_slices() { + let input = [true, false, true]; + let actual = slice_not(&input); + let expected = vec![false, true, false]; + assert_eq!(expected, actual); + } + + // AND + #[test] + fn we_can_and_boolean_slices() { + let lhs = [true, false, true, false]; + let rhs = [true, true, false, false]; + let actual = slice_and(&lhs, &rhs); + let expected = vec![true, false, false, false]; + assert_eq!(expected, actual); + } + + // OR + #[test] + fn we_can_or_boolean_slices() { + let lhs = [true, false, true, false]; + let rhs = [true, true, false, false]; + let actual = slice_or(&lhs, &rhs); + let expected = vec![true, true, true, false]; + assert_eq!(expected, actual); + } + + // = + #[test] + fn we_can_eq_slices() { + let lhs = [1_i16, 2, 3]; + let rhs = [1_i16, 3, 3]; + let actual = slice_eq(&lhs, &rhs); + let expected = vec![true, false, true]; + assert_eq!(expected, actual); + + // Try strings + let lhs = ["Chloe".to_string(), "Margaret".to_string()]; + let rhs = ["Chloe".to_string(), "Chloe".to_string()]; + let actual = slice_eq(&lhs, &rhs); + let expected = vec![true, false]; + assert_eq!(expected, actual); + } + + #[test] + fn we_can_eq_slices_with_cast() { + let lhs = [1_i16, 2, 3]; + let rhs = [1_i32, 3, 3]; + let actual = slice_eq_with_casting(&lhs, &rhs); + let expected = vec![true, false, true]; + assert_eq!(expected, actual); + } + + // <= + #[test] + fn we_can_le_slices() { + let lhs = [1_i32, 2, 3]; + let rhs = [1_i32, 3, 2]; + let actual = slice_le(&lhs, &rhs); + let expected = vec![true, true, false]; + assert_eq!(expected, actual); + } + + #[test] + fn we_can_le_slices_with_cast() { + let lhs = [1_i16, 2, 3]; + let rhs = [1_i64, 3, 2]; + let actual = slice_le_with_casting(&lhs, &rhs); + let expected = vec![true, true, false]; + assert_eq!(expected, actual); + } + + // >= + #[test] + fn we_can_ge_slices() { + let lhs = [1_i128, 2, 3]; + let rhs = [1_i128, 3, 2]; + let actual = slice_ge(&lhs, &rhs); + let expected = vec![true, false, true]; + assert_eq!(expected, actual); + } + + #[test] + fn we_can_ge_slices_with_cast() { + let lhs = [1_i16, 2, 3]; + let rhs = [1_i64, 3, 2]; + let actual = slice_ge_with_casting(&lhs, &rhs); + let expected = vec![true, false, true]; + assert_eq!(expected, actual); + } + + // + + #[test] + fn we_can_try_add_slices() { + let lhs = [1_i16, 2, 3]; + let rhs = [4_i16, -5, 6]; + let actual = try_add_slices(&lhs, &rhs).unwrap(); + let expected = vec![5_i16, -3, 9]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_add_slices_if_overflow() { + let lhs = [i16::MAX, 1]; + let rhs = [1_i16, 1]; + assert!(matches!( + try_add_slices(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + #[test] + fn we_can_try_add_slices_with_cast() { + let lhs = [1_i16, 2, 3]; + let rhs = [4_i32, -5, 6]; + let actual = try_add_slices_with_casting(&lhs, &rhs).unwrap(); + let expected = vec![5_i32, -3, 9]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_add_slices_with_cast_if_overflow() { + let lhs = [-1_i16, 1]; + let rhs = [i32::MIN, 1]; + assert!(matches!( + try_add_slices_with_casting(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + // - + #[test] + fn we_can_try_subtract_slices() { + let lhs = [1_i16, 2, 3]; + let rhs = [4_i16, -5, 6]; + let actual = try_subtract_slices(&lhs, &rhs).unwrap(); + let expected = vec![-3_i16, 7, -3]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_subtract_slices_if_overflow() { + let lhs = [i128::MIN, 1]; + let rhs = [1_i128, 1]; + assert!(matches!( + try_subtract_slices(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + #[test] + fn we_can_try_subtract_slices_left_upcast() { + let lhs = [1_i16, 2, 3]; + let rhs = [4_i32, -5, 6]; + let actual = try_subtract_slices_left_upcast(&lhs, &rhs).unwrap(); + let expected = vec![-3_i32, 7, -3]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_subtract_slices_left_upcast_if_overflow() { + let lhs = [0_i16, 1]; + let rhs = [i32::MIN, 1]; + assert!(matches!( + try_subtract_slices_left_upcast(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + #[test] + fn we_can_try_subtract_slices_right_upcast() { + let lhs = [1_i32, 2, 3]; + let rhs = [4_i16, -5, 6]; + let actual = try_subtract_slices_right_upcast(&lhs, &rhs).unwrap(); + let expected = vec![-3_i32, 7, -3]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_subtract_slices_right_upcast_if_overflow() { + let lhs = [i32::MIN, 1]; + let rhs = [1_i16, 1]; + assert!(matches!( + try_subtract_slices_right_upcast(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + // * + #[test] + fn we_can_try_multiply_slices() { + let lhs = [1_i16, 2, 3]; + let rhs = [4_i16, -5, 6]; + let actual = try_multiply_slices(&lhs, &rhs).unwrap(); + let expected = vec![4_i16, -10, 18]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_multiply_slices_if_overflow() { + let lhs = [i32::MAX, 2]; + let rhs = [2, 2]; + assert!(matches!( + try_multiply_slices(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + #[test] + fn we_can_try_multiply_slices_with_cast() { + let lhs = [1_i16, 2, 3]; + let rhs = [4_i32, -5, 6]; + let actual = try_multiply_slices_with_casting(&lhs, &rhs).unwrap(); + let expected = vec![4_i32, -10, 18]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_multiply_slices_with_cast_if_overflow() { + let lhs = [2_i16, 2]; + let rhs = [i32::MAX, 2]; + assert!(matches!( + try_multiply_slices_with_casting(&lhs, &rhs), + Err(ColumnOperationError::IntegerOverflow { .. }) + )); + } + + // / + #[test] + fn we_can_try_divide_slices() { + let lhs = [5_i16, -5, -7, 9]; + let rhs = [-3_i16, 3, -4, 5]; + let actual = try_divide_slices(&lhs, &rhs).unwrap(); + let expected = vec![-1_i16, -1, 1, 1]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_divide_slices_if_divide_by_zero() { + let lhs = [1_i32, 2, 3]; + let rhs = [0_i32, -5, 6]; + assert!(matches!( + try_divide_slices(&lhs, &rhs), + Err(ColumnOperationError::DivisionByZero) + )); + } + + #[test] + fn we_can_try_divide_slices_left_upcast() { + let lhs = [5_i16, -4, -9, 9]; + let rhs = [-3_i32, 3, -4, 5]; + let actual = try_divide_slices_left_upcast(&lhs, &rhs).unwrap(); + let expected = vec![-1_i32, -1, 2, 1]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_divide_slices_left_upcast_if_divide_by_zero() { + let lhs = [1_i16, 2]; + let rhs = [0_i32, 2]; + assert!(matches!( + try_divide_slices_left_upcast(&lhs, &rhs), + Err(ColumnOperationError::DivisionByZero) + )); + } + + #[test] + fn we_can_try_divide_slices_right_upcast() { + let lhs = [15_i128, -82, -7, 9]; + let rhs = [-3_i32, 3, -4, 5]; + let actual = try_divide_slices_right_upcast(&lhs, &rhs).unwrap(); + let expected = vec![-5_i128, -27, 1, 1]; + assert_eq!(expected, actual); + } + + #[test] + fn we_cannot_try_divide_slices_right_upcast_if_divide_by_zero() { + let lhs = [1_i32, 2]; + let rhs = [0_i16, 2]; + assert!(matches!( + try_divide_slices_right_upcast(&lhs, &rhs), + Err(ColumnOperationError::DivisionByZero) + )); + } +}