diff --git a/crates/starknet-types-core/Cargo.toml b/crates/starknet-types-core/Cargo.toml index 7b13548..91ab802 100644 --- a/crates/starknet-types-core/Cargo.toml +++ b/crates/starknet-types-core/Cargo.toml @@ -19,6 +19,8 @@ lambdaworks-crypto = { git = "https://github.com/lambdaclass/lambdaworks.git", r arbitrary = { version = "1.3.0", optional = true, default-features = false } num-traits = { version = "0.2.16", default-features = false } +num-bigint = {version = "0.4.4", default-features = false} +num-integer = {version = "0.1.45", default-features = false} [features] default = ["std", "serde", "curve"] diff --git a/crates/starknet-types-core/src/felt.rs b/crates/starknet-types-core/src/felt.rs index 9b5f59d..ef43312 100644 --- a/crates/starknet-types-core/src/felt.rs +++ b/crates/starknet-types-core/src/felt.rs @@ -1,7 +1,9 @@ -use core::ops::{Add, Neg}; +use core::ops::{Add, Mul, Neg}; use bitvec::array::BitArray; -use num_traits::{FromPrimitive, ToPrimitive, Zero}; +use num_bigint::BigInt; +use num_integer::Integer; +use num_traits::{FromPrimitive, One, ToPrimitive, Zero}; #[cfg(target_pointer_width = "64")] pub type BitArrayStore = [u64; 4]; @@ -32,6 +34,7 @@ use arbitrary::{self, Arbitrary, Unstructured}; pub struct Felt(pub(crate) FieldElement); /// A non-zero [Felt]. +#[derive(Debug, Clone, Copy)] pub struct NonZeroFelt(FieldElement); #[derive(Debug)] @@ -258,14 +261,47 @@ impl Felt { Self(self.0.pow(exponent.0.representative())) } - /// Modular multiplication between `self` and `rhs` modulo `p`. + // Implemention taken from Jonathan Lei's starknet-rs + // https://github.com/xJonathanLEI/starknet-rs/blob/a3a0050f80e90bd40303256a85783f4b5b18258c/starknet-crypto/src/fe_utils.rs#L20 + /// Modular multiplication between `self` and `rhs` in modulo `p`. pub fn mul_mod(&self, rhs: &Self, p: &NonZeroFelt) -> Self { - (self * rhs).div_rem(p).1 + let multiplicand = BigInt::from_bytes_be(num_bigint::Sign::Plus, &self.to_bytes_be()); + let multiplier = BigInt::from_bytes_be(num_bigint::Sign::Plus, &rhs.to_bytes_be()); + let modulus = BigInt::from_bytes_be(num_bigint::Sign::Plus, &p.0.to_bytes_be()); + + let result = multiplicand.mul(multiplier).mod_floor(&modulus); + + let (_, buffer) = result.to_bytes_be(); + let mut result = [0u8; 32]; + + result[(32 - buffer.len())..].copy_from_slice(&buffer[..]); + + // safe .unwrap() + Felt::from_bytes_be(&result).unwrap() } - /// Modular inverse of `self` modulo `p`. - pub fn inverse_mod(&self, p: &NonZeroFelt) -> Option { - self.inverse().map(|x| x.div_rem(p).1) + // Implemention taken from Jonathan Lei's starknet-rs + // https://github.com/xJonathanLEI/starknet-rs/blob/a3a0050f80e90bd40303256a85783f4b5b18258c/starknet-crypto/src/fe_utils.rs#L46 + /// Multiplicative inverse of `self` in modulo `p`. + pub fn mod_inverse(&self, p: &NonZeroFelt) -> Option { + let operand = BigInt::from_bytes_be(num_bigint::Sign::Plus, &self.0.to_bytes_be()); + let modulus = BigInt::from_bytes_be(num_bigint::Sign::Plus, &p.0.to_bytes_be()); + + let extended_gcd = operand.extended_gcd(&modulus); + if extended_gcd.gcd != BigInt::one() { + return None; + } + let result = if extended_gcd.x < BigInt::zero() { + extended_gcd.x + modulus + } else { + extended_gcd.x + }; + + let (_, buffer) = result.to_bytes_be(); + let mut result = [0u8; 32]; + result[(32 - buffer.len())..].copy_from_slice(&buffer[..]); + + Felt::from_bytes_be(&result).ok() } /// Remainder of dividing `self` by `n` as integers. @@ -917,7 +953,6 @@ mod test { use crate::felt_arbitrary::nonzero_felt; use core::ops::Shl; use proptest::prelude::*; - #[cfg(feature = "serde")] use serde_test::{assert_de_tokens, assert_ser_tokens, Configure, Token}; @@ -1149,14 +1184,17 @@ mod test { #[test] fn inverse_mod_of_zero_is_none(p in nonzero_felt()) { let nzp = NonZeroFelt(p.0); - prop_assert!(Felt::ZERO.inverse_mod(&nzp).is_none()); + prop_assert!(Felt::ZERO.mod_inverse(&nzp).is_none()); } #[test] fn inverse_mod_in_range(x in nonzero_felt(), p in nonzero_felt()) { let nzp = NonZeroFelt(p.0); - prop_assert!(x.inverse_mod(&nzp) <= Some(Felt::MAX)); - prop_assert!(x.inverse_mod(&nzp) < Some(p)); + let Some(result) = x.mod_inverse(&nzp) else { return Ok(()) }; + + prop_assert!(result <= Felt::MAX); + prop_assert!(result < p); + prop_assert!(result.mul_mod(&x, &nzp) == Felt::ONE); } #[test] @@ -1507,4 +1545,60 @@ mod test { ) ); } + + #[test] + fn inverse_and_mul_mod() { + let nzps: Vec = [ + Felt::from(5_i32).try_into().unwrap(), + Felt::from_hex("0x5").unwrap().try_into().unwrap(), + Felt::from_hex("0x1234").unwrap().try_into().unwrap(), + Felt::from_hex("0xabcdef123").unwrap().try_into().unwrap(), + Felt::from_hex("0xffffffffffffff") + .unwrap() + .try_into() + .unwrap(), + Felt::from_hex("0xfffffffffffffffffffffffffffffff") + .unwrap() + .try_into() + .unwrap(), + Felt::MAX.try_into().unwrap(), + ] + .to_vec(); + let nums = [ + Felt::from_hex("0x0").unwrap(), + Felt::from_hex("0x1").unwrap(), + Felt::from_hex("0x2").unwrap(), + Felt::from_hex("0x5").unwrap(), + Felt::from_hex("0x123abc").unwrap(), + Felt::from_hex("0xabcdef9812367312").unwrap(), + Felt::from_hex("0xffffffffffffffffffffffffffffff").unwrap(), + Felt::from_hex("0xffffffffffffffffffffffffffffffffffffffffff").unwrap(), + Felt::MAX, + ]; + + for felt in nums { + for nzp in nzps.iter() { + let result = felt.mod_inverse(nzp); + if result.is_some() { + assert_eq!(result.unwrap().mul_mod(&felt, nzp), Felt::ONE); + } + } + } + } + + #[test] + fn check_mul_mod() { + let x = Felt::from_dec_str( + "3618502788666131213697322783095070105623107215331596699973092056135872020480", + ) + .unwrap(); + let y = Felt::from_dec_str("46118400291").unwrap(); + let p: NonZeroFelt = Felt::from_dec_str("123987312893120893724347692364") + .unwrap() + .try_into() + .unwrap(); + let expected_result = Felt::from_dec_str("68082278891996790254001523512").unwrap(); + + assert_eq!(x.mul_mod(&y, &p), expected_result); + } }