diff --git a/libcrux-ml-dsa/src/sample.rs b/libcrux-ml-dsa/src/sample.rs index 05fd2c9f9..9e093b9ea 100644 --- a/libcrux-ml-dsa/src/sample.rs +++ b/libcrux-ml-dsa/src/sample.rs @@ -4,28 +4,34 @@ use crate::{ hash_functions::XOF, }; -fn sample_from_uniform_distribution_next( +fn rejection_sample_less_than_field_modulus( randomness: &[u8], + sampled: &mut usize, out: &mut PolynomialRingElement, -) -> usize { - let mut sampled = 0; +) -> bool { + let mut done = false; for bytes in randomness.chunks(3) { - let b0 = bytes[0] as i32; - let b1 = bytes[1] as i32; - let b2 = bytes[2] as i32; + if !done { + let b0 = bytes[0] as i32; + let b1 = bytes[1] as i32; + let b2 = bytes[2] as i32; + + let potential_coefficient = ((b2 << 16) | (b1 << 8) | b0) & 0x00_7F_FF_FF; - let potential_coefficient = ((b2 << 16) | (b1 << 8) | b0) & 0x00_7F_FF_FF; + if potential_coefficient < FIELD_MODULUS && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + out.coefficients[*sampled] = potential_coefficient; + *sampled += 1; + } - if potential_coefficient < FIELD_MODULUS && sampled < COEFFICIENTS_IN_RING_ELEMENT { - out.coefficients[sampled] = potential_coefficient; - sampled += 1; + if *sampled == COEFFICIENTS_IN_RING_ELEMENT { + done = true; + } } } - sampled + done } - #[allow(non_snake_case)] pub(crate) fn sample_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingElement { let mut state = XOF::new(seed); @@ -33,11 +39,60 @@ pub(crate) fn sample_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingEleme let mut out = PolynomialRingElement::ZERO; - let mut sampled = sample_from_uniform_distribution_next(&randomness, &mut out); + let mut sampled = 0; + let mut done = rejection_sample_less_than_field_modulus(&randomness, &mut sampled, &mut out); + + while !done { + let randomness = XOF::squeeze_next_block(&mut state); + done = rejection_sample_less_than_field_modulus(&randomness, &mut sampled, &mut out); + } + + out +} + +fn rejection_sample_less_than_eta_equals_4( + randomness: &[u8], + sampled: &mut usize, + out: &mut PolynomialRingElement, +) -> bool { + let mut done = false; + + for byte in randomness { + if !done { + let try_0 = byte & 0xF; + let try_1 = byte >> 4; + + if try_0 < (2 * 4) + 1 && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + out.coefficients[*sampled] = 4 - (try_0 as i32); + *sampled += 1; + } + + if try_1 < (2 * 4) + 1 && *sampled < COEFFICIENTS_IN_RING_ELEMENT { + out.coefficients[*sampled] = 4 - (try_1 as i32); + *sampled += 1; + } + + if *sampled == COEFFICIENTS_IN_RING_ELEMENT { + done = true; + } + } + } + + done +} +#[allow(non_snake_case)] +pub(crate) fn sample_error_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingElement { + let mut state = XOF::new(seed); + let randomness = XOF::squeeze_next_block(&mut state); + + let mut out = PolynomialRingElement::ZERO; + + let mut sampled = 0; + let mut done = rejection_sample_less_than_eta_equals_4(&randomness, &mut sampled, &mut out); - while sampled < COEFFICIENTS_IN_RING_ELEMENT { + while !done { let randomness = XOF::squeeze_next_block(&mut state); - sampled += sample_from_uniform_distribution_next(&randomness, &mut out); + done = rejection_sample_less_than_eta_equals_4(&randomness, &mut sampled, &mut out); } out