diff --git a/libcrux-ml-kem/src/simd/simd256.rs b/libcrux-ml-kem/src/simd/simd256.rs index 3a0f14d81..510e2ce4d 100644 --- a/libcrux-ml-kem/src/simd/simd256.rs +++ b/libcrux-ml-kem/src/simd/simd256.rs @@ -1,4 +1,8 @@ use crate::{ + arithmetic::{ + BARRETT_MULTIPLIER, BARRETT_R, BARRETT_SHIFT, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R, + MONTGOMERY_SHIFT, + }, constants::FIELD_MODULUS, simd::{portable, simd_trait::*}, }; @@ -9,6 +13,13 @@ pub(crate) struct SIMD256Vector { elements: __m256i, } +#[allow(dead_code)] +fn print_m256i_as_i32s(a: __m256i, prefix: String) { + let mut a_bytes = [0i32; 8]; + unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; + println!("{}: {:?}", prefix, a_bytes); +} + #[allow(non_snake_case)] #[inline(always)] fn ZERO() -> SIMD256Vector { @@ -111,18 +122,52 @@ fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { #[inline(always)] fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector { - let input = portable::PortableVector::from_i32_array(to_i32_array(v)); - let output = portable::PortableVector::barrett_reduce(input); + let reduced = unsafe { + let barrett_multiplier = _mm256_set1_epi32(BARRETT_MULTIPLIER as i32); + let barrett_r_halved = _mm256_set1_epi64x(BARRETT_R >> 1); + let field_modulus = _mm256_set1_epi32(FIELD_MODULUS); - from_i32_array(portable::PortableVector::to_i32_array(output)) + let mut t_low = _mm256_mul_epi32(v.elements, barrett_multiplier); + t_low = _mm256_add_epi64(t_low, barrett_r_halved); + let quotient_low = _mm256_srli_epi64(t_low, BARRETT_SHIFT as i32); + + let mut t_high = _mm256_shuffle_epi32(v.elements, 0b00_11_00_01); + t_high = _mm256_mul_epi32(t_high, barrett_multiplier); + t_high = _mm256_add_epi64(t_high, barrett_r_halved); + let quotient_high = _mm256_slli_epi64(t_high, 6); + + let quotient = _mm256_blend_epi32(quotient_low, quotient_high, 0b1_0_1_0_1_0_1_0); + let quotient = _mm256_mullo_epi32(quotient, field_modulus); + + _mm256_sub_epi32(v.elements, quotient) + }; + + SIMD256Vector { elements: reduced } } #[inline(always)] fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector { - let input = portable::PortableVector::from_i32_array(to_i32_array(v)); - let output = portable::PortableVector::montgomery_reduce(input); + let reduced = unsafe { + let montgomery_shift_mask = _mm256_set1_epi32((1 << MONTGOMERY_SHIFT) - 1); + let field_modulus = _mm256_set1_epi32(FIELD_MODULUS); + let inverse_of_modulus_mod_montgomery_r = + _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32); - from_i32_array(portable::PortableVector::to_i32_array(output)) + let t = _mm256_and_si256(v.elements, montgomery_shift_mask); + let t = _mm256_mullo_epi32(t, inverse_of_modulus_mod_montgomery_r); + + let k = _mm256_and_si256(t, montgomery_shift_mask); + let k = _mm256_slli_epi32(k, 16); + let k = _mm256_srai_epi32(k, 16); + + let k_times_modulus = _mm256_mullo_epi32(k, field_modulus); + let c = _mm256_srai_epi32(k_times_modulus, MONTGOMERY_SHIFT as i32); + let value_high = _mm256_srai_epi32(v.elements, MONTGOMERY_SHIFT as i32); + + _mm256_sub_epi32(value_high, c) + }; + + SIMD256Vector { elements: reduced } } #[inline(always)] @@ -185,12 +230,31 @@ fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector { #[inline(always)] fn ntt_multiply(lhs: &SIMD256Vector, rhs: &SIMD256Vector, zeta0: i32, zeta1: i32) -> SIMD256Vector { - let input1 = portable::PortableVector::from_i32_array(to_i32_array(*lhs)); - let input2 = portable::PortableVector::from_i32_array(to_i32_array(*rhs)); + let result = unsafe { + // Calculate the first element of the output binomial + let zetas = _mm256_set_epi32(-zeta1, 0, zeta1, 0, -zeta0, 0, zeta0, 0); - let output = portable::PortableVector::ntt_multiply(&input1, &input2, zeta0, zeta1); + let left = _mm256_mullo_epi32(lhs.elements, rhs.elements); + let right = montgomery_reduce(SIMD256Vector { elements: left }); - from_i32_array(portable::PortableVector::to_i32_array(output)) + let right = _mm256_mullo_epi32(right.elements, zetas); + + let right = _mm256_shuffle_epi32(right, 0b00_11_00_01); + + let result_0 = _mm256_add_epi32(left, right); + + // Calculate the second element in the output binomial + let rhs_adjacent_swapped = _mm256_shuffle_epi32(rhs.elements, 0b10_11_00_01); + let result_1 = _mm256_mullo_epi32(lhs.elements, rhs_adjacent_swapped); + + let swapped = _mm256_shuffle_epi32(result_1, 0b10_00_00_00); + let result_1 = _mm256_add_epi32(result_1, swapped); + + // Put them together + _mm256_blend_epi32(result_0, result_1, 0b1_0_1_0_1_0_1_0) + }; + + montgomery_reduce(SIMD256Vector { elements: result }) } #[inline(always)]