Skip to content

Commit

Permalink
Optimize kyber key generation (#37)
Browse files Browse the repository at this point in the history
* Optimized ring-element encoding in PKE key generation.
* Optimized binomial sampling for PKE key generation.
* Removed generic binomial sampler.
  • Loading branch information
xvzcf authored Aug 9, 2023
1 parent 472c171 commit b4214ed
Show file tree
Hide file tree
Showing 11 changed files with 123 additions and 120 deletions.
34 changes: 20 additions & 14 deletions benches/kyber768.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};

use libcrux::kem::Algorithm;
use libcrux::drbg::Drbg;
use libcrux::digest;
use libcrux::drbg::Drbg;
use libcrux::kem::Algorithm;

pub fn comparisons_key_generation(c: &mut Criterion) {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();
let mut group = c.benchmark_group("Kyber768 Key Generation");

group.bench_function("libcrux reference implementation", |b| {
b.iter(
|| {
let (_secret_key, _public_key) = libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
}
)
b.iter(|| {
let (_secret_key, _public_key) =
libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
})
});

group.bench_function("pqclean reference implementation", |b| {
Expand All @@ -30,7 +29,8 @@ pub fn comparisons_encapsulation(c: &mut Criterion) {
b.iter_batched(
|| {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();
let (_secret_key, public_key) = libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
let (_secret_key, public_key) =
libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();

(drbg, public_key)
},
Expand All @@ -50,7 +50,8 @@ pub fn comparisons_encapsulation(c: &mut Criterion) {
public_key
},
|public_key| {
let (_shared_secret, _ciphertext) = pqcrypto_kyber::kyber768::encapsulate(&public_key);
let (_shared_secret, _ciphertext) =
pqcrypto_kyber::kyber768::encapsulate(&public_key);
},
BatchSize::SmallInput,
)
Expand All @@ -64,12 +65,15 @@ pub fn comparisons_decapsulation(c: &mut Criterion) {
b.iter_batched(
|| {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();
let (secret_key, public_key) = libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
let (_shared_secret, ciphertext) = libcrux::kem::encapsulate(Algorithm::Kyber768, &public_key, &mut drbg).unwrap();
let (secret_key, public_key) =
libcrux::kem::key_gen(Algorithm::Kyber768, &mut drbg).unwrap();
let (_shared_secret, ciphertext) =
libcrux::kem::encapsulate(Algorithm::Kyber768, &public_key, &mut drbg).unwrap();
(secret_key, ciphertext)
},
|(secret_key, ciphertext)| {
let _shared_secret = libcrux::kem::decapsulate(Algorithm::Kyber768, &ciphertext, &secret_key);
let _shared_secret =
libcrux::kem::decapsulate(Algorithm::Kyber768, &ciphertext, &secret_key);
},
BatchSize::SmallInput,
)
Expand All @@ -79,12 +83,14 @@ pub fn comparisons_decapsulation(c: &mut Criterion) {
b.iter_batched(
|| {
let (public_key, secret_key) = pqcrypto_kyber::kyber768::keypair();
let (_shared_secret, ciphertext) = pqcrypto_kyber::kyber768::encapsulate(&public_key);
let (_shared_secret, ciphertext) =
pqcrypto_kyber::kyber768::encapsulate(&public_key);

(ciphertext, secret_key)
},
|(ciphertext, secret_key)| {
let _shared_secret = pqcrypto_kyber::kyber768::decapsulate(&ciphertext, &secret_key);
let _shared_secret =
pqcrypto_kyber::kyber768::decapsulate(&ciphertext, &secret_key);
},
BatchSize::SmallInput,
)
Expand Down
11 changes: 11 additions & 0 deletions examples/kyber768_generate_keypair.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use libcrux::digest;
use libcrux::drbg::Drbg;
use libcrux::kem;

fn main() {
let mut drbg = Drbg::new(digest::Algorithm::Sha256).unwrap();

for _i in 0..100000 {
let (_secret_key, _public_key) = kem::key_gen(kem::Algorithm::Kyber768, &mut drbg).unwrap();
}
}
10 changes: 5 additions & 5 deletions src/drbg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,13 @@ impl Drbg {
/// Implementation of the [`RngCore`] trait for the [`Drbg`].
impl RngCore for Drbg {
fn next_u32(&mut self) -> u32 {
let mut bytes : [u8; 4] = [0; 4];
let mut bytes: [u8; 4] = [0; 4];
self.generate(&mut bytes).unwrap();

(bytes[0] as u32) |
(bytes[1] as u32) << 8 |
(bytes[2] as u32) << 16 |
(bytes[3] as u32) << 24
(bytes[0] as u32)
| (bytes[1] as u32) << 8
| (bytes[2] as u32) << 16
| (bytes[3] as u32) << 24
}

fn next_u64(&mut self) -> u64 {
Expand Down
4 changes: 2 additions & 2 deletions src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ mod kyber768;
// (and change the visibility of the exported functions to pub(crate)) the
// moment we have an implementation of one. This is tracked by:
// https://github.com/cryspen/libcrux/issues/36
pub use kyber768::generate_keypair as kyber768_generate_keypair_derand;
pub use kyber768::encapsulate as kyber768_encapsulate_derand;
pub use kyber768::decapsulate as kyber768_decapsulate_derand;
pub use kyber768::encapsulate as kyber768_encapsulate_derand;
pub use kyber768::generate_keypair as kyber768_generate_keypair_derand;

/// KEM Algorithms
///
Expand Down
40 changes: 21 additions & 19 deletions src/kem/kyber768/ind_cpa.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::kem::kyber768::utils::{
ArrayConversion, ArrayPadding, PanickingIntegerCasts, UpdatableArray, UpdatingArray, VecUpdate,
ArrayConversion, ArrayPadding, PanickingIntegerCasts, UpdatableArray, UpdatingArray,
};

use crate::kem::kyber768::{
Expand All @@ -10,14 +10,14 @@ use crate::kem::kyber768::{
},
parameters::{
hash_functions::{G, H, PRF, XOF},
KyberPolynomialRingElement, BITS_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT,
CPA_PKE_CIPHERTEXT_SIZE, CPA_PKE_KEY_GENERATION_SEED_SIZE, CPA_PKE_MESSAGE_SIZE,
CPA_PKE_PUBLIC_KEY_SIZE, CPA_PKE_SECRET_KEY_SIZE, CPA_SERIALIZED_KEY_LEN, RANK,
REJECTION_SAMPLING_SEED_SIZE, T_AS_NTT_ENCODED_SIZE, VECTOR_U_COMPRESSION_FACTOR,
VECTOR_U_SIZE, VECTOR_V_COMPRESSION_FACTOR,
KyberPolynomialRingElement, BITS_PER_RING_ELEMENT, BYTES_PER_RING_ELEMENT,
COEFFICIENTS_IN_RING_ELEMENT, CPA_PKE_CIPHERTEXT_SIZE, CPA_PKE_KEY_GENERATION_SEED_SIZE,
CPA_PKE_MESSAGE_SIZE, CPA_PKE_PUBLIC_KEY_SIZE, CPA_PKE_SECRET_KEY_SIZE,
CPA_SERIALIZED_KEY_LEN, RANK, REJECTION_SAMPLING_SEED_SIZE, T_AS_NTT_ENCODED_SIZE,
VECTOR_U_COMPRESSION_FACTOR, VECTOR_U_SIZE, VECTOR_V_COMPRESSION_FACTOR,
},
sampling::{sample_from_binomial_distribution, sample_from_uniform_distribution},
serialize::{deserialize_little_endian, serialize_little_endian},
sampling::{sample_from_binomial_distribution_with_2_coins, sample_from_uniform_distribution},
serialize::{deserialize_little_endian, serialize_little_endian, serialize_little_endian_12},
BadRejectionSamplingRandomnessError,
};

Expand Down Expand Up @@ -52,10 +52,12 @@ impl KeyPair {
}
}

fn encode_12(input: [KyberPolynomialRingElement; RANK]) -> Vec<u8> {
let mut out = Vec::new();
for re in input.into_iter() {
out.extend_from_slice(&serialize_little_endian(re, 12));
fn encode_12(input: [KyberPolynomialRingElement; RANK]) -> [u8; RANK * BYTES_PER_RING_ELEMENT] {
let mut out = [0u8; RANK * BYTES_PER_RING_ELEMENT];

for (i, re) in input.into_iter().enumerate() {
out[i * BYTES_PER_RING_ELEMENT..(i + 1) * BYTES_PER_RING_ELEMENT]
.copy_from_slice(&serialize_little_endian_12(re));
}

out
Expand Down Expand Up @@ -93,7 +95,7 @@ pub(crate) fn generate_keypair(
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let secret = sample_from_binomial_distribution(2, &prf_output[..]);
let secret = sample_from_binomial_distribution_with_2_coins(prf_output);
secret_as_ntt[i] = ntt_representation(secret);
}

Expand All @@ -109,7 +111,7 @@ pub(crate) fn generate_keypair(
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let error = sample_from_binomial_distribution(2, &prf_output[..]);
let error = sample_from_binomial_distribution_with_2_coins(prf_output);
error_as_ntt[i] = ntt_representation(error);
}

Expand All @@ -120,13 +122,13 @@ pub(crate) fn generate_keypair(
}

// pk := (Encode_12(tˆ mod^{+}q) || ρ)
let public_key_serialized = encode_12(t_as_ntt).concat(seed_for_A);
let public_key_serialized = [&encode_12(t_as_ntt), seed_for_A].concat();

// sk := Encode_12(sˆ mod^{+}q)
let secret_key_serialized = encode_12(secret_as_ntt);

Ok(KeyPair::new(
secret_key_serialized.into_array(),
secret_key_serialized,
public_key_serialized.into_array(),
))
}
Expand Down Expand Up @@ -167,7 +169,7 @@ fn cbd(mut prf_input: [u8; 33]) -> ([KyberPolynomialRingElement; RANK], u8) {
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let r = sample_from_binomial_distribution(2, &prf_output);
let r = sample_from_binomial_distribution_with_2_coins(prf_output);
r_as_ntt[i] = ntt_representation(r);
}
(r_as_ntt, domain_separator)
Expand Down Expand Up @@ -231,14 +233,14 @@ pub(crate) fn encrypt(

// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);
error_1[i] = sample_from_binomial_distribution(2, &prf_output);
error_1[i] = sample_from_binomial_distribution_with_2_coins(prf_output);
}

// e_2 := CBD{η2}(PRF(r, N))
prf_input[32] = domain_separator;
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);
let error_2 = sample_from_binomial_distribution(2, &prf_output);
let error_2 = sample_from_binomial_distribution_with_2_coins(prf_output);

// u := NTT^{-1}(AˆT ◦ rˆ) + e_1
let mut u = multiply_matrix_by_column(&A_transpose, &r_as_ntt).map(|r| invert_ntt(r));
Expand Down
51 changes: 0 additions & 51 deletions src/kem/kyber768/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
self, KyberFieldElement, KyberPolynomialRingElement, COEFFICIENTS_IN_RING_ELEMENT,
};

/// [ pow(17, br(i), p) for 0 <= i < 128 ]
/// br(i) is the bit reversal of i regarded as a 7-bit number.
const ZETAS: [u16; 128] = [
1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
Expand All @@ -23,8 +21,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
2154,
];

/// [ pow(17, 2 * br(i) + 1, p) for 0 <= i < 128 ]
/// br(i) is the bit reversal of i regarded as a 7-bit number.
const MOD_ROOTS: [u16; 128] = [
17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229,
1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279,
Expand All @@ -39,26 +35,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {

const NTT_LAYERS: [usize; 7] = [2, 4, 8, 16, 32, 64, 128];

/// Use the Cooley–Tukey butterfly to compute an in-place NTT representation
/// of a `KyberPolynomialRingElement`.
///
/// This can be seen (see [CFRG draft]) as 128 applications of the linear map CT where
///
/// CT_i(a, b) => (a + zeta^i * b, a - zeta^i * b) mod q
///
/// for the appropriate i.
///
/// Because the Kyber base field has 256th roots of unity but not 512th roots
/// of unity, the resulting NTT representation is an element in:
///
/// ```plaintext
/// Product(i = 0 to 255) F_{3329}[x] / (x^2 - zeta^{2i+1}),
/// ```
///
/// This is isomorphic to `F_{3329}[x] / (x^{256} + 1)` by the
/// Chinese Remainder Theorem.
///
/// [CFRG draft]: <https://datatracker.ietf.org/doc/draft-cfrg-schwabe-kyber/>
pub fn ntt_representation(mut re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let mut zeta_i = 0;
for layer in NTT_LAYERS.iter().rev() {
Expand All @@ -76,17 +52,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
re
}

/// Use the Gentleman-Sande butterfly to invert, in-place, the NTT representation
/// of a `KyberPolynomialRingElement`. The inverse NTT can be computed (see [CFRG draft]) by
/// replacing CS_i by GS_j and
///
/// ```plaintext
/// GS_j(a, b) => ( (a + b) / 2, zeta^{2*j + 1} * (a - b) / 2 ) mod q
/// ```
///
/// for the appropriate j.
///
/// [CFRG draft]: https://datatracker.ietf.org/doc/draft-cfrg-schwabe-kyber/
pub fn invert_ntt(re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let inverse_of_2: KyberFieldElement =
KyberFieldElement::new((parameters::FIELD_MODULUS + 1) / 2);
Expand Down Expand Up @@ -114,22 +79,6 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
out
}

/// Two elements `a, b ∈ F_{3329}[x] / (x^2 - zeta^{2i+1})` in the Kyber NTT
/// domain:
///
/// ```plaintext
/// a = a_0 + a_1 * x
/// b = b_0 + b_1 * x
/// ```
///
/// can be multiplied as follows:
///
/// ```plaintext
/// (a_2 * x + a_1)(b_2 * x + b_1) =
/// (a_0 * b_0 + a_1 * b_1 * zeta^{2i + 1}) + (a_0 * b_1 + a_1 * b_0) * x
/// ```
///
/// for the appropriate i.
pub fn ntt_multiply(
left: &KyberPolynomialRingElement,
other: &KyberPolynomialRingElement,
Expand Down
5 changes: 4 additions & 1 deletion src/kem/kyber768/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ pub(crate) const BITS_PER_COEFFICIENT: usize = 12;
/// Coefficients per ring element
pub(crate) const COEFFICIENTS_IN_RING_ELEMENT: usize = 256;

/// Bits required per ring element
/// Bits required per (uncompressed) ring element
pub(crate) const BITS_PER_RING_ELEMENT: usize = COEFFICIENTS_IN_RING_ELEMENT * 12;

/// Bytes required per (uncompressed) ring element
pub(crate) const BYTES_PER_RING_ELEMENT: usize = BITS_PER_RING_ELEMENT / 8;

/// Seed size for rejection sampling.
///
/// See <https://eprint.iacr.org/2023/708> for some background regarding
Expand Down
39 changes: 21 additions & 18 deletions src/kem/kyber768/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::kem::kyber768::{
parameters::{self, KyberFieldElement, KyberPolynomialRingElement},
utils::bit_vector::LittleEndianBitStream,
BadRejectionSamplingRandomnessError,
};

Expand Down Expand Up @@ -37,28 +36,32 @@ pub fn sample_from_uniform_distribution(
Err(BadRejectionSamplingRandomnessError)
}

pub fn sample_from_binomial_distribution(
sampling_coins: usize,
randomness: &[u8],
pub fn sample_from_binomial_distribution_with_2_coins(
randomness: [u8; 128],
) -> KyberPolynomialRingElement {
assert_eq!(randomness.len(), sampling_coins * 64);

let mut sampled: KyberPolynomialRingElement = KyberPolynomialRingElement::ZERO;

for i in 0..sampled.len() {
let mut coin_tosses: u8 = 0;
for j in 0..sampling_coins {
coin_tosses += randomness.nth_bit(2 * i * sampling_coins + j);
}
let coin_tosses_a: KyberFieldElement = coin_tosses.into();
for (chunk_number, byte_chunk) in randomness.chunks_exact(4).enumerate() {
let random_bits_as_u32: u32 = (byte_chunk[0] as u32)
| (byte_chunk[1] as u32) << 8
| (byte_chunk[2] as u32) << 16
| (byte_chunk[3] as u32) << 24;

coin_tosses = 0;
for j in 0..sampling_coins {
coin_tosses += randomness.nth_bit(2 * i * sampling_coins + sampling_coins + j);
}
let coin_tosses_b: KyberFieldElement = coin_tosses.into();
let even_bits = random_bits_as_u32 & 0x55555555;
let odd_bits = (random_bits_as_u32 >> 1) & 0x55555555;

let coin_toss_outcomes = even_bits + odd_bits;

sampled[i] = coin_tosses_a - coin_tosses_b;
for outcome_set in (0..u32::BITS).step_by(4) {
let outcome_1: u16 = ((coin_toss_outcomes >> outcome_set) & 0x3) as u16;
let outcome_1: KyberFieldElement = outcome_1.into();

let outcome_2: u16 = ((coin_toss_outcomes >> (outcome_set + 2)) & 0x3) as u16;
let outcome_2: KyberFieldElement = outcome_2.into();

let offset = usize::try_from(outcome_set >> 2).unwrap();
sampled[8 * chunk_number + offset] = outcome_1 - outcome_2;
}
}

sampled
Expand Down
Loading

0 comments on commit b4214ed

Please sign in to comment.