diff --git a/.github/workflows/hax.yml b/.github/workflows/hax.yml index 94932d553..39c5c4267 100644 --- a/.github/workflows/hax.yml +++ b/.github/workflows/hax.yml @@ -65,3 +65,7 @@ jobs: HAX_HOME=${{ github.workspace }}/hax \ PATH="${PATH}:${{ github.workspace }}/fstar/bin" \ ./hax.py prove --admit + + - name: 🏃 Extract ML-DSA crate + working-directory: libcrux-ml-dsa + run: cargo hax into fstar diff --git a/libcrux-intrinsics/src/avx2_extract.rs b/libcrux-intrinsics/src/avx2_extract.rs index f1d42e188..8afb4ab49 100644 --- a/libcrux-intrinsics/src/avx2_extract.rs +++ b/libcrux-intrinsics/src/avx2_extract.rs @@ -5,18 +5,27 @@ pub type Vec256 = u8; pub type Vec128 = u8; +pub type Vec256Float = u8; +pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) { + debug_assert_eq!(output.len(), 32); + unimplemented!() +} pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) { debug_assert_eq!(output.len(), 16); unimplemented!() } - -pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) { - debug_assert_eq!(output.len(), 32); +pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) { + debug_assert_eq!(output.len(), 8); unimplemented!() } + pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) { - // debug_assert_eq!(output.len(), 8); + debug_assert!(output.len() >= 8); + unimplemented!() +} +pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) { + debug_assert_eq!(output.len(), 4); unimplemented!() } @@ -34,15 +43,21 @@ pub fn mm256_loadu_si256_u8(input: &[u8]) -> Vec256 { debug_assert_eq!(input.len(), 32); unimplemented!() } - pub fn mm256_loadu_si256_i16(input: &[i16]) -> Vec256 { debug_assert_eq!(input.len(), 16); unimplemented!() } +pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 { + debug_assert_eq!(input.len(), 8); + unimplemented!() +} pub fn mm256_setzero_si256() -> Vec256 { unimplemented!() } +pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 { + unimplemented!() +} pub fn mm_set_epi8( byte15: u8, @@ -126,13 +141,21 @@ pub fn mm256_set_epi16( unimplemented!() } +#[inline(always)] pub fn mm_set1_epi16(constant: i16) -> Vec128 { unimplemented!() } +#[inline(always)] pub fn mm256_set1_epi32(constant: i32) -> Vec256 { unimplemented!() } + +#[inline(always)] +pub fn mm_set_epi32(input3: i32, input2: i32, input1: i32, input0: i32) -> Vec128 { + unimplemented!() +} +#[inline(always)] pub fn mm256_set_epi32( input7: i32, input6: i32, @@ -146,22 +169,40 @@ pub fn mm256_set_epi32( unimplemented!() } +#[inline(always)] pub fn mm_add_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unimplemented!() } +#[inline(always)] pub fn mm256_add_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] pub fn mm256_madd_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] pub fn mm256_add_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_add_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 { + unimplemented!() +} + +#[inline(always)] +pub fn mm256_abs_epi32(a: Vec256) -> Vec256 { + unimplemented!() +} + pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +pub fn mm256_sub_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { + unimplemented!() +} + pub fn mm_sub_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unimplemented!() } @@ -174,9 +215,33 @@ pub fn mm_mullo_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unimplemented!() } +#[inline(always)] pub fn mm256_cmpgt_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_cmpgt_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { + unimplemented!() +} +#[inline(always)] +pub fn mm256_cmpeq_epi32(a: Vec256, b: Vec256) -> Vec256 { + unimplemented!() +} + +#[inline(always)] +pub fn mm256_sign_epi32(a: Vec256, b: Vec256) -> Vec256 { + unimplemented!() +} + +#[inline(always)] +pub fn mm256_castsi256_ps(a: Vec256) -> Vec256Float { + unimplemented!() +} + +#[inline(always)] +pub fn mm256_movemask_ps(a: Vec256Float) -> i32 { + unimplemented!() +} pub fn mm_mulhi_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 { unimplemented!() @@ -194,10 +259,25 @@ pub fn mm256_mul_epu32(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_mul_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { + unimplemented!() +} + +#[inline(always)] pub fn mm256_and_si256(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_or_si256(a: Vec256, b: Vec256) -> Vec256 { + unimplemented!() +} + +pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 { + unimplemented!() +} + pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 { unimplemented!() } @@ -220,6 +300,10 @@ pub fn mm256_srli_epi32(vector: Vec256) -> Vec256 { unimplemented!() } +pub fn mm_srli_epi64(vector: Vec128) -> Vec128 { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); + unimplemented!() +} pub fn mm256_srli_epi64(vector: Vec256) -> Vec256 { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); unimplemented!() @@ -291,19 +375,47 @@ pub fn mm256_inserti128_si256(vector: Vec256, vector_i128: V unimplemented!() } +#[inline(always)] pub fn mm256_blend_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 { debug_assert!(CONTROL >= 0 && CONTROL < 256); unimplemented!() } +#[inline(always)] +pub fn mm256_blend_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unimplemented!() +} + +// This is essentially _mm256_blendv_ps adapted for use with the Vec256 type. +// It is not offered by the AVX2 instruction set. +#[inline(always)] +pub fn vec256_blendv_epi32(a: Vec256, b: Vec256, mask: Vec256) -> Vec256 { + unimplemented!() +} + +#[inline(always)] pub fn mm_movemask_epi8(vector: Vec128) -> i32 { unimplemented!() } +#[inline(always)] pub fn mm256_permutevar8x32_epi32(vector: Vec256, control: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_srlv_epi32(vector: Vec256, counts: Vec256) -> Vec256 { + unimplemented!() +} +#[inline(always)] +pub fn mm256_srlv_epi64(vector: Vec256, counts: Vec256) -> Vec256 { + unimplemented!() +} + +pub fn mm_sllv_epi32(vector: Vec128, counts: Vec128) -> Vec128 { + unimplemented!() +} pub fn mm256_sllv_epi32(vector: Vec256, counts: Vec256) -> Vec256 { unimplemented!() } @@ -313,6 +425,12 @@ pub fn mm256_slli_epi64(x: Vec256) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_bsrli_epi128(x: Vec256) -> Vec256 { + debug_assert!(SHIFT_BY > 0 && SHIFT_BY < 16); + unimplemented!() +} + #[inline(always)] pub fn mm256_andnot_si256(a: Vec256, b: Vec256) -> Vec256 { unimplemented!() @@ -322,6 +440,10 @@ pub fn mm256_andnot_si256(a: Vec256, b: Vec256) -> Vec256 { pub fn mm256_set1_epi64x(a: i64) -> Vec256 { unimplemented!() } +#[inline(always)] +pub fn mm256_set_epi64x(input3: i64, input2: i64, input1: i64, input0: i64) -> Vec256 { + unimplemented!() +} #[inline(always)] pub fn mm256_unpacklo_epi64(a: Vec256, b: Vec256) -> Vec256 { diff --git a/libcrux-ml-dsa/examples/verify_65.rs b/libcrux-ml-dsa/examples/verify_65.rs index 3ebbd7245..3bc1289f8 100644 --- a/libcrux-ml-dsa/examples/verify_65.rs +++ b/libcrux-ml-dsa/examples/verify_65.rs @@ -14,9 +14,10 @@ fn main() { let message = random_array::<1023>(); let keypair = ml_dsa_65::generate_key_pair(key_generation_seed); - let signature = ml_dsa_65::sign(&keypair.signing_key, &message, signing_randomness); + let signature = ml_dsa_65::sign(&keypair.signing_key, &message, signing_randomness) + .expect("Rejection sampling failure probability is < 2⁻¹²⁸"); for _i in 0..100_000 { - ml_dsa_65::verify(&keypair.verification_key, &message, &signature).unwrap(); + let _ = ml_dsa_65::verify(&keypair.verification_key, &message, &signature); } } diff --git a/libcrux-ml-dsa/src/constants.rs b/libcrux-ml-dsa/src/constants.rs index ac15681aa..f67537ba9 100644 --- a/libcrux-ml-dsa/src/constants.rs +++ b/libcrux-ml-dsa/src/constants.rs @@ -27,3 +27,4 @@ pub(crate) const MESSAGE_REPRESENTATIVE_SIZE: usize = 64; pub(crate) const MASK_SEED_SIZE: usize = 64; pub(crate) const VERIFIER_CHALLENGE_SEED_SIZE: usize = 32; +pub(crate) const REJECTION_SAMPLE_BOUND: usize = 576; diff --git a/libcrux-ml-dsa/src/encoding/commitment.rs b/libcrux-ml-dsa/src/encoding/commitment.rs index 4c6c75990..f5a12e789 100644 --- a/libcrux-ml-dsa/src/encoding/commitment.rs +++ b/libcrux-ml-dsa/src/encoding/commitment.rs @@ -6,7 +6,7 @@ fn serialize( ) -> [u8; OUTPUT_SIZE] { let mut serialized = [0u8; OUTPUT_SIZE]; - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 128 => { // The commitment has coefficients in [0,15] => each coefficient occupies // 4 bits. Each SIMD unit contains 8 elements, which means each diff --git a/libcrux-ml-dsa/src/encoding/error.rs b/libcrux-ml-dsa/src/encoding/error.rs index 8a393d0a9..80080945c 100644 --- a/libcrux-ml-dsa/src/encoding/error.rs +++ b/libcrux-ml-dsa/src/encoding/error.rs @@ -8,7 +8,7 @@ pub(crate) fn serialize [u8; OUTPUT_SIZE] { let mut serialized = [0u8; OUTPUT_SIZE]; - match ETA { + match ETA as u8 { 2 => { const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 3; @@ -41,7 +41,7 @@ pub(crate) fn serialize( serialized: &[u8], ) -> PolynomialRingElement { - let mut serialized_chunks = match ETA { + let mut serialized_chunks = match ETA as u8 { 2 => serialized.chunks(3), 4 => serialized.chunks(4), _ => unreachable!(), diff --git a/libcrux-ml-dsa/src/encoding/gamma1.rs b/libcrux-ml-dsa/src/encoding/gamma1.rs index cd35c3488..09e93f725 100644 --- a/libcrux-ml-dsa/src/encoding/gamma1.rs +++ b/libcrux-ml-dsa/src/encoding/gamma1.rs @@ -10,7 +10,7 @@ pub(crate) fn serialize< ) -> [u8; OUTPUT_BYTES] { let mut serialized = [0u8; OUTPUT_BYTES]; - match GAMMA1_EXPONENT { + match GAMMA1_EXPONENT as u8 { 17 => { const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 18; @@ -43,7 +43,7 @@ pub(crate) fn serialize< pub(crate) fn deserialize( serialized: &[u8], ) -> PolynomialRingElement { - let mut serialized_chunks = match GAMMA1_EXPONENT { + let mut serialized_chunks = match GAMMA1_EXPONENT as u8 { 17 => serialized.chunks(18), 19 => serialized.chunks(20), _ => unreachable!(), diff --git a/libcrux-ml-dsa/src/encoding/signature.rs b/libcrux-ml-dsa/src/encoding/signature.rs index 9930430c3..233f3e224 100644 --- a/libcrux-ml-dsa/src/encoding/signature.rs +++ b/libcrux-ml-dsa/src/encoding/signature.rs @@ -36,16 +36,20 @@ impl< } let mut true_hints_seen = 0; - let hint_serialized = &mut signature[offset..]; + // Unfortunately the following does not go through hax: + // + // let hint_serialized = &mut signature[offset..]; + // + // Instead, we have to mutate signature[offset + ..] directly. for i in 0..ROWS_IN_A { for (j, hint) in self.hint[i].into_iter().enumerate() { if hint == 1 { - hint_serialized[true_hints_seen] = j as u8; + signature[offset + true_hints_seen] = j as u8; true_hints_seen += 1; } } - hint_serialized[MAX_ONES_IN_HINT + i] = true_hints_seen as u8; + signature[offset + MAX_ONES_IN_HINT + i] = true_hints_seen as u8; } signature @@ -80,50 +84,55 @@ impl< let mut previous_true_hints_seen = 0usize; - for i in 0..ROWS_IN_A { + let mut i = 0; + let mut malformed_hint = false; + + while i < ROWS_IN_A && !malformed_hint { let current_true_hints_seen = hint_serialized[MAX_ONES_IN_HINT + i] as usize; if (current_true_hints_seen < previous_true_hints_seen) || (previous_true_hints_seen > MAX_ONES_IN_HINT) { // the true hints seen should be increasing - // - // TODO: This return won't pass through hax; it'll need - // to be rewritten. See https://github.com/cryspen/libcrux/issues/341 - return Err(VerificationError::MalformedHintError); + malformed_hint = true; } - for j in previous_true_hints_seen..current_true_hints_seen { + let mut j = previous_true_hints_seen; + while !malformed_hint && j < current_true_hints_seen { if j > previous_true_hints_seen && hint_serialized[j] <= hint_serialized[j - 1] { // indices of true hints for a specific polynomial should be // increasing - // TODO: This return won't pass through hax; it'll need - // to be rewritten. See https://github.com/cryspen/libcrux/issues/341 - return Err(VerificationError::MalformedHintError); + malformed_hint = true; + } + if !malformed_hint { + hint[i][hint_serialized[j] as usize] = 1; + j += 1; } + } - hint[i][hint_serialized[j] as usize] = 1; + if !malformed_hint { + previous_true_hints_seen = current_true_hints_seen; + i += 1; } - previous_true_hints_seen = current_true_hints_seen; } - for bit in hint_serialized - .iter() - .take(MAX_ONES_IN_HINT) - .skip(previous_true_hints_seen) - { - if *bit != 0 { + i = previous_true_hints_seen; + while i < MAX_ONES_IN_HINT && !malformed_hint { + if hint_serialized[i] != 0 { // ensures padding indices are zero - // TODO: This return won't pass through hax; it'll need - // to be rewritten. See https://github.com/cryspen/libcrux/issues/341 - return Err(VerificationError::MalformedHintError); + malformed_hint = true; } + i += 1; } - Ok(Signature { - commitment_hash: commitment_hash.try_into().unwrap(), - signer_response: signer_response, - hint, - }) + if malformed_hint { + Err(VerificationError::MalformedHintError) + } else { + Ok(Signature { + commitment_hash: commitment_hash.try_into().unwrap(), + signer_response, + hint, + }) + } } } diff --git a/libcrux-ml-dsa/src/lib.rs b/libcrux-ml-dsa/src/lib.rs index 084c097b5..6bd0a6510 100644 --- a/libcrux-ml-dsa/src/lib.rs +++ b/libcrux-ml-dsa/src/lib.rs @@ -17,7 +17,10 @@ mod utils; // Public interface -pub use {ml_dsa_generic::VerificationError, types::*}; +pub use { + ml_dsa_generic::{SigningError, VerificationError}, + types::*, +}; pub use crate::constants::KEY_GENERATION_RANDOMNESS_SIZE; pub use crate::constants::SIGNING_RANDOMNESS_SIZE; diff --git a/libcrux-ml-dsa/src/ml_dsa_44.rs b/libcrux-ml-dsa/src/ml_dsa_44.rs index 2286d163f..62f3ae4d7 100644 --- a/libcrux-ml-dsa/src/ml_dsa_44.rs +++ b/libcrux-ml-dsa/src/ml_dsa_44.rs @@ -2,7 +2,7 @@ use crate::{ constants::*, ml_dsa_generic::{self, multiplexing}, types::*, - VerificationError, + SigningError, VerificationError, }; // ML-DSA-44-specific parameters @@ -97,7 +97,7 @@ macro_rules! instantiate { signing_key: &MLDSA44SigningKey, message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], - ) -> MLDSA44Signature { + ) -> Result { p::sign::< ROWS_IN_A, COLUMNS_IN_A, @@ -183,7 +183,7 @@ pub fn sign( signing_key: &MLDSA44SigningKey, message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], -) -> MLDSA44Signature { +) -> Result { multiplexing::sign::< ROWS_IN_A, COLUMNS_IN_A, diff --git a/libcrux-ml-dsa/src/ml_dsa_65.rs b/libcrux-ml-dsa/src/ml_dsa_65.rs index e5e2de977..03164928b 100644 --- a/libcrux-ml-dsa/src/ml_dsa_65.rs +++ b/libcrux-ml-dsa/src/ml_dsa_65.rs @@ -1,4 +1,4 @@ -use crate::{constants::*, types::*, VerificationError}; +use crate::{constants::*, types::*, SigningError, VerificationError}; // ML-DSA-65-specific parameters @@ -110,7 +110,7 @@ pub fn sign( signing_key: &MLDSA65SigningKey, message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], -) -> MLDSA65Signature { +) -> Result { crate::ml_dsa_generic::sign::< SIMDUnit, Shake128X4, diff --git a/libcrux-ml-dsa/src/ml_dsa_87.rs b/libcrux-ml-dsa/src/ml_dsa_87.rs index 409d6811b..eb02f21bc 100644 --- a/libcrux-ml-dsa/src/ml_dsa_87.rs +++ b/libcrux-ml-dsa/src/ml_dsa_87.rs @@ -1,4 +1,4 @@ -use crate::{constants::*, types::*, VerificationError}; +use crate::{constants::*, types::*, SigningError, VerificationError}; // ML-DSA-87 parameters @@ -113,7 +113,7 @@ pub fn sign( signing_key: &MLDSA87SigningKey, message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], -) -> MLDSA87Signature { +) -> Result { crate::ml_dsa_generic::sign::< SIMDUnit, Shake128X4, diff --git a/libcrux-ml-dsa/src/ml_dsa_generic.rs b/libcrux-ml-dsa/src/ml_dsa_generic.rs index ce6e0a757..7a35ca583 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic.rs @@ -102,6 +102,11 @@ pub enum VerificationError { CommitmentHashesDontMatchError, } +#[derive(Debug)] +pub enum SigningError { + RejectionSamplingError, +} + #[allow(non_snake_case)] pub(crate) fn sign< SIMDUnit: Operations, @@ -126,7 +131,7 @@ pub(crate) fn sign< signing_key: &[u8; SIGNING_KEY_SIZE], message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], -) -> MLDSASignature { +) -> Result, SigningError> { let (seed_for_A, seed_for_signing, verification_key_hash, s1_as_ntt, s2_as_ntt, t0_as_ntt) = encoding::signing_key::deserialize_then_ntt::< SIMDUnit, @@ -166,19 +171,19 @@ pub(crate) fn sign< let mut attempt = 0; - // TODO: This style of rejection sampling, with the break and the continues, - // won't pass through hax; it'll need to be rewritten. - // See https://github.com/cryspen/libcrux/issues/341 - let (commitment_hash, signer_response, hint) = loop { + let mut commitment_hash = None; + let mut signer_response = None; + let mut hint = None; + + // Depending on the mode, one try has a chance between 1/7 and 1/4 + // of succeeding. Thus it is safe to say that 576 + // (REJECTION_SAMPLE_BOUND) iterations are enough as (6/7)⁵⁷⁶ < + // 2⁻¹²⁸[1]. + // + // [1]: https://github.com/cloudflare/circl/blob/main/sign/dilithium/mode2/internal/dilithium.go#L341 + while attempt < REJECTION_SAMPLE_BOUND { attempt += 1; - // Depending on the mode, one try has a chance between 1/7 and 1/4 - // of succeeding. Thus it is safe to say that 576 iterations - // are enough as (6/7)⁵⁷⁶ < 2⁻¹²⁸[1]. - // - // [1]: https://github.com/cloudflare/circl/blob/main/sign/dilithium/mode2/internal/dilithium.go#L341 - debug_assert!(attempt < 576); - let mask = sample_mask_vector::( into_padded_array(&mask_seed), @@ -190,7 +195,7 @@ pub(crate) fn sign< let (w0, commitment) = decompose_vector::(A_times_mask); - let mut commitment_hash = [0; COMMITMENT_HASH_SIZE]; + let mut commitment_hash_candidate = [0; COMMITMENT_HASH_SIZE]; { let commitment_serialized = encoding::commitment::serialize_vector::< SIMDUnit, @@ -203,7 +208,7 @@ pub(crate) fn sign< shake.absorb(&message_representative); let mut shake = shake.absorb_final(&commitment_serialized); - shake.squeeze(&mut commitment_hash); + shake.squeeze(&mut commitment_hash_candidate); } let verifier_challenge_as_ntt = ntt(sample_challenge_ring_element::< @@ -211,7 +216,7 @@ pub(crate) fn sign< Shake256, ONES_IN_VERIFIER_CHALLENGE, >( - commitment_hash[0..VERIFIER_CHALLENGE_SEED_SIZE] + commitment_hash_candidate[0..VERIFIER_CHALLENGE_SEED_SIZE] .try_into() .unwrap(), )); @@ -225,44 +230,63 @@ pub(crate) fn sign< &verifier_challenge_as_ntt, ); - let signer_response = add_vectors::(&mask, &challenge_times_s1); + let signer_response_candidate = + add_vectors::(&mask, &challenge_times_s1); let w0_minus_challenge_times_s2 = subtract_vectors::(&w0, &challenge_times_s2); if vector_infinity_norm_exceeds::( - signer_response, + signer_response_candidate, (1 << GAMMA1_EXPONENT) - BETA, ) { - continue; - } - if vector_infinity_norm_exceeds::( - w0_minus_challenge_times_s2, - GAMMA2 - BETA, - ) { - continue; + } else { + if vector_infinity_norm_exceeds::( + w0_minus_challenge_times_s2, + GAMMA2 - BETA, + ) { + } else { + let challenge_times_t0 = vector_times_ring_element::( + &t0_as_ntt, + &verifier_challenge_as_ntt, + ); + if vector_infinity_norm_exceeds::(challenge_times_t0, GAMMA2) { + } else { + let w0_minus_c_times_s2_plus_c_times_t0 = add_vectors::( + &w0_minus_challenge_times_s2, + &challenge_times_t0, + ); + let (hint_candidate, ones_in_hint) = make_hint::( + w0_minus_c_times_s2_plus_c_times_t0, + commitment, + ); + + if ones_in_hint > MAX_ONES_IN_HINT { + } else { + attempt = REJECTION_SAMPLE_BOUND; // exit loop now + commitment_hash = Some(commitment_hash_candidate); + signer_response = Some(signer_response_candidate); + hint = Some(hint_candidate); + } + } + } } + } - let challenge_times_t0 = vector_times_ring_element::( - &t0_as_ntt, - &verifier_challenge_as_ntt, - ); - if vector_infinity_norm_exceeds::(challenge_times_t0, GAMMA2) { - continue; - } + let commitment_hash = match commitment_hash { + Some(commitment_hash) => Ok(commitment_hash), + None => Err(SigningError::RejectionSamplingError), + }?; - let w0_minus_c_times_s2_plus_c_times_t0 = - add_vectors::(&w0_minus_challenge_times_s2, &challenge_times_t0); - let (hint, ones_in_hint) = make_hint::( - w0_minus_c_times_s2_plus_c_times_t0, - commitment, - ); - if ones_in_hint > MAX_ONES_IN_HINT { - continue; - } + let signer_response = match signer_response { + Some(signer_response) => Ok(signer_response), + None => Err(SigningError::RejectionSamplingError), + }?; - break (commitment_hash, signer_response, hint); - }; + let hint = match hint { + Some(hint) => Ok(hint), + None => Err(SigningError::RejectionSamplingError), + }?; let signature = Signature:: { commitment_hash, @@ -271,7 +295,7 @@ pub(crate) fn sign< } .serialize::(); - MLDSASignature(signature) + Ok(MLDSASignature(signature)) } #[allow(non_snake_case)] diff --git a/libcrux-ml-dsa/src/ml_dsa_generic/instantiations.rs b/libcrux-ml-dsa/src/ml_dsa_generic/instantiations.rs index 985c6e231..6149e50f3 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic/instantiations.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic/instantiations.rs @@ -1,7 +1,11 @@ macro_rules! instantiate { ($modp:ident, $simdunit:path, $shake128x4:path, $shake256:path, $shake256x4:path) => { pub mod $modp { - use crate::{constants::*, ml_dsa_generic::VerificationError, types::*}; + use crate::{ + constants::*, + ml_dsa_generic::{SigningError, VerificationError}, + types::*, + }; /// Generate key pair. pub(crate) fn generate_key_pair< @@ -48,7 +52,7 @@ macro_rules! instantiate { signing_key: &[u8; SIGNING_KEY_SIZE], message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], - ) -> MLDSASignature { + ) -> Result, SigningError> { crate::ml_dsa_generic::sign::< $simdunit, $shake128x4, diff --git a/libcrux-ml-dsa/src/ml_dsa_generic/multiplexing.rs b/libcrux-ml-dsa/src/ml_dsa_generic/multiplexing.rs index 45519e97c..e0d62ad8d 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic/multiplexing.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic/multiplexing.rs @@ -84,7 +84,7 @@ pub(crate) fn sign< signing_key: &[u8; SIGNING_KEY_SIZE], message: &[u8], randomness: [u8; SIGNING_RANDOMNESS_SIZE], -) -> MLDSASignature { +) -> Result, SigningError> { if libcrux_platform::simd256_support() { sign_avx2::< ROWS_IN_A, diff --git a/libcrux-ml-dsa/src/sample.rs b/libcrux-ml-dsa/src/sample.rs index a2bfcb198..eb2a11e40 100644 --- a/libcrux-ml-dsa/src/sample.rs +++ b/libcrux-ml-dsa/src/sample.rs @@ -210,7 +210,7 @@ pub(crate) fn rejection_sample_less_than_eta bool { - match ETA { + match ETA as u8 { 2 => rejection_sample_less_than_eta_equals_2::(randomness, sampled, out), 4 => rejection_sample_less_than_eta_equals_4::(randomness, sampled, out), _ => unreachable!(), @@ -337,7 +337,7 @@ fn sample_mask_ring_element< >( seed: [u8; 66], ) -> PolynomialRingElement { - match GAMMA1_EXPONENT { + match GAMMA1_EXPONENT as u8 { 17 => { let mut out = [0u8; 576]; Shake256::shake256::<576>(&seed, &mut out); @@ -374,7 +374,7 @@ pub(crate) fn sample_mask_vector< let seed2 = update_seed(seed, domain_separator); let seed3 = update_seed(seed, domain_separator); - match GAMMA1_EXPONENT { + match GAMMA1_EXPONENT as u8 { 17 => { let mut out0 = [0; 576]; let mut out1 = [0; 576]; diff --git a/libcrux-ml-dsa/src/samplex4.rs b/libcrux-ml-dsa/src/samplex4.rs index b8bfd5535..1173c0abf 100644 --- a/libcrux-ml-dsa/src/samplex4.rs +++ b/libcrux-ml-dsa/src/samplex4.rs @@ -374,7 +374,7 @@ pub(crate) fn matrix_A< >( seed: [u8; 34], ) -> [[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A] { - match (ROWS_IN_A, COLUMNS_IN_A) { + match (ROWS_IN_A as u8, COLUMNS_IN_A as u8) { (4, 4) => matrix_A_4_by_4::(seed), (6, 5) => matrix_A_6_by_5::(seed), (8, 7) => matrix_A_8_by_7::(seed), @@ -504,7 +504,7 @@ pub(crate) fn sample_s1_and_s2< [PolynomialRingElement; S1_DIMENSION], [PolynomialRingElement; S2_DIMENSION], ) { - match (S1_DIMENSION, S2_DIMENSION) { + match (S1_DIMENSION as u8, S2_DIMENSION as u8) { (4, 4) => { sample_s1_and_s2_4_by_4::(seed) } diff --git a/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs b/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs index 033fcf05e..c8a3e40a1 100644 --- a/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs +++ b/libcrux-ml-dsa/src/simd/avx2/encoding/commitment.rs @@ -4,7 +4,7 @@ use libcrux_intrinsics::avx2::*; pub fn serialize(simd_unit: Vec256) -> [u8; OUTPUT_SIZE] { let mut serialized = [0u8; 19]; - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 4 => { let adjacent_2_combined = mm256_sllv_epi32(simd_unit, mm256_set_epi32(0, 28, 0, 28, 0, 28, 0, 28)); diff --git a/libcrux-ml-dsa/src/simd/avx2/encoding/error.rs b/libcrux-ml-dsa/src/simd/avx2/encoding/error.rs index 828bc2135..0d9095166 100644 --- a/libcrux-ml-dsa/src/simd/avx2/encoding/error.rs +++ b/libcrux-ml-dsa/src/simd/avx2/encoding/error.rs @@ -65,7 +65,7 @@ fn serialize_when_eta_is_4(simd_unit: Vec256) -> [u8; } #[inline(always)] pub fn serialize(simd_unit: Vec256) -> [u8; OUTPUT_SIZE] { - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 3 => serialize_when_eta_is_2::(simd_unit), 4 => serialize_when_eta_is_4::(simd_unit), _ => unreachable!(), @@ -118,7 +118,7 @@ fn deserialize_to_unsigned_when_eta_is_4(bytes: &[u8]) -> Vec256 { } #[inline(always)] pub(crate) fn deserialize_to_unsigned(serialized: &[u8]) -> Vec256 { - match ETA { + match ETA as u8 { 2 => deserialize_to_unsigned_when_eta_is_2(serialized), 4 => deserialize_to_unsigned_when_eta_is_4(serialized), _ => unreachable!(), diff --git a/libcrux-ml-dsa/src/simd/avx2/encoding/gamma1.rs b/libcrux-ml-dsa/src/simd/avx2/encoding/gamma1.rs index 38296f225..80b666707 100644 --- a/libcrux-ml-dsa/src/simd/avx2/encoding/gamma1.rs +++ b/libcrux-ml-dsa/src/simd/avx2/encoding/gamma1.rs @@ -66,7 +66,7 @@ fn serialize_when_gamma1_is_2_pow_19( #[inline(always)] pub(crate) fn serialize(simd_unit: Vec256) -> [u8; OUTPUT_SIZE] { - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 18 => serialize_when_gamma1_is_2_pow_17::(simd_unit), 20 => serialize_when_gamma1_is_2_pow_19::(simd_unit), _ => unreachable!(), @@ -130,7 +130,7 @@ fn deserialize_when_gamma1_is_2_pow_19(serialized: &[u8]) -> Vec256 { #[inline(always)] pub(crate) fn deserialize(serialized: &[u8]) -> Vec256 { - match GAMMA1_EXPONENT { + match GAMMA1_EXPONENT as u8 { 17 => deserialize_when_gamma1_is_2_pow_17(serialized), 19 => deserialize_when_gamma1_is_2_pow_19(serialized), _ => unreachable!(), diff --git a/libcrux-ml-dsa/src/simd/avx2/rejection_sample/less_than_eta.rs b/libcrux-ml-dsa/src/simd/avx2/rejection_sample/less_than_eta.rs index cd0f5b05d..052a6b855 100644 --- a/libcrux-ml-dsa/src/simd/avx2/rejection_sample/less_than_eta.rs +++ b/libcrux-ml-dsa/src/simd/avx2/rejection_sample/less_than_eta.rs @@ -7,7 +7,7 @@ use libcrux_intrinsics::avx2::*; #[inline(always)] fn shift_interval(coefficients: Vec256) -> Vec256 { - match ETA { + match ETA as u8 { 2 => { let quotient = mm256_mullo_epi32(coefficients, mm256_set1_epi32(26)); let quotient = mm256_srai_epi32::<7>(quotient); @@ -29,7 +29,7 @@ pub(crate) fn sample(input: &[u8], output: &mut [i32]) -> usiz // values that are 4-bits wide. let potential_coefficients = encoding::error::deserialize_to_unsigned::<4>(input); - let interval_boundary: i32 = match ETA { + let interval_boundary: i32 = match ETA as u8 { 2 => 15, 4 => 9, _ => unreachable!(), diff --git a/libcrux-ml-dsa/src/simd/portable/encoding/commitment.rs b/libcrux-ml-dsa/src/simd/portable/encoding/commitment.rs index 3c6462c20..c6886ba50 100644 --- a/libcrux-ml-dsa/src/simd/portable/encoding/commitment.rs +++ b/libcrux-ml-dsa/src/simd/portable/encoding/commitment.rs @@ -4,7 +4,7 @@ use crate::simd::portable::PortableSIMDUnit; pub fn serialize(simd_unit: PortableSIMDUnit) -> [u8; OUTPUT_SIZE] { let mut serialized = [0u8; OUTPUT_SIZE]; - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 4 => { // The commitment has coefficients in [0,15] => each coefficient occupies // 4 bits. diff --git a/libcrux-ml-dsa/src/simd/portable/encoding/error.rs b/libcrux-ml-dsa/src/simd/portable/encoding/error.rs index f9d1ae7fb..d7878fbc8 100644 --- a/libcrux-ml-dsa/src/simd/portable/encoding/error.rs +++ b/libcrux-ml-dsa/src/simd/portable/encoding/error.rs @@ -43,7 +43,7 @@ fn serialize_when_eta_is_4( pub(crate) fn serialize( simd_unit: PortableSIMDUnit, ) -> [u8; OUTPUT_SIZE] { - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 3 => serialize_when_eta_is_2::(simd_unit), 4 => serialize_when_eta_is_4::(simd_unit), _ => unreachable!(), @@ -88,7 +88,7 @@ fn deserialize_when_eta_is_4(serialized: &[u8]) -> PortableSIMDUnit { } #[inline(always)] pub(crate) fn deserialize(serialized: &[u8]) -> PortableSIMDUnit { - match ETA { + match ETA as u8 { 2 => deserialize_when_eta_is_2(serialized), 4 => deserialize_when_eta_is_4(serialized), _ => unreachable!(), diff --git a/libcrux-ml-dsa/src/simd/portable/encoding/gamma1.rs b/libcrux-ml-dsa/src/simd/portable/encoding/gamma1.rs index 67899ae72..eabb2fd81 100644 --- a/libcrux-ml-dsa/src/simd/portable/encoding/gamma1.rs +++ b/libcrux-ml-dsa/src/simd/portable/encoding/gamma1.rs @@ -65,7 +65,7 @@ fn serialize_when_gamma1_is_2_pow_19( pub(crate) fn serialize( simd_unit: PortableSIMDUnit, ) -> [u8; OUTPUT_SIZE] { - match OUTPUT_SIZE { + match OUTPUT_SIZE as u8 { 18 => serialize_when_gamma1_is_2_pow_17::(simd_unit), 20 => serialize_when_gamma1_is_2_pow_19::(simd_unit), _ => unreachable!(), @@ -141,7 +141,7 @@ fn deserialize_when_gamma1_is_2_pow_19(serialized: &[u8]) -> PortableSIMDUnit { } #[inline(always)] pub(crate) fn deserialize(serialized: &[u8]) -> PortableSIMDUnit { - match GAMMA1_EXPONENT { + match GAMMA1_EXPONENT as u8 { 17 => deserialize_when_gamma1_is_2_pow_17(serialized), 19 => deserialize_when_gamma1_is_2_pow_19(serialized), _ => unreachable!(), diff --git a/libcrux-ml-dsa/tests/nistkats.rs b/libcrux-ml-dsa/tests/nistkats.rs index e37749835..78fe8dbec 100644 --- a/libcrux-ml-dsa/tests/nistkats.rs +++ b/libcrux-ml-dsa/tests/nistkats.rs @@ -57,7 +57,8 @@ macro_rules! impl_nist_known_answer_tests { let message = hex::decode(kat.message).expect("Hex-decoding the message failed."); - let signature = $sign(&key_pair.signing_key, &message, kat.signing_randomness); + let signature = $sign(&key_pair.signing_key, &message, kat.signing_randomness) + .expect("Rejection sampling failure probability is < 2⁻¹²⁸"); let signature_hash = libcrux_sha3::sha256(&signature.0); assert_eq!( diff --git a/libcrux-ml-dsa/tests/self.rs b/libcrux-ml-dsa/tests/self.rs index 8eea70f26..24faa939a 100644 --- a/libcrux-ml-dsa/tests/self.rs +++ b/libcrux-ml-dsa/tests/self.rs @@ -59,7 +59,8 @@ macro_rules! impl_consistency_test { let key_pair = $key_gen(key_generation_seed); - let signature = $sign(&key_pair.signing_key, &message, signing_randomness); + let signature = $sign(&key_pair.signing_key, &message, signing_randomness) + .expect("Rejection sampling failure probability is < 2⁻¹²⁸"); $verify(&key_pair.verification_key, &message, &signature) .expect("Verification should pass since the signature was honestly generated"); @@ -80,7 +81,8 @@ macro_rules! impl_modified_signing_key_test { modify_signing_key::<{ $signing_key_size }>(&mut key_pair.signing_key.0); - let signature = $sign(&key_pair.signing_key, &message, signing_randomness); + let signature = $sign(&key_pair.signing_key, &message, signing_randomness) + .expect("Rejection sampling failure probability is < 2⁻¹²⁸"); assert!($verify(&key_pair.verification_key, &message, &signature).is_err()); } diff --git a/libcrux-ml-dsa/tests/wycheproof_sign.rs b/libcrux-ml-dsa/tests/wycheproof_sign.rs index 53fdc78a0..144de6563 100644 --- a/libcrux-ml-dsa/tests/wycheproof_sign.rs +++ b/libcrux-ml-dsa/tests/wycheproof_sign.rs @@ -48,7 +48,8 @@ macro_rules! wycheproof_sign_test { for test in test_group.tests { let message = hex::decode(test.msg).unwrap(); - let signature = $sign(&signing_key, &message, signing_randomness); + let signature = $sign(&signing_key, &message, signing_randomness) + .expect("Rejection sampling failure probability is < 2⁻¹²⁸"); if test.result == Result::Valid { assert_eq!(