Skip to content

Commit

Permalink
ntt multiplication on arm, and compress has a const parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
karthikbhargavan committed Apr 27, 2024
1 parent 306b13e commit 5c8a660
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
30 changes: 26 additions & 4 deletions libcrux-ml-kem/src/simd/simd128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,34 @@ fn inv_ntt_layer_2_step(mut v: SIMD128Vector, zeta: i32) -> SIMD128Vector {

#[inline(always)]
fn ntt_multiply(lhs: &SIMD128Vector, rhs: &SIMD128Vector, zeta0: i32, zeta1: i32) -> SIMD128Vector {
let input1 = portable::PortableVector::from_i32_array(SIMD128Vector::to_i32_array(*lhs));
let input2 = portable::PortableVector::from_i32_array(SIMD128Vector::to_i32_array(*rhs));
// montgomery_reduce(a0 * b0 + montgomery_reduce(a1 * b1) * zeta),
// montgomery_reduce(a0 * b1 + a1 * b0)

let output = portable::PortableVector::ntt_multiply(&input1, &input2, zeta0, zeta1);
let a0 = unsafe { vtrn1q_s32(lhs.low, lhs.high) }; // a0, a4, a2, a6
let a1 = unsafe { vtrn2q_s32(lhs.low, lhs.high) }; // a1, a5, a3, a7
let b0 = unsafe { vtrn1q_s32(rhs.low, rhs.high) }; // b0, b4, b2, b6
let b1 = unsafe { vtrn2q_s32(rhs.low, rhs.high) }; // b1, b5, b3, b7

SIMD128Vector::from_i32_array(portable::PortableVector::to_i32_array(output))
let zetas: [i32; 4] = [zeta0, zeta1, -zeta0, -zeta1];
let zeta = unsafe { vld1q_s32(zetas.as_ptr() as *const i32) };

let a0b0 = unsafe { vmulq_s32(a0, b0) };
let a1b1 = unsafe { vmulq_s32(a1, b1) };

let a1b1 = montgomery_reduce_i32x4_t(a1b1);
let a1b1_zeta = unsafe { vmulq_s32(a1b1, zeta) };
let fst = unsafe { vaddq_s32(a0b0, a1b1_zeta) };
let fst = montgomery_reduce_i32x4_t(fst);

let a0b1 = unsafe { vmulq_s32(a0, b1) };
let a1b0 = unsafe { vmulq_s32(a1, b0) };
let snd = unsafe { vaddq_s32(a0b1, a1b0) };
let snd = montgomery_reduce_i32x4_t(snd);

SIMD128Vector {
low: unsafe { vtrn1q_s32(fst, snd) },
high: unsafe { vtrn2q_s32(fst, snd) },
}
}

#[inline(always)]
Expand Down
45 changes: 41 additions & 4 deletions libcrux-ml-kem/src/simd/simd256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ pub(crate) struct SIMD256Vector {
}

#[allow(non_snake_case)]
#[inline(always)]
fn ZERO() -> SIMD256Vector {
SIMD256Vector {
elements: unsafe { _mm256_setzero_si256() },
}
}

#[inline(always)]
fn to_i32_array(v: SIMD256Vector) -> [i32; 8] {
let mut out = [0i32; 8];

Expand All @@ -26,31 +28,37 @@ fn to_i32_array(v: SIMD256Vector) -> [i32; 8] {
out
}

#[inline(always)]
fn from_i32_array(array: [i32; 8]) -> SIMD256Vector {
SIMD256Vector {
elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) },
}
}

#[inline(always)]
fn add_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector {
let c = unsafe { _mm256_set1_epi32(c) };

v.elements = unsafe { _mm256_add_epi32(v.elements, c) };

v
}

#[inline(always)]
fn add(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector {
lhs.elements = unsafe { _mm256_add_epi32(lhs.elements, rhs.elements) };

lhs
}

#[inline(always)]
fn sub(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector {
lhs.elements = unsafe { _mm256_sub_epi32(lhs.elements, rhs.elements) };

lhs
}

#[inline(always)]
fn multiply_by_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector {
let c = unsafe { _mm256_set1_epi32(c) };

Expand All @@ -62,6 +70,7 @@ fn multiply_by_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector {
v
}

#[inline(always)]
fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector {
let c = unsafe { _mm256_set1_epi32(c) };

Expand All @@ -70,18 +79,21 @@ fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i32) -> SIMD256Vector {
v
}

#[inline(always)]
fn shift_right<const SHIFT_BY: i32>(mut v: SIMD256Vector) -> SIMD256Vector {
v.elements = unsafe { _mm256_srai_epi32(v.elements, SHIFT_BY) };

v
}

#[inline(always)]
fn shift_left<const SHIFT_BY: i32>(mut v: SIMD256Vector) -> SIMD256Vector {
v.elements = unsafe { _mm256_slli_epi32(v.elements, SHIFT_BY) };

v
}

#[inline(always)]
fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector {
unsafe {
let field_modulus = _mm256_set1_epi32(FIELD_MODULUS);
Expand All @@ -97,20 +109,23 @@ fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector {
v
}

#[inline(always)]
fn barrett_reduce(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::barrett_reduce(input);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn montgomery_reduce(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::montgomery_reduce(input);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector {
unsafe {
let field_modulus_halved = _mm256_set1_epi32((FIELD_MODULUS - 1) / 2);
Expand All @@ -128,41 +143,47 @@ fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector {
v
}

fn compress(coefficient_bits: u8, v: SIMD256Vector) -> SIMD256Vector {
#[inline(always)]
fn compress<const COEFFICIENT_BITS: i32>(v: SIMD256Vector) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::compress(coefficient_bits, input);
let output = portable::PortableVector::compress::<COEFFICIENT_BITS>(input);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::ntt_layer_1_step(input, zeta1, zeta2);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::ntt_layer_2_step(input, zeta);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn inv_ntt_layer_1_step(v: SIMD256Vector, zeta1: i32, zeta2: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::inv_ntt_layer_1_step(input, zeta1, zeta2);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn inv_ntt_layer_2_step(v: SIMD256Vector, zeta: i32) -> SIMD256Vector {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
let output = portable::PortableVector::inv_ntt_layer_2_step(input, zeta);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn ntt_multiply(lhs: &SIMD256Vector, rhs: &SIMD256Vector, zeta0: i32, zeta1: i32) -> SIMD256Vector {
let input1 = portable::PortableVector::from_i32_array(to_i32_array(*lhs));
let input2 = portable::PortableVector::from_i32_array(to_i32_array(*rhs));
Expand All @@ -172,6 +193,7 @@ fn ntt_multiply(lhs: &SIMD256Vector, rhs: &SIMD256Vector, zeta0: i32, zeta1: i32
from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn serialize_1(mut v: SIMD256Vector) -> u8 {
let mut shifted_bytes = [0i32; 8];

Expand All @@ -193,59 +215,74 @@ fn serialize_1(mut v: SIMD256Vector) -> u8 {
| shifted_bytes[7]) as u8
}

#[inline(always)]
fn deserialize_1(a: u8) -> SIMD256Vector {
let output = portable::PortableVector::deserialize_1(a);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn serialize_4(v: SIMD256Vector) -> [u8; 4] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
portable::PortableVector::serialize_4(input)
}

#[inline(always)]
fn deserialize_4(v: &[u8]) -> SIMD256Vector {
let output = portable::PortableVector::deserialize_4(v);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn serialize_5(v: SIMD256Vector) -> [u8; 5] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));

portable::PortableVector::serialize_5(input)
}

#[inline(always)]
fn deserialize_5(v: &[u8]) -> SIMD256Vector {
let output = portable::PortableVector::deserialize_5(v);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn serialize_10(v: SIMD256Vector) -> [u8; 10] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
portable::PortableVector::serialize_10(input)
}

#[inline(always)]
fn deserialize_10(v: &[u8]) -> SIMD256Vector {
let output = portable::PortableVector::deserialize_10(v);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn serialize_11(v: SIMD256Vector) -> [u8; 11] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));
portable::PortableVector::serialize_11(input)
}

#[inline(always)]
fn deserialize_11(v: &[u8]) -> SIMD256Vector {
let output = portable::PortableVector::deserialize_11(v);

from_i32_array(portable::PortableVector::to_i32_array(output))
}

#[inline(always)]
fn serialize_12(v: SIMD256Vector) -> [u8; 12] {
let input = portable::PortableVector::from_i32_array(to_i32_array(v));

portable::PortableVector::serialize_12(input)
}

#[inline(always)]
fn deserialize_12(v: &[u8]) -> SIMD256Vector {
let output = portable::PortableVector::deserialize_12(v);

Expand Down Expand Up @@ -309,8 +346,8 @@ impl Operations for SIMD256Vector {
compress_1(v)
}

fn compress(coefficient_bits: u8, v: Self) -> Self {
compress(coefficient_bits, v)
fn compress<const COEFFICIENT_BITS: i32>(v: Self) -> Self {
compress::<COEFFICIENT_BITS>(v)
}

fn ntt_layer_1_step(a: Self, zeta1: i32, zeta2: i32) -> Self {
Expand Down

0 comments on commit 5c8a660

Please sign in to comment.