Skip to content

Commit

Permalink
Short error vector sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf committed May 24, 2024
1 parent e2d2b77 commit 005720f
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 29 deletions.
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/hash_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub(crate) fn H<const OUTPUT_LENGTH: usize>(input: &[u8]) -> [u8; OUTPUT_LENGTH]
out
}

pub(crate) mod XOF {
pub(crate) mod H_128 {
use libcrux_sha3::portable::{incremental, KeccakState1};

const BLOCK_SIZE: usize = 168;
Expand Down
141 changes: 126 additions & 15 deletions libcrux-ml-dsa/src/sample.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
arithmetic::PolynomialRingElement,
constants::{COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS},
hash_functions::XOF,
hash_functions::{H, H_128},
};

fn rejection_sample_less_than_field_modulus(
Expand Down Expand Up @@ -32,24 +32,63 @@ fn rejection_sample_less_than_field_modulus(

done
}
#[allow(non_snake_case)]
pub(crate) fn sample_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingElement {
let mut state = XOF::new(seed);
let randomness = XOF::squeeze_first_five_blocks(&mut state);
let mut state = H_128::new(seed);
let randomness = H_128::squeeze_first_five_blocks(&mut state);

let mut out = PolynomialRingElement::ZERO;

let mut sampled = 0;
let mut done = rejection_sample_less_than_field_modulus(&randomness, &mut sampled, &mut out);

while !done {
let randomness = XOF::squeeze_next_block(&mut state);
let randomness = H_128::squeeze_next_block(&mut state);
done = rejection_sample_less_than_field_modulus(&randomness, &mut sampled, &mut out);
}

out
}

fn rejection_sample_less_than_eta_equals_2(
randomness: &[u8],
sampled: &mut usize,
out: &mut PolynomialRingElement,
) -> bool {
let mut done = false;

for byte in randomness {
if !done {
let try_0 = byte & 0xF;
let try_1 = byte >> 4;

if try_0 < 15 && *sampled < COEFFICIENTS_IN_RING_ELEMENT {
let try_0 = try_0 as i32;

// (try_0 * 26) >> 7 computes ⌊try_0 / 5⌋
let try_0_mod_5 = try_0 - ((try_0 * 26) >> 7) * 5;

out.coefficients[*sampled] = 2 - try_0_mod_5;

*sampled += 1;
}

if try_1 < 15 && *sampled < COEFFICIENTS_IN_RING_ELEMENT {
let try_1 = try_1 as i32;
let try_1_mod_5 = try_1 - ((try_1 * 26) >> 7) * 5;

out.coefficients[*sampled] = 2 - try_1_mod_5;

*sampled += 1;
}

if *sampled == COEFFICIENTS_IN_RING_ELEMENT {
done = true;
}
}
}

done
}
fn rejection_sample_less_than_eta_equals_4(
randomness: &[u8],
sampled: &mut usize,
Expand All @@ -62,12 +101,12 @@ fn rejection_sample_less_than_eta_equals_4(
let try_0 = byte & 0xF;
let try_1 = byte >> 4;

if try_0 < (2 * 4) + 1 && *sampled < COEFFICIENTS_IN_RING_ELEMENT {
if try_0 < 9 && *sampled < COEFFICIENTS_IN_RING_ELEMENT {
out.coefficients[*sampled] = 4 - (try_0 as i32);
*sampled += 1;
}

if try_1 < (2 * 4) + 1 && *sampled < COEFFICIENTS_IN_RING_ELEMENT {
if try_1 < 9 && *sampled < COEFFICIENTS_IN_RING_ELEMENT {
out.coefficients[*sampled] = 4 - (try_1 as i32);
*sampled += 1;
}
Expand All @@ -80,19 +119,34 @@ fn rejection_sample_less_than_eta_equals_4(

done
}

pub(crate) fn rejection_sample_less_than_eta<const ETA: usize>(
randomness: &[u8],
sampled: &mut usize,
out: &mut PolynomialRingElement,
) -> bool {
match ETA {
2 => rejection_sample_less_than_eta_equals_2(randomness, sampled, out),
4 => rejection_sample_less_than_eta_equals_4(randomness, sampled, out),
_ => unreachable!(),
}
}

#[allow(non_snake_case)]
pub(crate) fn sample_error_ring_element_uniform(seed: [u8; 34]) -> PolynomialRingElement {
let mut state = XOF::new(seed);
let randomness = XOF::squeeze_next_block(&mut state);
pub(crate) fn sample_error_ring_element_uniform<const ETA: usize>(
seed: [u8; 66],
) -> PolynomialRingElement {
// TODO: Use incremental API to squeeze one block at a time.
let randomness = H::<272>(&seed);

let mut out = PolynomialRingElement::ZERO;

let mut sampled = 0;
let mut done = rejection_sample_less_than_eta_equals_4(&randomness, &mut sampled, &mut out);
let done = rejection_sample_less_than_eta::<ETA>(&randomness, &mut sampled, &mut out);

while !done {
let randomness = XOF::squeeze_next_block(&mut state);
done = rejection_sample_less_than_eta_equals_4(&randomness, &mut sampled, &mut out);
// TODO: Remove this panic using the incremental API.
if !done {
panic!("Not enough randomness");
}

out
Expand All @@ -104,7 +158,6 @@ mod tests {

use crate::arithmetic::FieldElement;

#[allow(non_snake_case)]
#[test]
fn test_sample_ring_element_uniform() {
let seed: [u8; 34] = [
Expand Down Expand Up @@ -147,4 +200,62 @@ mod tests {
expected_coefficients
);
}

#[test]
fn test_sample_error_ring_element_when_eta_is_4() {
let seed: [u8; 66] = [
236, 4, 148, 239, 41, 178, 188, 226, 130, 212, 6, 144, 208, 180, 180, 105, 47, 148, 75,
195, 181, 177, 5, 140, 204, 68, 24, 132, 169, 19, 68, 118, 67, 203, 13, 152, 29, 194,
235, 123, 101, 109, 162, 137, 198, 164, 97, 247, 11, 44, 34, 49, 235, 251, 243, 177,
213, 141, 65, 232, 136, 163, 85, 54, 10, 0,
];

let expected_coefficients: [i32; COEFFICIENTS_IN_RING_ELEMENT] = [
2, -4, 2, -2, 1, 2, 4, 2, 4, -1, -4, 3, 2, 4, -1, 2, -3, 3, 1, -2, 0, 3, -2, 3, 4, 1,
-3, -2, 0, -4, -1, -4, 3, -4, 0, -3, -2, -3, 2, -3, -3, 3, -4, -3, -4, 1, -2, 4, -3, 4,
4, 1, -3, -3, 4, 0, -2, 2, 4, -4, 4, -4, -1, -3, 4, 3, 2, -1, 3, -2, -2, -4, -1, -1, 4,
1, 4, 0, 3, 4, -1, -3, 4, -4, 4, 1, -3, 0, -4, 2, 1, 4, -1, 0, -2, -2, -3, 3, -3, 4, 3,
2, -2, -2, -1, 2, -1, -4, 3, 0, -2, 4, -1, 0, 4, -2, 4, -3, 2, -4, 2, 3, 3, 2, -4, 2,
0, -2, 1, -4, 0, -4, -3, 2, 0, -2, -4, 1, 2, 3, 4, -4, 2, 2, 1, -4, 0, -4, -3, -2, -2,
-2, -1, 1, 4, 1, 0, -2, 2, 1, 4, -4, -1, 0, -1, -3, 2, 1, 3, 3, 4, -2, -2, 3, 1, 3, 3,
-4, -2, -1, -4, -3, 4, 1, 2, -3, -1, 3, 4, -3, 0, -1, -1, -4, -2, 1, -2, 3, -1, -2, 2,
-1, -2, 0, -2, 2, 3, 3, 2, 3, 4, 3, -3, -4, 1, 4, -3, 2, 0, -4, 4, -4, 2, 4, -2, -3,
-4, 3, 0, 1, -2, 2, -1, 4, 4, 0, -1, 1, 4, -2, -3, 2, -2, 4, 2, 1, 1, 1, -3, -2, -2, 2,
2, -4, -1, 1,
];

assert_eq!(
sample_error_ring_element_uniform::<4>(seed).coefficients,
expected_coefficients
);
}

#[test]
fn test_sample_error_ring_element_when_eta_is_2() {
let seed: [u8; 66] = [
51, 203, 133, 235, 126, 210, 169, 81, 4, 134, 147, 168, 252, 67, 176, 99, 130, 186,
254, 103, 241, 199, 173, 78, 121, 232, 12, 244, 4, 143, 8, 174, 122, 170, 124, 35, 53,
49, 202, 94, 27, 249, 200, 186, 175, 198, 169, 116, 244, 227, 133, 111, 205, 140, 233,
110, 227, 67, 35, 226, 194, 75, 130, 105, 5, 0,
];

let expected_coefficients: [i32; COEFFICIENTS_IN_RING_ELEMENT] = [
1, 0, -1, 0, 1, -2, -1, 0, -2, 2, -1, -2, 1, -2, 1, -2, 1, 2, -2, 2, -2, -1, 0, -2, -1,
-2, -2, 1, 1, -1, 1, 1, 2, -2, 2, -1, 1, 2, 0, 2, -1, 0, 2, -2, -2, 2, 0, 2, 1, 1, 2,
1, 1, -2, 1, -1, 2, -2, -2, 2, -2, -2, 0, 0, -1, 0, 2, 0, 1, 2, 0, 2, -1, 2, 0, 2, 1,
-2, -2, 0, -1, -2, 2, -2, -1, 2, 1, -1, 2, 1, -2, -1, 1, -1, -1, -1, 2, -1, -2, -2, 2,
2, 0, -1, -1, -2, 0, -1, 0, 1, 2, -2, 0, 2, 2, 1, 0, -1, -1, 0, -2, 2, 2, -2, 2, 1, -1,
-2, -1, -2, -1, 1, 2, 2, -1, 0, 1, 2, -1, 0, 0, 0, 1, 1, -1, -1, -1, -2, 2, 0, -2, 0,
2, -1, 1, 1, 2, -2, 2, -2, 1, 0, -2, 1, 0, 0, -2, -2, 2, 2, -2, -1, 2, -2, 1, 0, 0, -1,
0, -2, 2, -1, -2, 2, -1, 1, -2, -1, 0, -2, 2, 1, 2, 2, 2, 0, 2, 2, 2, 0, 2, 2, 2, -1,
-2, 1, 1, 0, -2, 1, 0, 0, -2, 1, -2, -1, 2, 0, 0, 2, 0, -2, -1, -1, 2, 2, -1, -1, -1,
-2, -2, -1, -2, 2, -2, 0, 1, 0, -2, -2, 2, 0, 1, 0, 0, -2, -1, 1, -1, 1, -1, -1, -1, 2,
2, 0,
];

assert_eq!(
sample_error_ring_element_uniform::<2>(seed).coefficients,
expected_coefficients
);
}
}
3 changes: 0 additions & 3 deletions libcrux-ml-dsa/tests/kats/dilithium.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,6 @@ def rejection_sample(xof):
Shake128.absorb(seed)
coeffs = [rejection_sample(Shake128) for _ in range(self.n)]

self.A_rejection_sampling_seed = seed
self.A_sampled_ring_element = coeffs

return self.R(coeffs, is_ntt=is_ntt)

def _sample_mask_polynomial(self, rho_prime, i, kappa, is_ntt=False):
Expand Down
10 changes: 0 additions & 10 deletions libcrux-ml-dsa/tests/kats/generate_kats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
import hashlib


def generate_matrix_A_sampling_KATs():
algorithm = Dilithium3

for i in range(1):
pk, sk = algorithm.keygen()
print([x for x in algorithm.A_rejection_sampling_seed])
print([x for x in algorithm.A_sampled_ring_element])


def generate_nistkats():
for algorithm in [Dilithium2, Dilithium3, Dilithium5]:
kats_formatted = []
Expand Down Expand Up @@ -56,5 +47,4 @@ def generate_nistkats():
json.dump(kats_formatted, f, ensure_ascii=False, indent=4)


# generate_matrix_A_sampling_KATs()
generate_nistkats()

0 comments on commit 005720f

Please sign in to comment.