From b8aedfd3848fb764da7a5b0d4610199ed5a79656 Mon Sep 17 00:00:00 2001 From: Goutam Tamvada Date: Mon, 13 May 2024 10:51:31 -0400 Subject: [PATCH] AVX2 implementations of Kyber ciphertext compression and decompression. (#273) --- polynomials-avx2/src/lib.rs | 101 ++++++++++++++++++++++++++++--- polynomials-avx2/src/portable.rs | 49 --------------- 2 files changed, 93 insertions(+), 57 deletions(-) diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 39432918f..1c4cde0da 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -218,22 +218,107 @@ fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector { v } +// This implementation was taken from: +// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html +// +// TODO: Optimize this implementation if performance numbers suggest doing so. #[inline(always)] -fn compress(v: SIMD256Vector) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::compress::<{ COEFFICIENT_BITS }>(input); +fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + let result = unsafe { + let prod02 = _mm256_mul_epu32(lhs, rhs); + let prod13 = _mm256_mul_epu32( + _mm256_shuffle_epi32(lhs, 0b11_11_01_01), + _mm256_shuffle_epi32(rhs, 0b11_11_01_01), + ); - from_i16_array(portable::to_i16_array(output)) + _mm256_unpackhi_epi64( + _mm256_unpacklo_epi32(prod02, prod13), + _mm256_unpackhi_epi32(prod02, prod13), + ) + }; + + result +} + +#[inline(always)] +fn compress(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { + let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); + let compression_factor = _mm256_set1_epi32(10_321_340); + let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); + + // Compress the first 8 coefficients + let coefficients_low = _mm256_castsi256_si128(v.elements); + let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); + + let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS); + let compressed_low = _mm256_add_epi32(compressed_low, field_modulus_halved); + + let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); + let compressed_low = _mm256_srli_epi32(compressed_low, 35 - 32); + let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask); + + // Compress the next 8 coefficients + let coefficients_high = _mm256_extracti128_si256(v.elements, 1); + let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); + + let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS); + let compressed_high = _mm256_add_epi32(compressed_high, field_modulus_halved); + + let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); + let compressed_high = _mm256_srli_epi32(compressed_high, 35 - 32); + let compressed_high = _mm256_and_si256(compressed_high, coefficient_bits_mask); + + // Combine them + let compressed = _mm256_packs_epi32(compressed_low, compressed_high); + + _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) + }; + + v } #[inline(always)] fn decompress_ciphertext_coefficient( - v: SIMD256Vector, + mut v: SIMD256Vector, ) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::decompress_ciphertext_coefficient::<{ COEFFICIENT_BITS }>(input); + v.elements = unsafe { + let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32); + let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS); - from_i16_array(portable::to_i16_array(output)) + // Compress the first 8 coefficients + let coefficients_low = _mm256_castsi256_si128(v.elements); + let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); + + let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus); + let decompressed_low = _mm256_slli_epi32(decompressed_low, 1); + let decompressed_low = _mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); + + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_low = _mm256_srli_epi32(decompressed_low, COEFFICIENT_BITS); + let decompressed_low = _mm256_srli_epi32(decompressed_low, 1); + + // Compress the next 8 coefficients + let coefficients_high = _mm256_extracti128_si256(v.elements, 1); + let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); + + let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus); + let decompressed_high = _mm256_slli_epi32(decompressed_high, 1); + let decompressed_high = _mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); + + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_high = _mm256_srli_epi32(decompressed_high, COEFFICIENT_BITS); + let decompressed_high = _mm256_srli_epi32(decompressed_high, 1); + + // Combine them + let compressed = _mm256_packs_epi32(decompressed_low, decompressed_high); + + _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) + }; + + v } #[inline(always)] diff --git a/polynomials-avx2/src/portable.rs b/polynomials-avx2/src/portable.rs index ac6b93c94..b18c02d19 100644 --- a/polynomials-avx2/src/portable.rs +++ b/polynomials-avx2/src/portable.rs @@ -2,24 +2,6 @@ pub use libcrux_traits::{FIELD_ELEMENTS_IN_VECTOR, FIELD_MODULUS}; type FieldElement = i16; -pub(crate) fn get_n_least_significant_bits(n: u8, value: u32) -> u32 { - // hax_debug_assert!(n == 4 || n == 5 || n == 10 || n == 11 || n == MONTGOMERY_SHIFT); - - value & ((1 << n) - 1) -} - -pub(crate) fn compress_ciphertext_coefficient(coefficient_bits: u8, fe: u16) -> FieldElement { - // This has to be constant time due to: - // https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/ldX0ThYJuBo/m/ovODsdY7AwAJ - let mut compressed = (fe as u64) << coefficient_bits; - compressed += 1664 as u64; - - compressed *= 10_321_340; - compressed >>= 35; - - get_n_least_significant_bits(coefficient_bits, compressed as u32) as FieldElement -} - #[derive(Clone, Copy)] pub(crate) struct PortableVector { elements: [FieldElement; FIELD_ELEMENTS_IN_VECTOR], @@ -43,37 +25,6 @@ pub(crate) fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> Portable PortableVector { elements: array } } -#[inline(always)] -pub(crate) fn compress(mut v: PortableVector) -> PortableVector { - for i in 0..FIELD_ELEMENTS_IN_VECTOR { - v.elements[i] = - compress_ciphertext_coefficient(COEFFICIENT_BITS as u8, v.elements[i] as u16) as i16; - } - v -} - -#[inline(always)] -pub(crate) fn decompress_ciphertext_coefficient( - mut v: PortableVector, -) -> PortableVector { - debug_assert!(to_i16_array(v) - .into_iter() - .all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS)); - - for i in 0..FIELD_ELEMENTS_IN_VECTOR { - let mut decompressed = v.elements[i] as i32 * FIELD_MODULUS as i32; - decompressed = (decompressed << 1) + (1i32 << COEFFICIENT_BITS); - decompressed = decompressed >> (COEFFICIENT_BITS + 1); - v.elements[i] = decompressed as i16; - } - - debug_assert!(to_i16_array(v) - .into_iter() - .all(|coefficient| coefficient.abs() as u16 <= 1 << 12)); - - v -} - #[inline(always)] pub(crate) fn deserialize_5(bytes: &[u8]) -> PortableVector { let mut v = zero();