Skip to content

Commit

Permalink
Merge pull request #286 from cryspen/goutam/document-kyber-avx2
Browse files Browse the repository at this point in the history
Implemented deserialize_5 and documented the code some more.
  • Loading branch information
franziskuskiefer authored May 20, 2024
2 parents b013f19 + 6e2b44d commit 7de9e87
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 138 deletions.
55 changes: 53 additions & 2 deletions polynomials-avx2/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,20 @@ pub(crate) fn shift_left<const SHIFT_BY: i32>(vector: __m256i) -> __m256i {
pub(crate) fn cond_subtract_3329(vector: __m256i) -> __m256i {
let field_modulus = mm256_set1_epi16(FIELD_MODULUS);

// Compute v_i - Q and crate a mask from the sign bit of each of these
// quantities.
let v_minus_field_modulus = mm256_sub_epi16(vector, field_modulus);

let sign_mask = mm256_srai_epi16::<15>(v_minus_field_modulus);
let conditional_add_field_modulus = mm256_and_si256(sign_mask, field_modulus);

// If v_i - Q < 0 then add back Q to (v_i - Q).
let conditional_add_field_modulus = mm256_and_si256(sign_mask, field_modulus);
mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus)
}

const BARRETT_MULTIPLIER: i16 = 20159;

/// See Section 3.2 of the implementation notes document for an explanation
/// of this code.
#[inline(always)]
pub(crate) fn barrett_reduce(vector: __m256i) -> __m256i {
let t = mm256_mulhi_epi16(vector, mm256_set1_epi16(BARRETT_MULTIPLIER));
Expand Down Expand Up @@ -72,3 +76,50 @@ pub(crate) fn montgomery_multiply_by_constant(vector: __m256i, constant: i16) ->

mm256_sub_epi16(value_high, k_times_modulus)
}

#[inline(always)]
pub(crate) fn montgomery_multiply_by_constants(v: __m256i, c: __m256i) -> __m256i {
let value_low = mm256_mullo_epi16(v, c);

let k = mm256_mullo_epi16(
value_low,
mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16),
);
let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS));

let value_high = mm256_mulhi_epi16(v, c);

mm256_sub_epi16(value_high, k_times_modulus)
}

#[inline(always)]
pub(crate) fn montgomery_reduce_i32s(v: __m256i) -> __m256i {
let k = mm256_mullo_epi16(
v,
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32),
);
let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi32(FIELD_MODULUS as i32));

let value_high = mm256_srli_epi32::<16>(v);

let result = mm256_sub_epi16(value_high, k_times_modulus);

let result = mm256_slli_epi32::<16>(result);

mm256_srai_epi32::<16>(result)
}

#[inline(always)]
pub(crate) fn montgomery_multiply_m128i_by_constants(v: __m128i, c: __m128i) -> __m128i {
let value_low = mm_mullo_epi16(v, c);

let k = mm_mullo_epi16(
value_low,
mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16),
);
let k_times_modulus = mm_mulhi_epi16(k, mm_set1_epi16(FIELD_MODULUS));

let value_high = mm_mulhi_epi16(v, c);

mm_sub_epi16(value_high, k_times_modulus)
}
51 changes: 45 additions & 6 deletions polynomials-avx2/src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::intrinsics::*;
use libcrux_traits::FIELD_MODULUS;

// Multiply the 32-bit numbers contained in |lhs| and |rhs|, and store only
// the upper 32 bits of the resulting product.
// This implementation was taken from:
// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html
//
Expand Down Expand Up @@ -42,18 +44,39 @@ pub(crate) fn compress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
let compression_factor = mm256_set1_epi32(10_321_340);
let coefficient_bits_mask = mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1);

// Compress the first 8 coefficients
// ---- Compress the first 8 coefficients ----

// Take the bottom 128 bits, i.e. the first 8 16-bit coefficients
let coefficients_low = mm256_castsi256_si128(vector);

// If:
//
// coefficients_low[0:15] = A
// coefficients_low[16:31] = B
// coefficients_low[32:63] = C
// and so on ...
//
// after this step:
//
// coefficients_low[0:31] = A
// coefficients_low[32:63] = B
// and so on ...
let coefficients_low = mm256_cvtepi16_epi32(coefficients_low);

let compressed_low = mm256_slli_epi32::<{ COEFFICIENT_BITS }>(coefficients_low);
let compressed_low = mm256_add_epi32(compressed_low, field_modulus_halved);

let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor);

// Due to the mulhi_mm256_epi32 we've already shifted right by 32 bits, we
// just need to shift right by 35 - 32 = 3 more.
let compressed_low = mm256_srli_epi32::<3>(compressed_low);

let compressed_low = mm256_and_si256(compressed_low, coefficient_bits_mask);

// Compress the next 8 coefficients
// ---- Compress the next 8 coefficients ----

// Take the upper 128 bits, i.e. the next 8 16-bit coefficients
let coefficients_high = mm256_extracti128_si256::<1>(vector);
let coefficients_high = mm256_cvtepi16_epi32(coefficients_high);

Expand All @@ -64,9 +87,17 @@ pub(crate) fn compress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
let compressed_high = mm256_srli_epi32::<3>(compressed_high);
let compressed_high = mm256_and_si256(compressed_high, coefficient_bits_mask);

// Combine them
// Combining them, and grouping each set of 64-bits, this function results
// in:
//
// 0: low low low low | 1: high high high high | 2: low low low low | 3: high high high high
//
// where each |low| and |high| is a 16-bit element
let compressed = mm256_packs_epi32(compressed_low, compressed_high);

// To be in the right order, we need to move the |low|s above in position 2 to
// position 1 and the |high|s in position 1 to position 2, and leave the
// rest unchanged.
mm256_permute4x64_epi64::<0b11_01_10_00>(compressed)
}

Expand All @@ -77,7 +108,7 @@ pub(crate) fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
let field_modulus = mm256_set1_epi32(FIELD_MODULUS as i32);
let two_pow_coefficient_bits = mm256_set1_epi32(1 << COEFFICIENT_BITS);

// Compress the first 8 coefficients
// ---- Compress the first 8 coefficients ----
let coefficients_low = mm256_castsi256_si128(vector);
let coefficients_low = mm256_cvtepi16_epi32(coefficients_low);

Expand All @@ -90,7 +121,7 @@ pub(crate) fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
let decompressed_low = mm256_srli_epi32::<{ COEFFICIENT_BITS }>(decompressed_low);
let decompressed_low = mm256_srli_epi32::<1>(decompressed_low);

// Compress the next 8 coefficients
// ---- Compress the next 8 coefficients ----
let coefficients_high = mm256_extracti128_si256::<1>(vector);
let coefficients_high = mm256_cvtepi16_epi32(coefficients_high);

Expand All @@ -103,8 +134,16 @@ pub(crate) fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
let decompressed_high = mm256_srli_epi32::<{ COEFFICIENT_BITS }>(decompressed_high);
let decompressed_high = mm256_srli_epi32::<1>(decompressed_high);

// Combine them
// Combining them, and grouping each set of 64-bits, this function results
// in:
//
// 0: low low low low | 1: high high high high | 2: low low low low | 3: high high high high
//
// where each |low| and |high| is a 16-bit element
let compressed = mm256_packs_epi32(decompressed_low, decompressed_high);

// To be in the right order, we need to move the |low|s above in position 2 to
// position 1 and the |high|s in position 1 to position 2, and leave the
// rest unchanged.
mm256_permute4x64_epi64::<0b11_01_10_00>(compressed)
}
50 changes: 32 additions & 18 deletions polynomials-avx2/src/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,27 +38,41 @@ pub(crate) fn mm256_setzero_si256() -> __m256i {
}

pub(crate) fn mm_set_epi8(
byte15: i8,
byte14: i8,
byte13: i8,
byte12: i8,
byte11: i8,
byte10: i8,
byte9: i8,
byte8: i8,
byte7: i8,
byte6: i8,
byte5: i8,
byte4: i8,
byte3: i8,
byte2: i8,
byte1: i8,
byte0: i8,
byte15: u8,
byte14: u8,
byte13: u8,
byte12: u8,
byte11: u8,
byte10: u8,
byte9: u8,
byte8: u8,
byte7: u8,
byte6: u8,
byte5: u8,
byte4: u8,
byte3: u8,
byte2: u8,
byte1: u8,
byte0: u8,
) -> __m128i {
unsafe {
_mm_set_epi8(
byte15, byte14, byte13, byte12, byte11, byte10, byte9, byte8, byte7, byte6, byte5,
byte4, byte3, byte2, byte1, byte0,
byte15 as i8,
byte14 as i8,
byte13 as i8,
byte12 as i8,
byte11 as i8,
byte10 as i8,
byte9 as i8,
byte8 as i8,
byte7 as i8,
byte6 as i8,
byte5 as i8,
byte4 as i8,
byte3 as i8,
byte2 as i8,
byte1 as i8,
byte0 as i8,
)
}
}
Expand Down
66 changes: 9 additions & 57 deletions polynomials-avx2/src/ntt.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,6 @@
use crate::intrinsics::*;

use crate::arithmetic;
use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R};

#[inline(always)]
fn montgomery_multiply_by_constants(v: __m256i, c: __m256i) -> __m256i {
let value_low = mm256_mullo_epi16(v, c);

let k = mm256_mullo_epi16(
value_low,
mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16),
);
let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS));

let value_high = mm256_mulhi_epi16(v, c);

mm256_sub_epi16(value_high, k_times_modulus)
}

#[inline(always)]
fn montgomery_reduce_i32s(v: __m256i) -> __m256i {
let k = mm256_mullo_epi16(
v,
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32),
);
let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi32(FIELD_MODULUS as i32));

let value_high = mm256_srli_epi32::<16>(v);

let result = mm256_sub_epi16(value_high, k_times_modulus);

let result = mm256_slli_epi32::<16>(result);

mm256_srai_epi32::<16>(result)
}

#[inline(always)]
fn montgomery_multiply_m128i_by_constants(v: __m128i, c: __m128i) -> __m128i {
let value_low = mm_mullo_epi16(v, c);

let k = mm_mullo_epi16(
value_low,
mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16),
);
let k_times_modulus = mm_mulhi_epi16(k, mm_set1_epi16(FIELD_MODULUS));

let value_high = mm_mulhi_epi16(v, c);

mm_sub_epi16(value_high, k_times_modulus)
}

#[inline(always)]
pub(crate) fn ntt_layer_1_step(
Expand All @@ -64,7 +16,7 @@ pub(crate) fn ntt_layer_1_step(
);

let rhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector);
let rhs = montgomery_multiply_by_constants(rhs, zetas);
let rhs = arithmetic::montgomery_multiply_by_constants(rhs, zetas);

let lhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector);

Expand All @@ -79,7 +31,7 @@ pub(crate) fn ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m25
);

let rhs = mm256_shuffle_epi32::<0b11_10_11_10>(vector);
let rhs = montgomery_multiply_by_constants(rhs, zetas);
let rhs = arithmetic::montgomery_multiply_by_constants(rhs, zetas);

let lhs = mm256_shuffle_epi32::<0b01_00_01_00>(vector);

Expand All @@ -89,7 +41,7 @@ pub(crate) fn ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m25
#[inline(always)]
pub(crate) fn ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i {
let rhs = mm256_extracti128_si256::<1>(vector);
let rhs = montgomery_multiply_m128i_by_constants(rhs, mm_set1_epi16(zeta));
let rhs = arithmetic::montgomery_multiply_m128i_by_constants(rhs, mm_set1_epi16(zeta));

let lhs = mm256_castsi256_si128(vector);

Expand Down Expand Up @@ -119,7 +71,7 @@ pub(crate) fn inv_ntt_layer_1_step(
);

let sum = mm256_add_epi16(lhs, rhs);
let sum_times_zetas = montgomery_multiply_by_constants(
let sum_times_zetas = arithmetic::montgomery_multiply_by_constants(
sum,
mm256_set_epi16(
zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0,
Expand All @@ -142,7 +94,7 @@ pub(crate) fn inv_ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> _
);

let sum = mm256_add_epi16(lhs, rhs);
let sum_times_zetas = montgomery_multiply_by_constants(
let sum_times_zetas = arithmetic::montgomery_multiply_by_constants(
sum,
mm256_set_epi16(
zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0,
Expand All @@ -161,7 +113,7 @@ pub(crate) fn inv_ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i {

let upper_coefficients = mm_sub_epi16(lhs, rhs);
let upper_coefficients =
montgomery_multiply_m128i_by_constants(upper_coefficients, mm_set1_epi16(zeta));
arithmetic::montgomery_multiply_m128i_by_constants(upper_coefficients, mm_set1_epi16(zeta));

let combined = mm256_castsi128_si256(lower_coefficients);
let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients);
Expand Down Expand Up @@ -209,7 +161,7 @@ pub(crate) fn ntt_multiply(
let left = mm256_mullo_epi32(lhs_evens, rhs_evens);

let right = mm256_mullo_epi32(lhs_odds, rhs_odds);
let right = montgomery_reduce_i32s(right);
let right = arithmetic::montgomery_reduce_i32s(right);
let right = mm256_mullo_epi32(
right,
mm256_set_epi32(
Expand All @@ -225,7 +177,7 @@ pub(crate) fn ntt_multiply(
);

let products_left = mm256_add_epi32(left, right);
let products_left = montgomery_reduce_i32s(products_left);
let products_left = arithmetic::montgomery_reduce_i32s(products_left);

// Compute the second term of the product
let rhs_adjacent_swapped = mm256_shuffle_epi8(
Expand All @@ -236,7 +188,7 @@ pub(crate) fn ntt_multiply(
),
);
let products_right = mm256_madd_epi16(lhs, rhs_adjacent_swapped);
let products_right = montgomery_reduce_i32s(products_right);
let products_right = arithmetic::montgomery_reduce_i32s(products_right);
let products_right = mm256_slli_epi32::<16>(products_right);

// Combine them into one vector
Expand Down
Loading

0 comments on commit 7de9e87

Please sign in to comment.