From 5c8a66013b3fcac270aedfb7c2551d5aaafd291d Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Sat, 27 Apr 2024 17:05:39 +0200 Subject: [PATCH] ntt multiplication on arm, and compress has a const parameter --- libcrux-ml-kem/src/simd/simd128.rs | 30 +++++++++++++++++--- libcrux-ml-kem/src/simd/simd256.rs | 45 +++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/libcrux-ml-kem/src/simd/simd128.rs b/libcrux-ml-kem/src/simd/simd128.rs index 228c75aa8..a2dabfa55 100644 --- a/libcrux-ml-kem/src/simd/simd128.rs +++ b/libcrux-ml-kem/src/simd/simd128.rs @@ -347,12 +347,34 @@ fn inv_ntt_layer_2_step(mut v: SIMD128Vector, zeta: i32) -> SIMD128Vector { #[inline(always)] fn ntt_multiply(lhs: &SIMD128Vector, rhs: &SIMD128Vector, zeta0: i32, zeta1: i32) -> SIMD128Vector { - let input1 = portable::PortableVector::from_i32_array(SIMD128Vector::to_i32_array(*lhs)); - let input2 = portable::PortableVector::from_i32_array(SIMD128Vector::to_i32_array(*rhs)); + // montgomery_reduce(a0 * b0 + montgomery_reduce(a1 * b1) * zeta), + // montgomery_reduce(a0 * b1 + a1 * b0) - let output = portable::PortableVector::ntt_multiply(&input1, &input2, zeta0, zeta1); + let a0 = unsafe { vtrn1q_s32(lhs.low, lhs.high) }; // a0, a4, a2, a6 + let a1 = unsafe { vtrn2q_s32(lhs.low, lhs.high) }; // a1, a5, a3, a7 + let b0 = unsafe { vtrn1q_s32(rhs.low, rhs.high) }; // b0, b4, b2, b6 + let b1 = unsafe { vtrn2q_s32(rhs.low, rhs.high) }; // b1, b5, b3, b7 - SIMD128Vector::from_i32_array(portable::PortableVector::to_i32_array(output)) + let zetas: [i32; 4] = [zeta0, zeta1, -zeta0, -zeta1]; + let zeta = unsafe { vld1q_s32(zetas.as_ptr() as *const i32) }; + + let a0b0 = unsafe { vmulq_s32(a0, b0) }; + let a1b1 = unsafe { vmulq_s32(a1, b1) }; + + let a1b1 = montgomery_reduce_i32x4_t(a1b1); + let a1b1_zeta = unsafe { vmulq_s32(a1b1, zeta) }; + let fst = unsafe { vaddq_s32(a0b0, a1b1_zeta) }; + let fst = montgomery_reduce_i32x4_t(fst); + + let a0b1 = unsafe { vmulq_s32(a0, b1) }; + let a1b0 = unsafe { vmulq_s32(a1, b0) }; + let snd = unsafe { vaddq_s32(a0b1, a1b0) }; + let snd = montgomery_reduce_i32x4_t(snd); + + SIMD128Vector { + low: unsafe { vtrn1q_s32(fst, snd) }, + high: unsafe { vtrn2q_s32(fst, snd) }, + } } #[inline(always)] diff --git a/libcrux-ml-kem/src/simd/simd256.rs b/libcrux-ml-kem/src/simd/simd256.rs index 74fa4dae7..a4c5d5651 100644 --- a/libcrux-ml-kem/src/simd/simd256.rs +++ b/libcrux-ml-kem/src/simd/simd256.rs @@ -10,12 +10,14 @@ pub(crate) struct SIMD256Vector { } #[allow(non_snake_case)] +#[inline(always)] fn ZERO() -> SIMD256Vector { SIMD256Vector { elements: unsafe { _mm256_setzero_si256() }, } } +#[inline(always)] fn to_i32_array(v: SIMD256Vector) -> [i32; 8] { let mut out = [0i32; 8]; @@ -26,12 +28,14 @@ fn to_i32_array(v: SIMD256Vector) -> [i32; 8] { out } +#[inline(always)] fn from_i32_array(array: [i32; 8]) -> SIMD256Vector { SIMD256Vector { elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) }, } } +#[inline(always)] fn add_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector { let c = unsafe { _mm256_set1_epi32(c) }; @@ -39,18 +43,22 @@ fn add_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector { v } + +#[inline(always)] fn add(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { lhs.elements = unsafe { _mm256_add_epi32(lhs.elements, rhs.elements) }; lhs } +#[inline(always)] fn sub(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { lhs.elements = unsafe { _mm256_sub_epi32(lhs.elements, rhs.elements) }; lhs } +#[inline(always)] fn multiply_by_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector { let c = unsafe { _mm256_set1_epi32(c) }; @@ -62,6 +70,7 @@ fn multiply_by_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector { v } +#[inline(always)] fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector { let c = unsafe { _mm256_set1_epi32(c) }; @@ -70,18 +79,21 @@ fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector { v } +#[inline(always)] fn shift_right(mut v: SIMD256Vector) -> SIMD256Vector { v.elements = unsafe { _mm256_srai_epi32(v.elements, SHIFT_BY) }; v } +#[inline(always)] fn shift_left(mut v: SIMD256Vector) -> SIMD256Vector { v.elements = unsafe { _mm256_slli_epi32(v.elements, SHIFT_BY) }; v } +#[inline(always)] fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { unsafe { let field_modulus = _mm256_set1_epi32(FIELD_MODULUS); @@ -97,6 +109,7 @@ fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { v } +#[inline(always)] fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); let output = portable::PortableVector::barrett_reduce(input); @@ -104,6 +117,7 @@ fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector { from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); let output = portable::PortableVector::montgomery_reduce(input); @@ -111,6 +125,7 @@ fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector { from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector { unsafe { let field_modulus_halved = _mm256_set1_epi32((FIELD_MODULUS - 1) / 2); @@ -128,13 +143,15 @@ fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector { v } -fn compress(coefficient_bits: u8, v: SIMD256Vector) -> SIMD256Vector { +#[inline(always)] +fn compress(v: SIMD256Vector) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); - let output = portable::PortableVector::compress(coefficient_bits, input); + let output = portable::PortableVector::compress::(input); from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); let output = portable::PortableVector::ntt_layer_1_step(input, zeta1, zeta2); @@ -142,6 +159,7 @@ fn ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector { from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); let output = portable::PortableVector::ntt_layer_2_step(input, zeta); @@ -149,6 +167,7 @@ fn ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector { from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn inv_ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); let output = portable::PortableVector::inv_ntt_layer_1_step(input, zeta1, zeta2); @@ -156,6 +175,7 @@ fn inv_ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vect from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); let output = portable::PortableVector::inv_ntt_layer_2_step(input, zeta); @@ -163,6 +183,7 @@ fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector { from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn ntt_multiply(lhs: &SIMD256Vector, rhs: &SIMD256Vector, zeta0: i32, zeta1: i32) -> SIMD256Vector { let input1 = portable::PortableVector::from_i32_array(to_i32_array(*lhs)); let input2 = portable::PortableVector::from_i32_array(to_i32_array(*rhs)); @@ -172,6 +193,7 @@ fn ntt_multiply(lhs: &SIMD256Vector, rhs: &SIMD256Vector, zeta0: i32, zeta1: i32 from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn serialize_1(mut v: SIMD256Vector) -> u8 { let mut shifted_bytes = [0i32; 8]; @@ -193,59 +215,74 @@ fn serialize_1(mut v: SIMD256Vector) -> u8 { | shifted_bytes[7]) as u8 } +#[inline(always)] fn deserialize_1(a: u8) -> SIMD256Vector { let output = portable::PortableVector::deserialize_1(a); from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn serialize_4(v: SIMD256Vector) -> [u8; 4] { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); portable::PortableVector::serialize_4(input) } + +#[inline(always)] fn deserialize_4(v: &[u8]) -> SIMD256Vector { let output = portable::PortableVector::deserialize_4(v); from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn serialize_5(v: SIMD256Vector) -> [u8; 5] { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); portable::PortableVector::serialize_5(input) } + +#[inline(always)] fn deserialize_5(v: &[u8]) -> SIMD256Vector { let output = portable::PortableVector::deserialize_5(v); from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn serialize_10(v: SIMD256Vector) -> [u8; 10] { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); portable::PortableVector::serialize_10(input) } + +#[inline(always)] fn deserialize_10(v: &[u8]) -> SIMD256Vector { let output = portable::PortableVector::deserialize_10(v); from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn serialize_11(v: SIMD256Vector) -> [u8; 11] { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); portable::PortableVector::serialize_11(input) } + +#[inline(always)] fn deserialize_11(v: &[u8]) -> SIMD256Vector { let output = portable::PortableVector::deserialize_11(v); from_i32_array(portable::PortableVector::to_i32_array(output)) } +#[inline(always)] fn serialize_12(v: SIMD256Vector) -> [u8; 12] { let input = portable::PortableVector::from_i32_array(to_i32_array(v)); portable::PortableVector::serialize_12(input) } +#[inline(always)] fn deserialize_12(v: &[u8]) -> SIMD256Vector { let output = portable::PortableVector::deserialize_12(v); @@ -309,8 +346,8 @@ impl Operations for SIMD256Vector { compress_1(v) } - fn compress(coefficient_bits: u8, v: Self) -> Self { - compress(coefficient_bits, v) + fn compress(v: Self) -> Self { + compress::(v) } fn ntt_layer_1_step(a: Self, zeta1: i32, zeta2: i32) -> Self {