Skip to content

Commit

Permalink
AVX2 implementations of Kyber ciphertext compression and decompressio…
Browse files Browse the repository at this point in the history
…n. (#273)
  • Loading branch information
xvzcf authored May 13, 2024
1 parent 4884ae0 commit b8aedfd
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 57 deletions.
101 changes: 93 additions & 8 deletions polynomials-avx2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,22 +218,107 @@ fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector {
v
}

// This implementation was taken from:
// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html
//
// TODO: Optimize this implementation if performance numbers suggest doing so.
#[inline(always)]
fn compress<const COEFFICIENT_BITS: i32>(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::from_i16_array(to_i16_array(v));
let output = portable::compress::<{ COEFFICIENT_BITS }>(input);
fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i {
let result = unsafe {
let prod02 = _mm256_mul_epu32(lhs, rhs);
let prod13 = _mm256_mul_epu32(
_mm256_shuffle_epi32(lhs, 0b11_11_01_01),
_mm256_shuffle_epi32(rhs, 0b11_11_01_01),
);

from_i16_array(portable::to_i16_array(output))
_mm256_unpackhi_epi64(
_mm256_unpacklo_epi32(prod02, prod13),
_mm256_unpackhi_epi32(prod02, prod13),
)
};

result
}

#[inline(always)]
fn compress<const COEFFICIENT_BITS: i32>(mut v: SIMD256Vector) -> SIMD256Vector {
v.elements = unsafe {
let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2);
let compression_factor = _mm256_set1_epi32(10_321_340);
let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1);

// Compress the first 8 coefficients
let coefficients_low = _mm256_castsi256_si128(v.elements);
let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low);

let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS);
let compressed_low = _mm256_add_epi32(compressed_low, field_modulus_halved);

let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor);
let compressed_low = _mm256_srli_epi32(compressed_low, 35 - 32);
let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask);

// Compress the next 8 coefficients
let coefficients_high = _mm256_extracti128_si256(v.elements, 1);
let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high);

let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS);
let compressed_high = _mm256_add_epi32(compressed_high, field_modulus_halved);

let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor);
let compressed_high = _mm256_srli_epi32(compressed_high, 35 - 32);
let compressed_high = _mm256_and_si256(compressed_high, coefficient_bits_mask);

// Combine them
let compressed = _mm256_packs_epi32(compressed_low, compressed_high);

_mm256_permute4x64_epi64(compressed, 0b11_01_10_00)
};

v
}

#[inline(always)]
fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
v: SIMD256Vector,
mut v: SIMD256Vector,
) -> SIMD256Vector {
let input = portable::from_i16_array(to_i16_array(v));
let output = portable::decompress_ciphertext_coefficient::<{ COEFFICIENT_BITS }>(input);
v.elements = unsafe {
let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32);
let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS);

from_i16_array(portable::to_i16_array(output))
// Compress the first 8 coefficients
let coefficients_low = _mm256_castsi256_si128(v.elements);
let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low);

let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus);
let decompressed_low = _mm256_slli_epi32(decompressed_low, 1);
let decompressed_low = _mm256_add_epi32(decompressed_low, two_pow_coefficient_bits);

// We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack
// of support for const generic expressions.
let decompressed_low = _mm256_srli_epi32(decompressed_low, COEFFICIENT_BITS);
let decompressed_low = _mm256_srli_epi32(decompressed_low, 1);

// Compress the next 8 coefficients
let coefficients_high = _mm256_extracti128_si256(v.elements, 1);
let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high);

let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus);
let decompressed_high = _mm256_slli_epi32(decompressed_high, 1);
let decompressed_high = _mm256_add_epi32(decompressed_high, two_pow_coefficient_bits);

// We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack
// of support for const generic expressions.
let decompressed_high = _mm256_srli_epi32(decompressed_high, COEFFICIENT_BITS);
let decompressed_high = _mm256_srli_epi32(decompressed_high, 1);

// Combine them
let compressed = _mm256_packs_epi32(decompressed_low, decompressed_high);

_mm256_permute4x64_epi64(compressed, 0b11_01_10_00)
};

v
}

#[inline(always)]
Expand Down
49 changes: 0 additions & 49 deletions polynomials-avx2/src/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,6 @@ pub use libcrux_traits::{FIELD_ELEMENTS_IN_VECTOR, FIELD_MODULUS};

type FieldElement = i16;

pub(crate) fn get_n_least_significant_bits(n: u8, value: u32) -> u32 {
// hax_debug_assert!(n == 4 || n == 5 || n == 10 || n == 11 || n == MONTGOMERY_SHIFT);

value & ((1 << n) - 1)
}

pub(crate) fn compress_ciphertext_coefficient(coefficient_bits: u8, fe: u16) -> FieldElement {
// This has to be constant time due to:
// https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/ldX0ThYJuBo/m/ovODsdY7AwAJ
let mut compressed = (fe as u64) << coefficient_bits;
compressed += 1664 as u64;

compressed *= 10_321_340;
compressed >>= 35;

get_n_least_significant_bits(coefficient_bits, compressed as u32) as FieldElement
}

#[derive(Clone, Copy)]
pub(crate) struct PortableVector {
elements: [FieldElement; FIELD_ELEMENTS_IN_VECTOR],
Expand All @@ -43,37 +25,6 @@ pub(crate) fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> Portable
PortableVector { elements: array }
}

#[inline(always)]
pub(crate) fn compress<const COEFFICIENT_BITS: i32>(mut v: PortableVector) -> PortableVector {
for i in 0..FIELD_ELEMENTS_IN_VECTOR {
v.elements[i] =
compress_ciphertext_coefficient(COEFFICIENT_BITS as u8, v.elements[i] as u16) as i16;
}
v
}

#[inline(always)]
pub(crate) fn decompress_ciphertext_coefficient<const COEFFICIENT_BITS: i32>(
mut v: PortableVector,
) -> PortableVector {
debug_assert!(to_i16_array(v)
.into_iter()
.all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS));

for i in 0..FIELD_ELEMENTS_IN_VECTOR {
let mut decompressed = v.elements[i] as i32 * FIELD_MODULUS as i32;
decompressed = (decompressed << 1) + (1i32 << COEFFICIENT_BITS);
decompressed = decompressed >> (COEFFICIENT_BITS + 1);
v.elements[i] = decompressed as i16;
}

debug_assert!(to_i16_array(v)
.into_iter()
.all(|coefficient| coefficient.abs() as u16 <= 1 << 12));

v
}

#[inline(always)]
pub(crate) fn deserialize_5(bytes: &[u8]) -> PortableVector {
let mut v = zero();
Expand Down

0 comments on commit b8aedfd

Please sign in to comment.