Skip to content

Commit

Permalink
refactor: replace scale_scalar with ScalarExt::pow10
Browse files Browse the repository at this point in the history
  • Loading branch information
JayWhite2357 committed Oct 19, 2024
1 parent b34f7ae commit 89d54d1
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 47 deletions.
6 changes: 3 additions & 3 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{LiteralValue, OwnedColumn, TableRef};
use crate::base::{
math::decimal::{scale_scalar, Precision},
scalar::Scalar,
math::decimal::Precision,
scalar::{Scalar, ScalarExt},
slice_ops::slice_cast_with,
};
use alloc::{sync::Arc, vec::Vec};
Expand Down Expand Up @@ -213,7 +213,7 @@ impl<'a, S: Scalar> Column<'a, S> {
/// Convert a column to a vector of Scalar values with scaling
#[allow(clippy::missing_panics_doc)]
pub(crate) fn to_scalar_with_scaling(self, scale: i8) -> Vec<S> {
let scale_factor = scale_scalar(S::ONE, scale).expect("Invalid scale factor");
let scale_factor = S::pow10(u8::try_from(scale).expect("Upscale factor is nonnegative"));
match self {
Self::Boolean(col) => slice_cast_with(col, |b| S::from(b) * scale_factor),
Self::Decimal75(_, _, col) => slice_cast_with(col, |s| *s * scale_factor),
Expand Down
26 changes: 15 additions & 11 deletions crates/proof-of-sql/src/base/database/column_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use super::{ColumnOperationError, ColumnOperationResult};
use crate::base::{
database::ColumnType,
math::decimal::{scale_scalar, DecimalError, Precision},
math::decimal::{DecimalError, Precision},
scalar::{Scalar, ScalarExt},
};
use alloc::{format, string::ToString, vec::Vec};
Expand Down Expand Up @@ -548,7 +548,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool { Into::<S>::into(*l) * upscale_factor == *r })
Expand All @@ -569,7 +569,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool { Into::<S>::into(*l) == *r * upscale_factor })
Expand Down Expand Up @@ -624,7 +624,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -652,7 +652,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -709,7 +709,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -737,7 +737,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -786,13 +786,15 @@ where
.expect("numeric columns have scale");
// One of left_scale and right_scale is 0 so we can avoid scaling when unnecessary
let scalars: Vec<S> = if left_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, left_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(left_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) * upscale_factor + S::from(*r))
.collect()
} else if right_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, right_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(right_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) + upscale_factor * S::from(*r))
Expand Down Expand Up @@ -846,13 +848,15 @@ where
.expect("numeric columns have scale");
// One of left_scale and right_scale is 0 so we can avoid scaling when unnecessary
let scalars: Vec<S> = if left_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, left_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(left_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) * upscale_factor - S::from(*r))
.collect()
} else if right_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, right_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(right_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) - upscale_factor * S::from(*r))
Expand Down
33 changes: 10 additions & 23 deletions crates/proof-of-sql/src/base/math/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Module for parsing an `IntermediateDecimal` into a `Decimal75`.
use crate::base::scalar::{Scalar, ScalarConversionError};
use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt};
use alloc::{
format,
string::{String, ToString},
Expand Down Expand Up @@ -125,6 +125,7 @@ impl<S: Scalar> Decimal<S> {
}
}

#[allow(clippy::missing_panics_doc)]
/// Scale the decimal to the new scale factor. Negative scaling and overflow error out.
#[allow(clippy::cast_sign_loss)]
pub fn with_precision_and_scale(
Expand All @@ -138,10 +139,12 @@ impl<S: Scalar> Decimal<S> {
error: "Scale factor must be non-negative".to_string(),
});
}
let scaled_value = scale_scalar(self.value, scale_factor)?;
let scaled_value =
self.value * S::pow10(u8::try_from(scale_factor).expect("scale_factor is nonnegative"));
Ok(Decimal::new(scaled_value, new_precision, new_scale))
}

#[allow(clippy::missing_panics_doc)]
/// Get a decimal with given precision and scale from an i64
#[allow(clippy::cast_sign_loss)]
pub fn from_i64(value: i64, precision: Precision, scale: i8) -> DecimalResult<Self> {
Expand All @@ -157,10 +160,12 @@ impl<S: Scalar> Decimal<S> {
error: "Can not scale down a decimal".to_string(),
});
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
let scaled_value =
S::from(&value) * S::pow10(u8::try_from(scale).expect("scale is nonnegative"));
Ok(Decimal::new(scaled_value, precision, scale))
}

#[allow(clippy::missing_panics_doc)]
/// Get a decimal with given precision and scale from an i128
#[allow(clippy::cast_sign_loss)]
pub fn from_i128(value: i128, precision: Precision, scale: i8) -> DecimalResult<Self> {
Expand All @@ -176,7 +181,8 @@ impl<S: Scalar> Decimal<S> {
error: "Can not scale down a decimal".to_string(),
});
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
let scaled_value =
S::from(&value) * S::pow10(u8::try_from(scale).expect("scale is nonnegative"));
Ok(Decimal::new(scaled_value, precision, scale))
}
}
Expand Down Expand Up @@ -210,25 +216,6 @@ pub(crate) fn try_into_to_scalar<S: Scalar>(
})
}

/// Scale scalar by the given scale factor. Negative scaling is not allowed.
/// Note that we do not check for overflow.
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> DecimalResult<S> {
match scale {
0 => Ok(s),
_ if scale < 0 => Err(DecimalError::RoundingError {
error: "Scale factor must be non-negative".to_string(),
}),
_ => {
let ten = S::from(10);
let mut res = s;
for _ in 0..scale {
res *= ten;
}
Ok(res)
}
}
}

#[cfg(test)]
mod scale_adjust_test {

Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/scalar/test_scalar_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ fn we_can_get_test_scalar_constants_from_z_p() {
assert_eq!(TestScalar::from(2), TestScalar::TWO);
// -1/2 == least upper bound
assert_eq!(-TestScalar::TWO.inv().unwrap(), TestScalar::MAX_SIGNED);
assert_eq!(TestScalar::from(10), TestScalar::TEN);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use crate::{
owned_table_utility::*, Column, LiteralValue, OwnedTable, OwnedTableTestAccessor,
TestAccessor,
},
math::decimal::scale_scalar,
scalar::{Curve25519Scalar, Scalar},
scalar::{Curve25519Scalar, Scalar, ScalarExt},
},
sql::{
parse::ConversionError,
Expand Down Expand Up @@ -164,7 +163,7 @@ fn we_can_compare_columns_with_extreme_values() {

#[test]
fn we_can_compare_columns_with_small_decimal_values_without_scale() {
let scalar_pos = scale_scalar(Curve25519Scalar::ONE, 38).unwrap() - Curve25519Scalar::ONE;
let scalar_pos = Curve25519Scalar::pow10(38) - Curve25519Scalar::ONE;
let scalar_neg = -scalar_pos;
let data: OwnedTable<Curve25519Scalar> = owned_table([
bigint("a", [123, 25]),
Expand Down Expand Up @@ -192,7 +191,7 @@ fn we_can_compare_columns_with_small_decimal_values_without_scale() {

#[test]
fn we_can_compare_columns_with_small_decimal_values_with_scale() {
let scalar_pos = scale_scalar(Curve25519Scalar::ONE, 38).unwrap() - Curve25519Scalar::ONE;
let scalar_pos = Curve25519Scalar::pow10(38) - Curve25519Scalar::ONE;
let scalar_neg = -scalar_pos;
let data: OwnedTable<Curve25519Scalar> = owned_table([
bigint("a", [123, 25]),
Expand Down Expand Up @@ -222,7 +221,7 @@ fn we_can_compare_columns_with_small_decimal_values_with_scale() {

#[test]
fn we_can_compare_columns_with_small_decimal_values_with_differing_scale_gte() {
let scalar_pos = scale_scalar(Curve25519Scalar::ONE, 38).unwrap() - Curve25519Scalar::ONE;
let scalar_pos = Curve25519Scalar::pow10(38) - Curve25519Scalar::ONE;
let scalar_neg = -scalar_pos;
let data: OwnedTable<Curve25519Scalar> = owned_table([
bigint("a", [123, 25]),
Expand Down
11 changes: 6 additions & 5 deletions crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::base::{database::Column, math::decimal::scale_scalar, scalar::Scalar};
use crate::base::{
database::Column,
scalar::{Scalar, ScalarExt},
};
use bumpalo::Bump;

#[allow(
Expand Down Expand Up @@ -65,10 +68,8 @@ pub(crate) fn scale_and_add_subtract_eval<S: Scalar>(
is_subtract: bool,
) -> S {
let max_scale = lhs_scale.max(rhs_scale);
let left_scaled_eval = scale_scalar(lhs_eval, max_scale - lhs_scale)
.expect("scaling factor should not be negative");
let right_scaled_eval = scale_scalar(rhs_eval, max_scale - rhs_scale)
.expect("scaling factor should not be negative");
let left_scaled_eval = lhs_eval * S::pow10(max_scale.abs_diff(lhs_scale));
let right_scaled_eval = rhs_eval * S::pow10(max_scale.abs_diff(rhs_scale));
if is_subtract {
left_scaled_eval - right_scaled_eval
} else {
Expand Down

0 comments on commit 89d54d1

Please sign in to comment.