diff --git a/libcrux-ml-kem/src/simd/simd256.rs b/libcrux-ml-kem/src/simd/simd256.rs index a2ce2e36c..17ba61313 100644 --- a/libcrux-ml-kem/src/simd/simd256.rs +++ b/libcrux-ml-kem/src/simd/simd256.rs @@ -17,10 +17,17 @@ pub(crate) struct SIMD256Vector { #[allow(dead_code)] fn print_m256i_as_i32s(a: __m256i, prefix: String) { let mut a_bytes = [0i32; 8]; - unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; + unsafe { _mm256_storeu_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; println!("{}: {:?}", prefix, a_bytes); } +#[allow(dead_code)] +fn print_m256i_as_i64s(a: __m256i, prefix: String) { + let mut a_bytes = [0i64; 4]; + unsafe { _mm256_storeu_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; + println!("{}: {:x?}", prefix, a_bytes); +} + #[allow(non_snake_case)] #[inline(always)] fn ZERO() -> SIMD256Vector { @@ -158,7 +165,7 @@ fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector { let t = _mm256_mullo_epi16(v.elements, inverse_of_modulus_mod_montgomery_r); let k_times_modulus = _mm256_mulhi_epi16(t, field_modulus); let value_high = _mm256_srli_epi32(v.elements, MONTGOMERY_SHIFT as i32); - let res =_mm256_sub_epi16(value_high, k_times_modulus); + let res = _mm256_sub_epi16(value_high, k_times_modulus); let res = _mm256_slli_epi32(res, 16); _mm256_srai_epi32(res, 16) }; @@ -423,8 +430,30 @@ fn deserialize_5(v: &[u8]) -> SIMD256Vector { #[inline(always)] fn serialize_10(v: SIMD256Vector) -> [u8; 10] { - let input = portable::PortableVector::from_i32_array(to_i32_array(v)); - portable::PortableVector::serialize_10(input) + let mut out = [0u8; 16]; + + unsafe { + let shifted = _mm256_sllv_epi32(v.elements, _mm256_set_epi32(10, 0, 10, 0, 10, 0, 10, 0)); + let shifted = _mm256_shuffle_epi32(shifted, 0b_00_11_00_01); + + let bits = _mm256_add_epi32(v.elements, shifted); + + let bits = _mm256_shuffle_epi32(bits, 0b_00_00_10_00); + let bits = _mm256_sllv_epi32(bits, _mm256_set_epi32(0, 0, 0, 12, 0, 0, 0, 12)); + let bits = _mm256_srli_epi64(bits, 12); + + let bits = _mm256_permute4x64_epi64(bits, 0b00_00_10_00); + let shuffle_by = _mm256_set_epi8( + 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, -1, -1, -1, -1, -1, -1, 12, + 11, 10, 9, 8, 4, 3, 2, 1, 0, + ); + + let bits_sequential = _mm256_shuffle_epi8(bits, shuffle_by); + let bits_sequential = _mm256_castsi256_si128(bits_sequential); + _mm_storeu_si128(out.as_mut_ptr() as *mut __m128i, bits_sequential); + }; + + out[0..10].try_into().unwrap() } #[inline(always)]