Skip to content

Commit

Permalink
Use canonical signed representatives and lazy modular reduction. (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf authored Aug 21, 2023
1 parent 067f3f6 commit 98299d5
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 113 deletions.
66 changes: 20 additions & 46 deletions src/kem/kyber768/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,31 @@ use crate::kem::kyber768::parameters::{COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODUL

pub(crate) type KyberFieldElement = i16;

const BARRETT_SHIFT: u32 = 24; // 2 * ceil(log_2(FIELD_MODULUS))
const BARRETT_MULTIPLIER: u64 = (1u64 << BARRETT_SHIFT) / (FIELD_MODULUS as u64);
const BARRETT_SHIFT: i32 = 26;
const BARRETT_R: i32 = 1i32 << BARRETT_SHIFT;
const BARRETT_MULTIPLIER: i32 = 20159; // floor((BARRETT_R / FIELD_MODULUS) + 0.5)

pub(crate) fn barrett_reduce(value: i32) -> KyberFieldElement {
let product: u64 = (value as u64) * BARRETT_MULTIPLIER;
let quotient: u32 = (product >> BARRETT_SHIFT) as u32;
pub(crate) fn barrett_reduce(value: i16) -> KyberFieldElement {
let quotient = (i32::from(value) * BARRETT_MULTIPLIER) + (BARRETT_R >> 1);
let quotient = (quotient >> BARRETT_SHIFT) as i16;

// TODO: Justify in the comments (and subsequently in the proofs) that these
// operations do not lead to overflow/underflow.
let remainder = (value as u32) - (quotient * (FIELD_MODULUS as u32));
let remainder: i16 = remainder as i16;

let remainder_minus_modulus = remainder - FIELD_MODULUS;

// TODO: Check if LLVM detects this and optimizes it away into a
// conditional.
let selector = remainder_minus_modulus >> 15;

(selector & remainder) | (!selector & remainder_minus_modulus)
}

pub(crate) fn fe_add(lhs: KyberFieldElement, rhs: KyberFieldElement) -> KyberFieldElement {
let sum: i16 = lhs + rhs;
let difference: i16 = sum - FIELD_MODULUS;

let mask = difference >> 15;

(mask & sum) | (!mask & difference)
}

pub(crate) fn fe_sub(lhs: KyberFieldElement, rhs: KyberFieldElement) -> KyberFieldElement {
let difference = lhs - rhs;
let difference_plus_modulus: i16 = difference + FIELD_MODULUS;

let mask = difference >> 15;

(!mask & difference) | (mask & difference_plus_modulus)
value - (quotient * FIELD_MODULUS)
}

pub(crate) fn fe_mul(lhs: KyberFieldElement, rhs: KyberFieldElement) -> KyberFieldElement {
// TODO: This will shortly be replaced by an implementation of
// montgomery reduction.
let product: i32 = i32::from(lhs) * i32::from(rhs);

barrett_reduce(product)
let reduced = (product % i32::from(FIELD_MODULUS)) as i16;

if reduced > FIELD_MODULUS / 2 {
reduced - FIELD_MODULUS
} else if reduced < -FIELD_MODULUS / 2 {
reduced + FIELD_MODULUS
} else {
reduced
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand All @@ -58,14 +40,6 @@ impl KyberPolynomialRingElement {
pub const ZERO: Self = Self {
coefficients: [0i16; COEFFICIENTS_IN_RING_ELEMENT],
};

pub fn new(coefficients: [KyberFieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
Self { coefficients }
}

pub fn coefficients(&self) -> &[KyberFieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
&self.coefficients
}
}

impl Index<usize> for KyberPolynomialRingElement {
Expand Down Expand Up @@ -97,7 +71,7 @@ impl ops::Add for KyberPolynomialRingElement {
fn add(self, other: Self) -> Self {
let mut result = KyberPolynomialRingElement::ZERO;
for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
result.coefficients[i] = fe_add(self.coefficients[i], other.coefficients[i]);
result.coefficients[i] = self.coefficients[i] + other.coefficients[i];
}
result
}
Expand All @@ -109,7 +83,7 @@ impl ops::Sub for KyberPolynomialRingElement {
fn sub(self, other: Self) -> Self {
let mut result = KyberPolynomialRingElement::ZERO;
for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
result.coefficients[i] = fe_sub(self.coefficients[i], other.coefficients[i]);
result.coefficients[i] = self.coefficients[i] - other.coefficients[i];
}
result
}
Expand Down
28 changes: 16 additions & 12 deletions src/kem/kyber768/compress.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
use crate::kem::kyber768::{
arithmetic::{KyberFieldElement, KyberPolynomialRingElement},
parameters,
parameters::{self, FIELD_MODULUS},
};

pub fn compress(
re: KyberPolynomialRingElement,
mut re: KyberPolynomialRingElement,
bits_per_compressed_coefficient: usize,
) -> KyberPolynomialRingElement {
KyberPolynomialRingElement::new(
re.coefficients()
.map(|coefficient| compress_q(coefficient, bits_per_compressed_coefficient)),
)
re.coefficients = re
.coefficients
.map(|coefficient| compress_q(coefficient, bits_per_compressed_coefficient));
re
}

pub fn decompress(
re: KyberPolynomialRingElement,
mut re: KyberPolynomialRingElement,
bits_per_compressed_coefficient: usize,
) -> KyberPolynomialRingElement {
KyberPolynomialRingElement::new(
re.coefficients()
.map(|coefficient| decompress_q(coefficient, bits_per_compressed_coefficient)),
)
re.coefficients = re
.coefficients
.map(|coefficient| decompress_q(coefficient, bits_per_compressed_coefficient));
re
}

fn compress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement {
debug_assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);

let two_pow_bit_size = 1u32 << to_bit_size;

let mut compressed = (fe as u32) * (two_pow_bit_size << 1);
// Convert from canonical signed representative to canonical unsigned
// representative.
let fe_unsigned = fe + ((fe >> 15) & FIELD_MODULUS);

let mut compressed = (fe_unsigned as u32) * (two_pow_bit_size << 1);
compressed += parameters::FIELD_MODULUS as u32;
compressed /= (parameters::FIELD_MODULUS << 1) as u32;

Expand Down
80 changes: 43 additions & 37 deletions src/kem/kyber768/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,38 @@
use crate::kem::kyber768::arithmetic::KyberPolynomialRingElement;
use crate::kem::kyber768::parameters::RANK;
use crate::kem::kyber768::{
arithmetic::{barrett_reduce, KyberPolynomialRingElement},
parameters::RANK,
};

use self::kyber_polynomial_ring_element_mod::ntt_multiply;

pub(crate) mod kyber_polynomial_ring_element_mod {
use crate::kem::kyber768::{
arithmetic::{fe_add, fe_mul, fe_sub, KyberFieldElement, KyberPolynomialRingElement},
parameters::{self, COEFFICIENTS_IN_RING_ELEMENT},
arithmetic::{barrett_reduce, fe_mul, KyberFieldElement, KyberPolynomialRingElement},
parameters::COEFFICIENTS_IN_RING_ELEMENT,
};

const ZETAS: [i16; 128] = [
1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821,
289, 331, 3253, 1756, 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910,
17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594,
2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885,
2154,
1, -1600, -749, -40, -687, 630, -1432, 848, 1062, -1410, 193, 797, -543, -69, 569, -1583,
296, -882, 1339, 1476, -283, 56, -1089, 1333, 1426, -1235, 535, -447, -936, -450, -1355,
821, 289, 331, -76, -1573, 1197, -1025, -1052, -1274, 650, -1352, -816, 632, -464, 33,
1320, -1414, -1010, 1435, 807, 452, 1438, -461, 1534, -927, -682, -712, 1481, 648, -855,
-219, 1227, 910, 17, -568, 583, -680, 1637, 723, -1041, 1100, 1409, -667, -48, 233, 756,
-1173, -314, -279, -1626, 1651, -540, -1540, -1482, 952, 1461, -642, 939, -1021, -892,
-941, 733, -992, 268, 641, 1584, -1031, -1292, -109, 375, -780, -1239, 1645, 1063, 319,
-556, 757, -1230, 561, -863, -735, -525, 1092, 403, 1026, 1143, -1179, -554, 886, -1607,
1212, -1455, 1029, -1219, -394, 885, -1175,
];

const MOD_ROOTS: [i16; 128] = [
17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288, 1041, 1100, 2229,
1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573, 2156, 1173, 3015, 314, 3050, 279,
1703, 1626, 1651, 1678, 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687,
642, 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641,
2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239,
1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863,
2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 2775, 554,
886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 2110, 1219, 2935, 394, 885,
2444, 2154, 1175,
17, -17, -568, 568, 583, -583, -680, 680, 1637, -1637, 723, -723, -1041, 1041, 1100, -1100,
1409, -1409, -667, 667, -48, 48, 233, -233, 756, -756, -1173, 1173, -314, 314, -279, 279,
-1626, 1626, 1651, -1651, -540, 540, -1540, 1540, -1482, 1482, 952, -952, 1461, -1461,
-642, 642, 939, -939, -1021, 1021, -892, 892, -941, 941, 733, -733, -992, 992, 268, -268,
641, -641, 1584, -1584, -1031, 1031, -1292, 1292, -109, 109, 375, -375, -780, 780, -1239,
1239, 1645, -1645, 1063, -1063, 319, -319, -556, 556, 757, -757, -1230, 1230, 561, -561,
-863, 863, -735, 735, -525, 525, 1092, -1092, 403, -403, 1026, -1026, 1143, -1143, -1179,
1179, -554, 554, 886, -886, -1607, 1607, 1212, -1212, -1455, 1455, 1029, -1029, -1219,
1219, -394, 394, 885, -885, -1175, 1175,
];

const NTT_LAYERS: [usize; 7] = [2, 4, 8, 16, 32, 64, 128];
Expand All @@ -43,21 +45,18 @@ pub(crate) mod kyber_polynomial_ring_element_mod {

for j in offset..offset + layer {
let t = fe_mul(re[j + layer], ZETAS[zeta_i]);
re[j + layer] = fe_sub(re[j], t);
re[j] = fe_add(re[j], t);
re[j + layer] = re[j] - t;
re[j] += t;
}
}
}
re.coefficients = re.coefficients.map(barrett_reduce);

re
}

pub fn invert_ntt(re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let inverse_of_2: i16 = (parameters::FIELD_MODULUS + 1) >> 1;

let mut out = KyberPolynomialRingElement::ZERO;
for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
out[i] = re[i];
}
pub fn invert_ntt(mut re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let inverse_of_2: i16 = -1664;

let mut zeta_i = COEFFICIENTS_IN_RING_ELEMENT / 2;

Expand All @@ -66,14 +65,15 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
zeta_i -= 1;

for j in offset..offset + layer {
let a_minus_b = fe_sub(out[j + layer], out[j]);
out[j] = fe_mul(fe_add(out[j], out[j + layer]), inverse_of_2);
out[j + layer] = fe_mul(fe_mul(a_minus_b, ZETAS[zeta_i]), inverse_of_2);
let a_minus_b = re[j + layer] - re[j];
re[j] = fe_mul(re[j] + re[j + layer], inverse_of_2);
re[j + layer] = fe_mul(fe_mul(a_minus_b, ZETAS[zeta_i]), inverse_of_2);
}
}
}
re.coefficients = re.coefficients.map(barrett_reduce);

out
re
}

fn ntt_multiply_binomials(
Expand All @@ -82,8 +82,8 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
zeta: i16,
) -> (KyberFieldElement, KyberFieldElement) {
(
fe_add(fe_mul(a0, b0), fe_mul(fe_mul(a1, b1), zeta)),
fe_add(fe_mul(a0, b1), fe_mul(a1, b0)),
fe_mul(a0, b0) + fe_mul(fe_mul(a1, b1), zeta),
fe_mul(a0, b1) + fe_mul(a1, b0),
)
}

Expand All @@ -110,6 +110,8 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
out[i + 2] = product.0;
out[i + 3] = product.1;
}
out.coefficients = out.coefficients.map(barrett_reduce);

out
}
}
Expand All @@ -125,7 +127,9 @@ pub(crate) fn multiply_matrix_by_column(
let product = ntt_multiply(matrix_element, &vector[j]);
result[i] = result[i] + product;
}
result[i].coefficients = result[i].coefficients.map(barrett_reduce);
}

result
}

Expand All @@ -139,5 +143,7 @@ pub(crate) fn multiply_row_by_column(
result = result + ntt_multiply(row_element, column_element);
}

result.coefficients = result.coefficients.map(barrett_reduce);

result
}
20 changes: 9 additions & 11 deletions src/kem/kyber768/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::kem::kyber768::{
arithmetic::{fe_sub, KyberPolynomialRingElement},
arithmetic::KyberPolynomialRingElement,
parameters::{COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS, REJECTION_SAMPLING_SEED_SIZE},
BadRejectionSamplingRandomnessError,
};
Expand All @@ -11,21 +11,19 @@ pub fn sample_from_uniform_distribution(
let mut out: KyberPolynomialRingElement = KyberPolynomialRingElement::ZERO;

for bytes in randomness.chunks(3) {
let b = i16::from(bytes[0]);
let b1 = i16::from(bytes[1]);
let b2 = i16::from(bytes[2]);
let b1 = i16::from(bytes[0]);
let b2 = i16::from(bytes[1]);
let b3 = i16::from(bytes[2]);

let d1 = b + (256 * (b1 % 16));

// Integer division is flooring in Rust.
let d2 = (b1 / 16) + (16 * b2);
let d1 = ((b2 & 0xF) << 8) | b1;
let d2 = (b3 << 4) | (b2 >> 4);

if d1 < FIELD_MODULUS && sampled_coefficients < COEFFICIENTS_IN_RING_ELEMENT {
out[sampled_coefficients] = d1 as i16;
out[sampled_coefficients] = d1;
sampled_coefficients += 1
}
if d2 < FIELD_MODULUS && sampled_coefficients < COEFFICIENTS_IN_RING_ELEMENT {
out[sampled_coefficients] = d2 as i16;
out[sampled_coefficients] = d2;
sampled_coefficients += 1;
}

Expand Down Expand Up @@ -89,7 +87,7 @@ pub fn sample_from_binomial_distribution_2(randomness: [u8; 128]) -> KyberPolyno
let outcome_2 = ((coin_toss_outcomes >> (outcome_set + 2)) & 0x3) as i16;

let offset = (outcome_set >> 2) as usize;
sampled[8 * chunk_number + offset] = fe_sub(outcome_1, outcome_2);
sampled[8 * chunk_number + offset] = outcome_1 - outcome_2;
}
}

Expand Down
9 changes: 6 additions & 3 deletions src/kem/kyber768/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::kem::kyber768::{
arithmetic::KyberPolynomialRingElement,
parameters::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT},
parameters::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS},
};

/// This file contains instantiations of the functions
Expand Down Expand Up @@ -135,8 +135,11 @@ pub fn serialize_little_endian_12(re: KyberPolynomialRingElement) -> [u8; BYTES_
let mut serialized = [0u8; BYTES_PER_RING_ELEMENT];

for (i, chunks) in re.coefficients.chunks_exact(2).enumerate() {
let coefficient1 = chunks[0];
let coefficient2 = chunks[1];
let mut coefficient1 = chunks[0];
coefficient1 += (coefficient1 >> 15) & FIELD_MODULUS;

let mut coefficient2 = chunks[1];
coefficient2 += (coefficient2 >> 15) & FIELD_MODULUS;

serialized[3 * i] = (coefficient1 & 0xFF) as u8;
serialized[3 * i + 1] = ((coefficient1 >> 8) | ((coefficient2 & 0xF) << 4)) as u8;
Expand Down
Loading

0 comments on commit 98299d5

Please sign in to comment.