Skip to content

Commit

Permalink
fix: implement TryFrom<BigInt> for MontScalar<T> properly (#274)
Browse files Browse the repository at this point in the history
# Rationale for this change

`TryFrom<BigInt> for MontScalar<T>` is implemented improperly, and never
overflows

# What changes are included in this PR?

The implementation is fixed.

# Are these changes tested?

Yes. Tests are added.
  • Loading branch information
JayWhite2357 authored Oct 17, 2024
2 parents c749566 + 44e6aea commit b1179f6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 27 deletions.
49 changes: 23 additions & 26 deletions crates/proof-of-sql/src/base/scalar/mont_scalar.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use super::{Scalar, ScalarConversionError};
use crate::base::math::decimal::MAX_SUPPORTED_PRECISION;
use alloc::{format, string::String, vec::Vec};
use alloc::{
format,
string::{String, ToString},
vec::Vec,
};
use ark_ff::{BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use bytemuck::TransparentWrapper;
Expand Down Expand Up @@ -217,35 +220,29 @@ impl<T: MontConfig<4>> MontScalar<T> {
}
}

impl<T: MontConfig<4>> TryFrom<BigInt> for MontScalar<T> {
impl<T> TryFrom<BigInt> for MontScalar<T>
where
T: MontConfig<4>,
MontScalar<T>: Scalar,
{
type Error = ScalarConversionError;

fn try_from(value: BigInt) -> Result<Self, Self::Error> {
// Obtain the absolute value to ignore the sign when counting digits
let value_abs = value.abs();

// Extract digits and check the number of digits directly
let (_, digits) = value_abs.to_u64_digits();

// Check if the number of digits exceeds the maximum precision allowed
if digits.len() > MAX_SUPPORTED_PRECISION.into() {
return Err(ScalarConversionError::Overflow{ error: format!(
"Attempted to parse a number with {} digits, which exceeds the max supported precision of {}",
digits.len(),
MAX_SUPPORTED_PRECISION
)});
if value.abs() > BigInt::from(<MontScalar<T>>::MAX_SIGNED) {
return Err(ScalarConversionError::Overflow {
error: "BigInt too large for Scalar".to_string(),
});
}

// Continue with the previous logic
assert!(digits.len() <= 4); // This should not happen if the precision check is correct
let mut data = [0u64; 4];
data[..digits.len()].copy_from_slice(&digits);
let result = Self::from_bigint(data);
match value.sign() {
// Updated to use value.sign() for clarity
num_bigint::Sign::Minus => Ok(-result),
_ => Ok(result),
}
let (sign, digits) = value.to_u64_digits();
assert!(digits.len() <= 4); // This should not happen if the above check is correct
let mut limbs = [0u64; 4];
limbs[..digits.len()].copy_from_slice(&digits);
let result = Self::from(limbs);
Ok(match sign {
num_bigint::Sign::Minus => -result,
num_bigint::Sign::Plus | num_bigint::Sign::NoSign => result,
})
}
}
impl<T: MontConfig<4>> From<[u64; 4]> for MontScalar<T> {
Expand Down
41 changes: 40 additions & 1 deletion crates/proof-of-sql/src/base/scalar/mont_scalar_test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::base::{
map::IndexSet,
scalar::{Curve25519Scalar, Scalar, ScalarConversionError},
scalar::{test_scalar::TestScalar, Curve25519Scalar, Scalar, ScalarConversionError},
};
use alloc::{format, string::ToString, vec::Vec};
use byte_slice_cast::AsByteSlice;
Expand Down Expand Up @@ -471,3 +471,42 @@ fn the_string_hash_implementation_uses_the_full_range_of_bits() {
}
}
}

#[test]
fn test_bigint_to_scalar_overflow() {
assert_eq!(
TestScalar::try_from(
"3618502788666131106986593281521497120428558179689953803000975469142727125494"
.parse::<BigInt>()
.unwrap()
)
.unwrap(),
TestScalar::MAX_SIGNED
);
assert_eq!(
TestScalar::try_from(
"-3618502788666131106986593281521497120428558179689953803000975469142727125494"
.parse::<BigInt>()
.unwrap()
)
.unwrap(),
-TestScalar::MAX_SIGNED
);

assert!(matches!(
TestScalar::try_from(
"3618502788666131106986593281521497120428558179689953803000975469142727125495"
.parse::<BigInt>()
.unwrap()
),
Err(ScalarConversionError::Overflow { .. })
));
assert!(matches!(
TestScalar::try_from(
"-3618502788666131106986593281521497120428558179689953803000975469142727125495"
.parse::<BigInt>()
.unwrap()
),
Err(ScalarConversionError::Overflow { .. })
));
}

0 comments on commit b1179f6

Please sign in to comment.