Skip to content

Commit

Permalink
First shot at AVX2 implementations of Barrett reduction, Montgomery r…
Browse files Browse the repository at this point in the history
…eduction, and ntt_multiply. (#242)
  • Loading branch information
xvzcf authored Apr 30, 2024
1 parent bc488f3 commit f29a751
Showing 1 changed file with 74 additions and 10 deletions.
84 changes: 74 additions & 10 deletions libcrux-ml-kem/src/simd/simd256.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{
arithmetic::{
BARRETT_MULTIPLIER, BARRETT_R, BARRETT_SHIFT, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R,
MONTGOMERY_SHIFT,
},
constants::FIELD_MODULUS,
simd::{portable, simd_trait::*},
};
Expand All @@ -9,6 +13,13 @@ pub(crate) struct SIMD256Vector {
elements: __m256i,
}

#[allow(dead_code)]
fn print_m256i_as_i32s(a: __m256i, prefix: String) {
let mut a_bytes = [0i32; 8];
unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) };
println!("{}: {:?}", prefix, a_bytes);
}

#[allow(non_snake_case)]
#[inline(always)]
fn ZERO() -> SIMD256Vector {
Expand Down Expand Up @@ -111,18 +122,52 @@ fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector {

#[inline(always)]
fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::barrett_reduce(input);
let reduced = unsafe {
let barrett_multiplier = _mm256_set1_epi32(BARRETT_MULTIPLIER as i32);
let barrett_r_halved = _mm256_set1_epi64x(BARRETT_R >> 1);
let field_modulus = _mm256_set1_epi32(FIELD_MODULUS);

from_i32_array(portable::PortableVector::to_i32_array(output))
let mut t_low = _mm256_mul_epi32(v.elements, barrett_multiplier);
t_low = _mm256_add_epi64(t_low, barrett_r_halved);
let quotient_low = _mm256_srli_epi64(t_low, BARRETT_SHIFT as i32);

let mut t_high = _mm256_shuffle_epi32(v.elements, 0b00_11_00_01);
t_high = _mm256_mul_epi32(t_high, barrett_multiplier);
t_high = _mm256_add_epi64(t_high, barrett_r_halved);
let quotient_high = _mm256_slli_epi64(t_high, 6);

let quotient = _mm256_blend_epi32(quotient_low, quotient_high, 0b1_0_1_0_1_0_1_0);
let quotient = _mm256_mullo_epi32(quotient, field_modulus);

_mm256_sub_epi32(v.elements, quotient)
};

SIMD256Vector { elements: reduced }
}

#[inline(always)]
fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::montgomery_reduce(input);
let reduced = unsafe {
let montgomery_shift_mask = _mm256_set1_epi32((1 << MONTGOMERY_SHIFT) - 1);
let field_modulus = _mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
_mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);

from_i32_array(portable::PortableVector::to_i32_array(output))
let t = _mm256_and_si256(v.elements, montgomery_shift_mask);
let t = _mm256_mullo_epi32(t, inverse_of_modulus_mod_montgomery_r);

let k = _mm256_and_si256(t, montgomery_shift_mask);
let k = _mm256_slli_epi32(k, 16);
let k = _mm256_srai_epi32(k, 16);

let k_times_modulus = _mm256_mullo_epi32(k, field_modulus);
let c = _mm256_srai_epi32(k_times_modulus, MONTGOMERY_SHIFT as i32);
let value_high = _mm256_srai_epi32(v.elements, MONTGOMERY_SHIFT as i32);

_mm256_sub_epi32(value_high, c)
};

SIMD256Vector { elements: reduced }
}

#[inline(always)]
Expand Down Expand Up @@ -185,12 +230,31 @@ fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector {

#[inline(always)]
fn ntt_multiply(lhs: &SIMD256Vector, rhs: &SIMD256Vector, zeta0: i32, zeta1: i32) -> SIMD256Vector {
let input1 = portable::PortableVector::from_i32_array(to_i32_array(*lhs));
let input2 = portable::PortableVector::from_i32_array(to_i32_array(*rhs));
let result = unsafe {
// Calculate the first element of the output binomial
let zetas = _mm256_set_epi32(-zeta1, 0, zeta1, 0, -zeta0, 0, zeta0, 0);

let output = portable::PortableVector::ntt_multiply(&input1, &input2, zeta0, zeta1);
let left = _mm256_mullo_epi32(lhs.elements, rhs.elements);
let right = montgomery_reduce(SIMD256Vector { elements: left });

from_i32_array(portable::PortableVector::to_i32_array(output))
let right = _mm256_mullo_epi32(right.elements, zetas);

let right = _mm256_shuffle_epi32(right, 0b00_11_00_01);

let result_0 = _mm256_add_epi32(left, right);

// Calculate the second element in the output binomial
let rhs_adjacent_swapped = _mm256_shuffle_epi32(rhs.elements, 0b10_11_00_01);
let result_1 = _mm256_mullo_epi32(lhs.elements, rhs_adjacent_swapped);

let swapped = _mm256_shuffle_epi32(result_1, 0b10_00_00_00);
let result_1 = _mm256_add_epi32(result_1, swapped);

// Put them together
_mm256_blend_epi32(result_0, result_1, 0b1_0_1_0_1_0_1_0)
};

montgomery_reduce(SIMD256Vector { elements: result })
}

#[inline(always)]
Expand Down

0 comments on commit f29a751

Please sign in to comment.