diff --git a/crates/stark-felt/Cargo.toml b/crates/stark-felt/Cargo.toml index 03e1b43..d2968dc 100644 --- a/crates/stark-felt/Cargo.toml +++ b/crates/stark-felt/Cargo.toml @@ -13,12 +13,15 @@ readme = "README.md" [dependencies] bitvec = { version = "1.0.1", default-features = false } serde = { version = "1.0.163", optional = true, default-features = false } -lambdaworks-math = { version = "0.1.3", default-features = false } +lambdaworks-math = { version = "0.2.0", default_features = false } +arbitrary = { version = "1.3.0", optional = true, default-features = false } +num-traits = { version = "0.2.16", default-features = false } [features] default = ["std", "serde"] std = [] alloc = ["serde?/alloc"] +arbitrary = ["std", "dep:arbitrary"] [dev-dependencies] proptest = "1.1.0" diff --git a/crates/stark-felt/src/arbitrary.rs b/crates/stark-felt/src/arbitrary_proptest.rs similarity index 98% rename from crates/stark-felt/src/arbitrary.rs rename to crates/stark-felt/src/arbitrary_proptest.rs index 28ae9e5..49ea020 100644 --- a/crates/stark-felt/src/arbitrary.rs +++ b/crates/stark-felt/src/arbitrary_proptest.rs @@ -1,4 +1,5 @@ use lambdaworks_math::{field::element::FieldElement, unsigned_integer::element::UnsignedInteger}; +use num_traits::Zero; use proptest::prelude::*; use crate::Felt; diff --git a/crates/stark-felt/src/lib.rs b/crates/stark-felt/src/lib.rs index ee190b7..12c0328 100644 --- a/crates/stark-felt/src/lib.rs +++ b/crates/stark-felt/src/lib.rs @@ -1,9 +1,12 @@ #![cfg_attr(not(feature = "std"), no_std)] +use core::ops::{Add, Neg}; + use bitvec::array::BitArray; +use num_traits::{FromPrimitive, ToPrimitive, Zero}; #[cfg(test)] -mod arbitrary; +mod arbitrary_proptest; #[cfg(target_pointer_width = "64")] pub type BitArrayStore = [u64; 4]; @@ -25,8 +28,11 @@ use lambdaworks_math::{ unsigned_integer::element::UnsignedInteger, }; +#[cfg(feature = "arbitrary")] +use arbitrary::{self, Arbitrary, Unstructured}; + /// Definition of the Field Element type. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Felt(FieldElement); /// A non-zero [Felt]. @@ -124,25 +130,27 @@ impl Felt { BitArray::new(limbs) } - /// Checks if `self` is equal to [Felt::Zero]. - pub fn is_zero(&self) -> bool { - *self == Felt::ZERO - } /// Finite field division. pub fn field_div(&self, rhs: &NonZeroFelt) -> Self { Self(self.0 / rhs.0) } - /// Floor division. + /// Truncated quotient between `self` and `rhs`. pub fn floor_div(&self, rhs: &NonZeroFelt) -> Self { Self(FieldElement::from( &(self.0.representative().div_rem(&rhs.0.representative())).0, )) } - /// Multiplicative inverse. + /// Quotient and remainder between `self` and `rhs`. + pub fn div_rem(&self, rhs: &NonZeroFelt) -> (Self, Self) { + let (q, r) = self.0.representative().div_rem(&rhs.0.representative()); + (Self(FieldElement::from(&q)), Self(FieldElement::from(&r))) + } + + /// Multiplicative inverse inside field. pub fn inverse(&self) -> Option { - Some(Self(self.0.inv())) + self.0.inv().map(Self).ok() } /// Finds the square root. There may be 2 roots for each square, and the lower one is returned. @@ -157,15 +165,80 @@ impl Felt { } /// Raises `self` to the power of `exponent`. - pub fn pow(&self, exponent: u128) -> Self { - Self(self.0.pow(exponent)) + pub fn pow(&self, exponent: impl Into) -> Self { + Self(self.0.pow(exponent.into())) } - /// Performs self modulo n, with n being smaller that the stark field prime. - pub fn mod_floor(&self, n: &Self) -> Self { - Self(FieldElement::from( - &(self.0).representative().div_rem(&n.0.representative()).1, - )) + /// Raises `self` to the power of `exponent`. + pub fn pow_felt(&self, exponent: &Felt) -> Self { + Self(self.0.pow(exponent.0.representative())) + } + + /// Modular multiplication between `self` and `rhs` modulo `p`. + pub fn mul_mod(&self, rhs: &Self, p: &NonZeroFelt) -> Self { + (self * rhs).div_rem(p).1 + } + + /// Modular inverse of `self` modulo `p`. + pub fn inverse_mod(&self, p: &NonZeroFelt) -> Option { + self.inverse().map(|x| x.div_rem(p).1) + } + + /// Remainder of dividing `self` by `n` as integers. + pub fn mod_floor(&self, n: &NonZeroFelt) -> Self { + self.div_rem(n).1 + } + + /// Parse a hex-encoded number into `Felt`. + pub fn from_hex(hex_string: &str) -> Result { + FieldElement::from_hex(hex_string) + .map(Self) + .map_err(|_| FromStrError) + } + + /// Parse a decimal-encoded number into `Felt`. + pub fn from_dec_str(dec_string: &str) -> Result { + if dec_string.starts_with('-') { + UnsignedInteger::from_dec_str(dec_string.strip_prefix('-').unwrap()) + .map(|x| Self(FieldElement::from(&x)).neg()) + .map_err(|_| FromStrError) + } else { + UnsignedInteger::from_dec_str(dec_string) + .map(|x| Self(FieldElement::from(&x))) + .map_err(|_| FromStrError) + } + } + + /// Convert `self`'s representative into an array of `u64` digits, + /// least significant digits first. + pub fn to_le_digits(&self) -> [u64; 4] { + let mut limbs = self.0.representative().limbs; + limbs.reverse(); + limbs + } + + /// Convert `self`'s representative into an array of `u64` digits, + /// most significant digits first. + pub fn to_be_digits(&self) -> [u64; 4] { + self.0.representative().limbs + } + + /// Count the minimum number of bits needed to express `self`'s representative. + pub fn bits(&self) -> usize { + self.0.representative().bits_le() + } +} + +#[cfg(feature = "arbitrary")] +impl<'a> Arbitrary<'a> for Felt { + // Creates an arbitrary `Felt` from unstructured input for fuzzing. + // It uses the default implementation to create the internal limbs and then + // uses the usual constructors from `lambdaworks-math`. + fn arbitrary(u: &mut Unstructured) -> arbitrary::Result { + let limbs = <[u64; 4]>::arbitrary(u)?; + let uint = UnsignedInteger::from_limbs(limbs); + let felt = FieldElement::new(uint); + Ok(Felt(felt)) } } @@ -224,8 +297,145 @@ impl TryFrom<&Felt> for NonZeroFelt { } } +impl From for Felt { + fn from(value: u128) -> Felt { + Self(FieldElement::from(&UnsignedInteger::from(value))) + } +} + +impl From for Felt { + fn from(value: i128) -> Felt { + let mut res = Self(FieldElement::from(&UnsignedInteger::from( + value.unsigned_abs(), + ))); + if value.is_negative() { + res = -res; + } + res + } +} + +macro_rules! impl_from { + ($from:ty, $with:ty) => { + impl From<$from> for Felt { + fn from(value: $from) -> Self { + (value as $with).into() + } + } + }; +} + +impl_from!(u8, u128); +impl_from!(u16, u128); +impl_from!(u32, u128); +impl_from!(u64, u128); +impl_from!(usize, u128); +impl_from!(i8, i128); +impl_from!(i16, i128); +impl_from!(i32, i128); +impl_from!(i64, i128); +impl_from!(isize, i128); + +impl FromPrimitive for Felt { + fn from_i64(value: i64) -> Option { + Some(value.into()) + } + + fn from_u64(value: u64) -> Option { + Some(value.into()) + } + + fn from_i128(value: i128) -> Option { + Some(value.into()) + } + + fn from_u128(value: u128) -> Option { + Some(value.into()) + } +} + +// TODO: we need to decide whether we want conversions to signed primitives +// will support converting the high end of the field to negative. +impl ToPrimitive for Felt { + fn to_u64(&self) -> Option { + self.to_u128().and_then(|x| u64::try_from(x).ok()) + } + + fn to_i64(&self) -> Option { + self.to_u128().and_then(|x| i64::try_from(x).ok()) + } + + fn to_u128(&self) -> Option { + match self.0.representative().limbs { + [0, 0, hi, lo] => Some((lo as u128) | ((hi as u128) << 64)), + _ => None, + } + } + + fn to_i128(&self) -> Option { + self.to_u128().and_then(|x| i128::try_from(x).ok()) + } +} + +impl Zero for Felt { + fn is_zero(&self) -> bool { + *self == Felt::ZERO + } + + fn zero() -> Felt { + Felt::ZERO + } +} + +impl Add<&Felt> for u64 { + type Output = Option; + + fn add(self, rhs: &Felt) -> Option { + const PRIME_DIGITS_BE_HI: [u64; 3] = + [0x0800000000000011, 0x0000000000000000, 0x0000000000000000]; + const PRIME_MINUS_U64_MAX_DIGITS_BE_HI: [u64; 3] = + [0x0800000000000010, 0xffffffffffffffff, 0xffffffffffffffff]; + + // Match with the 64 bits digits in big-endian order to + // characterize how the sum will behave. + match rhs.to_be_digits() { + // All digits are `0`, so the sum is simply `self`. + [0, 0, 0, 0] => Some(self), + // A single digit means this is effectively the sum of two `u64` numbers. + [0, 0, 0, low] => self.checked_add(low), + // Now we need to compare the 3 most significant digits. + // There are two relevant cases from now on, either `rhs` behaves like a + // substraction of a `u64` or the result of the sum falls out of range. + + // The 3 MSB only match the prime for Felt::max_value(), which is -1 + // in the signed field, so this is equivalent to substracting 1 to `self`. + [hi @ .., _] if hi == PRIME_DIGITS_BE_HI => self.checked_sub(1), + + // For the remaining values between `[-u64::MAX..0]` (where `{0, -1}` have + // already been covered) the MSB matches that of `PRIME - u64::MAX`. + // Because we're in the negative number case, we count down. Because `0` + // and `-1` correspond to different MSBs, `0` and `1` in the LSB are less + // than `-u64::MAX`, the smallest value we can add to (read, substract its + // magnitude from) a `u64` number, meaning we exclude them from the valid + // case. + // For the remaining range, we take the absolute value module-2 while + // correcting by substracting `1` (note we actually substract `2` because + // the absolute value itself requires substracting `1`. + [hi @ .., low] if hi == PRIME_MINUS_U64_MAX_DIGITS_BE_HI && low >= 2 => { + (self).checked_sub(u64::MAX - (low - 2)) + } + // Any other case will result in an addition that is out of bounds, so + // the addition fails, returning `None`. + _ => None, + } + } +} + mod arithmetic { - use core::{iter, ops}; + use core::{ + iter, + ops::{self, Neg}, + }; use super::*; @@ -279,6 +489,24 @@ mod arithmetic { } } + /// Field addition. Never overflows/underflows. + impl ops::Add for Felt { + type Output = Felt; + + fn add(self, rhs: u64) -> Self::Output { + self + Felt::from(rhs) + } + } + + /// Field addition. Never overflows/underflows. + impl ops::Add for &Felt { + type Output = Felt; + + fn add(self, rhs: u64) -> Self::Output { + self + Felt::from(rhs) + } + } + /// Field subtraction. Never overflows/underflows. impl ops::SubAssign for Felt { fn sub_assign(&mut self, rhs: Felt) { @@ -329,6 +557,40 @@ mod arithmetic { } } + /// Field subtraction. Never overflows/underflows. + #[allow(clippy::suspicious_arithmetic_impl)] + impl ops::Sub for u64 { + type Output = Option; + fn sub(self, rhs: Felt) -> Self::Output { + self + &rhs.neg() + } + } + + /// Field subtraction. Never overflows/underflows. + #[allow(clippy::suspicious_arithmetic_impl)] + impl ops::Sub<&Felt> for u64 { + type Output = Option; + fn sub(self, rhs: &Felt) -> Self::Output { + self + &rhs.neg() + } + } + + /// Field subtraction. Never overflows/underflows. + impl ops::Sub for Felt { + type Output = Felt; + fn sub(self, rhs: u64) -> Self::Output { + self - Self::from(rhs) + } + } + + /// Field subtraction. Never overflows/underflows. + impl ops::Sub for &Felt { + type Output = Felt; + fn sub(self, rhs: u64) -> Self::Output { + self - Felt::from(rhs) + } + } + /// Field multiplication. Never overflows/underflows. impl ops::MulAssign for Felt { fn mul_assign(&mut self, rhs: Felt) { @@ -567,12 +829,11 @@ mod errors { #[cfg(test)] mod test { use super::alloc::{format, string::String, vec::Vec}; - use crate::arbitrary::nonzero_felt; - use core::ops::Shl; - use super::*; - + use crate::arbitrary_proptest::nonzero_felt; + use core::ops::Shl; use proptest::prelude::*; + #[cfg(feature = "serde")] use serde_test::{assert_de_tokens, assert_ser_tokens, Configure, Token}; @@ -686,8 +947,9 @@ mod test { } #[test] - fn mod_floor_in_range(x in any::(), n in any::()) { - let x_mod_n = x.mod_floor(&n); + fn mod_floor_in_range(x in any::(), n in nonzero_felt()) { + let nzn = NonZeroFelt(n.0); + let x_mod_n = x.mod_floor(&nzn); prop_assert!(x_mod_n <= Felt::MAX); prop_assert!(x_mod_n < n); } @@ -769,6 +1031,26 @@ mod test { prop_assert_eq!(x * x.inverse().unwrap(), Felt::ONE ) } + #[test] + fn inverse_mod_of_zero_is_none(p in nonzero_felt()) { + let nzp = NonZeroFelt(p.0); + prop_assert!(Felt::ZERO.inverse_mod(&nzp).is_none()); + } + + #[test] + fn inverse_mod_in_range(x in nonzero_felt(), p in nonzero_felt()) { + let nzp = NonZeroFelt(p.0); + prop_assert!(x.inverse_mod(&nzp) <= Some(Felt::MAX)); + prop_assert!(x.inverse_mod(&nzp) < Some(p)); + } + + #[test] + fn mul_mod_in_range(x in any::(), y in any::(), p in nonzero_felt()) { + let nzp = NonZeroFelt(p.0); + prop_assert!(x.mul_mod(&y, &nzp) <= Felt::MAX); + prop_assert!(x.mul_mod(&y, &nzp) < p); + } + #[test] fn non_zero_felt_new_is_ok_when_not_zero(x in nonzero_felt()) { prop_assert!(NonZeroFelt::try_from(x).is_ok()); @@ -869,14 +1151,14 @@ mod test { #[test] fn pow_operations() { - assert_eq!(Felt::ONE.pow(5), Felt::ONE); - assert_eq!(Felt::ZERO.pow(5), Felt::ZERO); - assert_eq!(Felt::THREE.pow(0), Felt::ONE); + assert_eq!(Felt::ONE.pow(5u32), Felt::ONE); + assert_eq!(Felt::ZERO.pow(5u32), Felt::ZERO); + assert_eq!(Felt::THREE.pow(0u32), Felt::ONE); assert_eq!( - Felt(FieldElement::from(200)).pow(4), + Felt(FieldElement::from(200)).pow(4u32), Felt(FieldElement::from(1600000000)) ); - assert_eq!(Felt::MAX.pow(9), Felt::MAX); + assert_eq!(Felt::MAX.pow(9u32), Felt::MAX); } #[test]