diff --git a/polynomials-avx2/src/intrinsics.rs b/polynomials-avx2/src/intrinsics.rs index d28b227c7..93133cc18 100644 --- a/polynomials-avx2/src/intrinsics.rs +++ b/polynomials-avx2/src/intrinsics.rs @@ -3,16 +3,171 @@ pub(crate) use core::arch::x86::*; #[cfg(target_arch = "x86_64")] pub(crate) use core::arch::x86_64::*; +pub(crate) fn mm256_storeu_si256(output: &mut [i16], vector: __m256i) { + debug_assert_eq!(output.len(), 16); + unsafe { + _mm256_storeu_si256(output.as_mut_ptr() as *mut __m256i, vector); + } +} +pub(crate) fn mm_storeu_si128(output: &mut [i16], vector: __m128i) { + debug_assert_eq!(output.len(), 8); + unsafe { + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector); + } +} + +pub(crate) fn mm_storeu_bytes_si128(output: &mut [u8], vector: __m128i) { + debug_assert_eq!(output.len(), 16); + unsafe { + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector); + } +} + +pub(crate) fn mm_loadu_si128(input: &[u8]) -> __m128i { + debug_assert_eq!(input.len(), 16); + unsafe { _mm_loadu_si128(input.as_ptr() as *const __m128i) } +} + +pub(crate) fn mm256_loadu_si256(input: &[i16]) -> __m256i { + debug_assert_eq!(input.len(), 16); + unsafe { _mm256_loadu_si256(input.as_ptr() as *const __m256i) } +} + +pub(crate) fn mm256_setzero_si256() -> __m256i { + unsafe { _mm256_setzero_si256() } +} + +pub(crate) fn mm_set_epi8( + byte15: i8, + byte14: i8, + byte13: i8, + byte12: i8, + byte11: i8, + byte10: i8, + byte9: i8, + byte8: i8, + byte7: i8, + byte6: i8, + byte5: i8, + byte4: i8, + byte3: i8, + byte2: i8, + byte1: i8, + byte0: i8, +) -> __m128i { + unsafe { + _mm_set_epi8( + byte15, byte14, byte13, byte12, byte11, byte10, + byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0, + ) + } +} + +pub(crate) fn mm256_set_epi8( + byte31: i8, + byte30: i8, + byte29: i8, + byte28: i8, + byte27: i8, + byte26: i8, + byte25: i8, + byte24: i8, + byte23: i8, + byte22: i8, + byte21: i8, + byte20: i8, + byte19: i8, + byte18: i8, + byte17: i8, + byte16: i8, + byte15: i8, + byte14: i8, + byte13: i8, + byte12: i8, + byte11: i8, + byte10: i8, + byte9: i8, + byte8: i8, + byte7: i8, + byte6: i8, + byte5: i8, + byte4: i8, + byte3: i8, + byte2: i8, + byte1: i8, + byte0: i8, +) -> __m256i { + unsafe { + _mm256_set_epi8( + byte31, byte30, byte29, byte28, byte27, byte26, byte25, byte24, byte23, byte22, byte21, + byte20, byte19, byte18, byte17, byte16, byte15, byte14, byte13, byte12, byte11, byte10, + byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0, + ) + } +} + pub(crate) fn mm256_set1_epi16(constant: i16) -> __m256i { unsafe { _mm256_set1_epi16(constant) } } +pub(crate) fn mm256_set_epi16( + input15: i16, + input14: i16, + input13: i16, + input12: i16, + input11: i16, + input10: i16, + input9: i16, + input8: i16, + input7: i16, + input6: i16, + input5: i16, + input4: i16, + input3: i16, + input2: i16, + input1: i16, + input0: i16, +) -> __m256i { + unsafe { + _mm256_set_epi16( + input15, input14, input13, input12, input11, input10, input9, input8, input7, input6, + input5, input4, input3, input2, input1, input0, + ) + } +} + +pub(crate) fn mm_set1_epi16(constant: i16) -> __m128i { + unsafe { _mm_set1_epi16(constant) } +} + pub(crate) fn mm256_set1_epi32(constant: i32) -> __m256i { unsafe { _mm256_set1_epi32(constant) } } +pub(crate) fn mm256_set_epi32( + input7: i32, + input6: i32, + input5: i32, + input4: i32, + input3: i32, + input2: i32, + input1: i32, + input0: i32, +) -> __m256i { + unsafe { + _mm256_set_epi32( + input7, input6, input5, input4, input3, input2, input1, input0, + ) + } +} +pub(crate) fn mm_add_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_add_epi16(lhs, rhs) } +} pub(crate) fn mm256_add_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_add_epi16(lhs, rhs) } } +pub(crate) fn mm256_madd_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_madd_epi16(lhs, rhs) } +} pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_add_epi32(lhs, rhs) } } @@ -20,10 +175,26 @@ pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { pub(crate) fn mm256_sub_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_sub_epi16(lhs, rhs) } } +pub(crate) fn mm_sub_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_sub_epi16(lhs, rhs) } +} pub(crate) fn mm256_mullo_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_mullo_epi16(lhs, rhs) } } + +pub(crate) fn mm_mullo_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_mullo_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_cmpgt_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_cmpgt_epi16(lhs, rhs) } +} + +pub(crate) fn mm_mulhi_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_mulhi_epi16(lhs, rhs) } +} + pub(crate) fn mm256_mullo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_mullo_epi32(lhs, rhs) } } @@ -48,6 +219,11 @@ pub(crate) fn mm256_srai_epi16(vector: __m256i) -> __m256i debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_srai_epi16(vector, SHIFT_BY) } } +pub(crate) fn mm256_srai_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_srai_epi32(vector, SHIFT_BY) } +} + pub(crate) fn mm256_srli_epi16(vector: __m256i) -> __m256i { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_srli_epi16(vector, SHIFT_BY) } @@ -57,6 +233,11 @@ pub(crate) fn mm256_srli_epi32(vector: __m256i) -> __m256i unsafe { _mm256_srli_epi32(vector, SHIFT_BY) } } +pub(crate) fn mm256_srli_epi64(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); + unsafe { _mm256_srli_epi64(vector, SHIFT_BY) } +} + pub(crate) fn mm256_slli_epi16(vector: __m256i) -> __m256i { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_slli_epi16(vector, SHIFT_BY) } @@ -67,6 +248,12 @@ pub(crate) fn mm256_slli_epi32(vector: __m256i) -> __m256i unsafe { _mm256_slli_epi32(vector, SHIFT_BY) } } +pub(crate) fn mm_shuffle_epi8(vector: __m128i, control: __m128i) -> __m128i { + unsafe { _mm_shuffle_epi8(vector, control) } +} +pub(crate) fn mm256_shuffle_epi8(vector: __m256i, control: __m256i) -> __m256i { + unsafe { _mm256_shuffle_epi8(vector, control) } +} pub(crate) fn mm256_shuffle_epi32(vector: __m256i) -> __m256i { debug_assert!(CONTROL >= 0 && CONTROL < 256); unsafe { _mm256_shuffle_epi32(vector, CONTROL) } @@ -92,11 +279,17 @@ pub(crate) fn mm256_unpackhi_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { pub(crate) fn mm256_castsi256_si128(vector: __m256i) -> __m128i { unsafe { _mm256_castsi256_si128(vector) } } +pub(crate) fn mm256_castsi128_si256(vector: __m128i) -> __m256i { + unsafe { _mm256_castsi128_si256(vector) } +} pub(crate) fn mm256_cvtepi16_epi32(vector: __m128i) -> __m256i { unsafe { _mm256_cvtepi16_epi32(vector) } } +pub(crate) fn mm_packs_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_packs_epi16(lhs, rhs) } +} pub(crate) fn mm256_packs_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_packs_epi32(lhs, rhs) } } @@ -105,3 +298,28 @@ pub(crate) fn mm256_extracti128_si256(vector: __m256i) -> __ debug_assert!(CONTROL == 0 || CONTROL == 1); unsafe { _mm256_extracti128_si256(vector, CONTROL) } } + +pub(crate) fn mm256_inserti128_si256( + vector: __m256i, + vector_i128: __m128i, +) -> __m256i { + debug_assert!(CONTROL == 0 || CONTROL == 1); + unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) } +} + +pub(crate) fn mm256_blend_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_blend_epi16(lhs, rhs, CONTROL) } +} + +pub(crate) fn mm_movemask_epi8(vector: __m128i) -> i32 { + unsafe { _mm_movemask_epi8(vector) } +} + +pub(crate) fn mm256_permutevar8x32_epi32(vector: __m256i, control: __m256i) -> __m256i { + unsafe { _mm256_permutevar8x32_epi32(vector, control) } +} + +pub(crate) fn mm256_sllv_epi32(vector: __m256i, counts: __m256i) -> __m256i { + unsafe { _mm256_sllv_epi32(vector, counts) } +} diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 52de02fd0..e24519afd 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -1,7 +1,4 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; use libcrux_traits::Operations; #[cfg(test)] @@ -24,24 +21,21 @@ pub struct SIMD256Vector { #[inline(always)] fn zero() -> SIMD256Vector { SIMD256Vector { - elements: unsafe { _mm256_setzero_si256() }, + elements: mm256_setzero_si256(), } } #[inline(always)] fn to_i16_array(v: SIMD256Vector) -> [i16; 16] { - let mut out = [0i16; 16]; + let mut output = [0i16; 16]; + mm256_storeu_si256(&mut output[..], v.elements); - unsafe { - _mm256_storeu_si256(out.as_mut_ptr() as *mut __m256i, v.elements); - } - - out + output } #[inline(always)] fn from_i16_array(array: &[i16]) -> SIMD256Vector { SIMD256Vector { - elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) }, + elements: mm256_loadu_si256(array), } } @@ -187,9 +181,9 @@ impl Operations for SIMD256Vector { serialize::serialize_1(vector.elements) } - fn deserialize_1(input: &[u8]) -> Self { + fn deserialize_1(bytes: &[u8]) -> Self { Self { - elements: serialize::deserialize_1(input), + elements: serialize::deserialize_1(bytes), } } diff --git a/polynomials-avx2/src/ntt.rs b/polynomials-avx2/src/ntt.rs index 28377dbfc..2ebb1561d 100644 --- a/polynomials-avx2/src/ntt.rs +++ b/polynomials-avx2/src/ntt.rs @@ -1,210 +1,172 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; use crate::arithmetic; use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; #[inline(always)] -fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { - v = unsafe { - let value_low = _mm256_mullo_epi16(v, c); +fn montgomery_multiply_by_constants(v: __m256i, c: __m256i) -> __m256i { + let value_low = mm256_mullo_epi16(v, c); - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); + let k = mm256_mullo_epi16( + value_low, + mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS)); - let value_high = _mm256_mulhi_epi16(v, c); + let value_high = mm256_mulhi_epi16(v, c); - _mm256_sub_epi16(value_high, k_times_modulus) - }; - - v + mm256_sub_epi16(value_high, k_times_modulus) } #[inline(always)] -fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { - v = unsafe { - let k = _mm256_mullo_epi16( - v, - _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi32(FIELD_MODULUS as i32)); +fn montgomery_reduce_i32s(v: __m256i) -> __m256i { + let k = mm256_mullo_epi16( + v, + mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi32(FIELD_MODULUS as i32)); - let value_high = _mm256_srli_epi32(v, 16); + let value_high = mm256_srli_epi32::<16>(v); - let result = _mm256_sub_epi16(value_high, k_times_modulus); + let result = mm256_sub_epi16(value_high, k_times_modulus); - let result = _mm256_slli_epi32(result, 16); - _mm256_srai_epi32(result, 16) - }; + let result = mm256_slli_epi32::<16>(result); - v + mm256_srai_epi32::<16>(result) } #[inline(always)] -fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { - v = unsafe { - let value_low = _mm_mullo_epi16(v, c); - - let k = _mm_mullo_epi16( - value_low, - _mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm_mulhi_epi16(k, _mm_set1_epi16(FIELD_MODULUS)); +fn montgomery_multiply_m128i_by_constants(v: __m128i, c: __m128i) -> __m128i { + let value_low = mm_mullo_epi16(v, c); - let value_high = _mm_mulhi_epi16(v, c); + let k = mm_mullo_epi16( + value_low, + mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm_mulhi_epi16(k, mm_set1_epi16(FIELD_MODULUS)); - _mm_sub_epi16(value_high, k_times_modulus) - }; + let value_high = mm_mulhi_epi16(v, c); - v + mm_sub_epi16(value_high, k_times_modulus) } #[inline(always)] pub(crate) fn ntt_layer_1_step( - mut vector: __m256i, + vector: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, ) -> __m256i { - vector = unsafe { - let zetas = _mm256_set_epi16( - -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, - zeta1, -zeta0, -zeta0, zeta0, zeta0, - ); + let zetas = mm256_set_epi16( + -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, zeta1, + -zeta0, -zeta0, zeta0, zeta0, + ); - let rhs = _mm256_shuffle_epi32(vector, 0b11_11_01_01); - let rhs = montgomery_multiply_by_constants(rhs, zetas); + let rhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector); + let rhs = montgomery_multiply_by_constants(rhs, zetas); - let lhs = _mm256_shuffle_epi32(vector, 0b10_10_00_00); + let lhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector); - _mm256_add_epi16(lhs, rhs) - }; - - vector + mm256_add_epi16(lhs, rhs) } #[inline(always)] -pub(crate) fn ntt_layer_2_step(mut vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { - vector = unsafe { - let zetas = _mm256_set_epi16( - -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, - -zeta0, zeta0, zeta0, zeta0, zeta0, - ); - - let rhs = _mm256_shuffle_epi32(vector, 0b11_10_11_10); - let rhs = montgomery_multiply_by_constants(rhs, zetas); +pub(crate) fn ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + let zetas = mm256_set_epi16( + -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, -zeta0, + zeta0, zeta0, zeta0, zeta0, + ); - let lhs = _mm256_shuffle_epi32(vector, 0b01_00_01_00); + let rhs = mm256_shuffle_epi32::<0b11_10_11_10>(vector); + let rhs = montgomery_multiply_by_constants(rhs, zetas); - _mm256_add_epi16(lhs, rhs) - }; + let lhs = mm256_shuffle_epi32::<0b01_00_01_00>(vector); - vector + mm256_add_epi16(lhs, rhs) } #[inline(always)] -pub(crate) fn ntt_layer_3_step(mut vector: __m256i, zeta: i16) -> __m256i { - vector = unsafe { - let rhs = _mm256_extracti128_si256(vector, 1); - let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); +pub(crate) fn ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i { + let rhs = mm256_extracti128_si256::<1>(vector); + let rhs = montgomery_multiply_m128i_by_constants(rhs, mm_set1_epi16(zeta)); - let lhs = _mm256_castsi256_si128(vector); + let lhs = mm256_castsi256_si128(vector); - let lower_coefficients = _mm_add_epi16(lhs, rhs); - let upper_coefficients = _mm_sub_epi16(lhs, rhs); + let lower_coefficients = mm_add_epi16(lhs, rhs); + let upper_coefficients = mm_sub_epi16(lhs, rhs); - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + let combined = mm256_castsi128_si256(lower_coefficients); + let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients); - combined - }; - - vector + combined } #[inline(always)] pub(crate) fn inv_ntt_layer_1_step( - mut vector: __m256i, + vector: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, ) -> __m256i { - vector = unsafe { - let lhs = _mm256_shuffle_epi32(vector, 0b11_11_01_01); - - let rhs = _mm256_shuffle_epi32(vector, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), - ); + let lhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector); - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, - ), - ); + let rhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector); + let rhs = mm256_mullo_epi16( + rhs, + mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), + ); - let sum = arithmetic::barrett_reduce(sum); + let sum = mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + mm256_set_epi16( + zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, + ), + ); - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) - }; + let sum = arithmetic::barrett_reduce(sum); - vector + mm256_blend_epi16::<0b1_1_0_0_1_1_0_0>(sum, sum_times_zetas) } #[inline(always)] -pub(crate) fn inv_ntt_layer_2_step(mut vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { - vector = unsafe { - let lhs = _mm256_permute4x64_epi64(vector, 0b11_11_01_01); - - let rhs = _mm256_permute4x64_epi64(vector, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), - ); - - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, - ), - ); - - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) - }; - - vector +pub(crate) fn inv_ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + let lhs = mm256_permute4x64_epi64::<0b11_11_01_01>(vector); + + let rhs = mm256_permute4x64_epi64::<0b10_10_00_00>(vector); + let rhs = mm256_mullo_epi16( + rhs, + mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), + ); + + let sum = mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + mm256_set_epi16( + zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, + ), + ); + + mm256_blend_epi16::<0b1_1_1_1_0_0_0_0>(sum, sum_times_zetas) } #[inline(always)] -pub(crate) fn inv_ntt_layer_3_step(mut vector: __m256i, zeta: i16) -> __m256i { - vector = unsafe { - let lhs = _mm256_extracti128_si256(vector, 1); - let rhs = _mm256_castsi256_si128(vector); - - let lower_coefficients = _mm_add_epi16(lhs, rhs); +pub(crate) fn inv_ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i { + let lhs = mm256_extracti128_si256::<1>(vector); + let rhs = mm256_castsi256_si128(vector); - let upper_coefficients = _mm_sub_epi16(lhs, rhs); - let upper_coefficients = - montgomery_multiply_m128i_by_constants(upper_coefficients, _mm_set1_epi16(zeta)); + let lower_coefficients = mm_add_epi16(lhs, rhs); - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + let upper_coefficients = mm_sub_epi16(lhs, rhs); + let upper_coefficients = + montgomery_multiply_m128i_by_constants(upper_coefficients, mm_set1_epi16(zeta)); - combined - }; + let combined = mm256_castsi128_si256(lower_coefficients); + let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients); - vector + combined } #[inline(always)] @@ -216,69 +178,67 @@ pub(crate) fn ntt_multiply( zeta2: i16, zeta3: i16, ) -> __m256i { - return unsafe { - // Compute the first term of the product - let shuffle_with = _mm256_set_epi8( - 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, - 12, 9, 8, 5, 4, 1, 0, - ); - const PERMUTE_WITH: i32 = 0b11_01_10_00; - - // Prepare the left hand side - let lhs_shuffled = _mm256_shuffle_epi8(lhs, shuffle_with); - let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); - - let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); - let lhs_evens = _mm256_cvtepi16_epi32(lhs_evens); - - let lhs_odds = _mm256_extracti128_si256(lhs_shuffled, 1); - let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); - - // Prepare the right hand side - let rhs_shuffled = _mm256_shuffle_epi8(rhs, shuffle_with); - let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); - - let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); - let rhs_evens = _mm256_cvtepi16_epi32(rhs_evens); - - let rhs_odds = _mm256_extracti128_si256(rhs_shuffled, 1); - let rhs_odds = _mm256_cvtepi16_epi32(rhs_odds); - - // Start operating with them - let left = _mm256_mullo_epi32(lhs_evens, rhs_evens); - - let right = _mm256_mullo_epi32(lhs_odds, rhs_odds); - let right = montgomery_reduce_i32s(right); - let right = _mm256_mullo_epi32( - right, - _mm256_set_epi32( - -(zeta3 as i32), - zeta3 as i32, - -(zeta2 as i32), - zeta2 as i32, - -(zeta1 as i32), - zeta1 as i32, - -(zeta0 as i32), - zeta0 as i32, - ), - ); - - let products_left = _mm256_add_epi32(left, right); - let products_left = montgomery_reduce_i32s(products_left); - - // Compute the second term of the product - let rhs_adjacent_swapped = _mm256_shuffle_epi8( - rhs, - _mm256_set_epi8( - 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, - 5, 4, 7, 6, 1, 0, 3, 2, - ), - ); - let products_right = _mm256_madd_epi16(lhs, rhs_adjacent_swapped); - let products_right = montgomery_reduce_i32s(products_right); - let products_right = _mm256_slli_epi32(products_right, 16); - - // Combine them into one vector - _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) - }; + // Compute the first term of the product + let shuffle_with = mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, + 9, 8, 5, 4, 1, 0, + ); + const PERMUTE_WITH: i32 = 0b11_01_10_00; + + // Prepare the left hand side + let lhs_shuffled = mm256_shuffle_epi8(lhs, shuffle_with); + let lhs_shuffled = mm256_permute4x64_epi64::<{ PERMUTE_WITH }>(lhs_shuffled); + + let lhs_evens = mm256_castsi256_si128(lhs_shuffled); + let lhs_evens = mm256_cvtepi16_epi32(lhs_evens); + + let lhs_odds = mm256_extracti128_si256::<1>(lhs_shuffled); + let lhs_odds = mm256_cvtepi16_epi32(lhs_odds); + + // Prepare the right hand side + let rhs_shuffled = mm256_shuffle_epi8(rhs, shuffle_with); + let rhs_shuffled = mm256_permute4x64_epi64::<{ PERMUTE_WITH }>(rhs_shuffled); + + let rhs_evens = mm256_castsi256_si128(rhs_shuffled); + let rhs_evens = mm256_cvtepi16_epi32(rhs_evens); + + let rhs_odds = mm256_extracti128_si256::<1>(rhs_shuffled); + let rhs_odds = mm256_cvtepi16_epi32(rhs_odds); + + // Start operating with them + let left = mm256_mullo_epi32(lhs_evens, rhs_evens); + + let right = mm256_mullo_epi32(lhs_odds, rhs_odds); + let right = montgomery_reduce_i32s(right); + let right = mm256_mullo_epi32( + right, + mm256_set_epi32( + -(zeta3 as i32), + zeta3 as i32, + -(zeta2 as i32), + zeta2 as i32, + -(zeta1 as i32), + zeta1 as i32, + -(zeta0 as i32), + zeta0 as i32, + ), + ); + + let products_left = mm256_add_epi32(left, right); + let products_left = montgomery_reduce_i32s(products_left); + + // Compute the second term of the product + let rhs_adjacent_swapped = mm256_shuffle_epi8( + rhs, + mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5, + 4, 7, 6, 1, 0, 3, 2, + ), + ); + let products_right = mm256_madd_epi16(lhs, rhs_adjacent_swapped); + let products_right = montgomery_reduce_i32s(products_right); + let products_right = mm256_slli_epi32::<16>(products_right); + + // Combine them into one vector + mm256_blend_epi16::<0b1_0_1_0_1_0_1_0>(products_left, products_right) } diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs index aa6cc6ca6..40542efec 100644 --- a/polynomials-avx2/src/sampling.rs +++ b/polynomials-avx2/src/sampling.rs @@ -1,7 +1,4 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; use crate::serialize::{deserialize_12, serialize_1}; use libcrux_traits::FIELD_MODULUS; @@ -756,34 +753,30 @@ const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ #[inline(always)] pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { - let count = unsafe { - let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); + let field_modulus = mm256_set1_epi16(FIELD_MODULUS); - let potential_coefficients = deserialize_12(input); + let potential_coefficients = deserialize_12(input); - let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); - let good = serialize_1(compare_with_field_modulus); + let compare_with_field_modulus = mm256_cmpgt_epi16(field_modulus, potential_coefficients); + let good = serialize_1(compare_with_field_modulus); - let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; - let lower_shuffles = _mm_loadu_si128(lower_shuffles.as_ptr() as *const __m128i); - let lower_coefficients = _mm256_castsi256_si128(potential_coefficients); - let lower_coefficients = _mm_shuffle_epi8(lower_coefficients, lower_shuffles); + let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; + let lower_shuffles = mm_loadu_si128(&lower_shuffles); + let lower_coefficients = mm256_castsi256_si128(potential_coefficients); + let lower_coefficients = mm_shuffle_epi8(lower_coefficients, lower_shuffles); - _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, lower_coefficients); - let sampled_count = good[0].count_ones(); + mm_storeu_si128(&mut output[0..8], lower_coefficients); + let sampled_count = good[0].count_ones() as usize; - let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; - let upper_shuffles = _mm_loadu_si128(upper_shuffles.as_ptr() as *const __m128i); - let upper_coefficients = _mm256_extractf128_si256(potential_coefficients, 1); - let upper_coefficients = _mm_shuffle_epi8(upper_coefficients, upper_shuffles); + let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; + let upper_shuffles = mm_loadu_si128(&upper_shuffles); + let upper_coefficients = mm256_extracti128_si256::<1>(potential_coefficients); + let upper_coefficients = mm_shuffle_epi8(upper_coefficients, upper_shuffles); - _mm_storeu_si128( - output.as_mut_ptr().offset(sampled_count as isize) as *mut __m128i, - upper_coefficients, - ); + mm_storeu_si128( + &mut output[sampled_count..sampled_count + 8], + upper_coefficients, + ); - sampled_count + good[1].count_ones() - }; - - count as usize + sampled_count + (good[1].count_ones() as usize) } diff --git a/polynomials-avx2/src/serialize.rs b/polynomials-avx2/src/serialize.rs index 39b75ea2c..3483c7afb 100644 --- a/polynomials-avx2/src/serialize.rs +++ b/polynomials-avx2/src/serialize.rs @@ -1,23 +1,17 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; -use crate::portable; -use crate::SIMD256Vector; +use crate::{portable, SIMD256Vector}; #[inline(always)] pub(crate) fn serialize_1(vector: __m256i) -> [u8; 2] { - let bits_packed = unsafe { - let lsb_shifted_up = _mm256_slli_epi16(vector, 15); + let lsb_shifted_up = mm256_slli_epi16::<15>(vector); - let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); - let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); + let low_lanes = mm256_castsi256_si128(lsb_shifted_up); + let high_lanes = mm256_extracti128_si256::<1>(lsb_shifted_up); - let msbs = _mm_packs_epi16(low_lanes, high_lanes); + let msbs = mm_packs_epi16(low_lanes, high_lanes); - _mm_movemask_epi8(msbs) - }; + let bits_packed = mm_movemask_epi8(msbs); let mut serialized = [0u8; 2]; serialized[0] = bits_packed as u8; @@ -28,193 +22,185 @@ pub(crate) fn serialize_1(vector: __m256i) -> [u8; 2] { #[inline(always)] pub(crate) fn deserialize_1(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsb_to_msb = _mm256_set_epi16( - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - ); - - let coefficients = _mm256_set_epi16( - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsb_to_msb); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 7); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) - }; + let shift_lsb_to_msb = mm256_set_epi16( + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + ); + + let coefficients = mm256_set_epi16( + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsb_to_msb); + let coefficients_in_lsb = mm256_srli_epi16::<7>(coefficients_in_msb); + + mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 1) - 1)) } #[inline(always)] pub(crate) fn serialize_4(vector: __m256i) -> [u8; 8] { let mut serialized = [0u8; 16]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - ), - ); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_2_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, - ), - ); - - let combined = _mm256_permutevar8x32_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), - ); - let combined = _mm256_castsi256_si128(combined); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, combined); - } - - serialized[0..8].try_into().unwrap() -} - -#[inline(always)] -pub(crate) fn deserialize_4(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - ); - - let coefficients = _mm256_set_epi16( - bytes[7] as i16, - bytes[7] as i16, - bytes[6] as i16, - bytes[6] as i16, - bytes[5] as i16, - bytes[5] as i16, - bytes[4] as i16, - bytes[4] as i16, - bytes[3] as i16, - bytes[3] as i16, - bytes[2] as i16, - bytes[2] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 4); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) - }; + 1, + ), + ); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_2_combined, + mm256_set_epi8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, + ), + ); + + let combined = mm256_permutevar8x32_epi32( + adjacent_8_combined, + mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), + ); + let combined = mm256_castsi256_si128(combined); + + mm_storeu_bytes_si128(&mut serialized[..], combined); + + serialized[0..8].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_4(bytes: &[u8]) -> __m256i { + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let coefficients = mm256_set_epi16( + bytes[7] as i16, + bytes[7] as i16, + bytes[6] as i16, + bytes[6] as i16, + bytes[5] as i16, + bytes[5] as i16, + bytes[4] as i16, + bytes[4] as i16, + bytes[3] as i16, + bytes[3] as i16, + bytes[2] as i16, + bytes[2] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients_in_lsb = mm256_srli_epi16::<4>(coefficients_in_msb); + + mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 4) - 1)) } #[inline(always)] pub(crate) fn serialize_5(vector: __m256i) -> [u8; 10] { let mut serialized = [0u8; 32]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 22); - - let adjacent_8_combined = _mm256_shuffle_epi32(adjacent_4_combined, 0b00_00_10_00); - let adjacent_8_combined = _mm256_sllv_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_8_combined = _mm256_srli_epi64(adjacent_8_combined, 12); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(5) as *mut __m128i, upper_8); - } + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), + ); + let adjacent_4_combined = mm256_srli_epi64::<22>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi32::<0b00_00_10_00>(adjacent_4_combined); + let adjacent_8_combined = mm256_sllv_epi32( + adjacent_8_combined, + mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_8_combined = mm256_srli_epi64::<12>(adjacent_8_combined); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[5..21], upper_8); serialized[0..10].try_into().unwrap() } @@ -230,95 +216,91 @@ pub(crate) fn deserialize_5(bytes: &[u8]) -> __m256i { pub(crate) fn serialize_10(vector: __m256i) -> [u8; 20] { let mut serialized = [0u8; 32]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 12); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, - 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(10) as *mut __m128i, upper_8); - } + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_4_combined = mm256_srli_epi64::<12>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_4_combined, + mm256_set_epi8( + -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, + 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[10..26], upper_8); serialized[0..20].try_into().unwrap() } #[inline(always)] pub(crate) fn deserialize_10(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - ); - - let lower_coefficients = _mm_loadu_si128(bytes.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(bytes.as_ptr().offset(4) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 6); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 10) - 1)); - - coefficients - }; + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + ); + + let lower_coefficients = mm_loadu_si128(bytes[0..16].try_into().unwrap()); + let lower_coefficients = mm_shuffle_epi8( + lower_coefficients, + mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), + ); + let upper_coefficients = mm_loadu_si128(bytes[4..20].try_into().unwrap()); + let upper_coefficients = mm_shuffle_epi8( + upper_coefficients, + mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), + ); + + let coefficients = mm256_castsi128_si256(lower_coefficients); + let coefficients = mm256_inserti128_si256::<1>(coefficients, upper_coefficients); + + let coefficients = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = mm256_srli_epi16::<6>(coefficients); + let coefficients = mm256_and_si256(coefficients, mm256_set1_epi16((1 << 10) - 1)); + + coefficients } #[inline(always)] @@ -339,93 +321,89 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> __m256i { pub(crate) fn serialize_12(vector: __m256i) -> [u8; 24] { let mut serialized = [0u8; 32]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 8); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, - 10, 9, 8, 5, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(12) as *mut __m128i, upper_8); - } + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), + ); + let adjacent_4_combined = mm256_srli_epi64::<8>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_4_combined, + mm256_set_epi8( + -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, + 10, 9, 8, 5, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[12..28], upper_8); serialized[0..24].try_into().unwrap() } #[inline(always)] pub(crate) fn deserialize_12(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - - let lower_coefficients = _mm_loadu_si128(bytes.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(bytes.as_ptr().offset(8) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 4); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 12) - 1)); - - coefficients - }; + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let lower_coefficients = mm_loadu_si128(bytes[0..16].try_into().unwrap()); + let lower_coefficients = mm_shuffle_epi8( + lower_coefficients, + mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), + ); + let upper_coefficients = mm_loadu_si128(bytes[8..24].try_into().unwrap()); + let upper_coefficients = mm_shuffle_epi8( + upper_coefficients, + mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), + ); + + let coefficients = mm256_castsi128_si256(lower_coefficients); + let coefficients = mm256_inserti128_si256::<1>(coefficients, upper_coefficients); + + let coefficients = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = mm256_srli_epi16::<4>(coefficients); + let coefficients = mm256_and_si256(coefficients, mm256_set1_epi16((1 << 12) - 1)); + + coefficients }