Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworking field operations. #38

Merged
merged 3 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/kem/kyber768.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod parameters;
mod sampling;
mod serialize;
mod utils;
mod field_element;

use utils::{ArrayConversion, UpdatingArray2};

Expand Down
33 changes: 16 additions & 17 deletions src/kem/kyber768/compress.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::kem::kyber768::parameters::{self, KyberFieldElement, KyberPolynomialRingElement};
use crate::kem::kyber768::{parameters::{self, KyberPolynomialRingElement}, field_element::KyberFieldElement};

pub fn compress(
re: KyberPolynomialRingElement,
Expand All @@ -21,28 +21,27 @@ pub fn decompress(
}

fn compress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement {
assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);
debug_assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);

let two_pow_bit_size = 2u32.pow(to_bit_size.try_into().unwrap_or_else(|_| {
panic!(
"Conversion should work since to_bit_size is never greater than {}.",
parameters::BITS_PER_COEFFICIENT
)
}));
let two_pow_bit_size = 1u32 << to_bit_size;

let compressed = ((u32::from(fe.value) * 2 * two_pow_bit_size)
+ u32::from(KyberFieldElement::MODULUS))
/ u32::from(2 * KyberFieldElement::MODULUS);
let mut compressed = u32::from(fe.value) * (two_pow_bit_size << 1);
compressed += u32::from(KyberFieldElement::MODULUS);
compressed /= u32::from(KyberFieldElement::MODULUS << 1);

(compressed % two_pow_bit_size).into()
KyberFieldElement {
value: (compressed & (two_pow_bit_size - 1)).try_into().unwrap()
}
}

fn decompress_q(fe: KyberFieldElement, to_bit_size: usize) -> KyberFieldElement {
assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);
debug_assert!(to_bit_size <= parameters::BITS_PER_COEFFICIENT);

let decompressed = (2 * u32::from(fe.value) * u32::from(KyberFieldElement::MODULUS)
+ (1 << to_bit_size))
>> (to_bit_size + 1);
let mut decompressed = u32::from(fe.value) * u32::from(KyberFieldElement::MODULUS);
decompressed = (decompressed << 1) + (1 << to_bit_size);
decompressed >>= to_bit_size + 1;

decompressed.into()
KyberFieldElement {
value: decompressed.try_into().unwrap()
}
}
111 changes: 111 additions & 0 deletions src/kem/kyber768/field_element.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use std::ops;

use crate::kem::kyber768::{parameters::FIELD_MODULUS, utils::field::FieldElement};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct KyberFieldElement {
pub value: u16,
}

impl KyberFieldElement {
pub const MODULUS: u16 = FIELD_MODULUS;

const BARRETT_SHIFT : u32 = 24; // 2 * ceil(log_2(FIELD_MODULUS))
const BARRETT_MULTIPLIER : u32 = (1u32 << Self::BARRETT_SHIFT) / (Self::MODULUS as u32);

pub fn barrett_reduce(value : u32) -> Self {
let product : u64 = u64::from(value) * u64::from(Self::BARRETT_MULTIPLIER);
let quotient : u32 = (product >> Self::BARRETT_SHIFT).try_into().unwrap();
xvzcf marked this conversation as resolved.
Show resolved Hide resolved

let remainder = value - (quotient * u32::from(Self::MODULUS));
let remainder : u16 = remainder.try_into().unwrap();

let remainder_minus_modulus = remainder.wrapping_sub(Self::MODULUS);

// TODO: Check if LLVM detects this and optimizes it away into a
// conditional.
let selector = 0u16.wrapping_sub((remainder_minus_modulus >> 15) & 1);

Self {
value: (selector & remainder) | (!selector & remainder_minus_modulus),
}
}
}

impl FieldElement for KyberFieldElement {
const ZERO: Self = Self { value: 0 };

fn new(number: u16) -> Self {
Self::barrett_reduce(u32::from(number))
}

fn nth_bit_little_endian(&self, n: usize) -> u8 {
((self.value >> n) & 1) as u8
}
}

impl From<u8> for KyberFieldElement {
fn from(number: u8) -> Self {
Self {
value: u16::from(number)
}
}
}

impl From<KyberFieldElement> for u16 {
fn from(fe: KyberFieldElement) -> Self {
fe.value
}
}

impl ops::Add for KyberFieldElement {
type Output = Self;

fn add(self, other: Self) -> Self {
let sum: u16 = self.value + other.value;
let difference: u16 = sum.wrapping_sub(Self::MODULUS);

let mask = 0u16.wrapping_sub((difference >> 15) & 1);

Self {
value: (mask & sum) | (!mask & difference),
}
}
}
impl ops::Sub for KyberFieldElement {
type Output = Self;

fn sub(self, other: Self) -> Self {
let lhs = self.value;
let rhs = Self::MODULUS - other.value;

let sum: u16 = lhs + rhs;
let difference: u16 = sum.wrapping_sub(Self::MODULUS);

let mask = 0u16.wrapping_sub((difference >> 15) & 1);

Self {
value: (mask & sum) | (!mask & difference),
}
}
}

impl ops::Mul for KyberFieldElement {
type Output = Self;

fn mul(self, other: Self) -> Self {
xvzcf marked this conversation as resolved.
Show resolved Hide resolved
let product: u32 = u32::from(self.value) * u32::from(other.value);

Self::barrett_reduce(product)
}
}

impl ops::Mul<u16> for KyberFieldElement {
type Output = Self;

fn mul(self, other: u16) -> Self {
let product: u32 = u32::from(self.value) * u32::from(other);

Self::barrett_reduce(product)
}
}
42 changes: 21 additions & 21 deletions src/kem/kyber768/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use crate::kem::kyber768::parameters::{KyberPolynomialRingElement, RANK};
use self::kyber_polynomial_ring_element_mod::ntt_multiply;

pub(crate) mod kyber_polynomial_ring_element_mod {
use crate::kem::kyber768::utils::field::FieldElement;

use crate::kem::kyber768::parameters::{
self, KyberFieldElement, KyberPolynomialRingElement, COEFFICIENTS_IN_RING_ELEMENT,
self, KyberPolynomialRingElement, COEFFICIENTS_IN_RING_ELEMENT,
};
use crate::kem::kyber768::field_element::KyberFieldElement;

const ZETAS: [u16; 128] = [
1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746,
Expand Down Expand Up @@ -40,10 +39,9 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
for layer in NTT_LAYERS.iter().rev() {
for offset in (0..(COEFFICIENTS_IN_RING_ELEMENT - layer)).step_by(2 * layer) {
zeta_i += 1;
let zeta: KyberFieldElement = ZETAS[zeta_i].into();

for j in offset..offset + layer {
let t = zeta * re[j + layer];
let t = re[j + layer] * ZETAS[zeta_i];
re[j + layer] = re[j] - t;
re[j] = re[j] + t;
}
Expand All @@ -53,8 +51,7 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
}

pub fn invert_ntt(re: KyberPolynomialRingElement) -> KyberPolynomialRingElement {
let inverse_of_2: KyberFieldElement =
KyberFieldElement::new((parameters::FIELD_MODULUS + 1) / 2);
let inverse_of_2: u16 = (parameters::FIELD_MODULUS + 1) >> 1;

let mut out = KyberPolynomialRingElement::ZERO;
for i in 0..re.len() {
Expand All @@ -66,36 +63,39 @@ pub(crate) mod kyber_polynomial_ring_element_mod {
for layer in NTT_LAYERS {
for offset in (0..(COEFFICIENTS_IN_RING_ELEMENT - layer)).step_by(2 * layer) {
zeta_i -= 1;
let zeta: KyberFieldElement = ZETAS[zeta_i].into();

for j in offset..offset + layer {
let a_minus_b = out[j + layer] - out[j];
out[j] = inverse_of_2 * (out[j] + out[j + layer]);
out[j + layer] = inverse_of_2 * zeta * a_minus_b;
out[j] = (out[j] + out[j + layer]) * inverse_of_2;
out[j + layer] = (a_minus_b * ZETAS[zeta_i]) * inverse_of_2;
}
}
}

out
}

fn ntt_multiply_binomials((a0, a1): (KyberFieldElement, KyberFieldElement),
(b0, b1): (KyberFieldElement, KyberFieldElement),
zeta: u16) -> (KyberFieldElement, KyberFieldElement) {
((a0 * b0) + ((a1 * b1) * zeta),
(a0 * b1) + (a1 * b0))
}

pub fn ntt_multiply(
left: &KyberPolynomialRingElement,
other: &KyberPolynomialRingElement,
right: &KyberPolynomialRingElement,
) -> KyberPolynomialRingElement {
let mut out = KyberPolynomialRingElement::ZERO;

for i in (0..COEFFICIENTS_IN_RING_ELEMENT).step_by(2) {
let mod_root: KyberFieldElement = MOD_ROOTS[i / 2].into();

let a0_times_b0 = left[i] * other[i];
let a1_times_b1 = left[i + 1] * other[i + 1];

let a0_times_b1 = left[i + 1] * other[i];
let a1_times_b0 = left[i] * other[i + 1];
for i in (0..out.coefficients.len()).step_by(4) {
let product = ntt_multiply_binomials((left[i], left[i+1]), (right[i], right[i + 1]), MOD_ROOTS[i / 2]);
out[i] = product.0;
out[i + 1] = product.1;

out[i] = a0_times_b0 + (a1_times_b1 * mod_root);
out[i + 1] = a0_times_b1 + a1_times_b0;
let product = ntt_multiply_binomials((left[i + 2], left[i + 3]), (right[i + 2], right[i + 3]), MOD_ROOTS[(i + 2) / 2]);
out[i + 2] = product.0;
out[i + 3] = product.1;
}
out
}
Expand Down
6 changes: 2 additions & 4 deletions src/kem/kyber768/parameters.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::kem::kyber768::utils::{field::PrimeFieldElement, ring::PolynomialRingElement};
use crate::kem::kyber768::field_element::KyberFieldElement;
use crate::kem::kyber768::utils::ring::PolynomialRingElement;

/// Field modulus: 3329
pub(crate) const FIELD_MODULUS: u16 = 3329;
Expand Down Expand Up @@ -79,9 +80,6 @@ pub(crate) mod hash_functions {
}
}

/// A Kyber field element.
pub(crate) type KyberFieldElement = PrimeFieldElement<FIELD_MODULUS>;

/// A Kyber ring element
pub(crate) type KyberPolynomialRingElement =
PolynomialRingElement<KyberFieldElement, COEFFICIENTS_IN_RING_ELEMENT>;
11 changes: 6 additions & 5 deletions src/kem/kyber768/sampling.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::kem::kyber768::{
parameters::{self, KyberFieldElement, KyberPolynomialRingElement},
parameters::{self, KyberPolynomialRingElement},
BadRejectionSamplingRandomnessError,
};
use crate::kem::kyber768::field_element::KyberFieldElement;

pub fn sample_from_uniform_distribution(
randomness: [u8; parameters::REJECTION_SAMPLING_SEED_SIZE],
Expand All @@ -20,11 +21,11 @@ pub fn sample_from_uniform_distribution(
let d2 = (b1 / 16) + (16 * b2);

if d1 < parameters::FIELD_MODULUS && sampled_coefficients < out.len() {
out[sampled_coefficients] = d1.into();
out[sampled_coefficients] = KyberFieldElement { value : d1 };
sampled_coefficients += 1
}
if d2 < parameters::FIELD_MODULUS && sampled_coefficients < out.len() {
out[sampled_coefficients] = d2.into();
out[sampled_coefficients] = KyberFieldElement { value : d2 };
sampled_coefficients += 1;
}

Expand Down Expand Up @@ -53,10 +54,10 @@ pub fn sample_from_binomial_distribution_with_2_coins(
let coin_toss_outcomes = even_bits + odd_bits;

for outcome_set in (0..u32::BITS).step_by(4) {
let outcome_1: u16 = ((coin_toss_outcomes >> outcome_set) & 0x3) as u16;
let outcome_1: u8 = ((coin_toss_outcomes >> outcome_set) & 0x3) as u8;
let outcome_1: KyberFieldElement = outcome_1.into();

let outcome_2: u16 = ((coin_toss_outcomes >> (outcome_set + 2)) & 0x3) as u16;
let outcome_2: u8 = ((coin_toss_outcomes >> (outcome_set + 2)) & 0x3) as u8;
let outcome_2: KyberFieldElement = outcome_2.into();

let offset = usize::try_from(outcome_set >> 2).unwrap();
Expand Down
8 changes: 6 additions & 2 deletions src/kem/kyber768/serialize.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::kem::kyber768::utils::{bit_vector::BitVector, ring::LittleEndianBitStream};

use crate::kem::kyber768::parameters::{
KyberFieldElement, KyberPolynomialRingElement, BITS_PER_COEFFICIENT, BYTES_PER_RING_ELEMENT,
KyberPolynomialRingElement, BITS_PER_COEFFICIENT, BYTES_PER_RING_ELEMENT,
};

use crate::kem::kyber768::field_element::KyberFieldElement;

pub fn serialize_little_endian(
re: KyberPolynomialRingElement,
bits_per_coefficient: usize,
Expand Down Expand Up @@ -48,7 +50,9 @@ fn field_element_from_little_endian_bit_vector(bit_vector: BitVector) -> KyberFi
value |= u16::from(bit) << i;
}

value.into()
KyberFieldElement {
value
}
}

pub fn deserialize_little_endian(
Expand Down
Loading