From d6afc5f9c49fdd0e0d4ac3aba8a037b7f4ca0976 Mon Sep 17 00:00:00 2001 From: Goutam Tamvada Date: Mon, 13 May 2024 03:02:04 -0400 Subject: [PATCH] AVX2 implementations of Barrett reduction, the NTT, and deserialization for some parameters. (#270) * NTT and Barrett reduce * Inverse NTT functions * inv_ntt_layer_1_step * deserialize_1 * deserialize_4 * decompress -> decompress_ciphertext_coefficient for clarity * deserialize_12 * deserialize_10 * ntt_multiply --- libcrux-ml-kem/src/serialize.rs | 8 +- polynomials-aarch64/src/lib.rs | 4 +- polynomials-aarch64/src/simd128ops.rs | 4 +- polynomials-avx2/src/debug.rs | 14 +- polynomials-avx2/src/lib.rs | 479 ++++++++++++++++++++--- polynomials-avx2/src/portable.rs | 522 +------------------------- polynomials/src/lib.rs | 8 +- traits/src/lib.rs | 2 +- 8 files changed, 461 insertions(+), 580 deletions(-) diff --git a/libcrux-ml-kem/src/serialize.rs b/libcrux-ml-kem/src/serialize.rs index bc8ca9597..0d8db5d69 100644 --- a/libcrux-ml-kem/src/serialize.rs +++ b/libcrux-ml-kem/src/serialize.rs @@ -217,7 +217,7 @@ fn deserialize_then_decompress_10( cloop! { for (i, bytes) in serialized.chunks_exact(20).enumerate() { let coefficient = Vector::deserialize_10(bytes); - re.coefficients[i] = Vector::decompress::<10>(coefficient); + re.coefficients[i] = Vector::decompress_ciphertext_coefficient::<10>(coefficient); } } re @@ -234,7 +234,7 @@ fn deserialize_then_decompress_11( cloop! { for (i, bytes) in serialized.chunks_exact(22).enumerate() { let coefficient = Vector::deserialize_11(bytes); - re.coefficients[i] = Vector::decompress::<11>(coefficient); + re.coefficients[i] = Vector::decompress_ciphertext_coefficient::<11>(coefficient); } } @@ -266,7 +266,7 @@ fn deserialize_then_decompress_4( cloop! { for (i, bytes) in serialized.chunks_exact(8).enumerate() { let coefficient = Vector::deserialize_4(bytes); - re.coefficients[i] = Vector::decompress::<4>(coefficient); + re.coefficients[i] = Vector::decompress_ciphertext_coefficient::<4>(coefficient); } } re @@ -283,7 +283,7 @@ fn deserialize_then_decompress_5( cloop! { for (i, bytes) in serialized.chunks_exact(10).enumerate() { re.coefficients[i] = Vector::deserialize_5(bytes); - re.coefficients[i] = Vector::decompress::<5>(re.coefficients[i]); + re.coefficients[i] = Vector::decompress_ciphertext_coefficient::<5>(re.coefficients[i]); } } re diff --git a/polynomials-aarch64/src/lib.rs b/polynomials-aarch64/src/lib.rs index 49af38cf3..38a2dc578 100644 --- a/polynomials-aarch64/src/lib.rs +++ b/polynomials-aarch64/src/lib.rs @@ -70,8 +70,8 @@ impl Operations for SIMD128Vector { compress::(v) } - fn decompress(v: Self) -> Self { - decompress::(v) + fn decompress_ciphertext_coefficient(v: Self) -> Self { + decompress_ciphertext_coefficient::(v) } fn ntt_layer_1_step(a: Self, zeta1: i16, zeta2: i16, zeta3: i16, zeta4: i16) -> Self { diff --git a/polynomials-aarch64/src/simd128ops.rs b/polynomials-aarch64/src/simd128ops.rs index a6fb55a46..3f1b03a16 100644 --- a/polynomials-aarch64/src/simd128ops.rs +++ b/polynomials-aarch64/src/simd128ops.rs @@ -281,7 +281,9 @@ fn decompress_uint32x4_t(v: uint32x4_t) -> uint32x4 } #[inline(always)] -pub(crate) fn decompress(mut v: SIMD128Vector) -> SIMD128Vector { +pub(crate) fn decompress_ciphertext_coefficient( + mut v: SIMD128Vector, +) -> SIMD128Vector { let mask16 = _vdupq_n_u32(0xffff); let low0 = _vandq_u32(_vreinterpretq_u32_s16(v.low), mask16); let low1 = _vshrq_n_u32::<16>(_vreinterpretq_u32_s16(v.low)); diff --git a/polynomials-avx2/src/debug.rs b/polynomials-avx2/src/debug.rs index 1fce72113..c49a167b8 100644 --- a/polynomials-avx2/src/debug.rs +++ b/polynomials-avx2/src/debug.rs @@ -4,19 +4,25 @@ use core::arch::x86::*; use core::arch::x86_64::*; #[allow(dead_code)] -fn print_m256i_as_i16s(a: __m256i, prefix: &'static str) { +pub(crate) fn print_m256i_as_i16s(a: __m256i, prefix: &'static str) { let mut a_bytes = [0i16; 16]; unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; - println!("{}: {:04x?}", prefix, a_bytes); + println!("{}: {:?}", prefix, a_bytes); +} +#[allow(dead_code)] +pub(crate) fn print_m256i_as_i32s(a: __m256i, prefix: &'static str) { + let mut a_bytes = [0i32; 8]; + unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; + println!("{}: {:?}", prefix, a_bytes); } #[allow(dead_code)] -fn print_m128i_as_i16s(a: __m128i, prefix: &'static str) { +pub(crate) fn print_m128i_as_i16s(a: __m128i, prefix: &'static str) { let mut a_bytes = [0i16; 8]; unsafe { _mm_store_si128(a_bytes.as_mut_ptr() as *mut __m128i, a) }; println!("{}: {:?}", prefix, a_bytes); } #[allow(dead_code)] -fn print_m128i_as_i8s(a: __m128i, prefix: &'static str) { +pub(crate) fn print_m128i_as_i8s(a: __m128i, prefix: &'static str) { let mut a_bytes = [0i8; 16]; unsafe { _mm_store_si128(a_bytes.as_mut_ptr() as *mut __m128i, a) }; println!("{}: {:?}", prefix, a_bytes); diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index de41a8d41..39432918f 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -7,6 +7,8 @@ use libcrux_traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMER mod debug; mod portable; +const BARRETT_MULTIPLIER: i16 = 20159; + #[derive(Clone, Copy)] pub struct SIMD256Vector { elements: __m256i, @@ -103,11 +105,20 @@ fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { } #[inline(always)] -fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::barrett_reduce(input); +fn barrett_reduce(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { + let t = _mm256_mulhi_epi16(v.elements, _mm256_set1_epi16(BARRETT_MULTIPLIER)); + let t = _mm256_add_epi16(t, _mm256_set1_epi16(512)); - from_i16_array(portable::to_i16_array(output)) + let quotient = _mm256_srai_epi16(t, 10); + + let quotient_times_field_modulus = + _mm256_mullo_epi16(quotient, _mm256_set1_epi16(FIELD_MODULUS)); + + _mm256_sub_epi16(v.elements, quotient_times_field_modulus) + }; + + v } #[inline(always)] @@ -130,6 +141,64 @@ fn montgomery_multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vecto v } +#[inline(always)] +fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { + v = unsafe { + let value_low = _mm256_mullo_epi16(v, c); + + let k = _mm256_mullo_epi16( + value_low, + _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); + + let value_high = _mm256_mulhi_epi16(v, c); + + _mm256_sub_epi16(value_high, k_times_modulus) + }; + + v +} + +#[inline(always)] +fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { + v = unsafe { + let k = _mm256_mullo_epi16( + v, + _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), + ); + let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi32(FIELD_MODULUS as i32)); + + let value_high = _mm256_srli_epi32(v, 16); + + let result = _mm256_sub_epi16(value_high, k_times_modulus); + + let result = _mm256_slli_epi32(result, 16); + _mm256_srai_epi32(result, 16) + }; + + v +} + +#[inline(always)] +fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { + v = unsafe { + let value_low = _mm_mullo_epi16(v, c); + + let k = _mm_mullo_epi16( + value_low, + _mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = _mm_mulhi_epi16(k, _mm_set1_epi16(FIELD_MODULUS)); + + let value_high = _mm_mulhi_epi16(v, c); + + _mm_sub_epi16(value_high, k_times_modulus) + }; + + v +} + #[inline(always)] fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector { v.elements = unsafe { @@ -158,71 +227,156 @@ fn compress(v: SIMD256Vector) -> SIMD256Vector { } #[inline(always)] -fn decompress(v: SIMD256Vector) -> SIMD256Vector { +fn decompress_ciphertext_coefficient( + v: SIMD256Vector, +) -> SIMD256Vector { let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::decompress::<{ COEFFICIENT_BITS }>(input); + let output = portable::decompress_ciphertext_coefficient::<{ COEFFICIENT_BITS }>(input); from_i16_array(portable::to_i16_array(output)) } #[inline(always)] fn ntt_layer_1_step( - v: SIMD256Vector, + mut v: SIMD256Vector, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, ) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::ntt_layer_1_step(input, zeta0, zeta1, zeta2, zeta3); + v.elements = unsafe { + let zetas = _mm256_set_epi16( + -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, + zeta1, -zeta0, -zeta0, zeta0, zeta0, + ); - from_i16_array(portable::to_i16_array(output)) + let rhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); + let rhs = montgomery_multiply_by_constants(rhs, zetas); + + let lhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); + + _mm256_add_epi16(lhs, rhs) + }; + + v } #[inline(always)] -fn ntt_layer_2_step(v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::ntt_layer_2_step(input, zeta0, zeta1); +fn ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { + v.elements = unsafe { + let zetas = _mm256_set_epi16( + -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, + -zeta0, zeta0, zeta0, zeta0, zeta0, + ); - from_i16_array(portable::to_i16_array(output)) + let rhs = _mm256_shuffle_epi32(v.elements, 0b11_10_11_10); + let rhs = montgomery_multiply_by_constants(rhs, zetas); + + let lhs = _mm256_shuffle_epi32(v.elements, 0b01_00_01_00); + + _mm256_add_epi16(lhs, rhs) + }; + + v } #[inline(always)] -fn ntt_layer_3_step(v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::ntt_layer_3_step(input, zeta); +fn ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { + v.elements = unsafe { + let rhs = _mm256_extracti128_si256(v.elements, 1); + let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); - from_i16_array(portable::to_i16_array(output)) + let lhs = _mm256_castsi256_si128(v.elements); + + let lower_coefficients = _mm_add_epi16(lhs, rhs); + let upper_coefficients = _mm_sub_epi16(lhs, rhs); + + let combined = _mm256_castsi128_si256(lower_coefficients); + let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + + combined + }; + + v } #[inline(always)] fn inv_ntt_layer_1_step( - v: SIMD256Vector, + mut v: SIMD256Vector, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, ) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::inv_ntt_layer_1_step(input, zeta0, zeta1, zeta2, zeta3); + v.elements = unsafe { + let lhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); - from_i16_array(portable::to_i16_array(output)) + let rhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); + let rhs = _mm256_mullo_epi16( + rhs, + _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), + ); + + let sum = _mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + _mm256_set_epi16( + zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, + ), + ); + + let sum = barrett_reduce(SIMD256Vector { elements: sum }).elements; + + _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) + }; + + v } #[inline(always)] -fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::inv_ntt_layer_2_step(input, zeta0, zeta1); +fn inv_ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { + v.elements = unsafe { + let lhs = _mm256_permute4x64_epi64(v.elements, 0b11_11_01_01); - from_i16_array(portable::to_i16_array(output)) + let rhs = _mm256_permute4x64_epi64(v.elements, 0b10_10_00_00); + let rhs = _mm256_mullo_epi16( + rhs, + _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), + ); + + let sum = _mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + _mm256_set_epi16( + zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, + ), + ); + + _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) + }; + + v } #[inline(always)] -fn inv_ntt_layer_3_step(v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - let input = portable::from_i16_array(to_i16_array(v)); - let output = portable::inv_ntt_layer_3_step(input, zeta); +fn inv_ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { + v.elements = unsafe { + let lhs = _mm256_extracti128_si256(v.elements, 1); + let rhs = _mm256_castsi256_si128(v.elements); - from_i16_array(portable::to_i16_array(output)) + let lower_coefficients = _mm_add_epi16(lhs, rhs); + + let upper_coefficients = _mm_sub_epi16(lhs, rhs); + let upper_coefficients = + montgomery_multiply_m128i_by_constants(upper_coefficients, _mm_set1_epi16(zeta)); + + let combined = _mm256_castsi128_si256(lower_coefficients); + let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + + combined + }; + + v } #[inline(always)] @@ -234,12 +388,73 @@ fn ntt_multiply( zeta2: i16, zeta3: i16, ) -> SIMD256Vector { - let input0 = portable::from_i16_array(to_i16_array(*lhs)); - let input1 = portable::from_i16_array(to_i16_array(*rhs)); + let products = unsafe { + // Compute the first term of the product + let shuffle_with = _mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, + 12, 9, 8, 5, 4, 1, 0, + ); + const PERMUTE_WITH: i32 = 0b11_01_10_00; + + // Prepare the left hand side + let lhs_shuffled = _mm256_shuffle_epi8(lhs.elements, shuffle_with); + let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); + + let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); + let lhs_evens = _mm256_cvtepi16_epi32(lhs_evens); + + let lhs_odds = _mm256_extracti128_si256(lhs_shuffled, 1); + let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); + + // Prepare the right hand side + let rhs_shuffled = _mm256_shuffle_epi8(rhs.elements, shuffle_with); + let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); + + let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); + let rhs_evens = _mm256_cvtepi16_epi32(rhs_evens); + + let rhs_odds = _mm256_extracti128_si256(rhs_shuffled, 1); + let rhs_odds = _mm256_cvtepi16_epi32(rhs_odds); + + // Start operating with them + let left = _mm256_mullo_epi32(lhs_evens, rhs_evens); + + let right = _mm256_mullo_epi32(lhs_odds, rhs_odds); + let right = montgomery_reduce_i32s(right); + let right = _mm256_mullo_epi32( + right, + _mm256_set_epi32( + -(zeta3 as i32), + zeta3 as i32, + -(zeta2 as i32), + zeta2 as i32, + -(zeta1 as i32), + zeta1 as i32, + -(zeta0 as i32), + zeta0 as i32, + ), + ); - let output = portable::ntt_multiply(&input0, &input1, zeta0, zeta1, zeta2, zeta3); + let products_left = _mm256_add_epi32(left, right); + let products_left = montgomery_reduce_i32s(products_left); - from_i16_array(portable::to_i16_array(output)) + // Compute the second term of the product + let rhs_adjacent_swapped = _mm256_shuffle_epi8( + rhs.elements, + _mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, + 5, 4, 7, 6, 1, 0, 3, 2, + ), + ); + let products_right = _mm256_madd_epi16(lhs.elements, rhs_adjacent_swapped); + let products_right = montgomery_reduce_i32s(products_right); + let products_right = _mm256_slli_epi32(products_right, 16); + + // Combine them into one vector + _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) + }; + + SIMD256Vector { elements: products } } #[inline(always)] @@ -264,10 +479,55 @@ fn serialize_1(v: SIMD256Vector) -> [u8; 2] { } #[inline(always)] -fn deserialize_1(a: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_1(a); +fn deserialize_1(bytes: &[u8]) -> SIMD256Vector { + let deserialized = unsafe { + let shift_lsb_to_msb = _mm256_set_epi16( + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + ); - from_i16_array(portable::to_i16_array(output)) + let coefficients = _mm256_set_epi16( + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsb_to_msb); + let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 7); + + _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) + }; + + SIMD256Vector { + elements: deserialized, + } } #[inline(always)] @@ -318,10 +578,55 @@ fn serialize_4(v: SIMD256Vector) -> [u8; 8] { } #[inline(always)] -fn deserialize_4(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_4(v); +fn deserialize_4(bytes: &[u8]) -> SIMD256Vector { + let deserialized = unsafe { + let shift_lsbs_to_msbs = _mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); - from_i16_array(portable::to_i16_array(output)) + let coefficients = _mm256_set_epi16( + bytes[7] as i16, + bytes[7] as i16, + bytes[6] as i16, + bytes[6] as i16, + bytes[5] as i16, + bytes[5] as i16, + bytes[4] as i16, + bytes[4] as i16, + bytes[3] as i16, + bytes[3] as i16, + bytes[2] as i16, + bytes[2] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 4); + + _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) + }; + + SIMD256Vector { + elements: deserialized, + } } #[inline(always)] @@ -434,9 +739,50 @@ fn serialize_10(v: SIMD256Vector) -> [u8; 20] { #[inline(always)] fn deserialize_10(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_10(v); + let deserialized = unsafe { + let shift_lsbs_to_msbs = _mm256_set_epi16( + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + ); - from_i16_array(portable::to_i16_array(output)) + let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); + let lower_coefficients = _mm_shuffle_epi8( + lower_coefficients, + _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), + ); + let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(4) as *const __m128i); + let upper_coefficients = _mm_shuffle_epi8( + upper_coefficients, + _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), + ); + + let coefficients = _mm256_castsi128_si256(lower_coefficients); + let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); + + let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = _mm256_srli_epi16(coefficients, 6); + let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 10) - 1)); + + coefficients + }; + + SIMD256Vector { + elements: deserialized, + } } #[inline(always)] @@ -506,9 +852,50 @@ fn serialize_12(v: SIMD256Vector) -> [u8; 24] { #[inline(always)] fn deserialize_12(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_12(v); + let deserialized = unsafe { + let shift_lsbs_to_msbs = _mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); - from_i16_array(portable::to_i16_array(output)) + let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); + let lower_coefficients = _mm_shuffle_epi8( + lower_coefficients, + _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), + ); + let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(8) as *const __m128i); + let upper_coefficients = _mm_shuffle_epi8( + upper_coefficients, + _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), + ); + + let coefficients = _mm256_castsi128_si256(lower_coefficients); + let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); + + let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = _mm256_srli_epi16(coefficients, 4); + let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 12) - 1)); + + coefficients + }; + + SIMD256Vector { + elements: deserialized, + } } #[inline(always)] @@ -573,8 +960,8 @@ impl Operations for SIMD256Vector { compress::(v) } - fn decompress(v: Self) -> Self { - decompress::(v) + fn decompress_ciphertext_coefficient(v: Self) -> Self { + decompress_ciphertext_coefficient::(v) } fn ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { diff --git a/polynomials-avx2/src/portable.rs b/polynomials-avx2/src/portable.rs index 7738b3142..ac6b93c94 100644 --- a/polynomials-avx2/src/portable.rs +++ b/polynomials-avx2/src/portable.rs @@ -1,15 +1,6 @@ -use libcrux_traits::INVERSE_OF_MODULUS_MOD_MONTGOMERY_R; pub use libcrux_traits::{FIELD_ELEMENTS_IN_VECTOR, FIELD_MODULUS}; type FieldElement = i16; -type MontgomeryFieldElement = i16; -type FieldElementTimesMontgomeryR = i16; -const MONTGOMERY_SHIFT: u8 = 16; -const MONTGOMERY_R: i32 = 1 << MONTGOMERY_SHIFT; - -const BARRETT_SHIFT: i32 = 26; -const BARRETT_R: i32 = 1 << BARRETT_SHIFT; -const BARRETT_MULTIPLIER: i32 = 20159; pub(crate) fn get_n_least_significant_bits(n: u8, value: u32) -> u32 { // hax_debug_assert!(n == 4 || n == 5 || n == 10 || n == 11 || n == MONTGOMERY_SHIFT); @@ -17,29 +8,6 @@ pub(crate) fn get_n_least_significant_bits(n: u8, value: u32) -> u32 { value & ((1 << n) - 1) } -pub(crate) fn montgomery_reduce_element(value: i32) -> MontgomeryFieldElement { - // This forces hax to extract code for MONTGOMERY_R before it extracts code - // for this function. The removal of this line is being tracked in: - // https://github.com/cryspen/libcrux/issues/134 - let _ = MONTGOMERY_R; - - let k = (value as i16) as i32 * (INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32); - let k_times_modulus = (k as i16 as i32) * (FIELD_MODULUS as i32); - - let c = (k_times_modulus >> MONTGOMERY_SHIFT) as i16; - let value_high = (value >> MONTGOMERY_SHIFT) as i16; - - value_high - c -} - -#[inline(always)] -pub(crate) fn montgomery_multiply_fe_by_fer( - fe: FieldElement, - fer: FieldElementTimesMontgomeryR, -) -> FieldElement { - montgomery_reduce_element((fe as i32) * (fer as i32)) -} - pub(crate) fn compress_ciphertext_coefficient(coefficient_bits: u8, fe: u16) -> FieldElement { // This has to be constant time due to: // https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/ldX0ThYJuBo/m/ovODsdY7AwAJ @@ -75,34 +43,6 @@ pub(crate) fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> Portable PortableVector { elements: array } } -pub(crate) fn barrett_reduce_element(value: FieldElement) -> FieldElement { - // hax_debug_assert!( - // i32::from(value) > -BARRETT_R && i32::from(value) < BARRETT_R, - // "value is {value}" - // ); - - let t = (i32::from(value) * BARRETT_MULTIPLIER) + (BARRETT_R >> 1); - let quotient = (t >> BARRETT_SHIFT) as i16; - - let result = value - (quotient * FIELD_MODULUS); - - // hax_debug_assert!( - // result > -FIELD_MODULUS && result < FIELD_MODULUS, - // "value is {value}" - // ); - - result -} - -#[inline(always)] -pub(crate) fn barrett_reduce(mut v: PortableVector) -> PortableVector { - for i in 0..FIELD_ELEMENTS_IN_VECTOR { - v.elements[i] = barrett_reduce_element(v.elements[i]); - } - - v -} - #[inline(always)] pub(crate) fn compress(mut v: PortableVector) -> PortableVector { for i in 0..FIELD_ELEMENTS_IN_VECTOR { @@ -113,7 +53,9 @@ pub(crate) fn compress(mut v: PortableVector) -> Po } #[inline(always)] -pub(crate) fn decompress(mut v: PortableVector) -> PortableVector { +pub(crate) fn decompress_ciphertext_coefficient( + mut v: PortableVector, +) -> PortableVector { debug_assert!(to_i16_array(v) .into_iter() .all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS)); @@ -132,386 +74,6 @@ pub(crate) fn decompress(mut v: PortableVector) -> v } -#[inline(always)] -pub(crate) fn ntt_layer_1_step( - mut v: PortableVector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> PortableVector { - // First 8 elements. - let t = montgomery_multiply_fe_by_fer(v.elements[2], zeta0); - v.elements[2] = v.elements[0] - t; - v.elements[0] = v.elements[0] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[3], zeta0); - v.elements[3] = v.elements[1] - t; - v.elements[1] = v.elements[1] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[6], zeta1); - v.elements[6] = v.elements[4] - t; - v.elements[4] = v.elements[4] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[7], zeta1); - v.elements[7] = v.elements[5] - t; - v.elements[5] = v.elements[5] + t; - - // Next 8 elements. - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 2], zeta2); - v.elements[8 + 2] = v.elements[8 + 0] - t; - v.elements[8 + 0] = v.elements[8 + 0] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 3], zeta2); - v.elements[8 + 3] = v.elements[8 + 1] - t; - v.elements[8 + 1] = v.elements[8 + 1] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 6], zeta3); - v.elements[8 + 6] = v.elements[8 + 4] - t; - v.elements[8 + 4] = v.elements[8 + 4] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 7], zeta3); - v.elements[8 + 7] = v.elements[8 + 5] - t; - v.elements[8 + 5] = v.elements[8 + 5] + t; - - v -} - -#[inline(always)] -pub(crate) fn ntt_layer_3_step(mut v: PortableVector, zeta: i16) -> PortableVector { - let t = montgomery_multiply_fe_by_fer(v.elements[8], zeta); - v.elements[8] = v.elements[0] - t; - v.elements[0] = v.elements[0] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[9], zeta); - v.elements[9] = v.elements[1] - t; - v.elements[1] = v.elements[1] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[10], zeta); - v.elements[10] = v.elements[2] - t; - v.elements[2] = v.elements[2] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[11], zeta); - v.elements[11] = v.elements[3] - t; - v.elements[3] = v.elements[3] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[12], zeta); - v.elements[12] = v.elements[4] - t; - v.elements[4] = v.elements[4] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[13], zeta); - v.elements[13] = v.elements[5] - t; - v.elements[5] = v.elements[5] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[14], zeta); - v.elements[14] = v.elements[6] - t; - v.elements[6] = v.elements[6] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[15], zeta); - v.elements[15] = v.elements[7] - t; - v.elements[7] = v.elements[7] + t; - - v -} - -#[inline(always)] -pub(crate) fn ntt_layer_2_step(mut v: PortableVector, zeta0: i16, zeta1: i16) -> PortableVector { - // First 8 elements. - let t = montgomery_multiply_fe_by_fer(v.elements[4], zeta0); - v.elements[4] = v.elements[0] - t; - v.elements[0] = v.elements[0] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[5], zeta0); - v.elements[5] = v.elements[1] - t; - v.elements[1] = v.elements[1] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[6], zeta0); - v.elements[6] = v.elements[2] - t; - v.elements[2] = v.elements[2] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[7], zeta0); - v.elements[7] = v.elements[3] - t; - v.elements[3] = v.elements[3] + t; - - // Next 8 elements. - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 4], zeta1); - v.elements[8 + 4] = v.elements[8 + 0] - t; - v.elements[8 + 0] = v.elements[8 + 0] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 5], zeta1); - v.elements[8 + 5] = v.elements[8 + 1] - t; - v.elements[8 + 1] = v.elements[8 + 1] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 6], zeta1); - v.elements[8 + 6] = v.elements[8 + 2] - t; - v.elements[8 + 2] = v.elements[8 + 2] + t; - - let t = montgomery_multiply_fe_by_fer(v.elements[8 + 7], zeta1); - v.elements[8 + 7] = v.elements[8 + 3] - t; - v.elements[8 + 3] = v.elements[8 + 3] + t; - - v -} - -#[inline(always)] -pub(crate) fn inv_ntt_layer_1_step( - mut v: PortableVector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> PortableVector { - // First 8 elements. - let a_minus_b = v.elements[2] - v.elements[0]; - v.elements[0] = barrett_reduce_element(v.elements[0] + v.elements[2]); - v.elements[2] = montgomery_multiply_fe_by_fer(a_minus_b, zeta0); - - let a_minus_b = v.elements[3] - v.elements[1]; - v.elements[1] = barrett_reduce_element(v.elements[1] + v.elements[3]); - v.elements[3] = montgomery_multiply_fe_by_fer(a_minus_b, zeta0); - - let a_minus_b = v.elements[6] - v.elements[4]; - v.elements[4] = barrett_reduce_element(v.elements[4] + v.elements[6]); - v.elements[6] = montgomery_multiply_fe_by_fer(a_minus_b, zeta1); - - let a_minus_b = v.elements[7] - v.elements[5]; - v.elements[5] = barrett_reduce_element(v.elements[5] + v.elements[7]); - v.elements[7] = montgomery_multiply_fe_by_fer(a_minus_b, zeta1); - - // Next 8 elements. - let a_minus_b = v.elements[8 + 2] - v.elements[8 + 0]; - v.elements[8 + 0] = barrett_reduce_element(v.elements[8 + 0] + v.elements[8 + 2]); - v.elements[8 + 2] = montgomery_multiply_fe_by_fer(a_minus_b, zeta2); - - let a_minus_b = v.elements[8 + 3] - v.elements[8 + 1]; - v.elements[8 + 1] = barrett_reduce_element(v.elements[8 + 1] + v.elements[8 + 3]); - v.elements[8 + 3] = montgomery_multiply_fe_by_fer(a_minus_b, zeta2); - - let a_minus_b = v.elements[8 + 6] - v.elements[8 + 4]; - v.elements[8 + 4] = barrett_reduce_element(v.elements[8 + 4] + v.elements[8 + 6]); - v.elements[8 + 6] = montgomery_multiply_fe_by_fer(a_minus_b, zeta3); - - let a_minus_b = v.elements[8 + 7] - v.elements[8 + 5]; - v.elements[8 + 5] = barrett_reduce_element(v.elements[8 + 5] + v.elements[8 + 7]); - v.elements[8 + 7] = montgomery_multiply_fe_by_fer(a_minus_b, zeta3); - - v -} - -#[inline(always)] -pub(crate) fn inv_ntt_layer_2_step( - mut v: PortableVector, - zeta0: i16, - zeta1: i16, -) -> PortableVector { - // First 8 elements. - let a_minus_b = v.elements[4] - v.elements[0]; - v.elements[0] = v.elements[0] + v.elements[4]; - v.elements[4] = montgomery_multiply_fe_by_fer(a_minus_b, zeta0); - - let a_minus_b = v.elements[5] - v.elements[1]; - v.elements[1] = v.elements[1] + v.elements[5]; - v.elements[5] = montgomery_multiply_fe_by_fer(a_minus_b, zeta0); - - let a_minus_b = v.elements[6] - v.elements[2]; - v.elements[2] = v.elements[2] + v.elements[6]; - v.elements[6] = montgomery_multiply_fe_by_fer(a_minus_b, zeta0); - - let a_minus_b = v.elements[7] - v.elements[3]; - v.elements[3] = v.elements[3] + v.elements[7]; - v.elements[7] = montgomery_multiply_fe_by_fer(a_minus_b, zeta0); - - // Next 8 elements. - let a_minus_b = v.elements[8 + 4] - v.elements[8 + 0]; - v.elements[8 + 0] = v.elements[8 + 0] + v.elements[8 + 4]; - v.elements[8 + 4] = montgomery_multiply_fe_by_fer(a_minus_b, zeta1); - - let a_minus_b = v.elements[8 + 5] - v.elements[8 + 1]; - v.elements[8 + 1] = v.elements[8 + 1] + v.elements[8 + 5]; - v.elements[8 + 5] = montgomery_multiply_fe_by_fer(a_minus_b, zeta1); - - let a_minus_b = v.elements[8 + 6] - v.elements[8 + 2]; - v.elements[8 + 2] = v.elements[8 + 2] + v.elements[8 + 6]; - v.elements[8 + 6] = montgomery_multiply_fe_by_fer(a_minus_b, zeta1); - - let a_minus_b = v.elements[8 + 7] - v.elements[8 + 3]; - v.elements[8 + 3] = v.elements[8 + 3] + v.elements[8 + 7]; - v.elements[8 + 7] = montgomery_multiply_fe_by_fer(a_minus_b, zeta1); - - v -} - -#[inline(always)] -pub(crate) fn inv_ntt_layer_3_step(mut v: PortableVector, zeta: i16) -> PortableVector { - let a_minus_b = v.elements[8] - v.elements[0]; - v.elements[0] = v.elements[0] + v.elements[8]; - v.elements[8] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[9] - v.elements[1]; - v.elements[1] = v.elements[1] + v.elements[9]; - v.elements[9] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[10] - v.elements[2]; - v.elements[2] = v.elements[2] + v.elements[10]; - v.elements[10] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[11] - v.elements[3]; - v.elements[3] = v.elements[3] + v.elements[11]; - v.elements[11] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[12] - v.elements[4]; - v.elements[4] = v.elements[4] + v.elements[12]; - v.elements[12] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[13] - v.elements[5]; - v.elements[5] = v.elements[5] + v.elements[13]; - v.elements[13] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[14] - v.elements[6]; - v.elements[6] = v.elements[6] + v.elements[14]; - v.elements[14] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - let a_minus_b = v.elements[15] - v.elements[7]; - v.elements[7] = v.elements[7] + v.elements[15]; - v.elements[15] = montgomery_multiply_fe_by_fer(a_minus_b, zeta); - - v -} - -#[inline(always)] -pub(crate) fn ntt_multiply_binomials( - (a0, a1): (FieldElement, FieldElement), - (b0, b1): (FieldElement, FieldElement), - zeta: FieldElementTimesMontgomeryR, -) -> (MontgomeryFieldElement, MontgomeryFieldElement) { - ( - montgomery_reduce_element( - (a0 as i32) * (b0 as i32) - + (montgomery_reduce_element((a1 as i32) * (b1 as i32)) as i32) * (zeta as i32), - ), - montgomery_reduce_element((a0 as i32) * (b1 as i32) + (a1 as i32) * (b0 as i32)), - ) -} - -#[inline(always)] -pub(crate) fn ntt_multiply( - lhs: &PortableVector, - rhs: &PortableVector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> PortableVector { - let mut out = zero(); - - // First 8 elements. - let product = ntt_multiply_binomials( - (lhs.elements[0], lhs.elements[1]), - (rhs.elements[0], rhs.elements[1]), - zeta0, - ); - out.elements[0] = product.0; - out.elements[1] = product.1; - - let product = ntt_multiply_binomials( - (lhs.elements[2], lhs.elements[3]), - (rhs.elements[2], rhs.elements[3]), - -zeta0, - ); - out.elements[2] = product.0; - out.elements[3] = product.1; - - let product = ntt_multiply_binomials( - (lhs.elements[4], lhs.elements[5]), - (rhs.elements[4], rhs.elements[5]), - zeta1, - ); - out.elements[4] = product.0; - out.elements[5] = product.1; - - let product = ntt_multiply_binomials( - (lhs.elements[6], lhs.elements[7]), - (rhs.elements[6], rhs.elements[7]), - -zeta1, - ); - out.elements[6] = product.0; - out.elements[7] = product.1; - - // Next 8 elements. - let product = ntt_multiply_binomials( - (lhs.elements[8 + 0], lhs.elements[8 + 1]), - (rhs.elements[8 + 0], rhs.elements[8 + 1]), - zeta2, - ); - out.elements[8 + 0] = product.0; - out.elements[8 + 1] = product.1; - - let product = ntt_multiply_binomials( - (lhs.elements[8 + 2], lhs.elements[8 + 3]), - (rhs.elements[8 + 2], rhs.elements[8 + 3]), - -zeta2, - ); - out.elements[8 + 2] = product.0; - out.elements[8 + 3] = product.1; - - let product = ntt_multiply_binomials( - (lhs.elements[8 + 4], lhs.elements[8 + 5]), - (rhs.elements[8 + 4], rhs.elements[8 + 5]), - zeta3, - ); - out.elements[8 + 4] = product.0; - out.elements[8 + 5] = product.1; - - let product = ntt_multiply_binomials( - (lhs.elements[8 + 6], lhs.elements[8 + 7]), - (rhs.elements[8 + 6], rhs.elements[8 + 7]), - -zeta3, - ); - out.elements[8 + 6] = product.0; - out.elements[8 + 7] = product.1; - - out -} - -#[inline(always)] -pub(crate) fn deserialize_1(v: &[u8]) -> PortableVector { - let mut result = zero(); - - for i in 0..8 { - result.elements[i] = ((v[0] >> i) & 0x1) as i16; - } - for i in 8..FIELD_ELEMENTS_IN_VECTOR { - result.elements[i] = ((v[1] >> (i - 8)) & 0x1) as i16; - } - - result -} - -#[inline(always)] -pub(crate) fn deserialize_4(bytes: &[u8]) -> PortableVector { - let mut v = zero(); - - v.elements[0] = (bytes[0] & 0x0F) as i16; - v.elements[1] = ((bytes[0] >> 4) & 0x0F) as i16; - v.elements[2] = (bytes[1] & 0x0F) as i16; - v.elements[3] = ((bytes[1] >> 4) & 0x0F) as i16; - v.elements[4] = (bytes[2] & 0x0F) as i16; - v.elements[5] = ((bytes[2] >> 4) & 0x0F) as i16; - v.elements[6] = (bytes[3] & 0x0F) as i16; - v.elements[7] = ((bytes[3] >> 4) & 0x0F) as i16; - - v.elements[8] = (bytes[4] & 0x0F) as i16; - v.elements[9] = ((bytes[4] >> 4) & 0x0F) as i16; - v.elements[10] = (bytes[5] & 0x0F) as i16; - v.elements[11] = ((bytes[5] >> 4) & 0x0F) as i16; - v.elements[12] = (bytes[6] & 0x0F) as i16; - v.elements[13] = ((bytes[6] >> 4) & 0x0F) as i16; - v.elements[14] = (bytes[7] & 0x0F) as i16; - v.elements[15] = ((bytes[7] >> 4) & 0x0F) as i16; - - v -} - #[inline(always)] pub(crate) fn deserialize_5(bytes: &[u8]) -> PortableVector { let mut v = zero(); @@ -537,33 +99,6 @@ pub(crate) fn deserialize_5(bytes: &[u8]) -> PortableVector { v } -#[inline(always)] -pub(crate) fn deserialize_10(bytes: &[u8]) -> PortableVector { - let mut result = zero(); - - result.elements[0] = ((bytes[1] as i16 & 0x03) << 8 | (bytes[0] as i16 & 0xFF)) as i16; - result.elements[1] = ((bytes[2] as i16 & 0x0F) << 6 | (bytes[1] as i16 >> 2)) as i16; - result.elements[2] = ((bytes[3] as i16 & 0x3F) << 4 | (bytes[2] as i16 >> 4)) as i16; - result.elements[3] = (((bytes[4] as i16) << 2) | (bytes[3] as i16 >> 6)) as i16; - result.elements[4] = ((bytes[6] as i16 & 0x03) << 8 | (bytes[5] as i16 & 0xFF)) as i16; - result.elements[5] = ((bytes[7] as i16 & 0x0F) << 6 | (bytes[6] as i16 >> 2)) as i16; - result.elements[6] = ((bytes[8] as i16 & 0x3F) << 4 | (bytes[7] as i16 >> 4)) as i16; - result.elements[7] = (((bytes[9] as i16) << 2) | (bytes[8] as i16 >> 6)) as i16; - - result.elements[8] = - ((bytes[10 + 1] as i16 & 0x03) << 8 | (bytes[10 + 0] as i16 & 0xFF)) as i16; - result.elements[9] = ((bytes[10 + 2] as i16 & 0x0F) << 6 | (bytes[10 + 1] as i16 >> 2)) as i16; - result.elements[10] = ((bytes[10 + 3] as i16 & 0x3F) << 4 | (bytes[10 + 2] as i16 >> 4)) as i16; - result.elements[11] = (((bytes[10 + 4] as i16) << 2) | (bytes[10 + 3] as i16 >> 6)) as i16; - result.elements[12] = - ((bytes[10 + 6] as i16 & 0x03) << 8 | (bytes[10 + 5] as i16 & 0xFF)) as i16; - result.elements[13] = ((bytes[10 + 7] as i16 & 0x0F) << 6 | (bytes[10 + 6] as i16 >> 2)) as i16; - result.elements[14] = ((bytes[10 + 8] as i16 & 0x3F) << 4 | (bytes[10 + 7] as i16 >> 4)) as i16; - result.elements[15] = (((bytes[10 + 9] as i16) << 2) | (bytes[10 + 8] as i16 >> 6)) as i16; - - result -} - #[inline(always)] pub(crate) fn serialize_11(v: PortableVector) -> [u8; 22] { let mut result = [0u8; 22]; @@ -627,57 +162,6 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> PortableVector { result } -#[inline(always)] -pub(crate) fn deserialize_12(bytes: &[u8]) -> PortableVector { - let mut re = zero(); - - let byte0 = bytes[0] as i16; - let byte1 = bytes[1] as i16; - let byte2 = bytes[2] as i16; - let byte3 = bytes[3] as i16; - let byte4 = bytes[4] as i16; - let byte5 = bytes[5] as i16; - let byte6 = bytes[6] as i16; - let byte7 = bytes[7] as i16; - let byte8 = bytes[8] as i16; - let byte9 = bytes[9] as i16; - let byte10 = bytes[10] as i16; - let byte11 = bytes[11] as i16; - - re.elements[0] = (byte1 & 0x0F) << 8 | (byte0 & 0xFF); - re.elements[1] = (byte2 << 4) | ((byte1 >> 4) & 0x0F); - re.elements[2] = (byte4 & 0x0F) << 8 | (byte3 & 0xFF); - re.elements[3] = (byte5 << 4) | ((byte4 >> 4) & 0x0F); - re.elements[4] = (byte7 & 0x0F) << 8 | (byte6 & 0xFF); - re.elements[5] = (byte8 << 4) | ((byte7 >> 4) & 0x0F); - re.elements[6] = (byte10 & 0x0F) << 8 | (byte9 & 0xFF); - re.elements[7] = (byte11 << 4) | ((byte10 >> 4) & 0x0F); - - let byte12 = bytes[12] as i16; - let byte13 = bytes[13] as i16; - let byte14 = bytes[14] as i16; - let byte15 = bytes[15] as i16; - let byte16 = bytes[16] as i16; - let byte17 = bytes[17] as i16; - let byte18 = bytes[18] as i16; - let byte19 = bytes[19] as i16; - let byte20 = bytes[20] as i16; - let byte21 = bytes[21] as i16; - let byte22 = bytes[22] as i16; - let byte23 = bytes[23] as i16; - - re.elements[8] = (byte13 & 0x0F) << 8 | (byte12 & 0xFF); - re.elements[9] = (byte14 << 4) | ((byte13 >> 4) & 0x0F); - re.elements[10] = (byte16 & 0x0F) << 8 | (byte15 & 0xFF); - re.elements[11] = (byte17 << 4) | ((byte16 >> 4) & 0x0F); - re.elements[12] = (byte19 & 0x0F) << 8 | (byte18 & 0xFF); - re.elements[13] = (byte20 << 4) | ((byte19 >> 4) & 0x0F); - re.elements[14] = (byte22 & 0x0F) << 8 | (byte21 & 0xFF); - re.elements[15] = (byte23 << 4) | ((byte22 >> 4) & 0x0F); - - re -} - #[inline(always)] pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { let mut result = [0i16; 16]; diff --git a/polynomials/src/lib.rs b/polynomials/src/lib.rs index 2c727d249..34db69b5e 100644 --- a/polynomials/src/lib.rs +++ b/polynomials/src/lib.rs @@ -338,7 +338,9 @@ fn compress(mut v: PortableVector) -> PortableVecto } #[inline(always)] -fn decompress(mut v: PortableVector) -> PortableVector { +fn decompress_ciphertext_coefficient( + mut v: PortableVector, +) -> PortableVector { debug_assert!(to_i16_array(v) .into_iter() .all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS)); @@ -1117,8 +1119,8 @@ impl Operations for PortableVector { compress::(v) } - fn decompress(v: Self) -> Self { - decompress::(v) + fn decompress_ciphertext_coefficient(v: Self) -> Self { + decompress_ciphertext_coefficient::(v) } fn ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { diff --git a/traits/src/lib.rs b/traits/src/lib.rs index 06391ff82..4000ab7d2 100644 --- a/traits/src/lib.rs +++ b/traits/src/lib.rs @@ -28,7 +28,7 @@ pub trait Operations: Copy + Clone { // Compression fn compress_1(v: Self) -> Self; fn compress(v: Self) -> Self; - fn decompress(v: Self) -> Self; + fn decompress_ciphertext_coefficient(v: Self) -> Self; // NTT fn ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self;