diff --git a/p256/src/arithmetic/scalar.rs b/p256/src/arithmetic/scalar.rs index 3a766cb4c..3d741bdeb 100644 --- a/p256/src/arithmetic/scalar.rs +++ b/p256/src/arithmetic/scalar.rs @@ -13,7 +13,7 @@ use core::{ }; use elliptic_curve::{ Curve, - bigint::{Limb, U256, prelude::*}, + bigint::{Limb, U256, U384, U512, prelude::*}, group::ff::{self, Field, PrimeField}, ops::{Invert, Reduce, ReduceNonZero}, rand_core::TryRngCore, @@ -635,6 +635,20 @@ impl Reduce for Scalar { } } +impl Reduce for Scalar { + fn reduce(w: &U512) -> Self { + let (lo, hi) = w.split(); + let w_reduced = barrett_reduce(lo, hi); + Self(w_reduced) + } +} + +impl Reduce for Scalar { + fn reduce(w: &U384) -> Self { + >::reduce(&w.concat(&U128::ZERO)) + } +} + impl ReduceNonZero for Scalar { fn reduce_nonzero(w: &U256) -> Self { const ORDER_MINUS_ONE: U256 = NistP256::ORDER.wrapping_sub(&U256::ONE); @@ -651,6 +665,29 @@ impl ReduceNonZero for Scalar { } } +impl ReduceNonZero for Scalar { + fn reduce_nonzero(w: &U512) -> Self { + // Reduce U512 to U256 first, then apply non-zero reduction + let reduced = >::reduce(w); + const ORDER_MINUS_ONE: U256 = NistP256::ORDER.wrapping_sub(&U256::ONE); + let (r, underflow) = reduced.0.borrowing_sub(&ORDER_MINUS_ONE, Limb::ZERO); + let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8); + Self(U256::conditional_select(&reduced.0, &r, !underflow).wrapping_add(&U256::ONE)) + } +} + +impl ReduceNonZero for Scalar { + fn reduce_nonzero(w: &U384) -> Self { + // Convert U384 to U512 by zero-padding the high bits + let w_bytes = w.to_be_bytes(); + let mut w512_bytes = [0u8; 64]; + // Copy U384 bytes to the lower part of U512 (384 bits = 48 bytes) + w512_bytes[16..64].copy_from_slice(&w_bytes); + let w512 = U512::from_be_byte_array(w512_bytes.into()); + >::reduce_nonzero(&w512) + } +} + impl Sum for Scalar { fn sum>(iter: I) -> Self { iter.reduce(Add::add).unwrap_or(Self::ZERO) @@ -719,7 +756,7 @@ mod tests { use crate::{FieldBytes, NistP256, NonZeroScalar, SecretKey}; use elliptic_curve::{ Curve, - array::Array, + bigint::ArrayEncoding, group::ff::{Field, PrimeField}, ops::{BatchInvert, ReduceNonZero}, }; @@ -785,40 +822,102 @@ mod tests { #[test] fn reduce_nonzero() { - assert_eq!(Scalar::reduce_nonzero(&Array::default()).0, U256::ONE,); - assert_eq!(Scalar::reduce_nonzero(&U256::ONE).0, U256::from_u8(2),); assert_eq!( - Scalar::reduce_nonzero(&U256::from_u8(2)).0, + >::reduce_nonzero(&U256::ZERO).0, + U256::ONE, + ); + assert_eq!( + >::reduce_nonzero(&U256::ONE).0, + U256::from_u8(2), + ); + assert_eq!( + >::reduce_nonzero(&U256::from_u8(2)).0, U256::from_u8(3), ); - assert_eq!(Scalar::reduce_nonzero(&NistP256::ORDER).0, U256::from_u8(2),); assert_eq!( - Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(1))).0, + >::reduce_nonzero(&NistP256::ORDER).0, + U256::from_u8(2), + ); + assert_eq!( + >::reduce_nonzero( + &NistP256::ORDER.wrapping_sub(&U256::from_u8(1)) + ) + .0, U256::ONE, ); assert_eq!( - Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(2))).0, + >::reduce_nonzero( + &NistP256::ORDER.wrapping_sub(&U256::from_u8(2)) + ) + .0, NistP256::ORDER.wrapping_sub(&U256::ONE), ); assert_eq!( - Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_sub(&U256::from_u8(3))).0, + >::reduce_nonzero( + &NistP256::ORDER.wrapping_sub(&U256::from_u8(3)) + ) + .0, NistP256::ORDER.wrapping_sub(&U256::from_u8(2)), ); assert_eq!( - Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_add(&U256::ONE)).0, + >::reduce_nonzero( + &NistP256::ORDER.wrapping_add(&U256::ONE) + ) + .0, U256::from_u8(3), ); assert_eq!( - Scalar::reduce_nonzero(&NistP256::ORDER.wrapping_add(&U256::from_u8(2))).0, + >::reduce_nonzero( + &NistP256::ORDER.wrapping_add(&U256::from_u8(2)) + ) + .0, U256::from_u8(4), ); } + #[test] + fn reduce_nonzero_u384() { + use elliptic_curve::bigint::{ArrayEncoding, U384}; + + // Test with 48 zero bytes (384 bits) + let zero_u384 = U384::ZERO; + assert_eq!( + >::reduce_nonzero(&zero_u384).0, + U256::ONE + ); + + // Test with small values + let mut bytes = [0u8; 48]; + bytes[47] = 1; // Set the least significant byte to 1 + let u384_val = U384::from_be_byte_array(bytes.into()); + assert_eq!( + >::reduce_nonzero(&u384_val).0, + U256::from_u8(2) + ); + + bytes[47] = 2; + let u384_val2 = U384::from_be_byte_array(bytes.into()); + assert_eq!( + >::reduce_nonzero(&u384_val2).0, + U256::from_u8(3) + ); + + // Test with a value that spans the full 384 bits + let large_value = U384::from_be_hex( + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + ); + let reduced = >::reduce_nonzero(&large_value); + // The result should be non-zero and within the field + assert_ne!(reduced.0, U256::ZERO); + assert!(reduced.0 < NistP256::ORDER); + } + prop_compose! { fn non_zero_scalar()(bytes in any::<[u8; 32]>()) -> NonZeroScalar { - NonZeroScalar::reduce_nonzero(&FieldBytes::from(bytes)) + let uint = U256::from_be_byte_array(bytes.into()); + >::reduce_nonzero(&uint) } }