From 89d54d176ba4fa472a4bfd7d794e48f0d8501185 Mon Sep 17 00:00:00 2001 From: Jay White Date: Sat, 19 Oct 2024 16:00:27 -0400 Subject: [PATCH] refactor: replace `scale_scalar` with `ScalarExt::pow10` --- .../proof-of-sql/src/base/database/column.rs | 6 ++-- .../src/base/database/column_operation.rs | 26 ++++++++------- crates/proof-of-sql/src/base/math/decimal.rs | 33 ++++++------------- .../src/base/scalar/test_scalar_test.rs | 1 + .../sql/proof_exprs/inequality_expr_test.rs | 9 +++-- .../src/sql/proof_exprs/numerical_util.rs | 11 ++++--- 6 files changed, 39 insertions(+), 47 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 382806a47..3d3b11372 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -1,7 +1,7 @@ use super::{LiteralValue, OwnedColumn, TableRef}; use crate::base::{ - math::decimal::{scale_scalar, Precision}, - scalar::Scalar, + math::decimal::Precision, + scalar::{Scalar, ScalarExt}, slice_ops::slice_cast_with, }; use alloc::{sync::Arc, vec::Vec}; @@ -213,7 +213,7 @@ impl<'a, S: Scalar> Column<'a, S> { /// Convert a column to a vector of Scalar values with scaling #[allow(clippy::missing_panics_doc)] pub(crate) fn to_scalar_with_scaling(self, scale: i8) -> Vec { - let scale_factor = scale_scalar(S::ONE, scale).expect("Invalid scale factor"); + let scale_factor = S::pow10(u8::try_from(scale).expect("Upscale factor is nonnegative")); match self { Self::Boolean(col) => slice_cast_with(col, |b| S::from(b) * scale_factor), Self::Decimal75(_, _, col) => slice_cast_with(col, |s| *s * scale_factor), diff --git a/crates/proof-of-sql/src/base/database/column_operation.rs b/crates/proof-of-sql/src/base/database/column_operation.rs index 3c45e2be9..765fc3895 100644 --- a/crates/proof-of-sql/src/base/database/column_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_operation.rs @@ -2,7 +2,7 @@ use super::{ColumnOperationError, ColumnOperationResult}; use crate::base::{ database::ColumnType, - math::decimal::{scale_scalar, DecimalError, Precision}, + math::decimal::{DecimalError, Precision}, scalar::{Scalar, ScalarExt}, }; use alloc::{format, string::ToString, vec::Vec}; @@ -548,7 +548,7 @@ where .collect::>() } else { let upscale_factor = - scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative"); + S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs.iter()) .map(|(l, r)| -> bool { Into::::into(*l) * upscale_factor == *r }) @@ -569,7 +569,7 @@ where .collect::>() } else { let upscale_factor = - scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative"); + S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs.iter()) .map(|(l, r)| -> bool { Into::::into(*l) == *r * upscale_factor }) @@ -624,7 +624,7 @@ where .collect::>() } else { let upscale_factor = - scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative"); + S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs.iter()) .map(|(l, r)| -> bool { @@ -652,7 +652,7 @@ where .collect::>() } else { let upscale_factor = - scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative"); + S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs.iter()) .map(|(l, r)| -> bool { @@ -709,7 +709,7 @@ where .collect::>() } else { let upscale_factor = - scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative"); + S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs.iter()) .map(|(l, r)| -> bool { @@ -737,7 +737,7 @@ where .collect::>() } else { let upscale_factor = - scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative"); + S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs.iter()) .map(|(l, r)| -> bool { @@ -786,13 +786,15 @@ where .expect("numeric columns have scale"); // One of left_scale and right_scale is 0 so we can avoid scaling when unnecessary let scalars: Vec = if left_upscale > 0 { - let upscale_factor = scale_scalar(S::ONE, left_upscale)?; + let upscale_factor = + S::pow10(u8::try_from(left_upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs) .map(|(l, r)| S::from(*l) * upscale_factor + S::from(*r)) .collect() } else if right_upscale > 0 { - let upscale_factor = scale_scalar(S::ONE, right_upscale)?; + let upscale_factor = + S::pow10(u8::try_from(right_upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs) .map(|(l, r)| S::from(*l) + upscale_factor * S::from(*r)) @@ -846,13 +848,15 @@ where .expect("numeric columns have scale"); // One of left_scale and right_scale is 0 so we can avoid scaling when unnecessary let scalars: Vec = if left_upscale > 0 { - let upscale_factor = scale_scalar(S::ONE, left_upscale)?; + let upscale_factor = + S::pow10(u8::try_from(left_upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs) .map(|(l, r)| S::from(*l) * upscale_factor - S::from(*r)) .collect() } else if right_upscale > 0 { - let upscale_factor = scale_scalar(S::ONE, right_upscale)?; + let upscale_factor = + S::pow10(u8::try_from(right_upscale).expect("Upscale factor is nonnegative")); lhs.iter() .zip(rhs) .map(|(l, r)| S::from(*l) - upscale_factor * S::from(*r)) diff --git a/crates/proof-of-sql/src/base/math/decimal.rs b/crates/proof-of-sql/src/base/math/decimal.rs index dda6bc116..6d286915e 100644 --- a/crates/proof-of-sql/src/base/math/decimal.rs +++ b/crates/proof-of-sql/src/base/math/decimal.rs @@ -1,5 +1,5 @@ //! Module for parsing an `IntermediateDecimal` into a `Decimal75`. -use crate::base::scalar::{Scalar, ScalarConversionError}; +use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt}; use alloc::{ format, string::{String, ToString}, @@ -125,6 +125,7 @@ impl Decimal { } } + #[allow(clippy::missing_panics_doc)] /// Scale the decimal to the new scale factor. Negative scaling and overflow error out. #[allow(clippy::cast_sign_loss)] pub fn with_precision_and_scale( @@ -138,10 +139,12 @@ impl Decimal { error: "Scale factor must be non-negative".to_string(), }); } - let scaled_value = scale_scalar(self.value, scale_factor)?; + let scaled_value = + self.value * S::pow10(u8::try_from(scale_factor).expect("scale_factor is nonnegative")); Ok(Decimal::new(scaled_value, new_precision, new_scale)) } + #[allow(clippy::missing_panics_doc)] /// Get a decimal with given precision and scale from an i64 #[allow(clippy::cast_sign_loss)] pub fn from_i64(value: i64, precision: Precision, scale: i8) -> DecimalResult { @@ -157,10 +160,12 @@ impl Decimal { error: "Can not scale down a decimal".to_string(), }); } - let scaled_value = scale_scalar(S::from(&value), scale)?; + let scaled_value = + S::from(&value) * S::pow10(u8::try_from(scale).expect("scale is nonnegative")); Ok(Decimal::new(scaled_value, precision, scale)) } + #[allow(clippy::missing_panics_doc)] /// Get a decimal with given precision and scale from an i128 #[allow(clippy::cast_sign_loss)] pub fn from_i128(value: i128, precision: Precision, scale: i8) -> DecimalResult { @@ -176,7 +181,8 @@ impl Decimal { error: "Can not scale down a decimal".to_string(), }); } - let scaled_value = scale_scalar(S::from(&value), scale)?; + let scaled_value = + S::from(&value) * S::pow10(u8::try_from(scale).expect("scale is nonnegative")); Ok(Decimal::new(scaled_value, precision, scale)) } } @@ -210,25 +216,6 @@ pub(crate) fn try_into_to_scalar( }) } -/// Scale scalar by the given scale factor. Negative scaling is not allowed. -/// Note that we do not check for overflow. -pub(crate) fn scale_scalar(s: S, scale: i8) -> DecimalResult { - match scale { - 0 => Ok(s), - _ if scale < 0 => Err(DecimalError::RoundingError { - error: "Scale factor must be non-negative".to_string(), - }), - _ => { - let ten = S::from(10); - let mut res = s; - for _ in 0..scale { - res *= ten; - } - Ok(res) - } - } -} - #[cfg(test)] mod scale_adjust_test { diff --git a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs index fe8c19d9b..439e833a6 100644 --- a/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs +++ b/crates/proof-of-sql/src/base/scalar/test_scalar_test.rs @@ -8,4 +8,5 @@ fn we_can_get_test_scalar_constants_from_z_p() { assert_eq!(TestScalar::from(2), TestScalar::TWO); // -1/2 == least upper bound assert_eq!(-TestScalar::TWO.inv().unwrap(), TestScalar::MAX_SIGNED); + assert_eq!(TestScalar::from(10), TestScalar::TEN); } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs index a452da701..34605872c 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs @@ -5,8 +5,7 @@ use crate::{ owned_table_utility::*, Column, LiteralValue, OwnedTable, OwnedTableTestAccessor, TestAccessor, }, - math::decimal::scale_scalar, - scalar::{Curve25519Scalar, Scalar}, + scalar::{Curve25519Scalar, Scalar, ScalarExt}, }, sql::{ parse::ConversionError, @@ -164,7 +163,7 @@ fn we_can_compare_columns_with_extreme_values() { #[test] fn we_can_compare_columns_with_small_decimal_values_without_scale() { - let scalar_pos = scale_scalar(Curve25519Scalar::ONE, 38).unwrap() - Curve25519Scalar::ONE; + let scalar_pos = Curve25519Scalar::pow10(38) - Curve25519Scalar::ONE; let scalar_neg = -scalar_pos; let data: OwnedTable = owned_table([ bigint("a", [123, 25]), @@ -192,7 +191,7 @@ fn we_can_compare_columns_with_small_decimal_values_without_scale() { #[test] fn we_can_compare_columns_with_small_decimal_values_with_scale() { - let scalar_pos = scale_scalar(Curve25519Scalar::ONE, 38).unwrap() - Curve25519Scalar::ONE; + let scalar_pos = Curve25519Scalar::pow10(38) - Curve25519Scalar::ONE; let scalar_neg = -scalar_pos; let data: OwnedTable = owned_table([ bigint("a", [123, 25]), @@ -222,7 +221,7 @@ fn we_can_compare_columns_with_small_decimal_values_with_scale() { #[test] fn we_can_compare_columns_with_small_decimal_values_with_differing_scale_gte() { - let scalar_pos = scale_scalar(Curve25519Scalar::ONE, 38).unwrap() - Curve25519Scalar::ONE; + let scalar_pos = Curve25519Scalar::pow10(38) - Curve25519Scalar::ONE; let scalar_neg = -scalar_pos; let data: OwnedTable = owned_table([ bigint("a", [123, 25]), diff --git a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs index 305be291a..eacc03142 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs @@ -1,4 +1,7 @@ -use crate::base::{database::Column, math::decimal::scale_scalar, scalar::Scalar}; +use crate::base::{ + database::Column, + scalar::{Scalar, ScalarExt}, +}; use bumpalo::Bump; #[allow( @@ -65,10 +68,8 @@ pub(crate) fn scale_and_add_subtract_eval( is_subtract: bool, ) -> S { let max_scale = lhs_scale.max(rhs_scale); - let left_scaled_eval = scale_scalar(lhs_eval, max_scale - lhs_scale) - .expect("scaling factor should not be negative"); - let right_scaled_eval = scale_scalar(rhs_eval, max_scale - rhs_scale) - .expect("scaling factor should not be negative"); + let left_scaled_eval = lhs_eval * S::pow10(max_scale.abs_diff(lhs_scale)); + let right_scaled_eval = rhs_eval * S::pow10(max_scale.abs_diff(rhs_scale)); if is_subtract { left_scaled_eval - right_scaled_eval } else {