Skip to content

Commit

Permalink
More safe wrappers around avx2 intrinsics (#283).
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf authored May 17, 2024
1 parent 8fd44b7 commit 6f7f943
Show file tree
Hide file tree
Showing 5 changed files with 733 additions and 590 deletions.
218 changes: 218 additions & 0 deletions polynomials-avx2/src/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,198 @@ pub(crate) use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
pub(crate) use core::arch::x86_64::*;

pub(crate) fn mm256_storeu_si256(output: &mut [i16], vector: __m256i) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr() as *mut __m256i, vector);
}
}
pub(crate) fn mm_storeu_si128(output: &mut [i16], vector: __m128i) {
debug_assert_eq!(output.len(), 8);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector);
}
}

pub(crate) fn mm_storeu_bytes_si128(output: &mut [u8], vector: __m128i) {
debug_assert_eq!(output.len(), 16);
unsafe {
_mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector);
}
}

pub(crate) fn mm_loadu_si128(input: &[u8]) -> __m128i {
debug_assert_eq!(input.len(), 16);
unsafe { _mm_loadu_si128(input.as_ptr() as *const __m128i) }
}

pub(crate) fn mm256_loadu_si256(input: &[i16]) -> __m256i {
debug_assert_eq!(input.len(), 16);
unsafe { _mm256_loadu_si256(input.as_ptr() as *const __m256i) }
}

pub(crate) fn mm256_setzero_si256() -> __m256i {
unsafe { _mm256_setzero_si256() }
}

pub(crate) fn mm_set_epi8(
byte15: i8,
byte14: i8,
byte13: i8,
byte12: i8,
byte11: i8,
byte10: i8,
byte9: i8,
byte8: i8,
byte7: i8,
byte6: i8,
byte5: i8,
byte4: i8,
byte3: i8,
byte2: i8,
byte1: i8,
byte0: i8,
) -> __m128i {
unsafe {
_mm_set_epi8(
byte15, byte14, byte13, byte12, byte11, byte10,
byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0,
)
}
}

pub(crate) fn mm256_set_epi8(
byte31: i8,
byte30: i8,
byte29: i8,
byte28: i8,
byte27: i8,
byte26: i8,
byte25: i8,
byte24: i8,
byte23: i8,
byte22: i8,
byte21: i8,
byte20: i8,
byte19: i8,
byte18: i8,
byte17: i8,
byte16: i8,
byte15: i8,
byte14: i8,
byte13: i8,
byte12: i8,
byte11: i8,
byte10: i8,
byte9: i8,
byte8: i8,
byte7: i8,
byte6: i8,
byte5: i8,
byte4: i8,
byte3: i8,
byte2: i8,
byte1: i8,
byte0: i8,
) -> __m256i {
unsafe {
_mm256_set_epi8(
byte31, byte30, byte29, byte28, byte27, byte26, byte25, byte24, byte23, byte22, byte21,
byte20, byte19, byte18, byte17, byte16, byte15, byte14, byte13, byte12, byte11, byte10,
byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0,
)
}
}

pub(crate) fn mm256_set1_epi16(constant: i16) -> __m256i {
unsafe { _mm256_set1_epi16(constant) }
}
pub(crate) fn mm256_set_epi16(
input15: i16,
input14: i16,
input13: i16,
input12: i16,
input11: i16,
input10: i16,
input9: i16,
input8: i16,
input7: i16,
input6: i16,
input5: i16,
input4: i16,
input3: i16,
input2: i16,
input1: i16,
input0: i16,
) -> __m256i {
unsafe {
_mm256_set_epi16(
input15, input14, input13, input12, input11, input10, input9, input8, input7, input6,
input5, input4, input3, input2, input1, input0,
)
}
}

pub(crate) fn mm_set1_epi16(constant: i16) -> __m128i {
unsafe { _mm_set1_epi16(constant) }
}

pub(crate) fn mm256_set1_epi32(constant: i32) -> __m256i {
unsafe { _mm256_set1_epi32(constant) }
}
pub(crate) fn mm256_set_epi32(
input7: i32,
input6: i32,
input5: i32,
input4: i32,
input3: i32,
input2: i32,
input1: i32,
input0: i32,
) -> __m256i {
unsafe {
_mm256_set_epi32(
input7, input6, input5, input4, input3, input2, input1, input0,
)
}
}

pub(crate) fn mm_add_epi16(lhs: __m128i, rhs: __m128i) -> __m128i {
unsafe { _mm_add_epi16(lhs, rhs) }
}
pub(crate) fn mm256_add_epi16(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_add_epi16(lhs, rhs) }
}
pub(crate) fn mm256_madd_epi16(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_madd_epi16(lhs, rhs) }
}
pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_add_epi32(lhs, rhs) }
}

pub(crate) fn mm256_sub_epi16(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_sub_epi16(lhs, rhs) }
}
pub(crate) fn mm_sub_epi16(lhs: __m128i, rhs: __m128i) -> __m128i {
unsafe { _mm_sub_epi16(lhs, rhs) }
}

pub(crate) fn mm256_mullo_epi16(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_mullo_epi16(lhs, rhs) }
}

pub(crate) fn mm_mullo_epi16(lhs: __m128i, rhs: __m128i) -> __m128i {
unsafe { _mm_mullo_epi16(lhs, rhs) }
}

pub(crate) fn mm256_cmpgt_epi16(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_cmpgt_epi16(lhs, rhs) }
}

pub(crate) fn mm_mulhi_epi16(lhs: __m128i, rhs: __m128i) -> __m128i {
unsafe { _mm_mulhi_epi16(lhs, rhs) }
}

pub(crate) fn mm256_mullo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_mullo_epi32(lhs, rhs) }
}
Expand All @@ -48,6 +219,11 @@ pub(crate) fn mm256_srai_epi16<const SHIFT_BY: i32>(vector: __m256i) -> __m256i
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srai_epi16(vector, SHIFT_BY) }
}
pub(crate) fn mm256_srai_epi32<const SHIFT_BY: i32>(vector: __m256i) -> __m256i {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32);
unsafe { _mm256_srai_epi32(vector, SHIFT_BY) }
}

pub(crate) fn mm256_srli_epi16<const SHIFT_BY: i32>(vector: __m256i) -> __m256i {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_srli_epi16(vector, SHIFT_BY) }
Expand All @@ -57,6 +233,11 @@ pub(crate) fn mm256_srli_epi32<const SHIFT_BY: i32>(vector: __m256i) -> __m256i
unsafe { _mm256_srli_epi32(vector, SHIFT_BY) }
}

pub(crate) fn mm256_srli_epi64<const SHIFT_BY: i32>(vector: __m256i) -> __m256i {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unsafe { _mm256_srli_epi64(vector, SHIFT_BY) }
}

pub(crate) fn mm256_slli_epi16<const SHIFT_BY: i32>(vector: __m256i) -> __m256i {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16);
unsafe { _mm256_slli_epi16(vector, SHIFT_BY) }
Expand All @@ -67,6 +248,12 @@ pub(crate) fn mm256_slli_epi32<const SHIFT_BY: i32>(vector: __m256i) -> __m256i
unsafe { _mm256_slli_epi32(vector, SHIFT_BY) }
}

pub(crate) fn mm_shuffle_epi8(vector: __m128i, control: __m128i) -> __m128i {
unsafe { _mm_shuffle_epi8(vector, control) }
}
pub(crate) fn mm256_shuffle_epi8(vector: __m256i, control: __m256i) -> __m256i {
unsafe { _mm256_shuffle_epi8(vector, control) }
}
pub(crate) fn mm256_shuffle_epi32<const CONTROL: i32>(vector: __m256i) -> __m256i {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_shuffle_epi32(vector, CONTROL) }
Expand All @@ -92,11 +279,17 @@ pub(crate) fn mm256_unpackhi_epi32(lhs: __m256i, rhs: __m256i) -> __m256i {
pub(crate) fn mm256_castsi256_si128(vector: __m256i) -> __m128i {
unsafe { _mm256_castsi256_si128(vector) }
}
pub(crate) fn mm256_castsi128_si256(vector: __m128i) -> __m256i {
unsafe { _mm256_castsi128_si256(vector) }
}

pub(crate) fn mm256_cvtepi16_epi32(vector: __m128i) -> __m256i {
unsafe { _mm256_cvtepi16_epi32(vector) }
}

pub(crate) fn mm_packs_epi16(lhs: __m128i, rhs: __m128i) -> __m128i {
unsafe { _mm_packs_epi16(lhs, rhs) }
}
pub(crate) fn mm256_packs_epi32(lhs: __m256i, rhs: __m256i) -> __m256i {
unsafe { _mm256_packs_epi32(lhs, rhs) }
}
Expand All @@ -105,3 +298,28 @@ pub(crate) fn mm256_extracti128_si256<const CONTROL: i32>(vector: __m256i) -> __
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_extracti128_si256(vector, CONTROL) }
}

pub(crate) fn mm256_inserti128_si256<const CONTROL: i32>(
vector: __m256i,
vector_i128: __m128i,
) -> __m256i {
debug_assert!(CONTROL == 0 || CONTROL == 1);
unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) }
}

pub(crate) fn mm256_blend_epi16<const CONTROL: i32>(lhs: __m256i, rhs: __m256i) -> __m256i {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unsafe { _mm256_blend_epi16(lhs, rhs, CONTROL) }
}

pub(crate) fn mm_movemask_epi8(vector: __m128i) -> i32 {
unsafe { _mm_movemask_epi8(vector) }
}

pub(crate) fn mm256_permutevar8x32_epi32(vector: __m256i, control: __m256i) -> __m256i {
unsafe { _mm256_permutevar8x32_epi32(vector, control) }
}

pub(crate) fn mm256_sllv_epi32(vector: __m256i, counts: __m256i) -> __m256i {
unsafe { _mm256_sllv_epi32(vector, counts) }
}
22 changes: 8 additions & 14 deletions polynomials-avx2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use crate::intrinsics::*;
use libcrux_traits::Operations;

#[cfg(test)]
Expand All @@ -24,24 +21,21 @@ pub struct SIMD256Vector {
#[inline(always)]
fn zero() -> SIMD256Vector {
SIMD256Vector {
elements: unsafe { _mm256_setzero_si256() },
elements: mm256_setzero_si256(),
}
}

#[inline(always)]
fn to_i16_array(v: SIMD256Vector) -> [i16; 16] {
let mut out = [0i16; 16];
let mut output = [0i16; 16];
mm256_storeu_si256(&mut output[..], v.elements);

unsafe {
_mm256_storeu_si256(out.as_mut_ptr() as *mut __m256i, v.elements);
}

out
output
}
#[inline(always)]
fn from_i16_array(array: &[i16]) -> SIMD256Vector {
SIMD256Vector {
elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) },
elements: mm256_loadu_si256(array),
}
}

Expand Down Expand Up @@ -187,9 +181,9 @@ impl Operations for SIMD256Vector {
serialize::serialize_1(vector.elements)
}

fn deserialize_1(input: &[u8]) -> Self {
fn deserialize_1(bytes: &[u8]) -> Self {
Self {
elements: serialize::deserialize_1(input),
elements: serialize::deserialize_1(bytes),
}
}

Expand Down
Loading

0 comments on commit 6f7f943

Please sign in to comment.