Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making serialization functions specific instead of generic and removing helper code. #39

Merged
merged 12 commits into from
Aug 15, 2023
2 changes: 1 addition & 1 deletion src/kem/kyber768.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
mod compress;
mod field_element;
mod ind_cpa;
mod ntt;
mod parameters;
mod sampling;
mod serialize;
mod utils;
mod field_element;

use utils::{ArrayConversion, UpdatingArray2};

Expand Down
9 changes: 6 additions & 3 deletions src/kem/kyber768/compress.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::kem::kyber768::{parameters::{self, KyberPolynomialRingElement}, field_element::KyberFieldElement};
use crate::kem::kyber768::{
field_element::KyberFieldElement,
parameters::{self, KyberPolynomialRingElement},
};

pub fn compress(
re: KyberPolynomialRingElement,
Expand Down Expand Up @@ -30,7 +33,7 @@ fn compress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement {
compressed /= u32::from(KyberFieldElement::MODULUS << 1);

KyberFieldElement {
value: (compressed & (two_pow_bit_size - 1)) as u16
value: (compressed & (two_pow_bit_size - 1)) as u16,
}
}

Expand All @@ -42,6 +45,6 @@ fn decompress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement
decompressed >>= to_bit_size + 1;

KyberFieldElement {
value: decompressed as u16
value: decompressed as u16,
}
}
30 changes: 6 additions & 24 deletions src/kem/kyber768/field_element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ pub struct KyberFieldElement {
impl KyberFieldElement {
pub const MODULUS: u16 = FIELD_MODULUS;

const BARRETT_SHIFT : u32 = 24; // 2 * ceil(log_2(FIELD_MODULUS))
const BARRETT_MULTIPLIER : u32 = (1u32 << Self::BARRETT_SHIFT) / (Self::MODULUS as u32);
const BARRETT_SHIFT: u32 = 24; // 2 * ceil(log_2(FIELD_MODULUS))
const BARRETT_MULTIPLIER: u32 = (1u32 << Self::BARRETT_SHIFT) / (Self::MODULUS as u32);

pub fn barrett_reduce(value : u32) -> Self {
let product : u64 = u64::from(value) * u64::from(Self::BARRETT_MULTIPLIER);
let quotient : u32 = (product >> Self::BARRETT_SHIFT) as u32;
pub fn barrett_reduce(value: u32) -> Self {
let product: u64 = u64::from(value) * u64::from(Self::BARRETT_MULTIPLIER);
let quotient: u32 = (product >> Self::BARRETT_SHIFT) as u32;

let remainder = value - (quotient * u32::from(Self::MODULUS));
let remainder : u16 = remainder as u16;
let remainder: u16 = remainder as u16;

let remainder_minus_modulus = remainder.wrapping_sub(Self::MODULUS);

Expand All @@ -38,24 +38,6 @@ impl FieldElement for KyberFieldElement {
fn new(number: u16) -> Self {
Self::barrett_reduce(u32::from(number))
}

fn nth_bit_little_endian(&self, n: usize) -> u8 {
((self.value >> n) & 1) as u8
}
}

impl From<u8> for KyberFieldElement {
fn from(number: u8) -> Self {
Self {
value: u16::from(number)
}
}
}

impl From<KyberFieldElement> for u16 {
fn from(fe: KyberFieldElement) -> Self {
fe.value
}
}

impl ops::Add for KyberFieldElement {
Expand Down
160 changes: 73 additions & 87 deletions src/kem/kyber768/ind_cpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ use crate::kem::kyber768::{
},
parameters::{
hash_functions::{G, H, PRF, XOF},
KyberPolynomialRingElement, BITS_PER_RING_ELEMENT, BYTES_PER_RING_ELEMENT,
KyberPolynomialRingElement, BYTES_PER_RING_ELEMENT,
COEFFICIENTS_IN_RING_ELEMENT, CPA_PKE_CIPHERTEXT_SIZE, CPA_PKE_KEY_GENERATION_SEED_SIZE,
CPA_PKE_MESSAGE_SIZE, CPA_PKE_PUBLIC_KEY_SIZE, CPA_PKE_SECRET_KEY_SIZE,
CPA_SERIALIZED_KEY_LEN, RANK, REJECTION_SAMPLING_SEED_SIZE, T_AS_NTT_ENCODED_SIZE,
VECTOR_U_COMPRESSION_FACTOR, VECTOR_U_SIZE, VECTOR_V_COMPRESSION_FACTOR,
BYTES_PER_ENCODED_ELEMENT_OF_U, VECTOR_U_COMPRESSION_FACTOR, VECTOR_U_ENCODED_SIZE, VECTOR_V_COMPRESSION_FACTOR,
},
sampling::{sample_from_binomial_distribution_2, sample_from_uniform_distribution},
serialize::{
deserialize_little_endian_1, deserialize_little_endian_10, deserialize_little_endian_12,
deserialize_little_endian_4, serialize_little_endian_1, serialize_little_endian_10,
serialize_little_endian_12, serialize_little_endian_4,
},
sampling::{sample_from_binomial_distribution_with_2_coins, sample_from_uniform_distribution},
serialize::{deserialize_little_endian, serialize_little_endian, serialize_little_endian_12},
BadRejectionSamplingRandomnessError,
};

Expand Down Expand Up @@ -52,6 +56,48 @@ impl KeyPair {
}
}

#[inline(always)]
fn parse_a(
mut seed: [u8; 34],
transpose: bool,
) -> Result<[[KyberPolynomialRingElement; RANK]; RANK], BadRejectionSamplingRandomnessError> {
let mut a_transpose = [[KyberPolynomialRingElement::ZERO; RANK]; RANK];

for i in 0..RANK {
for j in 0..RANK {
seed[32] = i.as_u8();
seed[33] = j.as_u8();

let xof_bytes: [u8; REJECTION_SAMPLING_SEED_SIZE] = XOF(&seed);

// A[i][j] = A_transpose[j][i]
if transpose {
a_transpose[j][i] = sample_from_uniform_distribution(xof_bytes)?;
} else {
a_transpose[i][j] = sample_from_uniform_distribution(xof_bytes)?;
}
}
}
Ok(a_transpose)
}

#[inline(always)]
fn cbd(mut prf_input: [u8; 33]) -> ([KyberPolynomialRingElement; RANK], u8) {
let mut domain_separator = 0;
let mut re_as_ntt = [KyberPolynomialRingElement::ZERO; RANK];
for i in 0..re_as_ntt.len() {
prf_input[32] = domain_separator;
domain_separator += 1;

// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let r = sample_from_binomial_distribution_2(prf_output);
re_as_ntt[i] = ntt_representation(r);
}
(re_as_ntt, domain_separator)
}

fn encode_12(input: [KyberPolynomialRingElement; RANK]) -> [u8; RANK * BYTES_PER_RING_ELEMENT] {
let mut out = [0u8; RANK * BYTES_PER_RING_ELEMENT];

Expand Down Expand Up @@ -95,7 +141,7 @@ pub(crate) fn generate_keypair(
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let secret = sample_from_binomial_distribution_with_2_coins(prf_output);
let secret = sample_from_binomial_distribution_2(prf_output);
secret_as_ntt[i] = ntt_representation(secret);
}

Expand All @@ -111,7 +157,7 @@ pub(crate) fn generate_keypair(
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let error = sample_from_binomial_distribution_with_2_coins(prf_output);
let error = sample_from_binomial_distribution_2(prf_output);
error_as_ntt[i] = ntt_representation(error);
}

Expand All @@ -133,55 +179,13 @@ pub(crate) fn generate_keypair(
))
}

#[inline(always)]
fn parse_a(
mut seed: [u8; 34],
transpose: bool,
) -> Result<[[KyberPolynomialRingElement; RANK]; RANK], BadRejectionSamplingRandomnessError> {
let mut a_transpose = [[KyberPolynomialRingElement::ZERO; RANK]; RANK];

for i in 0..RANK {
for j in 0..RANK {
seed[32] = i.as_u8();
seed[33] = j.as_u8();

let xof_bytes: [u8; REJECTION_SAMPLING_SEED_SIZE] = XOF(&seed);

// A[i][j] = A_transpose[j][i]
if transpose {
a_transpose[j][i] = sample_from_uniform_distribution(xof_bytes)?;
} else {
a_transpose[i][j] = sample_from_uniform_distribution(xof_bytes)?;
}
}
}
Ok(a_transpose)
}

#[inline(always)]
fn cbd(mut prf_input: [u8; 33]) -> ([KyberPolynomialRingElement; RANK], u8) {
let mut domain_separator = 0;
let mut r_as_ntt = [KyberPolynomialRingElement::ZERO; RANK];
for i in 0..r_as_ntt.len() {
prf_input[32] = domain_separator;
domain_separator += 1;

// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);

let r = sample_from_binomial_distribution_with_2_coins(prf_output);
r_as_ntt[i] = ntt_representation(r);
}
(r_as_ntt, domain_separator)
}

fn encode_and_compress_u(input: [KyberPolynomialRingElement; RANK]) -> Vec<u8> {
let mut out = Vec::new();
for re in input.into_iter() {
out.extend_from_slice(&serialize_little_endian(
compress(re, VECTOR_U_COMPRESSION_FACTOR),
fn encode_and_compress_u(input: [KyberPolynomialRingElement; RANK]) -> [u8; VECTOR_U_ENCODED_SIZE] {
let mut out = [0u8; VECTOR_U_ENCODED_SIZE];
for (i, re) in input.into_iter().enumerate() {
out[i * BYTES_PER_ENCODED_ELEMENT_OF_U..(i + 1)* BYTES_PER_ENCODED_ELEMENT_OF_U].copy_from_slice(&serialize_little_endian_10(compress(
re,
VECTOR_U_COMPRESSION_FACTOR,
));
)));
}

out
Expand All @@ -194,15 +198,9 @@ pub(crate) fn encrypt(
randomness: &[u8; 32],
) -> Result<CiphertextCpa, BadRejectionSamplingRandomnessError> {
// tˆ := Decode_12(pk)
let mut t_as_ntt_ring_element_bytes = public_key.chunks(BITS_PER_RING_ELEMENT / 8);
let mut t_as_ntt = [KyberPolynomialRingElement::ZERO; RANK];
for i in 0..t_as_ntt.len() {
t_as_ntt[i] = deserialize_little_endian(
12,
t_as_ntt_ring_element_bytes.next().expect(
"t_as_ntt_ring_element_bytes should have enough bytes to deserialize to t_as_ntt",
),
);
for (i, t_as_ntt_bytes) in public_key[..T_AS_NTT_ENCODED_SIZE].chunks_exact(BYTES_PER_RING_ELEMENT).enumerate() {
t_as_ntt[i] = deserialize_little_endian_12(t_as_ntt_bytes);
}

// ρ := pk + 12·k·n / 8
Expand Down Expand Up @@ -233,23 +231,23 @@ pub(crate) fn encrypt(

// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);
error_1[i] = sample_from_binomial_distribution_with_2_coins(prf_output);
error_1[i] = sample_from_binomial_distribution_2(prf_output);
}

// e_2 := CBD{η2}(PRF(r, N))
prf_input[32] = domain_separator;
// 2 sampling coins * 64
let prf_output: [u8; 128] = PRF(&prf_input);
let error_2 = sample_from_binomial_distribution_with_2_coins(prf_output);
let error_2 = sample_from_binomial_distribution_2(prf_output);

// u := NTT^{-1}(AˆT ◦ rˆ) + e_1
let mut u = multiply_matrix_by_column(&A_transpose, &r_as_ntt).map(|r| invert_ntt(r));
let mut u = multiply_matrix_by_column(&A_transpose, &r_as_ntt).map(invert_ntt);
for i in 0..u.len() {
u[i] = u[i] + error_1[i];
}

// v := NTT^{−1}(tˆT ◦ rˆ) + e_2 + Decompress_q(Decode_1(m),1)
let message_as_ring_element = deserialize_little_endian(1, &message);
let message_as_ring_element = deserialize_little_endian_1(&message);
let v = invert_ntt(multiply_row_by_column(&t_as_ntt, &r_as_ntt))
+ error_2
+ decompress(message_as_ring_element, 1);
Expand All @@ -258,16 +256,10 @@ pub(crate) fn encrypt(
let c1 = encode_and_compress_u(u);

// c_2 := Encode_{dv}(Compress_q(v,d_v))
let c2 = serialize_little_endian(
compress(v, VECTOR_V_COMPRESSION_FACTOR),
VECTOR_V_COMPRESSION_FACTOR,
);
let c2 = serialize_little_endian_4(compress(v, VECTOR_V_COMPRESSION_FACTOR));

let ciphertext = c1
.into_iter()
.chain(c2.into_iter())
.collect::<Vec<u8>>()
.as_array();
let mut ciphertext : CiphertextCpa = (&c1).into_padded_array();
ciphertext[VECTOR_U_ENCODED_SIZE..].copy_from_slice(c2.as_slice());

Ok(ciphertext)
}
Expand All @@ -281,31 +273,25 @@ pub(crate) fn decrypt(
let mut secret_as_ntt = [KyberPolynomialRingElement::ZERO; RANK];

// u := Decompress_q(Decode_{d_u}(c), d_u)
for (i, u_bytes) in
(0..u_as_ntt.len()).zip(ciphertext.chunks((COEFFICIENTS_IN_RING_ELEMENT * 10) / 8))
for (i, u_bytes) in ciphertext[..VECTOR_U_ENCODED_SIZE].chunks_exact((COEFFICIENTS_IN_RING_ELEMENT * 10) / 8).enumerate()
{
let u = deserialize_little_endian(10, u_bytes);
let u = deserialize_little_endian_10(u_bytes);
u_as_ntt[i] = ntt_representation(decompress(u, 10));
}

// v := Decompress_q(Decode_{d_v}(c + d_u·k·n / 8), d_v)
let v = decompress(
deserialize_little_endian(VECTOR_V_COMPRESSION_FACTOR, &ciphertext[VECTOR_U_SIZE..]),
deserialize_little_endian_4(&ciphertext[VECTOR_U_ENCODED_SIZE..]),
VECTOR_V_COMPRESSION_FACTOR,
);

// sˆ := Decode_12(sk)
let mut secret_as_ntt_ring_element_bytes = secret_key.chunks(BITS_PER_RING_ELEMENT / 8);
for i in 0..secret_as_ntt.len() {
secret_as_ntt[i] = deserialize_little_endian(
12,
secret_as_ntt_ring_element_bytes.next().expect("secret_as_ntt_ring_element_bytes should have enough bytes to deserialize to secret_as_ntt"),
);
for (i, secret_bytes) in secret_key.chunks_exact(BYTES_PER_RING_ELEMENT).enumerate() {
secret_as_ntt[i] = deserialize_little_endian_12(secret_bytes);
}

// m := Encode_1(Compress_q(v − NTT^{−1}(sˆT ◦ NTT(u)) , 1))
let message = v - invert_ntt(multiply_row_by_column(&secret_as_ntt, &u_as_ntt));

// FIXME: remove conversion
serialize_little_endian(compress(message, 1), 1).as_array()
serialize_little_endian_1(compress(message, 1))
}
25 changes: 17 additions & 8 deletions src/kem/kyber768/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::kem::kyber768::parameters::{KyberPolynomialRingElement, RANK};
use self::kyber_polynomial_ring_element_mod::ntt_multiply;

pub(crate) mod kyber_polynomial_ring_element_mod {
use crate::kem::kyber768::field_element::KyberFieldElement;
use crate::kem::kyber768::parameters::{
self, KyberPolynomialRingElement, COEFFICIENTS_IN_RING_ELEMENT,
};
use crate::kem::kyber768::field_element::KyberFieldElement;

const ZETAS: [u16; 128] = [
1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
Expand Down Expand Up @@ -75,11 +75,12 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
out
}

fn ntt_multiply_binomials((a0, a1): (KyberFieldElement, KyberFieldElement),
(b0, b1): (KyberFieldElement, KyberFieldElement),
zeta: u16) -> (KyberFieldElement, KyberFieldElement) {
((a0 * b0) + ((a1 * b1) * zeta),
(a0 * b1) + (a1 * b0))
fn ntt_multiply_binomials(
(a0, a1): (KyberFieldElement, KyberFieldElement),
(b0, b1): (KyberFieldElement, KyberFieldElement),
zeta: u16,
) -> (KyberFieldElement, KyberFieldElement) {
((a0 * b0) + ((a1 * b1) * zeta), (a0 * b1) + (a1 * b0))
}

pub fn ntt_multiply(
Expand All @@ -89,11 +90,19 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
let mut out = KyberPolynomialRingElement::ZERO;

for i in (0..out.coefficients.len()).step_by(4) {
let product = ntt_multiply_binomials((left[i], left[i+1]), (right[i], right[i + 1]), MOD_ROOTS[i / 2]);
let product = ntt_multiply_binomials(
(left[i], left[i + 1]),
(right[i], right[i + 1]),
MOD_ROOTS[i / 2],
);
out[i] = product.0;
out[i + 1] = product.1;

let product = ntt_multiply_binomials((left[i + 2], left[i + 3]), (right[i + 2], right[i + 3]), MOD_ROOTS[(i + 2) / 2]);
let product = ntt_multiply_binomials(
(left[i + 2], left[i + 3]),
(right[i + 2], right[i + 3]),
MOD_ROOTS[(i + 2) / 2],
);
out[i + 2] = product.0;
out[i + 3] = product.1;
}
Expand Down
Loading