From 99a8c8f84181fd4736c68a704021e4567b5334af Mon Sep 17 00:00:00 2001 From: xvzcf Date: Thu, 23 May 2024 18:20:27 -0400 Subject: [PATCH] sample_ring_element_for_A --- libcrux-ml-dsa/src/lib.rs | 5 +- libcrux-ml-dsa/src/matrix.rs | 27 ++++++++ libcrux-ml-dsa/src/ml_dsa_65.rs | 10 +-- libcrux-ml-dsa/src/ml_dsa_generic.rs | 8 ++- libcrux-ml-dsa/src/sample.rs | 53 ++++++++++++++- libcrux-ml-dsa/src/utils.rs | 8 +++ libcrux-ml-dsa/tests/kats/dilithium.py | 4 ++ libcrux-ml-dsa/tests/kats/generate_kats.py | 77 +++++++++++++--------- 8 files changed, 153 insertions(+), 39 deletions(-) create mode 100644 libcrux-ml-dsa/src/matrix.rs create mode 100644 libcrux-ml-dsa/src/utils.rs diff --git a/libcrux-ml-dsa/src/lib.rs b/libcrux-ml-dsa/src/lib.rs index 570637119..23b49717c 100644 --- a/libcrux-ml-dsa/src/lib.rs +++ b/libcrux-ml-dsa/src/lib.rs @@ -1,7 +1,10 @@ mod arithmetic; mod constants; mod hash_functions; -mod ml_dsa_generic; +mod matrix; mod sample; +mod utils; + +mod ml_dsa_generic; pub mod ml_dsa_65; diff --git a/libcrux-ml-dsa/src/matrix.rs b/libcrux-ml-dsa/src/matrix.rs new file mode 100644 index 000000000..6196b4ead --- /dev/null +++ b/libcrux-ml-dsa/src/matrix.rs @@ -0,0 +1,27 @@ +use crate::{arithmetic::PolynomialRingElement, sample::sample_ring_element_for_A}; + +#[allow(non_snake_case)] +#[inline(always)] +pub(crate) fn expand_to_A( + mut seed: [u8; 34], + transposed: bool, +) -> [[PolynomialRingElement; COLUMNS_IN_A]; ROWS_IN_A] { + let mut A = [[PolynomialRingElement::ZERO; COLUMNS_IN_A]; ROWS_IN_A]; + + for i in 0..ROWS_IN_A { + for j in 0..COLUMNS_IN_A { + seed[32] = i as u8; + seed[33] = j as u8; + + let sampled = sample_ring_element_for_A(seed); + + if transposed { + A[j][i] = sampled; + } else { + A[i][j] = sampled; + } + } + } + + A +} diff --git a/libcrux-ml-dsa/src/ml_dsa_65.rs b/libcrux-ml-dsa/src/ml_dsa_65.rs index 00c168385..94172d7a3 100644 --- a/libcrux-ml-dsa/src/ml_dsa_65.rs +++ b/libcrux-ml-dsa/src/ml_dsa_65.rs @@ -17,10 +17,12 @@ pub struct MLDSA65KeyPair { /// Generate an ML-DSA-65 Key Pair pub fn generate_key_pair(randomness: [u8; 32]) -> MLDSA65KeyPair { - let (secret_key, public_key) = - crate::ml_dsa_generic::generate_key_pair::( - randomness, - ); + let (secret_key, public_key) = crate::ml_dsa_generic::generate_key_pair::< + ROWS_IN_A, + COLUMNS_IN_A, + SECRET_KEY_SIZE, + PUBLIC_KEY_SIZE, + >(randomness); MLDSA65KeyPair { secret_key, diff --git a/libcrux-ml-dsa/src/ml_dsa_generic.rs b/libcrux-ml-dsa/src/ml_dsa_generic.rs index c63cc60c5..4b1d896ad 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic.rs @@ -1,6 +1,8 @@ -use crate::hash_functions::H; +use crate::{hash_functions::H, matrix::expand_to_A, utils::into_padded_array}; +#[allow(non_snake_case)] pub(crate) fn generate_key_pair< + const ROWS_IN_A: usize, const COLUMNS_IN_A: usize, const SECRET_KEY_SIZE: usize, const PUBLIC_KEY_SIZE: usize, @@ -8,8 +10,10 @@ pub(crate) fn generate_key_pair< randomness: [u8; 32], ) -> ([u8; SECRET_KEY_SIZE], [u8; PUBLIC_KEY_SIZE]) { let seed_expanded = H::<1024>(&randomness); - let (_seed_for_matrix_a, seed_expanded) = seed_expanded.split_at(32); + let (seed_for_A, seed_expanded) = seed_expanded.split_at(32); let (_seed_for_short_vectors, _random_seed_for_signing) = seed_expanded.split_at(64); + let _A_hat = expand_to_A::(into_padded_array(seed_for_A), false); + todo!(); } diff --git a/libcrux-ml-dsa/src/sample.rs b/libcrux-ml-dsa/src/sample.rs index a0c326739..749bbe9ec 100644 --- a/libcrux-ml-dsa/src/sample.rs +++ b/libcrux-ml-dsa/src/sample.rs @@ -26,7 +26,8 @@ fn sample_from_uniform_distribution_next( sampled } -pub(super) fn sample_ring_element_uniformly(seed: [u8; 34]) -> PolynomialRingElement { +#[allow(non_snake_case)] +pub(crate) fn sample_ring_element_for_A(seed: [u8; 34]) -> PolynomialRingElement { let mut state = XOF::new(seed); let randomness = XOF::squeeze_first_five_blocks(&mut state); @@ -41,3 +42,53 @@ pub(super) fn sample_ring_element_uniformly(seed: [u8; 34]) -> PolynomialRingEle out } + +#[cfg(test)] +mod tests { + use super::*; + + use crate::arithmetic::FieldElement; + + #[test] + fn test_sample_ring_element_for_A() { + let seed: [u8; 34] = [ + 33, 192, 250, 216, 117, 61, 16, 12, 248, 51, 213, 110, 64, 57, 119, 80, 164, 83, 73, + 91, 80, 128, 195, 219, 203, 149, 170, 233, 16, 232, 209, 105, 4, 5, + ]; + + let expected_coefficients: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] = [ + 886541, 1468422, 793958, 7610434, 3986512, 913782, 2546456, 5820798, 1940159, 10062, + 3303190, 3831326, 4834267, 3500674, 16909, 8314529, 7469249, 5611755, 6181076, 269257, + 3566448, 2968856, 7556314, 6685884, 129963, 8017973, 1087829, 5842199, 6867133, 442098, + 3473053, 3812349, 556165, 55620, 4367526, 798402, 5317265, 2828265, 3808240, 3065841, + 6340895, 2710831, 715345, 5806109, 3689225, 4088547, 4258029, 2877620, 6867225, + 3275166, 4626484, 6596723, 5180488, 3836050, 1115576, 2086584, 749098, 4980044, + 7626966, 961947, 4695118, 6488634, 7898263, 841160, 1186851, 6958928, 4995591, 6829719, + 5910175, 2590788, 987365, 5983050, 7039561, 1406907, 4054912, 3093314, 237981, 6184639, + 515190, 5209488, 6460375, 4417602, 7890594, 6584284, 1729237, 5851336, 8226663, + 1843549, 5872244, 1375077, 6275711, 997136, 2593411, 5739784, 6621377, 7180456, + 1437441, 2607410, 197226, 4753353, 5086363, 6096080, 3057564, 5040851, 886178, 699532, + 3772666, 7983776, 1235995, 1960665, 1233119, 317423, 442071, 4649134, 5043634, 4164756, + 3166873, 2343835, 6256400, 6132302, 4124098, 6087733, 5371278, 3484545, 1020458, + 3688444, 7263864, 2413270, 4449757, 5561507, 7464292, 1176556, 8294481, 2892372, + 5509298, 194732, 7976046, 5907126, 4792878, 5059916, 3122481, 7009119, 5476286, + 4905623, 7374799, 7284599, 4929839, 538055, 5611660, 233595, 6125390, 7441322, 3752658, + 6655759, 4907614, 2281767, 1659504, 5490352, 4235568, 7143494, 6217399, 1581266, + 2455222, 1015526, 8366150, 2002613, 185543, 7904386, 8206829, 5380721, 2226008, + 7713547, 6961768, 7911095, 5604679, 6839785, 7573702, 1113136, 5563352, 7446030, + 6694003, 1725163, 4749689, 6474727, 7125683, 1830230, 5300491, 7927815, 5808662, + 2345184, 5462894, 5760340, 1949317, 1853703, 5060631, 5935138, 4873466, 3302619, + 5351360, 5707708, 2715882, 2050173, 52173, 5463772, 2851164, 1702574, 7167630, 1132010, + 1418205, 4182063, 4919187, 2707143, 6241533, 3241235, 2286591, 268487, 3799838, 558302, + 5882605, 6165192, 6702794, 5578115, 1893372, 7246495, 4974148, 2633723, 1522313, + 7636103, 6639058, 6765356, 3588710, 7011438, 4798122, 2329503, 4671411, 6787853, + 1838957, 306944, 5112958, 853077, 7844176, 384195, 839634, 1860349, 7289878, 4054796, + 703698, 5147821, 7632328, 5993194, 6329638, 5959986, 3073141, 675737, 7364844, 4124952, + ]; + + assert_eq!( + sample_ring_element_for_A(seed).coefficients, + expected_coefficients + ); + } +} diff --git a/libcrux-ml-dsa/src/utils.rs b/libcrux-ml-dsa/src/utils.rs new file mode 100644 index 000000000..8d4754d19 --- /dev/null +++ b/libcrux-ml-dsa/src/utils.rs @@ -0,0 +1,8 @@ +/// Pad the `slice` with `0`s at the end. +#[inline(always)] +pub(crate) fn into_padded_array(slice: &[u8]) -> [u8; LEN] { + debug_assert!(slice.len() <= LEN); + let mut out = [0u8; LEN]; + out[0..slice.len()].copy_from_slice(slice); + out +} diff --git a/libcrux-ml-dsa/tests/kats/dilithium.py b/libcrux-ml-dsa/tests/kats/dilithium.py index da3cee2c7..0544e7b0b 100644 --- a/libcrux-ml-dsa/tests/kats/dilithium.py +++ b/libcrux-ml-dsa/tests/kats/dilithium.py @@ -288,6 +288,10 @@ def rejection_sample(xof): seed = rho + bytes([j, i]) Shake128.absorb(seed) coeffs = [rejection_sample(Shake128) for _ in range(self.n)] + + self.A_rejection_sampling_seed = seed + self.A_sampled_ring_element = coeffs + return self.R(coeffs, is_ntt=is_ntt) def _sample_mask_polynomial(self, rho_prime, i, kappa, is_ntt=False): diff --git a/libcrux-ml-dsa/tests/kats/generate_kats.py b/libcrux-ml-dsa/tests/kats/generate_kats.py index 1342361eb..a770d5c01 100755 --- a/libcrux-ml-dsa/tests/kats/generate_kats.py +++ b/libcrux-ml-dsa/tests/kats/generate_kats.py @@ -6,40 +6,55 @@ import json import hashlib -for algorithm in [Dilithium2, Dilithium3, Dilithium5]: - kats_formatted = [] - entropy_input = bytes([i for i in range(48)]) - rng = AES256_CTR_DRBG(entropy_input) +def generate_matrix_A_sampling_KATs(): + algorithm = Dilithium3 - print("Generating KATs for ML-DSA-{}{}.".format(algorithm.k, algorithm.l)) + for i in range(1): + pk, sk = algorithm.keygen() + print([x for x in algorithm.A_rejection_sampling_seed]) + print([x for x in algorithm.A_sampled_ring_element]) - for i in range(100): - seed = rng.random_bytes(48) - algorithm.set_drbg_seed(seed) +def generate_nistkats(): + for algorithm in [Dilithium2, Dilithium3, Dilithium5]: + kats_formatted = [] + + entropy_input = bytes([i for i in range(48)]) + rng = AES256_CTR_DRBG(entropy_input) + + print("Generating KATs for ML-DSA-{}{}.".format(algorithm.k, algorithm.l)) + + for i in range(100): + seed = rng.random_bytes(48) + + algorithm.set_drbg_seed(seed) + + pk, sk = algorithm.keygen() + + msg_len = 33 * (i + 1) + msg = rng.random_bytes(msg_len) + sig = algorithm.sign(sk, msg) + + kats_formatted.append( + { + "key_generation_seed": bytes(algorithm.keygen_seed).hex(), + "sha3_256_hash_of_public_key": bytes( + hashlib.sha3_256(pk).digest() + ).hex(), + "sha3_256_hash_of_secret_key": bytes( + hashlib.sha3_256(sk).digest() + ).hex(), + "message": bytes(msg).hex(), + "sha3_256_hash_of_signature": bytes( + hashlib.sha3_256(sig).digest() + ).hex(), + } + ) + + with open("nistkats-{}{}.json".format(algorithm.k, algorithm.l), "w") as f: + json.dump(kats_formatted, f, ensure_ascii=False, indent=4) - pk, sk = algorithm.keygen() - msg_len = 33 * (i + 1) - msg = rng.random_bytes(msg_len) - sig = algorithm.sign(sk, msg) - - kats_formatted.append( - { - "key_generation_seed": bytes(algorithm.keygen_seed).hex(), - "sha3_256_hash_of_public_key": bytes( - hashlib.sha3_256(pk).digest() - ).hex(), - "sha3_256_hash_of_secret_key": bytes( - hashlib.sha3_256(sk).digest() - ).hex(), - "message": bytes(msg).hex(), - "sha3_256_hash_of_signature": bytes( - hashlib.sha3_256(sig).digest() - ).hex(), - } - ) - - with open("nistkats-{}{}.json".format(algorithm.k, algorithm.l), "w") as f: - json.dump(kats_formatted, f, ensure_ascii=False, indent=4) +# generate_matrix_A_sampling_KATs() +generate_nistkats()