Skip to content

Commit

Permalink
refactor!: add ScalarExt (#281)
Browse files Browse the repository at this point in the history
# Rationale for this change

Some helper code for `Scalar`s is scattered a bit. This PR clean this
up.

# What changes are included in this PR?

* Moved signed_cmp to ScalarExt trait.
* Replaced scale_scalar with ScalarExt::pow10

# Are these changes tested?

Yes
  • Loading branch information
JayWhite2357 authored Oct 20, 2024
2 parents a30c7d6 + 89d54d1 commit 87af1c2
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 77 deletions.
6 changes: 3 additions & 3 deletions crates/proof-of-sql/src/base/database/column.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{LiteralValue, OwnedColumn, TableRef};
use crate::base::{
math::decimal::{scale_scalar, Precision},
scalar::Scalar,
math::decimal::Precision,
scalar::{Scalar, ScalarExt},
slice_ops::slice_cast_with,
};
use alloc::{sync::Arc, vec::Vec};
Expand Down Expand Up @@ -213,7 +213,7 @@ impl<'a, S: Scalar> Column<'a, S> {
/// Convert a column to a vector of Scalar values with scaling
#[allow(clippy::missing_panics_doc)]
pub(crate) fn to_scalar_with_scaling(self, scale: i8) -> Vec<S> {
let scale_factor = scale_scalar(S::ONE, scale).expect("Invalid scale factor");
let scale_factor = S::pow10(u8::try_from(scale).expect("Upscale factor is nonnegative"));
match self {
Self::Boolean(col) => slice_cast_with(col, |b| S::from(b) * scale_factor),
Self::Decimal75(_, _, col) => slice_cast_with(col, |s| *s * scale_factor),
Expand Down
28 changes: 16 additions & 12 deletions crates/proof-of-sql/src/base/database/column_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
use super::{ColumnOperationError, ColumnOperationResult};
use crate::base::{
database::ColumnType,
math::decimal::{scale_scalar, DecimalError, Precision},
scalar::Scalar,
math::decimal::{DecimalError, Precision},
scalar::{Scalar, ScalarExt},
};
use alloc::{format, string::ToString, vec::Vec};
use core::{cmp::Ordering, fmt::Debug};
Expand Down Expand Up @@ -548,7 +548,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool { Into::<S>::into(*l) * upscale_factor == *r })
Expand All @@ -569,7 +569,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool { Into::<S>::into(*l) == *r * upscale_factor })
Expand Down Expand Up @@ -624,7 +624,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -652,7 +652,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -709,7 +709,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -737,7 +737,7 @@ where
.collect::<Vec<_>>()
} else {
let upscale_factor =
scale_scalar(S::ONE, upscale).expect("Upscale factor is nonnegative");
S::pow10(u8::try_from(upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs.iter())
.map(|(l, r)| -> bool {
Expand Down Expand Up @@ -786,13 +786,15 @@ where
.expect("numeric columns have scale");
// One of left_scale and right_scale is 0 so we can avoid scaling when unnecessary
let scalars: Vec<S> = if left_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, left_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(left_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) * upscale_factor + S::from(*r))
.collect()
} else if right_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, right_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(right_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) + upscale_factor * S::from(*r))
Expand Down Expand Up @@ -846,13 +848,15 @@ where
.expect("numeric columns have scale");
// One of left_scale and right_scale is 0 so we can avoid scaling when unnecessary
let scalars: Vec<S> = if left_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, left_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(left_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) * upscale_factor - S::from(*r))
.collect()
} else if right_upscale > 0 {
let upscale_factor = scale_scalar(S::ONE, right_upscale)?;
let upscale_factor =
S::pow10(u8::try_from(right_upscale).expect("Upscale factor is nonnegative"));
lhs.iter()
.zip(rhs)
.map(|(l, r)| S::from(*l) - upscale_factor * S::from(*r))
Expand Down
6 changes: 3 additions & 3 deletions crates/proof-of-sql/src/base/database/group_by_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use crate::base::{
database::{filter_util::filter_column_by_index, Column, OwnedColumn},
if_rayon,
scalar::Scalar,
scalar::{Scalar, ScalarExt},
};
use alloc::vec::Vec;
use bumpalo::Bump;
Expand Down Expand Up @@ -311,7 +311,7 @@ where
indexes[start..index]
.iter()
.map(|i| S::from(&slice[*i]))
.max_by(super::super::scalar::Scalar::signed_cmp)
.max_by(super::super::scalar::ScalarExt::signed_cmp)
}))
}

Expand Down Expand Up @@ -352,7 +352,7 @@ where
indexes[start..index]
.iter()
.map(|i| S::from(&slice[*i]))
.min_by(super::super::scalar::Scalar::signed_cmp)
.min_by(super::super::scalar::ScalarExt::signed_cmp)
}))
}

Expand Down
33 changes: 10 additions & 23 deletions crates/proof-of-sql/src/base/math/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Module for parsing an `IntermediateDecimal` into a `Decimal75`.
use crate::base::scalar::{Scalar, ScalarConversionError};
use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt};
use alloc::{
format,
string::{String, ToString},
Expand Down Expand Up @@ -125,6 +125,7 @@ impl<S: Scalar> Decimal<S> {
}
}

#[allow(clippy::missing_panics_doc)]
/// Scale the decimal to the new scale factor. Negative scaling and overflow error out.
#[allow(clippy::cast_sign_loss)]
pub fn with_precision_and_scale(
Expand All @@ -138,10 +139,12 @@ impl<S: Scalar> Decimal<S> {
error: "Scale factor must be non-negative".to_string(),
});
}
let scaled_value = scale_scalar(self.value, scale_factor)?;
let scaled_value =
self.value * S::pow10(u8::try_from(scale_factor).expect("scale_factor is nonnegative"));
Ok(Decimal::new(scaled_value, new_precision, new_scale))
}

#[allow(clippy::missing_panics_doc)]
/// Get a decimal with given precision and scale from an i64
#[allow(clippy::cast_sign_loss)]
pub fn from_i64(value: i64, precision: Precision, scale: i8) -> DecimalResult<Self> {
Expand All @@ -157,10 +160,12 @@ impl<S: Scalar> Decimal<S> {
error: "Can not scale down a decimal".to_string(),
});
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
let scaled_value =
S::from(&value) * S::pow10(u8::try_from(scale).expect("scale is nonnegative"));
Ok(Decimal::new(scaled_value, precision, scale))
}

#[allow(clippy::missing_panics_doc)]
/// Get a decimal with given precision and scale from an i128
#[allow(clippy::cast_sign_loss)]
pub fn from_i128(value: i128, precision: Precision, scale: i8) -> DecimalResult<Self> {
Expand All @@ -176,7 +181,8 @@ impl<S: Scalar> Decimal<S> {
error: "Can not scale down a decimal".to_string(),
});
}
let scaled_value = scale_scalar(S::from(&value), scale)?;
let scaled_value =
S::from(&value) * S::pow10(u8::try_from(scale).expect("scale is nonnegative"));
Ok(Decimal::new(scaled_value, precision, scale))
}
}
Expand Down Expand Up @@ -210,25 +216,6 @@ pub(crate) fn try_into_to_scalar<S: Scalar>(
})
}

/// Scale scalar by the given scale factor. Negative scaling is not allowed.
/// Note that we do not check for overflow.
pub(crate) fn scale_scalar<S: Scalar>(s: S, scale: i8) -> DecimalResult<S> {
match scale {
0 => Ok(s),
_ if scale < 0 => Err(DecimalError::RoundingError {
error: "Scale factor must be non-negative".to_string(),
}),
_ => {
let ten = S::from(10);
let mut res = s;
for _ in 0..scale {
res *= ten;
}
Ok(res)
}
}
}

#[cfg(test)]
mod scale_adjust_test {

Expand Down
5 changes: 4 additions & 1 deletion crates/proof-of-sql/src/base/scalar/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/// This module contains the definition of the `Scalar` trait, which is used to represent the scalar field used in Proof of SQL.
pub mod scalar;
mod scalar;
pub use scalar::Scalar;
mod error;
pub use error::ScalarConversionError;
Expand All @@ -13,3 +13,6 @@ pub(crate) use mont_scalar::MontScalar;
pub mod test_scalar;
#[cfg(test)]
mod test_scalar_test;

mod scalar_ext;
pub use scalar_ext::ScalarExt;
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/scalar/mont_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ impl super::Scalar for Curve25519Scalar {
const ZERO: Self = Self(ark_ff::MontFp!("0"));
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));
const TEN: Self = Self(ark_ff::MontFp!("10"));
}

impl<T> TryFrom<MontScalar<T>> for bool
Expand Down
15 changes: 0 additions & 15 deletions crates/proof-of-sql/src/base/scalar/mont_scalar_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::base::{
};
use alloc::{format, string::ToString, vec::Vec};
use byte_slice_cast::AsByteSlice;
use core::cmp::Ordering;
use num_bigint::BigInt;
use num_traits::{Inv, One, Zero};
use rand::{
Expand Down Expand Up @@ -361,20 +360,6 @@ fn the_one_scalar_is_the_multiplicative_identity() {
}
}

#[test]
fn scalar_comparison_works() {
let zero = Curve25519Scalar::ZERO;
let one = Curve25519Scalar::ONE;
let two = Curve25519Scalar::TWO;
let max = Curve25519Scalar::MAX_SIGNED;
let min = max + one;
assert_eq!(max.signed_cmp(&one), Ordering::Greater);
assert_eq!(one.signed_cmp(&zero), Ordering::Greater);
assert_eq!(min.signed_cmp(&zero), Ordering::Less);
assert_eq!((two * max).signed_cmp(&zero), Ordering::Less);
assert_eq!(two * max + one, zero);
}

#[test]
fn the_empty_string_will_be_mapped_to_the_zero_scalar() {
assert_eq!(Curve25519Scalar::from(""), Curve25519Scalar::zero());
Expand Down
12 changes: 3 additions & 9 deletions crates/proof-of-sql/src/base/scalar/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use crate::base::{encode::VarInt, ref_into::RefInto, scalar::ScalarConversionError};
use alloc::string::String;
use core::{cmp::Ordering, ops::Sub};
use core::ops::Sub;
use num_bigint::BigInt;

/// A trait for the scalar field used in Proof of SQL.
Expand Down Expand Up @@ -69,12 +69,6 @@ pub trait Scalar:
const ONE: Self;
/// 1 + 1
const TWO: Self;
/// Compare two `Scalar`s as signed numbers.
fn signed_cmp(&self, other: &Self) -> Ordering {
match *self - *other {
x if x.is_zero() => Ordering::Equal,
x if x > Self::MAX_SIGNED => Ordering::Less,
_ => Ordering::Greater,
}
}
/// 2 + 2 + 2 + 2 + 2
const TEN: Self;
}
55 changes: 55 additions & 0 deletions crates/proof-of-sql/src/base/scalar/scalar_ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use super::Scalar;
use core::cmp::Ordering;

/// Extention trait for blanket implementations for `Scalar` types.
/// This trait is primarily to avoid cluttering the core `Scalar` implementation with default implemenentations
/// and provides helper methods for `Scalar`.
pub trait ScalarExt: Scalar {
/// Compute 10^exponent for the Scalar. Note that we do not check for overflow.
fn pow10(exponent: u8) -> Self {
itertools::repeat_n(Self::TEN, exponent as usize).product()
}
/// Compare two `Scalar`s as signed numbers.
fn signed_cmp(&self, other: &Self) -> Ordering {
match *self - *other {
x if x.is_zero() => Ordering::Equal,
x if x > Self::MAX_SIGNED => Ordering::Less,
_ => Ordering::Greater,
}
}
}
impl<S: Scalar> ScalarExt for S {}

#[cfg(test)]
mod tests {
use super::*;
use crate::base::scalar::{test_scalar::TestScalar, Curve25519Scalar, MontScalar};
#[test]
fn scalar_comparison_works() {
let zero = Curve25519Scalar::ZERO;
let one = Curve25519Scalar::ONE;
let two = Curve25519Scalar::TWO;
let max = Curve25519Scalar::MAX_SIGNED;
let min = max + one;
assert_eq!(max.signed_cmp(&one), Ordering::Greater);
assert_eq!(one.signed_cmp(&zero), Ordering::Greater);
assert_eq!(min.signed_cmp(&zero), Ordering::Less);
assert_eq!((two * max).signed_cmp(&zero), Ordering::Less);
assert_eq!(two * max + one, zero);
}
#[test]
fn we_can_compute_powers_of_10() {
for i in 0..=u128::MAX.ilog10() {
assert_eq!(
TestScalar::pow10(u8::try_from(i).unwrap()),
TestScalar::from(u128::pow(10, i))
);
}
assert_eq!(
TestScalar::pow10(76),
MontScalar(ark_ff::MontFp!(
"10000000000000000000000000000000000000000000000000000000000000000000000000000"
))
);
}
}
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/scalar/test_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ impl Scalar for TestScalar {
const ZERO: Self = Self(ark_ff::MontFp!("0"));
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));
const TEN: Self = Self(ark_ff::MontFp!("10"));
}

pub struct TestMontConfig(pub ark_curve25519::FrConfig);
Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/scalar/test_scalar_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ fn we_can_get_test_scalar_constants_from_z_p() {
assert_eq!(TestScalar::from(2), TestScalar::TWO);
// -1/2 == least upper bound
assert_eq!(-TestScalar::TWO.inv().unwrap(), TestScalar::MAX_SIGNED);
assert_eq!(TestScalar::from(10), TestScalar::TEN);
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ impl Scalar for DoryScalar {
const ZERO: Self = Self(ark_ff::MontFp!("0"));
const ONE: Self = Self(ark_ff::MontFp!("1"));
const TWO: Self = Self(ark_ff::MontFp!("2"));
const TEN: Self = Self(ark_ff::MontFp!("10"));
}

#[derive(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::DoryScalar;
use crate::base::scalar::{Scalar, ScalarConversionError};
use crate::base::scalar::{Scalar, ScalarConversionError, ScalarExt};
use core::cmp::Ordering;
use num_bigint::BigInt;

Expand Down
Loading

0 comments on commit 87af1c2

Please sign in to comment.