Skip to content

Commit

Permalink
avx2 implementation of serialize_10
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf committed May 2, 2024
1 parent 727d034 commit 477df73
Showing 1 changed file with 33 additions and 4 deletions.
37 changes: 33 additions & 4 deletions libcrux-ml-kem/src/simd/simd256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ pub(crate) struct SIMD256Vector {
#[allow(dead_code)]
fn print_m256i_as_i32s(a: __m256i, prefix: String) {
let mut a_bytes = [0i32; 8];
unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
unsafe { _mm256_storeu_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
println!("{}: {:?}", prefix, a_bytes);
}

#[allow(dead_code)]
fn print_m256i_as_i64s(a: __m256i, prefix: String) {
let mut a_bytes = [0i64; 4];
unsafe { _mm256_storeu_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
println!("{}: {:x?}", prefix, a_bytes);
}

#[allow(non_snake_case)]
#[inline(always)]
fn ZERO() -> SIMD256Vector {
Expand Down Expand Up @@ -158,7 +165,7 @@ fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector {
let t = _mm256_mullo_epi16(v.elements, inverse_of_modulus_mod_montgomery_r);
let k_times_modulus = _mm256_mulhi_epi16(t, field_modulus);
let value_high = _mm256_srli_epi32(v.elements, MONTGOMERY_SHIFT as i32);
let res =_mm256_sub_epi16(value_high, k_times_modulus);
let res = _mm256_sub_epi16(value_high, k_times_modulus);
let res = _mm256_slli_epi32(res, 16);
_mm256_srai_epi32(res, 16)
};
Expand Down Expand Up @@ -423,8 +430,30 @@ fn deserialize_5(v: &[u8]) -> SIMD256Vector {

#[inline(always)]
fn serialize_10(v: SIMD256Vector) -> [u8; 10] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
portable::PortableVector::serialize_10(input)
let mut out = [0u8; 16];

unsafe {
let shifted = _mm256_sllv_epi32(v.elements, _mm256_set_epi32(10, 0, 10, 0, 10, 0, 10, 0));
let shifted = _mm256_shuffle_epi32(shifted, 0b_00_11_00_01);

let bits = _mm256_add_epi32(v.elements, shifted);

let bits = _mm256_shuffle_epi32(bits, 0b_00_00_10_00);
let bits = _mm256_sllv_epi32(bits, _mm256_set_epi32(0, 0, 0, 12, 0, 0, 0, 12));
let bits = _mm256_srli_epi64(bits, 12);

let bits = _mm256_permute4x64_epi64(bits, 0b00_00_10_00);
let shuffle_by = _mm256_set_epi8(
8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, -1, -1, -1, -1, -1, -1, 12,
11, 10, 9, 8, 4, 3, 2, 1, 0,
);

let bits_sequential = _mm256_shuffle_epi8(bits, shuffle_by);
let bits_sequential = _mm256_castsi256_si128(bits_sequential);
_mm_storeu_si128(out.as_mut_ptr() as *mut __m128i, bits_sequential);
};

out[0..10].try_into().unwrap()
}

#[inline(always)]
Expand Down

0 comments on commit 477df73

Please sign in to comment.