diff --git a/libcrux-ml-dsa/src/arithmetic.rs b/libcrux-ml-dsa/src/arithmetic.rs index 250ac6895..7417646a6 100644 --- a/libcrux-ml-dsa/src/arithmetic.rs +++ b/libcrux-ml-dsa/src/arithmetic.rs @@ -43,10 +43,6 @@ pub(crate) type FieldElementTimesMontgomeryR = i32; const MONTGOMERY_SHIFT: u8 = 32; const INVERSE_OF_MODULUS_MOD_MONTGOMERY_R: u64 = 58_728_449; // FIELD_MODULUS^{-1} mod 2^32 -/// This is calculated as (MONTGOMERY_R)^2 mod FIELD_MODULUS -/// where MONTGOMERY_R = 1 << MONTGOMERY_SHIFT -const MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS: i32 = 66; - pub(crate) fn montgomery_reduce(value: i64) -> MontgomeryFieldElement { let t = get_n_least_significant_bits(MONTGOMERY_SHIFT, value as u64) * INVERSE_OF_MODULUS_MOD_MONTGOMERY_R; @@ -79,7 +75,7 @@ pub(crate) fn montgomery_multiply_fe_by_fer( // // This approach has been taken from: // https://github.com/cloudflare/circl/blob/main/sign/dilithium/internal/common/field.go#L35 -fn power2round(t: i32) -> (i32, i32) { +pub(crate) fn power2round(t: i32) -> (i32, i32) { // -floor(N / 2) = -4,190,208 // floor((N - 1) / 2) = 4,190,208 debug_assert!(t >= -4_190_208 && t <= 4_190_208); @@ -107,27 +103,6 @@ fn power2round(t: i32) -> (i32, i32) { (t0, t1) } -pub(crate) fn power2round_vector( - t: [PolynomialRingElement; ROWS_IN_A], -) -> ( - [PolynomialRingElement; ROWS_IN_A], - [PolynomialRingElement; ROWS_IN_A], -) { - let mut vector_t0 = [PolynomialRingElement::ZERO; ROWS_IN_A]; - let mut vector_t1 = [PolynomialRingElement::ZERO; ROWS_IN_A]; - - for i in 0..ROWS_IN_A { - for (j, coefficient) in t[i].coefficients.into_iter().enumerate() { - let (c0, c1) = power2round(coefficient); - - vector_t0[i].coefficients[j] = c0; - vector_t1[i].coefficients[j] = c1; - } - } - - (vector_t0, vector_t1) -} - pub(crate) fn t0_to_unsigned_representative(t0: i32) -> i32 { (1 << (BITS_IN_LOWER_PART_OF_T - 1)) - t0 } diff --git a/libcrux-ml-dsa/src/constants.rs b/libcrux-ml-dsa/src/constants.rs index 3d601edf9..da31927d9 100644 --- a/libcrux-ml-dsa/src/constants.rs +++ b/libcrux-ml-dsa/src/constants.rs @@ -5,13 +5,15 @@ pub(crate) const COEFFICIENTS_IN_RING_ELEMENT: usize = 256; pub(crate) const FIELD_MODULUS_MINUS_ONE_BIT_LENGTH: usize = 23; pub(crate) const BITS_IN_LOWER_PART_OF_T: usize = 13; +pub(crate) const BYTES_FOR_RING_ELEMENT_OF_T0S: usize = + (BITS_IN_LOWER_PART_OF_T * COEFFICIENTS_IN_RING_ELEMENT) / 8; pub(crate) const BITS_IN_UPPER_PART_OF_T: usize = FIELD_MODULUS_MINUS_ONE_BIT_LENGTH - BITS_IN_LOWER_PART_OF_T; -pub(crate) const BYTES_IN_RING_ELEMENT_OF_T1S: usize = +pub(crate) const BYTES_FOR_RING_ELEMENT_OF_T1S: usize = (BITS_IN_UPPER_PART_OF_T * COEFFICIENTS_IN_RING_ELEMENT) / 8; pub(crate) const SEED_FOR_A_SIZE: usize = 32; pub(crate) const SEED_FOR_ERROR_VECTORS_SIZE: usize = 64; -pub(crate) const HASH_OF_PUBLIC_KEY_SIZE: usize = 64; +pub(crate) const BYTES_FOR_VERIFICATION_KEY_HASH: usize = 64; pub(crate) const SEED_FOR_SIGNING_SIZE: usize = 32; diff --git a/libcrux-ml-dsa/src/matrix.rs b/libcrux-ml-dsa/src/matrix.rs index 59ae89767..33451c0fe 100644 --- a/libcrux-ml-dsa/src/matrix.rs +++ b/libcrux-ml-dsa/src/matrix.rs @@ -1,9 +1,30 @@ use crate::{ - arithmetic::{add_to_ring_element, PolynomialRingElement}, + arithmetic::{add_to_ring_element, power2round, PolynomialRingElement}, ntt::{invert_ntt_montgomery, ntt, ntt_multiply_montgomery}, sample::{sample_error_ring_element_uniform, sample_ring_element_uniform}, }; +pub(crate) fn power2round_vector( + t: [PolynomialRingElement; ROWS_IN_A], +) -> ( + [PolynomialRingElement; ROWS_IN_A], + [PolynomialRingElement; ROWS_IN_A], +) { + let mut vector_t0 = [PolynomialRingElement::ZERO; ROWS_IN_A]; + let mut vector_t1 = [PolynomialRingElement::ZERO; ROWS_IN_A]; + + for i in 0..ROWS_IN_A { + for (j, coefficient) in t[i].coefficients.into_iter().enumerate() { + let (c0, c1) = power2round(coefficient); + + vector_t0[i].coefficients[j] = c0; + vector_t1[i].coefficients[j] = c1; + } + } + + (vector_t0, vector_t1) +} + #[inline(always)] pub(crate) fn sample_error_vector( mut seed: [u8; 66], @@ -25,7 +46,6 @@ pub(crate) fn sample_error_vector( #[inline(always)] pub(crate) fn expand_to_A( mut seed: [u8; 34], - transposed: bool, ) -> [[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A] { let mut A = [[PolynomialRingElement::ZERO; COLUMNS_IN_A]; ROWS_IN_A]; @@ -34,13 +54,7 @@ pub(crate) fn expand_to_A( seed[32] = j as u8; seed[33] = i as u8; - let sampled = sample_ring_element_uniform(seed); - - if transposed { - A[j][i] = sampled; - } else { - A[i][j] = sampled; - } + A[i][j] = sample_ring_element_uniform(seed); } } diff --git a/libcrux-ml-dsa/src/ml_dsa_65.rs b/libcrux-ml-dsa/src/ml_dsa_65.rs index b4dbb3e0a..fc50abad0 100644 --- a/libcrux-ml-dsa/src/ml_dsa_65.rs +++ b/libcrux-ml-dsa/src/ml_dsa_65.rs @@ -8,17 +8,20 @@ const COLUMNS_IN_A: usize = 5; const ETA: usize = 4; const TWO_TIMES_ETA_BIT_SIZE: usize = 4; // ⌊log_2(4)⌋ + 1 +const BYTES_FOR_ERROR_RING_ELEMENT: usize = + (TWO_TIMES_ETA_BIT_SIZE * COEFFICIENTS_IN_RING_ELEMENT) / 8; + const VERIFICATION_KEY_SIZE: usize = SEED_FOR_A_SIZE + (COEFFICIENTS_IN_RING_ELEMENT * ROWS_IN_A * (FIELD_MODULUS_MINUS_ONE_BIT_LENGTH - BITS_IN_LOWER_PART_OF_T)) / 8; -const SIGNING_KEY_SIZE: usize = (SEED_FOR_A_SIZE + SEED_FOR_SIGNING_SIZE + HASH_OF_PUBLIC_KEY_SIZE) - + (COEFFICIENTS_IN_RING_ELEMENT - * (((ROWS_IN_A + COLUMNS_IN_A) * TWO_TIMES_ETA_BIT_SIZE) - + (BITS_IN_LOWER_PART_OF_T * ROWS_IN_A))) - / 8; +const SIGNING_KEY_SIZE: usize = SEED_FOR_A_SIZE + + SEED_FOR_SIGNING_SIZE + + BYTES_FOR_VERIFICATION_KEY_HASH + + (ROWS_IN_A + COLUMNS_IN_A) * BYTES_FOR_ERROR_RING_ELEMENT + + ROWS_IN_A * BYTES_FOR_RING_ELEMENT_OF_T0S; pub struct MLDSA65KeyPair { pub signing_key: [u8; SIGNING_KEY_SIZE], @@ -31,6 +34,7 @@ pub fn generate_key_pair(randomness: [u8; 32]) -> MLDSA65KeyPair { ROWS_IN_A, COLUMNS_IN_A, ETA, + BYTES_FOR_ERROR_RING_ELEMENT, SIGNING_KEY_SIZE, VERIFICATION_KEY_SIZE, >(randomness); diff --git a/libcrux-ml-dsa/src/ml_dsa_generic.rs b/libcrux-ml-dsa/src/ml_dsa_generic.rs index a4e7fb972..14bf832e1 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic.rs @@ -1,15 +1,21 @@ use crate::{ - arithmetic::{power2round_vector, PolynomialRingElement}, - constants::{BYTES_IN_RING_ELEMENT_OF_T1S, SEED_FOR_A_SIZE, SEED_FOR_ERROR_VECTORS_SIZE}, + arithmetic::PolynomialRingElement, + constants::{ + BYTES_FOR_RING_ELEMENT_OF_T0S, BYTES_FOR_RING_ELEMENT_OF_T1S, + BYTES_FOR_VERIFICATION_KEY_HASH, SEED_FOR_A_SIZE, SEED_FOR_ERROR_VECTORS_SIZE, + SEED_FOR_SIGNING_SIZE, + }, hash_functions::H, - matrix::{compute_As1_plus_s2, expand_to_A, sample_error_vector}, - serialize::serialize_ring_element_of_t1s, + matrix::{compute_As1_plus_s2, expand_to_A, power2round_vector, sample_error_vector}, + serialize::{ + serialize_error_ring_element, serialize_ring_element_of_t0s, serialize_ring_element_of_t1s, + }, utils::into_padded_array, }; #[allow(non_snake_case)] #[inline(always)] -pub(super) fn serialize_verification_key< +pub(super) fn generate_serialized_verification_key< const ROWS_IN_A: usize, const VERIFICATION_KEY_SIZE: usize, >( @@ -20,32 +26,87 @@ pub(super) fn serialize_verification_key< verification_key_serialized[0..SEED_FOR_A_SIZE].copy_from_slice(&seed_for_A); for i in 0..ROWS_IN_A { - let offset = SEED_FOR_A_SIZE + (i * BYTES_IN_RING_ELEMENT_OF_T1S); - verification_key_serialized[offset..offset + BYTES_IN_RING_ELEMENT_OF_T1S] + let offset = SEED_FOR_A_SIZE + (i * BYTES_FOR_RING_ELEMENT_OF_T1S); + verification_key_serialized[offset..offset + BYTES_FOR_RING_ELEMENT_OF_T1S] .copy_from_slice(&serialize_ring_element_of_t1s(t1[i])); } verification_key_serialized } +#[allow(non_snake_case)] +#[inline(always)] +pub(super) fn generate_serialized_signing_key< + const ROWS_IN_A: usize, + const COLUMNS_IN_A: usize, + const ETA: usize, + const BYTES_FOR_ERROR_RING_ELEMENT: usize, + const SIGNING_KEY_SIZE: usize, +>( + seed_for_A: &[u8], + seed_for_signing: &[u8], + verification_key: &[u8], + s1: [PolynomialRingElement; COLUMNS_IN_A], + s2: [PolynomialRingElement; ROWS_IN_A], + t0: [PolynomialRingElement; ROWS_IN_A], +) -> [u8; SIGNING_KEY_SIZE] { + let mut signing_key_serialized = [0u8; SIGNING_KEY_SIZE]; + let mut offset = 0; + + signing_key_serialized[offset..offset + SEED_FOR_A_SIZE].copy_from_slice(&seed_for_A); + offset += SEED_FOR_A_SIZE; + + signing_key_serialized[offset..offset + SEED_FOR_SIGNING_SIZE] + .copy_from_slice(&seed_for_signing); + offset += SEED_FOR_SIGNING_SIZE; + + let verification_key_hash = H::(verification_key); + signing_key_serialized[offset..offset + BYTES_FOR_VERIFICATION_KEY_HASH] + .copy_from_slice(&verification_key_hash); + offset += BYTES_FOR_VERIFICATION_KEY_HASH; + + for i in 0..COLUMNS_IN_A { + signing_key_serialized[offset..offset + BYTES_FOR_ERROR_RING_ELEMENT].copy_from_slice( + &serialize_error_ring_element::(s1[i]), + ); + offset += BYTES_FOR_ERROR_RING_ELEMENT; + } + + for i in 0..ROWS_IN_A { + signing_key_serialized[offset..offset + BYTES_FOR_ERROR_RING_ELEMENT].copy_from_slice( + &serialize_error_ring_element::(s2[i]), + ); + offset += BYTES_FOR_ERROR_RING_ELEMENT; + } + + for i in 0..ROWS_IN_A { + signing_key_serialized[offset..offset + BYTES_FOR_RING_ELEMENT_OF_T0S] + .copy_from_slice(&serialize_ring_element_of_t0s(t0[i])); + offset += BYTES_FOR_RING_ELEMENT_OF_T0S; + } + + signing_key_serialized +} + #[allow(non_snake_case)] pub(crate) fn generate_key_pair< const ROWS_IN_A: usize, const COLUMNS_IN_A: usize, const ETA: usize, + const BYTES_FOR_ERROR_RING_ELEMENT: usize, const SIGNING_KEY_SIZE: usize, const VERIFICATION_KEY_SIZE: usize, >( randomness: [u8; 32], ) -> ([u8; SIGNING_KEY_SIZE], [u8; VERIFICATION_KEY_SIZE]) { - let seed_expanded = H::<1024>(&randomness); + let seed_expanded = H::<128>(&randomness); let (seed_for_A, seed_expanded) = seed_expanded.split_at(SEED_FOR_A_SIZE); - let (seed_for_error_vectors, _random_seed_for_signing) = + let (seed_for_error_vectors, seed_for_signing) = seed_expanded.split_at(SEED_FOR_ERROR_VECTORS_SIZE); let mut domain_separator: u16 = 0; - let A_as_ntt = expand_to_A::(into_padded_array(seed_for_A), false); + let A_as_ntt = expand_to_A::(into_padded_array(seed_for_A)); let s1 = sample_error_vector::( into_padded_array(seed_for_error_vectors), @@ -61,8 +122,22 @@ pub(crate) fn generate_key_pair< let (t0, t1) = power2round_vector::(t); let verification_key_serialized = - serialize_verification_key::(seed_for_A, t1); - let signing_key_serialized = [0u8; SIGNING_KEY_SIZE]; + generate_serialized_verification_key::(seed_for_A, t1); + + let signing_key_serialized = generate_serialized_signing_key::< + ROWS_IN_A, + COLUMNS_IN_A, + ETA, + BYTES_FOR_ERROR_RING_ELEMENT, + SIGNING_KEY_SIZE, + >( + seed_for_A, + seed_for_signing, + &verification_key_serialized, + s1, + s2, + t0, + ); (signing_key_serialized, verification_key_serialized) } diff --git a/libcrux-ml-dsa/src/serialize.rs b/libcrux-ml-dsa/src/serialize.rs index 7087d696d..7f8796a49 100644 --- a/libcrux-ml-dsa/src/serialize.rs +++ b/libcrux-ml-dsa/src/serialize.rs @@ -1,8 +1,13 @@ -use crate::arithmetic::{t0_to_unsigned_representative, PolynomialRingElement}; +use crate::{ + arithmetic::{t0_to_unsigned_representative, PolynomialRingElement}, + constants::{BYTES_FOR_RING_ELEMENT_OF_T0S, BYTES_FOR_RING_ELEMENT_OF_T1S}, +}; #[inline(always)] -fn serialize_ring_element_of_t0s(re: PolynomialRingElement) -> [u8; 416] { - let mut serialized = [0u8; 416]; +pub(crate) fn serialize_ring_element_of_t0s( + re: PolynomialRingElement, +) -> [u8; BYTES_FOR_RING_ELEMENT_OF_T0S] { + let mut serialized = [0u8; BYTES_FOR_RING_ELEMENT_OF_T0S]; for (i, coefficients) in re.coefficients.chunks_exact(8).enumerate() { let coefficient0 = t0_to_unsigned_representative(coefficients[0]); @@ -52,8 +57,10 @@ fn serialize_ring_element_of_t0s(re: PolynomialRingElement) -> [u8; 416] { } #[inline(always)] -pub(crate) fn serialize_ring_element_of_t1s(re: PolynomialRingElement) -> [u8; 320] { - let mut serialized = [0u8; 320]; +pub(crate) fn serialize_ring_element_of_t1s( + re: PolynomialRingElement, +) -> [u8; BYTES_FOR_RING_ELEMENT_OF_T1S] { + let mut serialized = [0u8; BYTES_FOR_RING_ELEMENT_OF_T1S]; for (i, coefficients) in re.coefficients.chunks_exact(4).enumerate() { serialized[5 * i] = (coefficients[0] & 0xFF) as u8; @@ -69,6 +76,29 @@ pub(crate) fn serialize_ring_element_of_t1s(re: PolynomialRingElement) -> [u8; 3 serialized } +#[inline(always)] +fn serialize_error_ring_element_when_eta_is_4( + re: PolynomialRingElement, +) -> [u8; BYTES_FOR_OUTPUT] { + let mut serialized = [0u8; BYTES_FOR_OUTPUT]; + + for (i, coefficients) in re.coefficients.chunks_exact(2).enumerate() { + serialized[i] = ((coefficients[1] as u8) << 4) | (coefficients[0] as u8); + } + + serialized +} + +pub(crate) fn serialize_error_ring_element( + re: PolynomialRingElement, +) -> [u8; BYTES_FOR_OUTPUT] { + match ETA { + 2 => todo!(), + 4 => serialize_error_ring_element_when_eta_is_4::(re), + _ => unreachable!(), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/libcrux-ml-dsa/tests/nistkats.rs b/libcrux-ml-dsa/tests/nistkats.rs index 20995d6ef..ceef22b30 100644 --- a/libcrux-ml-dsa/tests/nistkats.rs +++ b/libcrux-ml-dsa/tests/nistkats.rs @@ -3,8 +3,6 @@ use serde_json; use std::{fs::File, io::BufReader, path::Path}; -use libcrux_sha3::sha256; - #[derive(Debug, Deserialize)] struct MlDsaNISTKAT { #[serde(with = "hex::serde")] @@ -40,5 +38,11 @@ fn ml_dsa_65_nist_known_answer_tests() { let verification_key_hash = libcrux_sha3::sha256(&key_pair.verification_key); assert_eq!(verification_key_hash, kat.sha3_256_hash_of_verification_key); + + let signing_key_hash = libcrux_sha3::sha256(&key_pair.signing_key); + assert_eq!( + signing_key_hash, kat.sha3_256_hash_of_signing_key, + "signing_key_hash != kat.sha3_256_hash_of_signing_key" + ); } }