From f59d36fcce94f6ab9dcd24b3fe0ec092e12107f8 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Fri, 10 May 2024 10:34:08 +0200 Subject: [PATCH 01/59] sha3 for arm64 --- libcrux-sha3/Cargo.toml | 1 + libcrux-sha3/benches/sha3.rs | 21 ++- libcrux-sha3/src/lib.rs | 3 + libcrux-sha3/src/rust_simd.rs | 117 +++++++++++++ libcrux-sha3/src/rust_simd/sha3_arm64.rs | 211 +++++++++++++++++++++++ libcrux-sha3/tests/sha3.rs | 22 +++ 6 files changed, 374 insertions(+), 1 deletion(-) create mode 100644 libcrux-sha3/src/rust_simd.rs create mode 100644 libcrux-sha3/src/rust_simd/sha3_arm64.rs create mode 100644 libcrux-sha3/tests/sha3.rs diff --git a/libcrux-sha3/Cargo.toml b/libcrux-sha3/Cargo.toml index 5ec4ed9be..6326ac64f 100644 --- a/libcrux-sha3/Cargo.toml +++ b/libcrux-sha3/Cargo.toml @@ -13,6 +13,7 @@ libcrux-hacl = { version = "0.0.2-pre.2", path = "../sys/hacl", features = [ "sha3", ] } libcrux-platform = { version = "0.0.2-pre.2", path = "../sys/platform" } +hex = { version = "0.4.3", features = ["serde"] } # This is only required for verification. # The hax config is set by the hax toolchain. diff --git a/libcrux-sha3/benches/sha3.rs b/libcrux-sha3/benches/sha3.rs index 6ff6628e7..2bda837e1 100644 --- a/libcrux-sha3/benches/sha3.rs +++ b/libcrux-sha3/benches/sha3.rs @@ -19,7 +19,7 @@ pub fn fmt(x: usize) -> String { } macro_rules! impl_comp { - ($fun:ident, $libcrux:expr, $rust_crypto:ty, $openssl:expr) => { + ($fun:ident, $libcrux:expr, $arm64:ident, $rust_crypto:ty, $openssl:expr) => { // Comparing libcrux performance for different payload sizes and other implementations. fn $fun(c: &mut Criterion) { const PAYLOAD_SIZES: [usize; 1] = [1024 * 1024 * 10]; @@ -43,6 +43,21 @@ macro_rules! impl_comp { }, ); + #[cfg(feature = "simd128")] + group.bench_with_input( + BenchmarkId::new("arm64", fmt(*payload_size)), + payload_size, + |b, payload_size| { + b.iter_batched( + || randombytes(*payload_size), + |payload| { + let _d: [u8; digest_size($libcrux)] = rust_simd::$arm64(&payload); + }, + BatchSize::SmallInput, + ) + }, + ); + // group.bench_with_input( // BenchmarkId::new("RustCrypto", fmt(*payload_size)), // payload_size, @@ -109,24 +124,28 @@ macro_rules! impl_comp { impl_comp!( Sha3_224, Algorithm::Sha3_224, + sha3_224, sha3::Sha3_224, MessageDigest::sha3_224() // libcrux_pqclean::sha3_256 // This is wrong, but it's not actually used. ); impl_comp!( Sha3_256, Algorithm::Sha3_256, + sha3_256, sha3::Sha3_256, MessageDigest::sha3_256() // libcrux_pqclean::sha3_256 ); impl_comp!( Sha3_384, Algorithm::Sha3_384, + sha3_384, sha3::Sha3_384, MessageDigest::sha3_384() // libcrux_pqclean::sha3_384 ); impl_comp!( Sha3_512, Algorithm::Sha3_512, + sha3_512, sha3::Sha3_512, MessageDigest::sha3_512() // libcrux_pqclean::sha3_512 ); diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 63c1e61c0..4eee94b68 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -11,6 +11,9 @@ /// A Sha3x4 API pub mod x4; +#[cfg(feature = "simd128")] +pub mod rust_simd; + pub type Sha3_224Digest = [u8; 28]; pub type Sha3_256Digest = [u8; 32]; pub type Sha3_384Digest = [u8; 48]; diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs new file mode 100644 index 000000000..a2c7fa185 --- /dev/null +++ b/libcrux-sha3/src/rust_simd.rs @@ -0,0 +1,117 @@ +mod sha3_arm64; +use sha3_arm64::*; + + +#[inline(always)] +fn squeeze_first_block2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + squeeze2::(s, out0, out1) +} + +#[inline(always)] +fn squeeze_next_block2(s: &mut KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + keccakf1600(s); + squeeze2::(s, out0, out1) +} + +#[inline(always)] +pub fn squeeze_first_three_blocks2(s: &mut KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + squeeze_first_block2::(s, out0, out1); + squeeze_next_block2::(s, &mut out0[RATE..2*RATE], &mut out1[RATE..2*RATE]); + squeeze_next_block2::(s, &mut out0[2*RATE..3*RATE], &mut out1[2*RATE..3*RATE]) +} + +#[inline(always)] +fn squeeze_last2(mut s: KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + let mut b0 = [0u8; 200]; + let mut b1 = [0u8; 200]; + squeeze_next_block2::(&mut s, &mut b0, &mut b1); + out0.copy_from_slice(&b0[0..out0.len()]); + out1.copy_from_slice(&b1[0..out1.len()]); +} + +#[inline(always)] +fn squeeze_first_and_last2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + let mut b0 = [0u8; 200]; + let mut b1 = [0u8; 200]; + squeeze_first_block2::(s, &mut b0, &mut b1); + out0.copy_from_slice(&b0[0..out0.len()]); + out1.copy_from_slice(&b1[0..out1.len()]); +} + +#[inline(always)] +fn keccak(data0: &[u8], data1: &[u8], out0: &mut [u8], out1: &mut [u8]) { + debug_assert!(data0.len() == data1.len()); + debug_assert!(out0.len() == out1.len()); + let mut s = KeccakStateX2::new(); + for i in 0..data0.len()/RATE { + absorb_block2::(&mut s, &data0[i*RATE..(i+1)*RATE], &data1[i*RATE..(i+1)*RATE]); + } + let rem = data0.len() % RATE; + absorb_final2::(&mut s, &data0[data0.len()-rem..data0.len()], &data1[data1.len()-rem..data1.len()]); + + let blocks = out0.len()/RATE; + let last = out0.len() - out0.len()%RATE; + + if blocks == 0 { + squeeze_first_and_last2::(&s, out0, out1) + } else { + squeeze_first_block2::(&s, out0, out1); + for i in 1..blocks { + squeeze_next_block2::(&mut s, &mut out0[i*RATE..(i+1)*RATE], &mut out1[i*RATE..(i+1)*RATE]); + } + if last < out0.len() {squeeze_last2::(s, &mut out0[last..], &mut out1[last..])} + } +} + +pub fn sha3_224(data: &[u8]) -> [u8;28] { + let mut d0 = [0u8; 28]; + let mut d1 = [0u8; 28]; + keccak::<144,0x06u8>(data, data, &mut d0, &mut d1); + d0 +} + +pub fn sha3_256(data: &[u8]) -> [u8;32] { + let mut d0 = [0u8; 32]; + let mut d1 = [0u8; 32]; + keccak::<136, 0x06u8>(data, data, &mut d0, &mut d1); + d0 +} + +pub fn sha3_384(data: &[u8]) -> [u8;48] { + let mut d0 = [0u8; 48]; + let mut d1 = [0u8; 48]; + keccak::<104, 0x06u8>(data, data, &mut d0, &mut d1); + d0 +} + +pub fn sha3_512(data: &[u8]) -> [u8;64] { + let mut d0 = [0u8; 64]; + let mut d1 = [0u8; 64]; + keccak::<72,0x06u8>(data, data, &mut d0, &mut d1); + d0 +} + +pub fn shake128(data: &[u8]) -> [u8; LEN] { + let mut d0 = [0u8; LEN]; + let mut d1 = [0u8; LEN]; + keccak::<168, 0x1fu8>(data, data, &mut d0, &mut d1); + d0 +} + +pub fn shake128x2_init_absorb_final(data0: &[u8], data1: &[u8]) -> KeccakStateX2 { + let mut s = KeccakStateX2::new(); + absorb_final2::<168, 0x1fu8>(&mut s,data0,data1); + s +} + +pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakStateX2, out0:&mut [u8], out1:&mut [u8]) { + squeeze_first_three_blocks2::<168>(s, out0, out1) +} + +pub fn shake128x2_squeeze_next_block(s: &mut KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + squeeze_next_block2::<168>(s, out0, out1) +} + +pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { + keccak::<136, 0x1fu8>(input0, input1, out0, out1); +} diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs new file mode 100644 index 000000000..ab87fd08b --- /dev/null +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -0,0 +1,211 @@ +use core::arch::aarch64::*; + +// This file optimizes for the stable Rust Neon Intrinsics +// If we want to use the unstable neon-sha3 instructions, we could use: +// veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 +// These instructions might speed up our code even more. + + +/// Incremental state +#[cfg_attr(hax, hax_lib::opaque_type)] +#[derive(Clone, Copy)] +pub struct KeccakStateX2 { + pub st: [[uint64x2_t; 5]; 5], +} + +#[inline(always)] +fn rotate_left(x:uint64x2_t) -> uint64x2_t { + debug_assert!(LEFT+RIGHT == 64); + //unsafe { vsriq_n_u64::(vshlq_n_u64::(x), x) } + unsafe { veorq_u64(vshlq_n_u64::(x), vshrq_n_u64::(x)) } +} + + +impl KeccakStateX2 { + /// Create a new Shake128 x4 state. + #[inline(always)] + pub(crate) fn new() -> Self { + unsafe{ + Self { + st: [[vdupq_n_u64(0); 5]; 5], + } + } + } +} + +#[inline(always)] +fn theta(s: &mut KeccakStateX2) { + let mut c : [uint64x2_t; 5] = unsafe {[vdupq_n_u64(0); 5]}; + for i in 0..5 { + c[i] = unsafe {veorq_u64(s.st[0][i],veorq_u64(s.st[1][i], + veorq_u64(s.st[2][i],veorq_u64(s.st[3][i],s.st[4][i]))))}; + } + + for i in 0..5 { + let t = unsafe { veorq_u64(c[(i + 4) % 5], rotate_left::<1,63>(c[(i+1)%5])) }; + // let t = unsafe { vrax1q_u64(c[(i + 1) % 5], c[(i+4)%5]) }; + for j in 0..5 { + s.st[j][i] = unsafe {veorq_u64(s.st[j][i],t)}; + } + } +} + +const _ROTC: [usize;24] = + [1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14,]; + +#[inline(always)] +fn rho(s: &mut KeccakStateX2) { + s.st[0][0] = s.st[0][0]; + s.st[0][1] = rotate_left::<1,63>(s.st[0][1]); + s.st[0][2] = rotate_left::<62,2>(s.st[0][2]); + s.st[0][3] = rotate_left::<28,36>(s.st[0][3]); + s.st[0][4] = rotate_left::<27,37>(s.st[0][4]); + s.st[1][0] = rotate_left::<36,28>(s.st[1][0]); + s.st[1][1] = rotate_left::<44,20>(s.st[1][1]); + s.st[1][2] = rotate_left::<6,58>(s.st[1][2]); + s.st[1][3] = rotate_left::<55,9>(s.st[1][3]); + s.st[1][4] = rotate_left::<20,44>(s.st[1][4]); + s.st[2][0] = rotate_left::<3,61>(s.st[2][0]); + s.st[2][1] = rotate_left::<10,54>(s.st[2][1]); + s.st[2][2] = rotate_left::<43,21>(s.st[2][2]); + s.st[2][3] = rotate_left::<25,39>(s.st[2][3]); + s.st[2][4] = rotate_left::<39,25>(s.st[2][4]); + s.st[3][0] = rotate_left::<41,23>(s.st[3][0]); + s.st[3][1] = rotate_left::<45,19>(s.st[3][1]); + s.st[3][2] = rotate_left::<15,49>(s.st[3][2]); + s.st[3][3] = rotate_left::<21,43>(s.st[3][3]); + s.st[3][4] = rotate_left::<8,56>(s.st[3][4]); + s.st[4][0] = rotate_left::<18,46>(s.st[4][0]); + s.st[4][1] = rotate_left::<2,62>(s.st[4][1]); + s.st[4][2] = rotate_left::<61,3>(s.st[4][2]); + s.st[4][3] = rotate_left::<56,8>(s.st[4][3]); + s.st[4][4] = rotate_left::<14,50>(s.st[4][4]); +} + + +const _PI : [usize;24] = [ + 6, 12, 18, 24, 3, 9, 10, 16, 22, 1, 7, 13, 19, 20, 4, 5, 11, 17, 23, 2, 8, 14, 15, 21, +]; + +#[inline(always)] +fn pi(s: &mut KeccakStateX2) { + let old = s.st.clone(); + s.st[0][1] = old[1][1]; + s.st[0][2] = old[2][2]; + s.st[0][3] = old[3][3]; + s.st[0][4] = old[4][4]; + s.st[1][0] = old[0][3]; + s.st[1][1] = old[1][4]; + s.st[1][2] = old[2][0]; + s.st[1][3] = old[3][1]; + s.st[1][4] = old[4][2]; + s.st[2][0] = old[0][1]; + s.st[2][1] = old[1][2]; + s.st[2][2] = old[2][3]; + s.st[2][3] = old[3][4]; + s.st[2][4] = old[4][0]; + s.st[3][0] = old[0][4]; + s.st[3][1] = old[1][0]; + s.st[3][2] = old[2][1]; + s.st[3][3] = old[3][2]; + s.st[3][4] = old[4][3]; + s.st[4][0] = old[0][2]; + s.st[4][1] = old[1][3]; + s.st[4][2] = old[2][4]; + s.st[4][3] = old[3][0]; + s.st[4][4] = old[4][1]; +} + +#[inline(always)] +fn chi(s: &mut KeccakStateX2) { + let mut c : [uint64x2_t; 5] = unsafe {[vdupq_n_u64(0); 5]}; + for i in 0..5 { + for j in 0..5 { + c[j] = s.st[i][j] + } + for j in 0..5 { + s.st[i][j] = unsafe{ veorq_u64(s.st[i][j], vbicq_u64(c[(j + 2) % 5], c[(j + 1) % 5])) }; + } + } +} + +const ROUNDCONSTANTS: [u64;24] = [ + 0x0000_0000_0000_0001u64, 0x0000_0000_0000_8082u64, 0x8000_0000_0000_808au64, + 0x8000_0000_8000_8000u64, 0x0000_0000_0000_808bu64, 0x0000_0000_8000_0001u64, + 0x8000_0000_8000_8081u64, 0x8000_0000_0000_8009u64, 0x0000_0000_0000_008au64, + 0x0000_0000_0000_0088u64, 0x0000_0000_8000_8009u64, 0x0000_0000_8000_000au64, + 0x0000_0000_8000_808bu64, 0x8000_0000_0000_008bu64, 0x8000_0000_0000_8089u64, + 0x8000_0000_0000_8003u64, 0x8000_0000_0000_8002u64, 0x8000_0000_0000_0080u64, + 0x0000_0000_0000_800au64, 0x8000_0000_8000_000au64, 0x8000_0000_8000_8081u64, + 0x8000_0000_0000_8080u64, 0x0000_0000_8000_0001u64, 0x8000_0000_8000_8008u64, +]; + +#[inline(always)] +fn iota(s: &mut KeccakStateX2, round:usize) { + let c = unsafe { vdupq_n_u64(ROUNDCONSTANTS[round]) }; + s.st[0][0] = unsafe { veorq_u64(s.st[0][0], c) }; +} + +#[inline(always)] +pub(crate) fn keccakf1600(s: &mut KeccakStateX2) { + for i in 0..24 { + theta(s); + rho(s); + pi(s); + chi(s); + iota(s, i); + } +} + +#[inline(always)] +pub(crate) fn absorb_block2(s: &mut KeccakStateX2, block0: &[u8], block1: &[u8]) { + debug_assert!(RATE == block0.len() && RATE == block1.len() && RATE % 8 == 0); + for i in 0..RATE/16 { + let v0 = unsafe { vld1q_u64(block0[16*i..16*i+16].as_ptr() as *const u64) }; + let v1 = unsafe { vld1q_u64(block1[16*i..16*i+16].as_ptr() as *const u64) }; + s.st[(2*i)/5][(2*i)%5] = unsafe { veorq_u64(s.st[(2*i)/5][(2*i)%5], vtrn1q_u64(v0,v1)) }; + s.st[(2*i+1)/5][(2*i+1)%5] = unsafe { veorq_u64(s.st[(2*i+1)/5][(2*i+1)%5], vtrn2q_u64(v0,v1)) }; + } + if RATE%16 != 0 { + let i = (RATE/8 - 1)/5; + let j = (RATE/8 - 1)%5; + let mut u = [0u64; 2]; + u[0] = u64::from_le_bytes(block0[RATE-8..].try_into().unwrap()); + u[1] = u64::from_le_bytes(block1[RATE-8..].try_into().unwrap()); + s.st[i][j] = unsafe { vld1q_u64(u.as_ptr() as *const u64) }; + } + keccakf1600(s) +} + +#[inline(always)] +pub(crate) fn absorb_final2(s: &mut KeccakStateX2, last0: &[u8], last1: &[u8]) { + debug_assert!(last0.len() == last1.len() && last0.len() < RATE); + let mut b0 = [0u8; 200]; + let mut b1 = [0u8; 200]; + b0[0..last0.len()].copy_from_slice(last0); + b1[0..last1.len()].copy_from_slice(last1); + b0[last0.len()] = DELIM; + b0[RATE-1] = b0[RATE-1] | 128u8; + b1[last1.len()] = DELIM; + b1[RATE-1] = b1[RATE-1] | 128u8; + absorb_block2::(s, &b0[0..RATE], &b1[0..RATE]) +} + +#[inline(always)] +pub(crate) fn squeeze2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { + for i in 0..RATE/16 { + let v0 = unsafe { vtrn1q_u64(s.st[(2*i)/5][(2*i)%5], s.st[(2*i+1)/5][(2*i+1)%5]) }; + let v1 = unsafe { vtrn2q_u64(s.st[(2*i)/5][(2*i)%5], s.st[(2*i+1)/5][(2*i+1)%5]) }; + unsafe { vst1q_u64(out0[16*i..16*i+16].as_mut_ptr() as *mut u64, v0) }; + unsafe { vst1q_u64(out1[16*i..16*i+16].as_mut_ptr() as *mut u64, v1) }; + } + if RATE%16 != 0 { + debug_assert!(RATE % 8 == 0); + let i = (RATE/8 - 1)/5; + let j = (RATE/8 - 1)%5; + let mut u = [0u8;16]; + unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s.st[i][j])}; + out0[RATE-8..RATE].copy_from_slice(&u[0..8]); + out1[RATE-8..RATE].copy_from_slice(&u[8..16]); + } +} \ No newline at end of file diff --git a/libcrux-sha3/tests/sha3.rs b/libcrux-sha3/tests/sha3.rs new file mode 100644 index 000000000..342d229db --- /dev/null +++ b/libcrux-sha3/tests/sha3.rs @@ -0,0 +1,22 @@ +#[test] +fn sha3_kat_oneshot() { + let d256 = libcrux_sha3::sha256(b"Hello, World!"); + let expected256 = "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"; + assert_eq!(hex::encode(&d256), expected256); + + let dshake = libcrux_sha3::shake128::<42>(b"Hello, World!"); + let expectedshake = "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; + assert_eq!(hex::encode(&dshake), expectedshake); +} + +#[cfg(feature = "simd128")] +#[test] +fn sha3_simd_kat_oneshot() { + let d256 = libcrux_sha3::rust_simd::sha3_256(b"Hello, World!"); + let expected256 = "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"; + assert_eq!(hex::encode(&d256), expected256); + + let dshake = libcrux_sha3::rust_simd::shake128::<42>(b"Hello, World!"); + let expectedshake = "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; + assert_eq!(hex::encode(&dshake), expectedshake); +} From 73a6bdb3878bbd9d1aedba74713b0d8531f14b4d Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Fri, 10 May 2024 13:00:32 +0200 Subject: [PATCH 02/59] hooked up simd arm64 to ml-kem --- libcrux-ml-kem/src/hash_functions.rs | 131 ++++++++++++++++++++++- libcrux-sha3/src/rust_simd.rs | 19 +++- libcrux-sha3/src/rust_simd/sha3_arm64.rs | 7 +- 3 files changed, 146 insertions(+), 11 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 67d862cfa..0b11ca41c 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -1,22 +1,76 @@ #![allow(non_snake_case)] use crate::constants::H_DIGEST_SIZE; + +#[cfg(feature = "simd128")] +use libcrux_sha3::rust_simd; +#[cfg(not(feature = "simd128"))] use libcrux_sha3::{x4::Shake128StateX4, *}; +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn G(input: &[u8]) -> [u8; 64] { + rust_simd::sha3_512(input) +} +#[cfg(not(feature = "simd128"))] +#[inline(always)] pub(crate) fn G(input: &[u8]) -> [u8; digest_size(Algorithm::Sha3_512)] { sha512(input) } +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + rust_simd::sha3_256(input) +} +#[cfg(not(feature = "simd128"))] +#[inline(always)] pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { sha256(input) } +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { + rust_simd::shake256::(input) +} +#[cfg(not(feature = "simd128"))] +#[inline(always)] pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { shake256::(input) } +#[cfg(feature = "simd128")] +pub(crate) type Shake128x4State = [rust_simd::KeccakStateX2;2]; + +#[cfg(not(feature = "simd128"))] +pub(crate) type Shake128x4State = Shake128StateX4; + +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut states = [rust_simd::shake128x2_init();2]; + match K { + 2 => { + rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); + }, + 3 => { + rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); + rust_simd::shake128x2_absorb_final(&mut states[1],&input[2],&input[2]); + }, + _ => { + rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); + rust_simd::shake128x2_absorb_final(&mut states[1],&input[2],&input[3]); + }, + } + states +} + +#[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128StateX4 { +pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { debug_assert!(K == 2 || K == 3 || K == 4); let mut state = Shake128StateX4::new(); @@ -32,9 +86,40 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128StateX4 { pub(crate) const BLOCK_SIZE: usize = 168; pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn squeeze_three_blocks ( + state: &mut Shake128x4State, +) -> [[u8; THREE_BLOCKS]; K] { + let mut out0 = [0u8; THREE_BLOCKS]; + let mut out1 = [0u8; THREE_BLOCKS]; + let mut out2 = [0u8; THREE_BLOCKS]; + let mut out3 = [0u8; THREE_BLOCKS]; + let mut out = [[0u8; THREE_BLOCKS]; K]; + + match K { + 2 => { rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0, &mut out1); + out[0] = out0; + out[1] = out1; } + 3 => { rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0, &mut out1); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2, &mut out3); + out[0] = out0; + out[1] = out1; + out[2] = out2; } + _ => { rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0, &mut out1); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2, &mut out3); + out[0] = out0; + out[1] = out1; + out[2] = out2; + out[3] = out3; } + } + out +} + +#[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn squeeze_three_blocks( - xof_state: &mut Shake128StateX4, + xof_state: &mut Shake128x4State, ) -> [[u8; THREE_BLOCKS]; K] { let output: [[u8; THREE_BLOCKS]; K] = xof_state.squeeze_blocks(); let mut out = [[0u8; THREE_BLOCKS]; K]; @@ -44,9 +129,41 @@ pub(crate) fn squeeze_three_blocks( out } +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn squeeze_block( + state: &mut Shake128x4State, +) -> [[u8; BLOCK_SIZE]; K] { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + let mut out2 = [0u8; BLOCK_SIZE]; + let mut out3 = [0u8; BLOCK_SIZE]; + + let mut out = [[0u8; BLOCK_SIZE]; K]; + + match K { + 2 => { rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); + out[0] = out0; + out[1] = out1; } + 3 => { rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); + rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); + out[0] = out0; + out[1] = out1; + out[2] = out2; } + _ => { rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); + rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); + out[0] = out0; + out[1] = out1; + out[2] = out2; + out[3] = out3; } + } + out +} + +#[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn squeeze_block( - xof_state: &mut Shake128StateX4, + xof_state: &mut Shake128x4State, ) -> [[u8; BLOCK_SIZE]; K] { let output: [[u8; BLOCK_SIZE]; K] = xof_state.squeeze_blocks(); let mut out = [[0u8; BLOCK_SIZE]; K]; @@ -59,7 +176,13 @@ pub(crate) fn squeeze_block( /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn free_state(_xof_state: Shake128x4State) { +} + +#[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn free_state(xof_state: Shake128StateX4) { +pub(crate) fn free_state(xof_state: Shake128x4State) { xof_state.free_memory(); } diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index a2c7fa185..88ef5274c 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -1,6 +1,7 @@ mod sha3_arm64; use sha3_arm64::*; +pub use sha3_arm64::KeccakStateX2; #[inline(always)] fn squeeze_first_block2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { @@ -66,7 +67,7 @@ fn keccak(data0: &[u8], data1: &[u8], out0: &m pub fn sha3_224(data: &[u8]) -> [u8;28] { let mut d0 = [0u8; 28]; let mut d1 = [0u8; 28]; - keccak::<144,0x06u8>(data, data, &mut d0, &mut d1); + keccak::<144, 0x06u8>(data, data, &mut d0, &mut d1); d0 } @@ -98,12 +99,22 @@ pub fn shake128(data: &[u8]) -> [u8; LEN] { d0 } -pub fn shake128x2_init_absorb_final(data0: &[u8], data1: &[u8]) -> KeccakStateX2 { - let mut s = KeccakStateX2::new(); - absorb_final2::<168, 0x1fu8>(&mut s,data0,data1); +pub fn shake256(data: &[u8]) -> [u8; LEN] { + let mut d0 = [0u8; LEN]; + let mut d1 = [0u8; LEN]; + keccak::<136, 0x1fu8>(data, data, &mut d0, &mut d1); + d0 +} + +pub fn shake128x2_init() -> KeccakStateX2 { + let s = KeccakStateX2::new(); s } +pub fn shake128x2_absorb_final(s:&mut KeccakStateX2, data0: &[u8], data1: &[u8]) { + absorb_final2::<168, 0x1fu8>(s,data0,data1); +} + pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakStateX2, out0:&mut [u8], out1:&mut [u8]) { squeeze_first_three_blocks2::<168>(s, out0, out1) } diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs index ab87fd08b..491f20df4 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -161,8 +161,8 @@ pub(crate) fn keccakf1600(s: &mut KeccakStateX2) { pub(crate) fn absorb_block2(s: &mut KeccakStateX2, block0: &[u8], block1: &[u8]) { debug_assert!(RATE == block0.len() && RATE == block1.len() && RATE % 8 == 0); for i in 0..RATE/16 { - let v0 = unsafe { vld1q_u64(block0[16*i..16*i+16].as_ptr() as *const u64) }; - let v1 = unsafe { vld1q_u64(block1[16*i..16*i+16].as_ptr() as *const u64) }; + let v0 = unsafe { vld1q_u64(block0[16*i..(16*i)+16].as_ptr() as *const u64) }; + let v1 = unsafe { vld1q_u64(block1[16*i..(16*i)+16].as_ptr() as *const u64) }; s.st[(2*i)/5][(2*i)%5] = unsafe { veorq_u64(s.st[(2*i)/5][(2*i)%5], vtrn1q_u64(v0,v1)) }; s.st[(2*i+1)/5][(2*i+1)%5] = unsafe { veorq_u64(s.st[(2*i+1)/5][(2*i+1)%5], vtrn2q_u64(v0,v1)) }; } @@ -172,7 +172,8 @@ pub(crate) fn absorb_block2(s: &mut KeccakStateX2, block0: &[u let mut u = [0u64; 2]; u[0] = u64::from_le_bytes(block0[RATE-8..].try_into().unwrap()); u[1] = u64::from_le_bytes(block1[RATE-8..].try_into().unwrap()); - s.st[i][j] = unsafe { vld1q_u64(u.as_ptr() as *const u64) }; + let uvec = unsafe { vld1q_u64(u.as_ptr() as *const u64) }; + s.st[i][j] = unsafe { veorq_u64(s.st[i][j], uvec)}; } keccakf1600(s) } From 28ad26522e26ab99a11186462f9188bf8a9e6154 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Fri, 10 May 2024 13:46:11 +0200 Subject: [PATCH 03/59] disable nightly instructions --- libcrux-sha3/src/lib.rs | 2 +- libcrux-sha3/src/rust_simd/sha3_arm64.rs | 9 ++++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 4eee94b68..4f17cf9ac 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -11,7 +11,7 @@ /// A Sha3x4 API pub mod x4; -#[cfg(feature = "simd128")] +//#[cfg(feature = "simd128")] pub mod rust_simd; pub type Sha3_224Digest = [u8; 28]; diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs index 491f20df4..47a9b387c 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -39,11 +39,15 @@ fn theta(s: &mut KeccakStateX2) { for i in 0..5 { c[i] = unsafe {veorq_u64(s.st[0][i],veorq_u64(s.st[1][i], veorq_u64(s.st[2][i],veorq_u64(s.st[3][i],s.st[4][i]))))}; + // Needs nightly + // c[i] = unsafe {veor3q_u64(s.st[0][i],s.st[1][i], + // veor3q_u64(s.st[2][i],s.st[3][i],s.st[4][i]))}; } for i in 0..5 { let t = unsafe { veorq_u64(c[(i + 4) % 5], rotate_left::<1,63>(c[(i+1)%5])) }; - // let t = unsafe { vrax1q_u64(c[(i + 1) % 5], c[(i+4)%5]) }; + // Needs nightly + // let t = unsafe { vrax1q_u64(c[(i+4)%5], c[(i+1)%5]) }; for j in 0..5 { s.st[j][i] = unsafe {veorq_u64(s.st[j][i],t)}; } @@ -55,6 +59,7 @@ const _ROTC: [usize;24] = #[inline(always)] fn rho(s: &mut KeccakStateX2) { + // If combined with theta, we could use Nightly: vxarq_u64 s.st[0][0] = s.st[0][0]; s.st[0][1] = rotate_left::<1,63>(s.st[0][1]); s.st[0][2] = rotate_left::<62,2>(s.st[0][2]); @@ -125,6 +130,8 @@ fn chi(s: &mut KeccakStateX2) { } for j in 0..5 { s.st[i][j] = unsafe{ veorq_u64(s.st[i][j], vbicq_u64(c[(j + 2) % 5], c[(j + 1) % 5])) }; + // Needs nightly + //s.st[i][j] = unsafe{ vbcaxq_u64(s.st[i][j], c[(j + 2) % 5], c[(j + 1) % 5]) }; } } } From 4c61c4fcdfb3ef08dafbd8f5d21d5a43969b933c Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Sun, 12 May 2024 14:33:49 +0200 Subject: [PATCH 04/59] eliminated some memmoves --- libcrux-ml-kem/src/polynomial.rs | 6 ++---- libcrux-ml-kem/src/sampling.rs | 29 +++++++++++---------------- polynomials-aarch64/src/lib.rs | 6 +++--- polynomials-aarch64/src/rejsample.rs | 5 ++--- polynomials-aarch64/src/simd128ops.rs | 2 +- polynomials/src/lib.rs | 15 +++++++------- traits/src/lib.rs | 4 ++-- 7 files changed, 29 insertions(+), 38 deletions(-) diff --git a/libcrux-ml-kem/src/polynomial.rs b/libcrux-ml-kem/src/polynomial.rs index da38d39cb..02dd071c0 100644 --- a/libcrux-ml-kem/src/polynomial.rs +++ b/libcrux-ml-kem/src/polynomial.rs @@ -31,13 +31,11 @@ impl PolynomialRingElement { } #[inline(always)] - pub(crate) fn from_i16_array(a: [i16; 256]) -> Self { + pub(crate) fn from_i16_array(a: &[i16]) -> Self { let mut result = PolynomialRingElement::ZERO(); for i in 0..VECTORS_IN_RING_ELEMENT { result.coefficients[i] = Vector::from_i16_array( - a[i * FIELD_ELEMENTS_IN_VECTOR..(i + 1) * FIELD_ELEMENTS_IN_VECTOR] - .try_into() - .unwrap(), + &a[i * FIELD_ELEMENTS_IN_VECTOR..(i + 1) * FIELD_ELEMENTS_IN_VECTOR] ); } result diff --git a/libcrux-ml-kem/src/sampling.rs b/libcrux-ml-kem/src/sampling.rs index fe864bb9e..95bd23dca 100644 --- a/libcrux-ml-kem/src/sampling.rs +++ b/libcrux-ml-kem/src/sampling.rs @@ -47,28 +47,23 @@ use crate::{ fn sample_from_uniform_distribution_next( randomness: [[u8; N]; K], sampled_coefficients: &mut [usize; K], - out: &mut [[i16; 256]; K], + out: &mut [[i16; 272]; K], ) -> bool { // Would be great to trigger auto-vectorization or at least loop unrolling here for i in 0..K { for r in 0..N / 24 { - let remaining = COEFFICIENTS_IN_RING_ELEMENT - sampled_coefficients[i]; - if remaining > 0 { - let (sampled, vec) = Vector::rej_sample(&randomness[i][r * 24..(r * 24) + 24]); - let pick = if sampled <= remaining { - sampled - } else { - remaining - }; - out[i][sampled_coefficients[i]..sampled_coefficients[i] + pick] - .copy_from_slice(&vec[0..pick]); - sampled_coefficients[i] += pick; + if sampled_coefficients[i] < COEFFICIENTS_IN_RING_ELEMENT { + let out0 = out[i][sampled_coefficients[i]..sampled_coefficients[i]+16].as_mut(); + let sampled = Vector::rej_sample(&randomness[i][r * 24..(r * 24) + 24], out0); + sampled_coefficients[i] += sampled; } } } let mut done = true; for i in 0..K { - if sampled_coefficients[i] < COEFFICIENTS_IN_RING_ELEMENT { + if sampled_coefficients[i] >= COEFFICIENTS_IN_RING_ELEMENT { + sampled_coefficients[i] = COEFFICIENTS_IN_RING_ELEMENT; + } else { done = false } } @@ -80,7 +75,7 @@ pub(super) fn sample_from_xof( seeds: [[u8; 34]; K], ) -> [PolynomialRingElement; K] { let mut sampled_coefficients: [usize; K] = [0; K]; - let mut out: [[i16; 256]; K] = [[0; 256]; K]; + let mut out: [[i16; 272]; K] = [[0; 272]; K]; let mut xof_state = absorb(seeds); let randomness = squeeze_three_blocks(&mut xof_state); @@ -107,7 +102,7 @@ pub(super) fn sample_from_xof( // XXX: We have to manually free the state here due to a Eurydice issue. free_state(xof_state); - out.map(PolynomialRingElement::::from_i16_array) + out.map(|s| PolynomialRingElement::::from_i16_array(&s[0..256])) } /// Given a series of uniformly random bytes in `randomness`, for some number `eta`, @@ -192,7 +187,7 @@ fn sample_from_binomial_distribution_2( } } } - PolynomialRingElement::from_i16_array(sampled_i16s) + PolynomialRingElement::from_i16_array(&sampled_i16s) } #[cfg_attr(hax, hax_lib::requires(randomness.len() == 3 * 64))] @@ -229,7 +224,7 @@ fn sample_from_binomial_distribution_3( } } } - PolynomialRingElement::from_i16_array(sampled_i16s) + PolynomialRingElement::from_i16_array(&sampled_i16s) } #[inline(always)] diff --git a/polynomials-aarch64/src/lib.rs b/polynomials-aarch64/src/lib.rs index 49af38cf3..4e894ba76 100644 --- a/polynomials-aarch64/src/lib.rs +++ b/polynomials-aarch64/src/lib.rs @@ -22,7 +22,7 @@ impl Operations for SIMD128Vector { to_i16_array(v) } - fn from_i16_array(array: [i16; 16]) -> Self { + fn from_i16_array(array: &[i16]) -> Self { from_i16_array(array) } @@ -157,7 +157,7 @@ impl Operations for SIMD128Vector { deserialize_12(a) } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rejsample::rej_sample(a) + fn rej_sample(a: &[u8], out:&mut [i16]) -> usize { + rejsample::rej_sample(a, out) } } diff --git a/polynomials-aarch64/src/rejsample.rs b/polynomials-aarch64/src/rejsample.rs index e667bc3ab..00cdb5471 100644 --- a/polynomials-aarch64/src/rejsample.rs +++ b/polynomials-aarch64/src/rejsample.rs @@ -768,7 +768,7 @@ const IDX_TABLE: [[u8; 16]; 256] = [ ]; #[inline(always)] -pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { +pub(crate) fn rej_sample(a: &[u8], out:&mut [i16]) -> usize { let neon_bits: [u16; 8] = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80]; let bits = _vld1q_u16(&neon_bits); let fm = _vdupq_n_s16(3328); @@ -787,9 +787,8 @@ pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { let shifted1 = _vreinterpretq_s16_u8(_vqtbl1q_u8(_vreinterpretq_u8_s16(input.high), index_vec1)); - let mut out: [i16; 16] = [0i16; 16]; let idx0 = pick0 as usize; _vst1q_s16(&mut out[0..8], shifted0); _vst1q_s16(&mut out[idx0..idx0 + 8], shifted1); - ((pick0 + pick1) as usize, out) + (pick0 + pick1) as usize } diff --git a/polynomials-aarch64/src/simd128ops.rs b/polynomials-aarch64/src/simd128ops.rs index a6fb55a46..20cf3a840 100644 --- a/polynomials-aarch64/src/simd128ops.rs +++ b/polynomials-aarch64/src/simd128ops.rs @@ -27,7 +27,7 @@ pub(crate) fn to_i16_array(v: SIMD128Vector) -> [i16; 16] { } #[inline(always)] -pub(crate) fn from_i16_array(array: [i16; 16]) -> SIMD128Vector { +pub(crate) fn from_i16_array(array: &[i16]) -> SIMD128Vector { SIMD128Vector { low: _vld1q_s16(&array[0..8]), high: _vld1q_s16(&array[8..16]), diff --git a/polynomials/src/lib.rs b/polynomials/src/lib.rs index c04bdd4eb..f57db39aa 100644 --- a/polynomials/src/lib.rs +++ b/polynomials/src/lib.rs @@ -213,8 +213,8 @@ fn to_i16_array(v: PortableVector) -> [i16; FIELD_ELEMENTS_IN_VECTOR] { } #[inline(always)] -fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> PortableVector { - PortableVector { elements: array } +fn from_i16_array(array: &[i16]) -> PortableVector { + PortableVector { elements: array[0..16].try_into().unwrap() } } #[inline(always)] @@ -1041,8 +1041,7 @@ fn deserialize_12(bytes: &[u8]) -> PortableVector { } #[inline(always)] -fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - let mut result = [0i16; 16]; +fn rej_sample(a: &[u8], result: &mut[i16]) -> usize { let mut sampled = 0; for bytes in a.chunks(3) { let b1 = bytes[0] as i16; @@ -1061,7 +1060,7 @@ fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { sampled += 1 } } - (sampled, result) + sampled } impl Operations for PortableVector { @@ -1073,7 +1072,7 @@ impl Operations for PortableVector { to_i16_array(v) } - fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> Self { + fn from_i16_array(array: &[i16]) -> Self { from_i16_array(array) } @@ -1208,7 +1207,7 @@ impl Operations for PortableVector { deserialize_12(a) } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rej_sample(a) + fn rej_sample(a: &[u8], out:&mut [i16]) -> usize { + rej_sample(a, out) } } diff --git a/traits/src/lib.rs b/traits/src/lib.rs index 06391ff82..809fe7eb2 100644 --- a/traits/src/lib.rs +++ b/traits/src/lib.rs @@ -8,7 +8,7 @@ pub trait Operations: Copy + Clone { fn ZERO() -> Self; fn to_i16_array(v: Self) -> [i16; 16]; - fn from_i16_array(array: [i16; 16]) -> Self; + fn from_i16_array(array: &[i16]) -> Self; // Basic arithmetic fn add(lhs: Self, rhs: &Self) -> Self; @@ -61,7 +61,7 @@ pub trait Operations: Copy + Clone { fn serialize_12(a: Self) -> [u8; 24]; fn deserialize_12(a: &[u8]) -> Self; - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]); + fn rej_sample(a: &[u8], out:&mut [i16]) -> usize; } // hax does not support trait with default implementations, so we use the following patter From 35d60d101a6ce0c59f8f6c882f0edee60552d1a3 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Sun, 12 May 2024 18:53:53 +0200 Subject: [PATCH 05/59] removed some memmove in squeeze_three_blocks --- libcrux-ml-kem/src/hash_functions.rs | 33 +++++++++++++--------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 0b11ca41c..570cb0313 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -91,27 +91,24 @@ pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; pub(crate) fn squeeze_three_blocks ( state: &mut Shake128x4State, ) -> [[u8; THREE_BLOCKS]; K] { - let mut out0 = [0u8; THREE_BLOCKS]; - let mut out1 = [0u8; THREE_BLOCKS]; - let mut out2 = [0u8; THREE_BLOCKS]; - let mut out3 = [0u8; THREE_BLOCKS]; let mut out = [[0u8; THREE_BLOCKS]; K]; + let mut extra = [0u8; THREE_BLOCKS]; match K { - 2 => { rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0, &mut out1); - out[0] = out0; - out[1] = out1; } - 3 => { rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0, &mut out1); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2, &mut out3); - out[0] = out0; - out[1] = out1; - out[2] = out2; } - _ => { rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0, &mut out1); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2, &mut out3); - out[0] = out0; - out[1] = out1; - out[2] = out2; - out[3] = out3; } + 2 => { let (out0,out1) = out.split_at_mut(1); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0[0], &mut out1[0]); + } + 3 => { let (out0,out12) = out.split_at_mut(1); + let (out1,out2) = out12.split_at_mut(1); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0[0], &mut out1[0]); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2[0], &mut extra); + } + _ => { let (out0,out123) = out.split_at_mut(1); + let (out1,out23) = out123.split_at_mut(1); + let (out2,out3) = out23.split_at_mut(1); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0[0], &mut out1[0]); + rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2[0], &mut out3[0]); + } } out } From 96485c2eb1ae5240372a1e61ebb790a0c6add12e Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Sun, 12 May 2024 21:03:34 +0200 Subject: [PATCH 06/59] PRFxN for Arm64 --- libcrux-ml-kem/src/hash_functions.rs | 30 +++++++++++++++++++ libcrux-ml-kem/src/ind_cpa.rs | 43 +++++++++++++++------------- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 570cb0313..fdf702295 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -40,6 +40,36 @@ pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { shake256::(input) } +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + let mut out = [[0u8; LEN]; K]; + let mut extra = [0u8; LEN]; + + match K { + 2 => { let (out0,out1) = out.split_at_mut(1); + rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + } + 3 => { let (out0,out12) = out.split_at_mut(1); + let (out1,out2) = out12.split_at_mut(1); + rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + rust_simd::shake256x2(&input[2], &input[2], &mut out2[0], &mut extra); + } + _ => { let (out0,out123) = out.split_at_mut(1); + let (out1,out23) = out123.split_at_mut(1); + let (out2,out3) = out23.split_at_mut(1); + rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + rust_simd::shake256x2(&input[2], &input[3], &mut out2[0], &mut out3[0]); + } + } + out +} +#[cfg(not(feature = "simd128"))] +#[inline(always)] +pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + core::array::from_fn(|i| shake256::(&input[i])) +} + #[cfg(feature = "simd128")] pub(crate) type Shake128x4State = [rust_simd::KeccakStateX2;2]; diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index c11e67772..ff7960476 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -2,7 +2,7 @@ use libcrux_polynomials::Operations; use crate::{ constants::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT, SHARED_SECRET_SIZE}, - hash_functions::{G, PRF}, + hash_functions::{G, PRF, PRFxN}, helper::cloop, matrix::*, ntt::{ntt_binomially_sampled_ring_element, ntt_vector_u}, @@ -70,18 +70,20 @@ fn sample_ring_element_cbd< const ETA2: usize, Vector: Operations, >( - prf_input: &mut [u8; 33], - domain_separator: &mut u8, -) -> [PolynomialRingElement; K] { + prf_input: [u8; 33], + mut domain_separator: u8, +) -> ([PolynomialRingElement; K], u8) { let mut error_1 = [PolynomialRingElement::::ZERO(); K]; + let mut prf_inputs = [prf_input; K]; for i in 0..K { - prf_input[32] = *domain_separator; - *domain_separator += 1; - - let prf_output: [u8; ETA2_RANDOMNESS_SIZE] = PRF(prf_input); - error_1[i] = sample_from_binomial_distribution::(&prf_output); + prf_inputs[i][32] = domain_separator; + domain_separator += 1; + } + let prf_outputs : [[u8; ETA2_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); + for i in 0..K { + error_1[i] = sample_from_binomial_distribution::(&prf_outputs[i]); } - error_1 + (error_1, domain_separator) } /// Sample a vector of ring elements from a centered binomial distribution and @@ -93,17 +95,18 @@ fn sample_vector_cbd_then_ntt< const ETA_RANDOMNESS_SIZE: usize, Vector: Operations, >( - mut prf_input: [u8; 33], + prf_input: [u8; 33], mut domain_separator: u8, ) -> ([PolynomialRingElement; K], u8) { let mut re_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut prf_inputs = [prf_input; K]; for i in 0..K { - prf_input[32] = domain_separator; + prf_inputs[i][32] = domain_separator; domain_separator += 1; - - let prf_output: [u8; ETA_RANDOMNESS_SIZE] = PRF(&prf_input); - - let r = sample_from_binomial_distribution::(&prf_output); + } + let prf_outputs : [[u8; ETA_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); + for i in 0..K { + let r = sample_from_binomial_distribution::(&prf_outputs[i]); re_as_ntt[i] = ntt_binomially_sampled_ring_element(r); } (re_as_ntt, domain_separator) @@ -289,16 +292,16 @@ pub(crate) fn encrypt< // end for // rˆ := NTT(r) let mut prf_input: [u8; 33] = into_padded_array(randomness); - let (r_as_ntt, mut domain_separator) = + let (r_as_ntt, domain_separator) = sample_vector_cbd_then_ntt::(prf_input, 0); // for i from 0 to k−1 do // e1[i] := CBD_{η2}(PRF(r,N)) // N := N + 1 // end for - let error_1 = sample_ring_element_cbd::( - &mut prf_input, - &mut domain_separator, + let (error_1, domain_separator) = sample_ring_element_cbd::( + prf_input, + domain_separator, ); // e_2 := CBD{η2}(PRF(r, N)) From 857802de4c982ae71620884a8251ae5bb475f676 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Mon, 13 May 2024 12:54:08 +0200 Subject: [PATCH 07/59] sha3 instruc --- libcrux-sha3/src/rust_simd/sha3_arm64.rs | 132 +++++++++++++---------- 1 file changed, 77 insertions(+), 55 deletions(-) diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs index 47a9b387c..f7ebb13ce 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -16,6 +16,7 @@ pub struct KeccakStateX2 { #[inline(always)] fn rotate_left(x:uint64x2_t) -> uint64x2_t { debug_assert!(LEFT+RIGHT == 64); + // The following looks faster but is actually significantly slower //unsafe { vsriq_n_u64::(vshlq_n_u64::(x), x) } unsafe { veorq_u64(vshlq_n_u64::(x), vshrq_n_u64::(x)) } } @@ -32,59 +33,85 @@ impl KeccakStateX2 { } } } - + #[inline(always)] -fn theta(s: &mut KeccakStateX2) { - let mut c : [uint64x2_t; 5] = unsafe {[vdupq_n_u64(0); 5]}; - for i in 0..5 { - c[i] = unsafe {veorq_u64(s.st[0][i],veorq_u64(s.st[1][i], - veorq_u64(s.st[2][i],veorq_u64(s.st[3][i],s.st[4][i]))))}; - // Needs nightly - // c[i] = unsafe {veor3q_u64(s.st[0][i],s.st[1][i], - // veor3q_u64(s.st[2][i],s.st[3][i],s.st[4][i]))}; - } - - for i in 0..5 { - let t = unsafe { veorq_u64(c[(i + 4) % 5], rotate_left::<1,63>(c[(i+1)%5])) }; - // Needs nightly - // let t = unsafe { vrax1q_u64(c[(i+4)%5], c[(i+1)%5]) }; - for j in 0..5 { - s.st[j][i] = unsafe {veorq_u64(s.st[j][i],t)}; - } - } +fn _veor5q_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t, d: uint64x2_t, e: uint64x2_t) -> uint64x2_t { + let ab = unsafe {veorq_u64(a,b)}; + let cd = unsafe {veorq_u64(c,d)}; + let abcd = unsafe {veorq_u64(ab,cd)}; + unsafe {veorq_u64(abcd,e)} + // Needs nightly+neon-sha3 + //unsafe {veor3q_u64(veor3q_u64(a,b,c),d,e)} +} + +#[inline(always)] +fn _vrax1q_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + unsafe { veorq_u64(a, rotate_left::<1,63>(b)) } + // Needs nightly+neon-sha3 + //unsafe { vrax1q_u64(a, b) } +} + +#[inline(always)] +fn _vxarq_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + let ab = unsafe { veorq_u64(a, b) }; + rotate_left::(ab) + // Needs nightly+neon-sha3 + // unsafe { vxarq_u64::(a,b) } +} + +#[inline(always)] +fn _vbcaxq_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t) -> uint64x2_t { + unsafe{ veorq_u64(a, vbicq_u64(b, c)) } + // Needs nightly+neon-sha3 + // unsafe{ vbcaxq_u64(a, b, c) } +} + + + + +#[inline(always)] +fn theta(s: &KeccakStateX2) -> [uint64x2_t; 5] { + let c: [uint64x2_t;5] = core::array::from_fn(|j| _veor5q_u64(s.st[0][j],s.st[1][j],s.st[2][j],s.st[3][j],s.st[4][j])); + let t : [uint64x2_t; 5] = core::array::from_fn(|j| _vrax1q_u64(c[(j+4)%5], c[(j+1)%5])); + t } const _ROTC: [usize;24] = [1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14,]; #[inline(always)] -fn rho(s: &mut KeccakStateX2) { - // If combined with theta, we could use Nightly: vxarq_u64 - s.st[0][0] = s.st[0][0]; - s.st[0][1] = rotate_left::<1,63>(s.st[0][1]); - s.st[0][2] = rotate_left::<62,2>(s.st[0][2]); - s.st[0][3] = rotate_left::<28,36>(s.st[0][3]); - s.st[0][4] = rotate_left::<27,37>(s.st[0][4]); - s.st[1][0] = rotate_left::<36,28>(s.st[1][0]); - s.st[1][1] = rotate_left::<44,20>(s.st[1][1]); - s.st[1][2] = rotate_left::<6,58>(s.st[1][2]); - s.st[1][3] = rotate_left::<55,9>(s.st[1][3]); - s.st[1][4] = rotate_left::<20,44>(s.st[1][4]); - s.st[2][0] = rotate_left::<3,61>(s.st[2][0]); - s.st[2][1] = rotate_left::<10,54>(s.st[2][1]); - s.st[2][2] = rotate_left::<43,21>(s.st[2][2]); - s.st[2][3] = rotate_left::<25,39>(s.st[2][3]); - s.st[2][4] = rotate_left::<39,25>(s.st[2][4]); - s.st[3][0] = rotate_left::<41,23>(s.st[3][0]); - s.st[3][1] = rotate_left::<45,19>(s.st[3][1]); - s.st[3][2] = rotate_left::<15,49>(s.st[3][2]); - s.st[3][3] = rotate_left::<21,43>(s.st[3][3]); - s.st[3][4] = rotate_left::<8,56>(s.st[3][4]); - s.st[4][0] = rotate_left::<18,46>(s.st[4][0]); - s.st[4][1] = rotate_left::<2,62>(s.st[4][1]); - s.st[4][2] = rotate_left::<61,3>(s.st[4][2]); - s.st[4][3] = rotate_left::<56,8>(s.st[4][3]); - s.st[4][4] = rotate_left::<14,50>(s.st[4][4]); +fn theta_rho(s: &mut KeccakStateX2, t: [uint64x2_t; 5]) { + // If combined with last step of theta, we could use Nightly: vxarq_u64 + + s.st[0][0] = unsafe { veorq_u64(s.st[0][0],t[0]) }; + s.st[1][0] = _vxarq_u64::<36,28>(s.st[1][0],t[0]); + s.st[2][0] = _vxarq_u64::<3,61>(s.st[2][0],t[0]); + s.st[3][0] = _vxarq_u64::<41,23>(s.st[3][0],t[0]); + s.st[4][0] = _vxarq_u64::<18,46>(s.st[4][0],t[0]); + + s.st[0][1] = _vxarq_u64::<1,63>(s.st[0][1],t[1]); + s.st[1][1] = _vxarq_u64::<44,20>(s.st[1][1],t[1]); + s.st[2][1] = _vxarq_u64::<10,54>(s.st[2][1],t[1]); + s.st[3][1] = _vxarq_u64::<45,19>(s.st[3][1],t[1]); + s.st[4][1] = _vxarq_u64::<2,62>(s.st[4][1],t[1]); + + s.st[0][2] = _vxarq_u64::<62,2>(s.st[0][2],t[2]); + s.st[1][2] = _vxarq_u64::<6,58>(s.st[1][2],t[2]); + s.st[2][2] = _vxarq_u64::<43,21>(s.st[2][2],t[2]); + s.st[3][2] = _vxarq_u64::<15,49>(s.st[3][2],t[2]); + s.st[4][2] = _vxarq_u64::<61,3>(s.st[4][2],t[2]); + + s.st[0][3] = _vxarq_u64::<28,36>(s.st[0][3],t[3]); + s.st[1][3] = _vxarq_u64::<55,9>(s.st[1][3],t[3]); + s.st[2][3] = _vxarq_u64::<25,39>(s.st[2][3],t[3]); + s.st[3][3] = _vxarq_u64::<21,43>(s.st[3][3],t[3]); + s.st[4][3] = _vxarq_u64::<56,8>(s.st[4][3],t[3]); + + s.st[0][4] = _vxarq_u64::<27,37>(s.st[0][4],t[4]); + s.st[1][4] = _vxarq_u64::<20,44>(s.st[1][4],t[4]); + s.st[2][4] = _vxarq_u64::<39,25>(s.st[2][4],t[4]); + s.st[3][4] = _vxarq_u64::<8,56>(s.st[3][4],t[4]); + s.st[4][4] = _vxarq_u64::<14,50>(s.st[4][4],t[4]); } @@ -123,15 +150,10 @@ fn pi(s: &mut KeccakStateX2) { #[inline(always)] fn chi(s: &mut KeccakStateX2) { - let mut c : [uint64x2_t; 5] = unsafe {[vdupq_n_u64(0); 5]}; + let old = s.st; for i in 0..5 { for j in 0..5 { - c[j] = s.st[i][j] - } - for j in 0..5 { - s.st[i][j] = unsafe{ veorq_u64(s.st[i][j], vbicq_u64(c[(j + 2) % 5], c[(j + 1) % 5])) }; - // Needs nightly - //s.st[i][j] = unsafe{ vbcaxq_u64(s.st[i][j], c[(j + 2) % 5], c[(j + 1) % 5]) }; + s.st[i][j] = _vbcaxq_u64(s.st[i][j], old[i][(j + 2) % 5], old[i][(j + 1) % 5]); } } } @@ -156,8 +178,8 @@ fn iota(s: &mut KeccakStateX2, round:usize) { #[inline(always)] pub(crate) fn keccakf1600(s: &mut KeccakStateX2) { for i in 0..24 { - theta(s); - rho(s); + let t = theta(s); + theta_rho(s,t); pi(s); chi(s); iota(s, i); From bc97a585770f1d32f03587bec71ca8206d55129c Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Mon, 13 May 2024 20:34:51 +0200 Subject: [PATCH 08/59] rust sha3 made generic --- libcrux-ml-kem/src/hash_functions.rs | 2 +- libcrux-sha3/src/rust_simd.rs | 96 ++----- libcrux-sha3/src/rust_simd/sha3_arm64.rs | 281 +++++++++------------ libcrux-sha3/src/rust_simd/sha3_generic.rs | 219 ++++++++++++++++ libcrux-sha3/src/rust_simd/sha3_trait.rs | 17 ++ sys/pqclean/src/bindings.rs | 2 +- 6 files changed, 371 insertions(+), 246 deletions(-) create mode 100644 libcrux-sha3/src/rust_simd/sha3_generic.rs create mode 100644 libcrux-sha3/src/rust_simd/sha3_trait.rs diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index fdf702295..6974a6984 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -71,7 +71,7 @@ pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> } #[cfg(feature = "simd128")] -pub(crate) type Shake128x4State = [rust_simd::KeccakStateX2;2]; +pub(crate) type Shake128x4State = [rust_simd::KeccakState<2,core::arch::aarch64::uint64x2_t>;2]; #[cfg(not(feature = "simd128"))] pub(crate) type Shake128x4State = Shake128StateX4; diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index 88ef5274c..03b7d30ce 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -1,128 +1,68 @@ +mod sha3_trait; mod sha3_arm64; -use sha3_arm64::*; +mod sha3_generic; -pub use sha3_arm64::KeccakStateX2; - -#[inline(always)] -fn squeeze_first_block2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { - squeeze2::(s, out0, out1) -} - -#[inline(always)] -fn squeeze_next_block2(s: &mut KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { - keccakf1600(s); - squeeze2::(s, out0, out1) -} - -#[inline(always)] -pub fn squeeze_first_three_blocks2(s: &mut KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { - squeeze_first_block2::(s, out0, out1); - squeeze_next_block2::(s, &mut out0[RATE..2*RATE], &mut out1[RATE..2*RATE]); - squeeze_next_block2::(s, &mut out0[2*RATE..3*RATE], &mut out1[2*RATE..3*RATE]) -} - -#[inline(always)] -fn squeeze_last2(mut s: KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { - let mut b0 = [0u8; 200]; - let mut b1 = [0u8; 200]; - squeeze_next_block2::(&mut s, &mut b0, &mut b1); - out0.copy_from_slice(&b0[0..out0.len()]); - out1.copy_from_slice(&b1[0..out1.len()]); -} - -#[inline(always)] -fn squeeze_first_and_last2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { - let mut b0 = [0u8; 200]; - let mut b1 = [0u8; 200]; - squeeze_first_block2::(s, &mut b0, &mut b1); - out0.copy_from_slice(&b0[0..out0.len()]); - out1.copy_from_slice(&b1[0..out1.len()]); -} - -#[inline(always)] -fn keccak(data0: &[u8], data1: &[u8], out0: &mut [u8], out1: &mut [u8]) { - debug_assert!(data0.len() == data1.len()); - debug_assert!(out0.len() == out1.len()); - let mut s = KeccakStateX2::new(); - for i in 0..data0.len()/RATE { - absorb_block2::(&mut s, &data0[i*RATE..(i+1)*RATE], &data1[i*RATE..(i+1)*RATE]); - } - let rem = data0.len() % RATE; - absorb_final2::(&mut s, &data0[data0.len()-rem..data0.len()], &data1[data1.len()-rem..data1.len()]); - - let blocks = out0.len()/RATE; - let last = out0.len() - out0.len()%RATE; - - if blocks == 0 { - squeeze_first_and_last2::(&s, out0, out1) - } else { - squeeze_first_block2::(&s, out0, out1); - for i in 1..blocks { - squeeze_next_block2::(&mut s, &mut out0[i*RATE..(i+1)*RATE], &mut out1[i*RATE..(i+1)*RATE]); - } - if last < out0.len() {squeeze_last2::(s, &mut out0[last..], &mut out1[last..])} - } -} +pub use sha3_generic::*; pub fn sha3_224(data: &[u8]) -> [u8;28] { let mut d0 = [0u8; 28]; let mut d1 = [0u8; 28]; - keccak::<144, 0x06u8>(data, data, &mut d0, &mut d1); + keccak::<2, core::arch::aarch64::uint64x2_t, 144, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } pub fn sha3_256(data: &[u8]) -> [u8;32] { let mut d0 = [0u8; 32]; let mut d1 = [0u8; 32]; - keccak::<136, 0x06u8>(data, data, &mut d0, &mut d1); + keccak::<2, core::arch::aarch64::uint64x2_t, 136, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } pub fn sha3_384(data: &[u8]) -> [u8;48] { let mut d0 = [0u8; 48]; let mut d1 = [0u8; 48]; - keccak::<104, 0x06u8>(data, data, &mut d0, &mut d1); + keccak::<2, core::arch::aarch64::uint64x2_t, 104, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } pub fn sha3_512(data: &[u8]) -> [u8;64] { let mut d0 = [0u8; 64]; let mut d1 = [0u8; 64]; - keccak::<72,0x06u8>(data, data, &mut d0, &mut d1); + keccak::<2, core::arch::aarch64::uint64x2_t, 72, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } pub fn shake128(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; let mut d1 = [0u8; LEN]; - keccak::<168, 0x1fu8>(data, data, &mut d0, &mut d1); + keccak::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>([data, data], [&mut d0, &mut d1]); d0 } pub fn shake256(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; let mut d1 = [0u8; LEN]; - keccak::<136, 0x1fu8>(data, data, &mut d0, &mut d1); + keccak::<2, core::arch::aarch64::uint64x2_t, 136, 0x1fu8>([data, data], [&mut d0, &mut d1]); d0 } -pub fn shake128x2_init() -> KeccakStateX2 { - let s = KeccakStateX2::new(); +pub fn shake128x2_init() -> KeccakState<2,core::arch::aarch64::uint64x2_t> { + let s = KeccakState::new(); s } -pub fn shake128x2_absorb_final(s:&mut KeccakStateX2, data0: &[u8], data1: &[u8]) { - absorb_final2::<168, 0x1fu8>(s,data0,data1); +pub fn shake128x2_absorb_final(s:&mut KeccakState<2,core::arch::aarch64::uint64x2_t>, data0: &[u8], data1: &[u8]) { + absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(s,[data0,data1]); } -pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakStateX2, out0:&mut [u8], out1:&mut [u8]) { - squeeze_first_three_blocks2::<168>(s, out0, out1) +pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState<2,core::arch::aarch64::uint64x2_t>, out0:&mut [u8], out1:&mut [u8]) { + squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) } -pub fn shake128x2_squeeze_next_block(s: &mut KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { - squeeze_next_block2::<168>(s, out0, out1) +pub fn shake128x2_squeeze_next_block(s: &mut KeccakState<2,core::arch::aarch64::uint64x2_t>, out0: &mut [u8], out1: &mut [u8]) { + squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) } pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { - keccak::<136, 0x1fu8>(input0, input1, out0, out1); + keccak::<2,core::arch::aarch64::uint64x2_t,136, 0x1fu8>([input0, input1], [out0, out1]); } diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs index f7ebb13ce..396304421 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -1,4 +1,5 @@ use core::arch::aarch64::*; +use crate::rust_simd::sha3_trait::*; // This file optimizes for the stable Rust Neon Intrinsics // If we want to use the unstable neon-sha3 instructions, we could use: @@ -6,13 +7,6 @@ use core::arch::aarch64::*; // These instructions might speed up our code even more. -/// Incremental state -#[cfg_attr(hax, hax_lib::opaque_type)] -#[derive(Clone, Copy)] -pub struct KeccakStateX2 { - pub st: [[uint64x2_t; 5]; 5], -} - #[inline(always)] fn rotate_left(x:uint64x2_t) -> uint64x2_t { debug_assert!(LEFT+RIGHT == 64); @@ -21,19 +15,6 @@ fn rotate_left(x:uint64x2_t) -> uint64x2_t { unsafe { veorq_u64(vshlq_n_u64::(x), vshrq_n_u64::(x)) } } - -impl KeccakStateX2 { - /// Create a new Shake128 x4 state. - #[inline(always)] - pub(crate) fn new() -> Self { - unsafe{ - Self { - st: [[vdupq_n_u64(0); 5]; 5], - } - } - } -} - #[inline(always)] fn _veor5q_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t, d: uint64x2_t, e: uint64x2_t) -> uint64x2_t { let ab = unsafe {veorq_u64(a,b)}; @@ -66,176 +47,144 @@ fn _vbcaxq_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t) -> uint64x2_t { // unsafe{ vbcaxq_u64(a, b, c) } } - - - -#[inline(always)] -fn theta(s: &KeccakStateX2) -> [uint64x2_t; 5] { - let c: [uint64x2_t;5] = core::array::from_fn(|j| _veor5q_u64(s.st[0][j],s.st[1][j],s.st[2][j],s.st[3][j],s.st[4][j])); - let t : [uint64x2_t; 5] = core::array::from_fn(|j| _vrax1q_u64(c[(j+4)%5], c[(j+1)%5])); - t -} - -const _ROTC: [usize;24] = - [1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14,]; - #[inline(always)] -fn theta_rho(s: &mut KeccakStateX2, t: [uint64x2_t; 5]) { - // If combined with last step of theta, we could use Nightly: vxarq_u64 - - s.st[0][0] = unsafe { veorq_u64(s.st[0][0],t[0]) }; - s.st[1][0] = _vxarq_u64::<36,28>(s.st[1][0],t[0]); - s.st[2][0] = _vxarq_u64::<3,61>(s.st[2][0],t[0]); - s.st[3][0] = _vxarq_u64::<41,23>(s.st[3][0],t[0]); - s.st[4][0] = _vxarq_u64::<18,46>(s.st[4][0],t[0]); - - s.st[0][1] = _vxarq_u64::<1,63>(s.st[0][1],t[1]); - s.st[1][1] = _vxarq_u64::<44,20>(s.st[1][1],t[1]); - s.st[2][1] = _vxarq_u64::<10,54>(s.st[2][1],t[1]); - s.st[3][1] = _vxarq_u64::<45,19>(s.st[3][1],t[1]); - s.st[4][1] = _vxarq_u64::<2,62>(s.st[4][1],t[1]); - - s.st[0][2] = _vxarq_u64::<62,2>(s.st[0][2],t[2]); - s.st[1][2] = _vxarq_u64::<6,58>(s.st[1][2],t[2]); - s.st[2][2] = _vxarq_u64::<43,21>(s.st[2][2],t[2]); - s.st[3][2] = _vxarq_u64::<15,49>(s.st[3][2],t[2]); - s.st[4][2] = _vxarq_u64::<61,3>(s.st[4][2],t[2]); - - s.st[0][3] = _vxarq_u64::<28,36>(s.st[0][3],t[3]); - s.st[1][3] = _vxarq_u64::<55,9>(s.st[1][3],t[3]); - s.st[2][3] = _vxarq_u64::<25,39>(s.st[2][3],t[3]); - s.st[3][3] = _vxarq_u64::<21,43>(s.st[3][3],t[3]); - s.st[4][3] = _vxarq_u64::<56,8>(s.st[4][3],t[3]); - - s.st[0][4] = _vxarq_u64::<27,37>(s.st[0][4],t[4]); - s.st[1][4] = _vxarq_u64::<20,44>(s.st[1][4],t[4]); - s.st[2][4] = _vxarq_u64::<39,25>(s.st[2][4],t[4]); - s.st[3][4] = _vxarq_u64::<8,56>(s.st[3][4],t[4]); - s.st[4][4] = _vxarq_u64::<14,50>(s.st[4][4],t[4]); +fn _veorq_n_u64(a: uint64x2_t, c: u64) -> uint64x2_t { + let c = unsafe { vdupq_n_u64(c) }; + unsafe { veorq_u64(a, c) } } -const _PI : [usize;24] = [ - 6, 12, 18, 24, 3, 9, 10, 16, 22, 1, 7, 13, 19, 20, 4, 5, 11, 17, 23, 2, 8, 14, 15, 21, -]; - #[inline(always)] -fn pi(s: &mut KeccakStateX2) { - let old = s.st.clone(); - s.st[0][1] = old[1][1]; - s.st[0][2] = old[2][2]; - s.st[0][3] = old[3][3]; - s.st[0][4] = old[4][4]; - s.st[1][0] = old[0][3]; - s.st[1][1] = old[1][4]; - s.st[1][2] = old[2][0]; - s.st[1][3] = old[3][1]; - s.st[1][4] = old[4][2]; - s.st[2][0] = old[0][1]; - s.st[2][1] = old[1][2]; - s.st[2][2] = old[2][3]; - s.st[2][3] = old[3][4]; - s.st[2][4] = old[4][0]; - s.st[3][0] = old[0][4]; - s.st[3][1] = old[1][0]; - s.st[3][2] = old[2][1]; - s.st[3][3] = old[3][2]; - s.st[3][4] = old[4][3]; - s.st[4][0] = old[0][2]; - s.st[4][1] = old[1][3]; - s.st[4][2] = old[2][4]; - s.st[4][3] = old[3][0]; - s.st[4][4] = old[4][1]; -} - -#[inline(always)] -fn chi(s: &mut KeccakStateX2) { - let old = s.st; - for i in 0..5 { - for j in 0..5 { - s.st[i][j] = _vbcaxq_u64(s.st[i][j], old[i][(j + 2) % 5], old[i][(j + 1) % 5]); - } - } -} - -const ROUNDCONSTANTS: [u64;24] = [ - 0x0000_0000_0000_0001u64, 0x0000_0000_0000_8082u64, 0x8000_0000_0000_808au64, - 0x8000_0000_8000_8000u64, 0x0000_0000_0000_808bu64, 0x0000_0000_8000_0001u64, - 0x8000_0000_8000_8081u64, 0x8000_0000_0000_8009u64, 0x0000_0000_0000_008au64, - 0x0000_0000_0000_0088u64, 0x0000_0000_8000_8009u64, 0x0000_0000_8000_000au64, - 0x0000_0000_8000_808bu64, 0x8000_0000_0000_008bu64, 0x8000_0000_0000_8089u64, - 0x8000_0000_0000_8003u64, 0x8000_0000_0000_8002u64, 0x8000_0000_0000_0080u64, - 0x0000_0000_0000_800au64, 0x8000_0000_8000_000au64, 0x8000_0000_8000_8081u64, - 0x8000_0000_0000_8080u64, 0x0000_0000_8000_0001u64, 0x8000_0000_8000_8008u64, -]; - -#[inline(always)] -fn iota(s: &mut KeccakStateX2, round:usize) { - let c = unsafe { vdupq_n_u64(ROUNDCONSTANTS[round]) }; - s.st[0][0] = unsafe { veorq_u64(s.st[0][0], c) }; -} - -#[inline(always)] -pub(crate) fn keccakf1600(s: &mut KeccakStateX2) { - for i in 0..24 { - let t = theta(s); - theta_rho(s,t); - pi(s); - chi(s); - iota(s, i); - } -} - -#[inline(always)] -pub(crate) fn absorb_block2(s: &mut KeccakStateX2, block0: &[u8], block1: &[u8]) { - debug_assert!(RATE == block0.len() && RATE == block1.len() && RATE % 8 == 0); +pub(crate) fn load_block(s: &mut [[uint64x2_t;5];5], blocks: [&[u8];2]) { + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); for i in 0..RATE/16 { - let v0 = unsafe { vld1q_u64(block0[16*i..(16*i)+16].as_ptr() as *const u64) }; - let v1 = unsafe { vld1q_u64(block1[16*i..(16*i)+16].as_ptr() as *const u64) }; - s.st[(2*i)/5][(2*i)%5] = unsafe { veorq_u64(s.st[(2*i)/5][(2*i)%5], vtrn1q_u64(v0,v1)) }; - s.st[(2*i+1)/5][(2*i+1)%5] = unsafe { veorq_u64(s.st[(2*i+1)/5][(2*i+1)%5], vtrn2q_u64(v0,v1)) }; + let v0 = unsafe { vld1q_u64(blocks[0][16*i..16*(i+1)].as_ptr() as *const u64) }; + let v1 = unsafe { vld1q_u64(blocks[1][16*i..16*(i+1)].as_ptr() as *const u64) }; + s[(2*i)/5][(2*i)%5] = unsafe { veorq_u64(s[(2*i)/5][(2*i)%5], vtrn1q_u64(v0,v1)) }; + s[(2*i+1)/5][(2*i+1)%5] = unsafe { veorq_u64(s[(2*i+1)/5][(2*i+1)%5], vtrn2q_u64(v0,v1)) }; } if RATE%16 != 0 { let i = (RATE/8 - 1)/5; let j = (RATE/8 - 1)%5; let mut u = [0u64; 2]; - u[0] = u64::from_le_bytes(block0[RATE-8..].try_into().unwrap()); - u[1] = u64::from_le_bytes(block1[RATE-8..].try_into().unwrap()); + u[0] = u64::from_le_bytes(blocks[0][RATE-8..RATE].try_into().unwrap()); + u[1] = u64::from_le_bytes(blocks[1][RATE-8..RATE].try_into().unwrap()); let uvec = unsafe { vld1q_u64(u.as_ptr() as *const u64) }; - s.st[i][j] = unsafe { veorq_u64(s.st[i][j], uvec)}; + s[i][j] = unsafe { veorq_u64(s[i][j], uvec)}; } - keccakf1600(s) } #[inline(always)] -pub(crate) fn absorb_final2(s: &mut KeccakStateX2, last0: &[u8], last1: &[u8]) { - debug_assert!(last0.len() == last1.len() && last0.len() < RATE); - let mut b0 = [0u8; 200]; - let mut b1 = [0u8; 200]; - b0[0..last0.len()].copy_from_slice(last0); - b1[0..last1.len()].copy_from_slice(last1); - b0[last0.len()] = DELIM; - b0[RATE-1] = b0[RATE-1] | 128u8; - b1[last1.len()] = DELIM; - b1[RATE-1] = b1[RATE-1] | 128u8; - absorb_block2::(s, &b0[0..RATE], &b1[0..RATE]) +pub(crate) fn load_block_full(s: &mut [[uint64x2_t;5];5], blocks: [[u8;200];2]) { + let [b0,b1] = blocks; + load_block::(s,[&b0 as &[u8], &b1 as &[u8]]); } #[inline(always)] -pub(crate) fn squeeze2(s: &KeccakStateX2, out0: &mut [u8], out1: &mut [u8]) { +pub(crate) fn store_block(s: &[[uint64x2_t;5];5], out: [&mut [u8];2]) { for i in 0..RATE/16 { - let v0 = unsafe { vtrn1q_u64(s.st[(2*i)/5][(2*i)%5], s.st[(2*i+1)/5][(2*i+1)%5]) }; - let v1 = unsafe { vtrn2q_u64(s.st[(2*i)/5][(2*i)%5], s.st[(2*i+1)/5][(2*i+1)%5]) }; - unsafe { vst1q_u64(out0[16*i..16*i+16].as_mut_ptr() as *mut u64, v0) }; - unsafe { vst1q_u64(out1[16*i..16*i+16].as_mut_ptr() as *mut u64, v1) }; + let v0 = unsafe { vtrn1q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; + let v1 = unsafe { vtrn2q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; + unsafe { vst1q_u64(out[0][16*i..16*(i+1)].as_mut_ptr() as *mut u64, v0) }; + unsafe { vst1q_u64(out[1][16*i..16*(i+1)].as_mut_ptr() as *mut u64, v1) }; } if RATE%16 != 0 { debug_assert!(RATE % 8 == 0); let i = (RATE/8 - 1)/5; let j = (RATE/8 - 1)%5; let mut u = [0u8;16]; - unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s.st[i][j])}; - out0[RATE-8..RATE].copy_from_slice(&u[0..8]); - out1[RATE-8..RATE].copy_from_slice(&u[8..16]); + unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s[i][j])}; + out[0][RATE-8..RATE].copy_from_slice(&u[0..8]); + out[1][RATE-8..RATE].copy_from_slice(&u[8..16]); + } +} + +#[inline(always)] +pub(crate) fn store_block_full(s: &[[uint64x2_t;5];5]) -> [[u8;200];2] { + let mut out0 = [0u8; 200]; + let mut out1 = [0u8; 200]; + store_block::(s,[&mut out0, &mut out1]); + [out0, out1] + + // for i in 0..RATE/16 { + // let v0 = unsafe { vtrn1q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; + // let v1 = unsafe { vtrn2q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; + // unsafe { vst1q_u64(out[0][offset+16*i..offset+16*(i+1)].as_mut_ptr() as *mut u64, v0) }; + // unsafe { vst1q_u64(out[1][offset+16*i..offset+16*(i+1)].as_mut_ptr() as *mut u64, v1) }; + // } + // if RATE%16 != 0 { + // debug_assert!(RATE % 8 == 0); + // let i = (RATE/8 - 1)/5; + // let j = (RATE/8 - 1)%5; + // let mut u = [0u8;16]; + // unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s[i][j])}; + // out[0][offset+RATE-8..offset+RATE].copy_from_slice(&u[0..8]); + // out[1][offset+RATE-8..offset+RATE].copy_from_slice(&u[8..16]); + // } +} + +fn slice_n(a: [&[u8];2], start:usize, len:usize) -> [&[u8];2] { + [&a[0][start..start+len], &a[1][start..start+len]] +} + +fn split_at_mut_2(out: [&mut [u8]; 2], mid:usize) -> ([&mut [u8];2],[&mut [u8];2]) { + let [out0, out1] = out; + let (out00,out01) = out0.split_at_mut(mid); + let (out10,out11) = out1.split_at_mut(mid); + ([out00,out10], [out01,out11]) +} + +impl KeccakItem<2> for uint64x2_t { + fn zero() -> Self { + unsafe {vdupq_n_u64(0)} + } + + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) } -} \ No newline at end of file + + fn xor(a: Self, b: Self) -> Self { + unsafe {veorq_u64(a, b)} + } + + fn load_block(a:&mut [[Self;5];5], b:[&[u8];2]) { + load_block::(a, b) + } + + fn store_block(a:& [[Self;5];5], b:[&mut [u8];2]) { + store_block::(a, b) + } + + fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];2]) { + load_block_full::(a, b) + } + + fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];2] { + store_block_full::(a) + } + + fn slice_n(a:[&[u8];2],start:usize,len:usize) -> [&[u8];2] { + slice_n(a,start,len) + } + + fn split_at_mut_n(a:[&mut [u8];2],mid:usize) -> ([&mut [u8];2],[&mut [u8];2]) { + split_at_mut_2(a, mid) + } +} + diff --git a/libcrux-sha3/src/rust_simd/sha3_generic.rs b/libcrux-sha3/src/rust_simd/sha3_generic.rs new file mode 100644 index 000000000..2892be19e --- /dev/null +++ b/libcrux-sha3/src/rust_simd/sha3_generic.rs @@ -0,0 +1,219 @@ +use crate::rust_simd::sha3_trait::*; + +#[cfg_attr(hax, hax_lib::opaque_type)] +#[derive(Clone, Copy)] +pub struct KeccakState> { + pub st: [[T; 5]; 5], +} + +impl> KeccakState { + /// Create a new Shake128 x4 state. + #[inline(always)] + pub(crate) fn new() -> Self { + Self { + st: [[T::zero(); 5]; 5], + } + + } +} + +/// From here, everything is generic +/// +const _ROTC: [usize;24] = + [1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14,]; + + +#[inline(always)] +pub(crate) fn theta_rho>(s: &mut KeccakState) { + let c: [T; 5] = core::array::from_fn(|j| T::xor5(s.st[0][j],s.st[1][j],s.st[2][j],s.st[3][j],s.st[4][j])); + let t : [T; 5] = core::array::from_fn(|j| T::rotate_left1_and_xor(c[(j+4)%5], c[(j+1)%5])); + + s.st[0][0] = T::xor(s.st[0][0],t[0]); + s.st[1][0] = T::xor_and_rotate::<36,28>(s.st[1][0],t[0]); + s.st[2][0] = T::xor_and_rotate::<3,61>(s.st[2][0],t[0]); + s.st[3][0] = T::xor_and_rotate::<41,23>(s.st[3][0],t[0]); + s.st[4][0] = T::xor_and_rotate::<18,46>(s.st[4][0],t[0]); + + s.st[0][1] = T::xor_and_rotate::<1,63>(s.st[0][1],t[1]); + s.st[1][1] = T::xor_and_rotate::<44,20>(s.st[1][1],t[1]); + s.st[2][1] = T::xor_and_rotate::<10,54>(s.st[2][1],t[1]); + s.st[3][1] = T::xor_and_rotate::<45,19>(s.st[3][1],t[1]); + s.st[4][1] = T::xor_and_rotate::<2,62>(s.st[4][1],t[1]); + + s.st[0][2] = T::xor_and_rotate::<62,2>(s.st[0][2],t[2]); + s.st[1][2] = T::xor_and_rotate::<6,58>(s.st[1][2],t[2]); + s.st[2][2] = T::xor_and_rotate::<43,21>(s.st[2][2],t[2]); + s.st[3][2] = T::xor_and_rotate::<15,49>(s.st[3][2],t[2]); + s.st[4][2] = T::xor_and_rotate::<61,3>(s.st[4][2],t[2]); + + s.st[0][3] = T::xor_and_rotate::<28,36>(s.st[0][3],t[3]); + s.st[1][3] = T::xor_and_rotate::<55,9>(s.st[1][3],t[3]); + s.st[2][3] = T::xor_and_rotate::<25,39>(s.st[2][3],t[3]); + s.st[3][3] = T::xor_and_rotate::<21,43>(s.st[3][3],t[3]); + s.st[4][3] = T::xor_and_rotate::<56,8>(s.st[4][3],t[3]); + + s.st[0][4] = T::xor_and_rotate::<27,37>(s.st[0][4],t[4]); + s.st[1][4] = T::xor_and_rotate::<20,44>(s.st[1][4],t[4]); + s.st[2][4] = T::xor_and_rotate::<39,25>(s.st[2][4],t[4]); + s.st[3][4] = T::xor_and_rotate::<8,56>(s.st[3][4],t[4]); + s.st[4][4] = T::xor_and_rotate::<14,50>(s.st[4][4],t[4]); +} + + +const _PI : [usize;24] = [ + 6, 12, 18, 24, 3, 9, 10, 16, 22, 1, 7, 13, 19, 20, 4, 5, 11, 17, 23, 2, 8, 14, 15, 21, +]; + +#[inline(always)] +pub(crate) fn pi>(s: &mut KeccakState) { + let old = s.st.clone(); + s.st[0][1] = old[1][1]; + s.st[0][2] = old[2][2]; + s.st[0][3] = old[3][3]; + s.st[0][4] = old[4][4]; + s.st[1][0] = old[0][3]; + s.st[1][1] = old[1][4]; + s.st[1][2] = old[2][0]; + s.st[1][3] = old[3][1]; + s.st[1][4] = old[4][2]; + s.st[2][0] = old[0][1]; + s.st[2][1] = old[1][2]; + s.st[2][2] = old[2][3]; + s.st[2][3] = old[3][4]; + s.st[2][4] = old[4][0]; + s.st[3][0] = old[0][4]; + s.st[3][1] = old[1][0]; + s.st[3][2] = old[2][1]; + s.st[3][3] = old[3][2]; + s.st[3][4] = old[4][3]; + s.st[4][0] = old[0][2]; + s.st[4][1] = old[1][3]; + s.st[4][2] = old[2][4]; + s.st[4][3] = old[3][0]; + s.st[4][4] = old[4][1]; +} + +#[inline(always)] +pub(crate) fn chi>(s: &mut KeccakState) { + let old = s.st; + for i in 0..5 { + for j in 0..5 { + s.st[i][j] = T::and_not_xor(s.st[i][j], old[i][(j + 2) % 5], old[i][(j + 1) % 5]); + } + } +} + +const ROUNDCONSTANTS: [u64;24] = [ + 0x0000_0000_0000_0001u64, 0x0000_0000_0000_8082u64, 0x8000_0000_0000_808au64, + 0x8000_0000_8000_8000u64, 0x0000_0000_0000_808bu64, 0x0000_0000_8000_0001u64, + 0x8000_0000_8000_8081u64, 0x8000_0000_0000_8009u64, 0x0000_0000_0000_008au64, + 0x0000_0000_0000_0088u64, 0x0000_0000_8000_8009u64, 0x0000_0000_8000_000au64, + 0x0000_0000_8000_808bu64, 0x8000_0000_0000_008bu64, 0x8000_0000_0000_8089u64, + 0x8000_0000_0000_8003u64, 0x8000_0000_0000_8002u64, 0x8000_0000_0000_0080u64, + 0x0000_0000_0000_800au64, 0x8000_0000_8000_000au64, 0x8000_0000_8000_8081u64, + 0x8000_0000_0000_8080u64, 0x0000_0000_8000_0001u64, 0x8000_0000_8000_8008u64, +]; + +#[inline(always)] +pub(crate) fn iota>(s: &mut KeccakState, i:usize) { + s.st[0][0] = T::xor_constant(s.st[0][0], ROUNDCONSTANTS[i]); +} + + +#[inline(always)] +pub(crate) fn keccakf1600>(s: &mut KeccakState) { + for i in 0..24 { + theta_rho(s); + pi(s); + chi(s); + iota(s, i); + } +} + +#[inline(always)] +pub(crate) fn absorb_block,const RATE:usize>(s: &mut KeccakState, blocks: [&[u8];N]) { + T::load_block::(&mut s.st, blocks); + keccakf1600(s) +} + +#[inline(always)] +pub(crate) fn absorb_final,const RATE:usize, const DELIM:u8>( + s: &mut KeccakState, last: [&[u8];N]) { + debug_assert!(N > 0 && last[0].len() < RATE); + let last_len = last[0].len(); + let mut blocks = [[0u8; 200]; N]; + for i in 0..N { + blocks[i][0..last_len].copy_from_slice(&last[i]); + blocks[i][last_len] = DELIM; + blocks[i][RATE-1] = blocks[i][RATE-1] | 128u8; + } + T::load_block_full::(&mut s.st, blocks); + keccakf1600(s) +} + + +#[inline(always)] +pub(crate) fn squeeze_first_block,const RATE:usize>(s: &KeccakState, out: [&mut [u8];N]) { + T::store_block::(&s.st, out) +} + +#[inline(always)] +pub(crate) fn squeeze_next_block,const RATE:usize>(s: &mut KeccakState, out: [&mut [u8];N]) { + keccakf1600(s); + T::store_block::(&s.st, out) +} + + +#[inline(always)] +pub(crate) fn squeeze_first_three_blocks,const RATE:usize>( + s: &mut KeccakState, out: [&mut [u8];N]) { + let (o0,o1) = T::split_at_mut_n(out, RATE); + squeeze_first_block::(s, o0); + let (o1,o2) = T::split_at_mut_n(o1, RATE); + squeeze_next_block::(s, o1); + squeeze_next_block::(s, o2); +} + +#[inline(always)] +pub(crate) fn squeeze_last,const RATE:usize>(mut s: KeccakState, out: [&mut [u8];N]) { + keccakf1600(&mut s); + let b = T::store_block_full::(&s.st); + for i in 0..N { + out[i].copy_from_slice(&b[i][0..out[i].len()]); + } +} + +#[inline(always)] +pub(crate) fn squeeze_first_and_last,const RATE:usize>(s: &KeccakState, out: [&mut [u8];N]) { + let b = T::store_block_full::(&s.st); + for i in 0..N { + out[i].copy_from_slice(&b[i][0..out[i].len()]); + } +} + +#[inline(always)] +pub(crate) fn keccak,const RATE:usize, const DELIM:u8>(data: [&[u8]; N], out: [&mut [u8]; N]) { + let mut s = KeccakState::::new(); + for i in 0..data[0].len()/RATE { + absorb_block::(&mut s, T::slice_n(data,i*RATE,RATE)); + } + let rem = data[0].len() % RATE; + absorb_final::(&mut s, T::slice_n(data,data[0].len()-rem,rem)); + + let outlen = out[0].len(); + let blocks = outlen/RATE; + let last = outlen - (outlen%RATE); + + if blocks == 0 { + squeeze_first_and_last::(&s, out) + } else { + let (o0,mut o1) = T::split_at_mut_n(out, RATE); + squeeze_first_block::(&s, o0); + for _i in 1..blocks { + let (o,orest) = T::split_at_mut_n(o1, RATE); + squeeze_next_block::(&mut s, o); + o1 = orest; + } + if last < outlen {squeeze_last::(s, o1)} + } +} diff --git a/libcrux-sha3/src/rust_simd/sha3_trait.rs b/libcrux-sha3/src/rust_simd/sha3_trait.rs new file mode 100644 index 000000000..358dbff16 --- /dev/null +++ b/libcrux-sha3/src/rust_simd/sha3_trait.rs @@ -0,0 +1,17 @@ + +pub trait KeccakItem: Clone + Copy { + fn zero() -> Self; + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self; + fn rotate_left1_and_xor(a: Self, b: Self) -> Self; + fn xor_and_rotate(a: Self, b: Self) -> Self; + fn and_not_xor(a: Self, b: Self, c: Self) -> Self; + fn xor_constant(a: Self, c: u64) -> Self; + fn xor(a: Self, b: Self) -> Self; + fn load_block(a:&mut [[Self;5];5], b:[&[u8];N]); + fn store_block(a:& [[Self;5];5], b:[&mut [u8];N]); + fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];N]); + fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];N]; + fn slice_n(a:[&[u8];N],start:usize,len:usize) -> [&[u8];N]; + fn split_at_mut_n(a:[&mut [u8];N],mid:usize) -> ([&mut [u8];N],[&mut [u8];N]); +} + diff --git a/sys/pqclean/src/bindings.rs b/sys/pqclean/src/bindings.rs index 5f6602af9..59a2d73d9 100644 --- a/sys/pqclean/src/bindings.rs +++ b/sys/pqclean/src/bindings.rs @@ -1,4 +1,4 @@ -/* automatically generated by rust-bindgen 0.69.4 */ +/* automatically generated by rust-bindgen 0.69.1 */ pub const SHAKE128_RATE: u32 = 168; pub const SHAKE256_RATE: u32 = 136; From e93717b6e49841233bb0b6a964cf7a8087aad820 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Mon, 13 May 2024 23:07:51 +0200 Subject: [PATCH 09/59] some bugfixes in portable --- libcrux-ml-kem/src/hash_functions.rs | 75 +++----- libcrux-sha3/benches/sha3.rs | 4 +- libcrux-sha3/src/rust_simd.rs | 226 +++++++++++++++++++++-- libcrux-sha3/src/rust_simd/sha3_arm64.rs | 47 ++--- libcrux-sha3/tests/sha3.rs | 1 - 5 files changed, 250 insertions(+), 103 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 6974a6984..b7b3c8623 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -2,42 +2,24 @@ use crate::constants::H_DIGEST_SIZE; -#[cfg(feature = "simd128")] -use libcrux_sha3::rust_simd; -#[cfg(not(feature = "simd128"))] -use libcrux_sha3::{x4::Shake128StateX4, *}; +use libcrux_sha3::rust_simd::{self, KeccakState4}; -#[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn G(input: &[u8]) -> [u8; 64] { - rust_simd::sha3_512(input) -} -#[cfg(not(feature = "simd128"))] -#[inline(always)] -pub(crate) fn G(input: &[u8]) -> [u8; digest_size(Algorithm::Sha3_512)] { - sha512(input) + //rust_simd::sha3_512(input) + libcrux_sha3::sha512(input) } -#[cfg(feature = "simd128")] -#[inline(always)] -pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - rust_simd::sha3_256(input) -} -#[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - sha256(input) + //rust_simd::sha3_256(input) + libcrux_sha3::sha256(input) } -#[cfg(feature = "simd128")] -#[inline(always)] -pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { - rust_simd::shake256::(input) -} -#[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { - shake256::(input) + //rust_simd::shake256::(input) + libcrux_sha3::shake256::(input) } #[cfg(feature = "simd128")] @@ -67,21 +49,17 @@ pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> #[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { - core::array::from_fn(|i| shake256::(&input[i])) + core::array::from_fn(|i| rust_simd::shake256::(&input[i])) } -#[cfg(feature = "simd128")] -pub(crate) type Shake128x4State = [rust_simd::KeccakState<2,core::arch::aarch64::uint64x2_t>;2]; - -#[cfg(not(feature = "simd128"))] -pub(crate) type Shake128x4State = Shake128StateX4; +pub(crate) type Shake128x4State = KeccakState4; #[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { debug_assert!(K == 2 || K == 3 || K == 4); - let mut states = [rust_simd::shake128x2_init();2]; + let mut states = rust_simd::shake128x4_init(); match K { 2 => { rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); @@ -102,17 +80,14 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { #[inline(always)] pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { debug_assert!(K == 2 || K == 3 || K == 4); - - let mut state = Shake128StateX4::new(); - // XXX: We need to do this dance to get it through hax and eurydice for now. - let mut data: [&[u8]; K] = [&[0u8]; K]; + let mut states = rust_simd::shake128x4_init(); for i in 0..K { - data[i] = &input[i] as &[u8]; - } - state.absorb_final(data); - state + rust_simd::shake128_absorb_final(&mut states[i], &input[i]); + } + states } + pub(crate) const BLOCK_SIZE: usize = 168; pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; @@ -145,13 +120,12 @@ pub(crate) fn squeeze_three_blocks ( #[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn squeeze_three_blocks( - xof_state: &mut Shake128x4State, +pub(crate) fn squeeze_three_blocks ( + state: &mut Shake128x4State, ) -> [[u8; THREE_BLOCKS]; K] { - let output: [[u8; THREE_BLOCKS]; K] = xof_state.squeeze_blocks(); let mut out = [[0u8; THREE_BLOCKS]; K]; for i in 0..K { - out[i] = output[i]; + rust_simd::shake128_squeeze_first_three_blocks(&mut state[i], &mut out[i]); } out } @@ -190,26 +164,19 @@ pub(crate) fn squeeze_block( #[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn squeeze_block( - xof_state: &mut Shake128x4State, + state: &mut Shake128x4State, ) -> [[u8; BLOCK_SIZE]; K] { - let output: [[u8; BLOCK_SIZE]; K] = xof_state.squeeze_blocks(); let mut out = [[0u8; BLOCK_SIZE]; K]; for i in 0..K { - out[i] = output[i]; + rust_simd::shake128_squeeze_next_block(&mut state[i], &mut out[i]); } out } + /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. -#[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn free_state(_xof_state: Shake128x4State) { } - -#[cfg(not(feature = "simd128"))] -#[inline(always)] -pub(crate) fn free_state(xof_state: Shake128x4State) { - xof_state.free_memory(); -} diff --git a/libcrux-sha3/benches/sha3.rs b/libcrux-sha3/benches/sha3.rs index 2bda837e1..e54089864 100644 --- a/libcrux-sha3/benches/sha3.rs +++ b/libcrux-sha3/benches/sha3.rs @@ -43,9 +43,9 @@ macro_rules! impl_comp { }, ); - #[cfg(feature = "simd128")] + group.bench_with_input( - BenchmarkId::new("arm64", fmt(*payload_size)), + BenchmarkId::new("rust version (simd)", fmt(*payload_size)), payload_size, |b, payload_size| { b.iter_batched( diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index 03b7d30ce..f7e36fba1 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -1,68 +1,262 @@ mod sha3_trait; -mod sha3_arm64; +mod sha3_portable; mod sha3_generic; - pub use sha3_generic::*; +pub type KeccakState1 = KeccakState<1, u64>; +#[inline(always)] +fn keccakx1(data:[&[u8];1],out:[&mut[u8];1]) { + keccak::<1, u64, RATE, DELIM>(data,out) +} + +#[cfg(feature = "simd128")] +mod sha3_arm64; +#[cfg(feature = "simd128")] +pub type KeccakState2 = KeccakState<2, core::arch::aarch64::uint64x2_t>; +#[cfg(feature = "simd128")] +#[inline(always)] +fn keccakx2(data:[&[u8];2],out:[&mut[u8];2]) { + keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data,out) +} +#[cfg(feature = "simd128")] +pub type KeccakState4 = [KeccakState2; 2]; + + +#[cfg(not(feature = "simd128"))] +pub type KeccakState2 = [KeccakState1; 2]; +#[cfg(not(feature = "simd128"))] +pub type KeccakState4 = [KeccakState1; 4]; + + +#[cfg(feature = "simd128")] pub fn sha3_224(data: &[u8]) -> [u8;28] { let mut d0 = [0u8; 28]; let mut d1 = [0u8; 28]; - keccak::<2, core::arch::aarch64::uint64x2_t, 144, 0x06u8>([data, data], [&mut d0, &mut d1]); + keccakx2::<144, 0x06u8>([data, data], [&mut d0, &mut d1]); + d0 +} +#[cfg(not(feature = "simd128"))] +pub fn sha3_224(data: &[u8]) -> [u8;28] { + let mut d0 = [0u8; 28]; + keccakx1::<144, 0x06u8>([data], [&mut d0]); d0 } +#[cfg(feature = "simd128")] pub fn sha3_256(data: &[u8]) -> [u8;32] { let mut d0 = [0u8; 32]; let mut d1 = [0u8; 32]; - keccak::<2, core::arch::aarch64::uint64x2_t, 136, 0x06u8>([data, data], [&mut d0, &mut d1]); + keccakx2::<136, 0x06u8>([data, data], [&mut d0, &mut d1]); + d0 +} +#[cfg(not(feature = "simd128"))] +pub fn sha3_256(data: &[u8]) -> [u8;32] { + let mut d0 = [0u8; 32]; + keccakx1::<136, 0x06u8>([data], [&mut d0]); d0 } +#[cfg(feature = "simd128")] pub fn sha3_384(data: &[u8]) -> [u8;48] { let mut d0 = [0u8; 48]; let mut d1 = [0u8; 48]; - keccak::<2, core::arch::aarch64::uint64x2_t, 104, 0x06u8>([data, data], [&mut d0, &mut d1]); + keccakx2::<104, 0x06u8>([data, data], [&mut d0, &mut d1]); + d0 +} +#[cfg(not(feature = "simd128"))] +pub fn sha3_384(data: &[u8]) -> [u8;48] { + let mut d0 = [0u8; 48]; + keccakx1::<104, 0x06u8>([data], [&mut d0]); d0 } +#[cfg(feature = "simd128")] pub fn sha3_512(data: &[u8]) -> [u8;64] { let mut d0 = [0u8; 64]; let mut d1 = [0u8; 64]; - keccak::<2, core::arch::aarch64::uint64x2_t, 72, 0x06u8>([data, data], [&mut d0, &mut d1]); + keccakx2::<72, 0x06u8>([data, data], [&mut d0, &mut d1]); + d0 +} +#[cfg(not(feature = "simd128"))] +pub fn sha3_512(data: &[u8]) -> [u8;64] { + let mut d0 = [0u8; 64]; + keccakx1::<72, 0x06u8>([data], [&mut d0]); d0 } +#[cfg(feature = "simd128")] pub fn shake128(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; let mut d1 = [0u8; LEN]; - keccak::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>([data, data], [&mut d0, &mut d1]); + keccakx2::<168, 0x1fu8>([data, data], [&mut d0, &mut d1]); + d0 +} +#[cfg(not(feature = "simd128"))] +pub fn shake128(data: &[u8]) -> [u8; LEN] { + let mut d0 = [0u8; LEN]; + keccakx1::<168, 0x1fu8>([data], [&mut d0]); d0 } +#[cfg(feature = "simd128")] pub fn shake256(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; let mut d1 = [0u8; LEN]; - keccak::<2, core::arch::aarch64::uint64x2_t, 136, 0x1fu8>([data, data], [&mut d0, &mut d1]); + keccakx2::<136, 0x1fu8>([data, data], [&mut d0, &mut d1]); + d0 +} +#[cfg(not(feature = "simd128"))] +pub fn shake256(data: &[u8]) -> [u8; LEN] { + let mut d0 = [0u8; LEN]; + keccakx1::<136, 0x1fu8>([data], [&mut d0]); d0 } -pub fn shake128x2_init() -> KeccakState<2,core::arch::aarch64::uint64x2_t> { - let s = KeccakState::new(); - s +#[cfg(feature = "simd128")] +pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { + keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); +} +#[cfg(not(feature = "simd128"))] +pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { + keccakx1::<136, 0x1fu8>([input0], [out0]); + keccakx1::<136, 0x1fu8>([input1], [out1]); +} + +#[cfg(feature = "simd128")] +pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], + out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { + keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); + keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); +} +#[cfg(not(feature = "simd128"))] +pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], + out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { + keccakx1::<136, 0x1fu8>([input0], [out0]); + keccakx1::<136, 0x1fu8>([input1], [out1]); + keccakx1::<136, 0x1fu8>([input2], [out2]); + keccakx1::<136, 0x1fu8>([input3], [out3]); +} + +/// Incremental API + +pub fn shake128_init() -> KeccakState1 { + KeccakState1::new() +} + +pub fn shake128_absorb_final(s:&mut KeccakState1, data0: &[u8]) { + absorb_final::<1,u64,168,0x1fu8>(s,[data0]); +} + +pub fn shake128_squeeze_first_three_blocks(s: &mut KeccakState1, out0:&mut [u8]) { + squeeze_first_three_blocks::<1,u64,168>(s, [out0]) +} + +pub fn shake128_squeeze_next_block(s: &mut KeccakState1, out0: &mut [u8]) { + squeeze_next_block::<1,u64,168>(s, [out0]) +} + +#[cfg(feature = "simd128")] +pub fn shake128x2_init() -> KeccakState2 { + KeccakState2::new() +} +#[cfg(not(feature = "simd128"))] +pub fn shake128x2_init() -> KeccakState2 { + let s0 = KeccakState1::new(); + let s1 = KeccakState1::new(); + [s0,s1] } -pub fn shake128x2_absorb_final(s:&mut KeccakState<2,core::arch::aarch64::uint64x2_t>, data0: &[u8], data1: &[u8]) { +#[cfg(feature = "simd128")] +pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) { absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(s,[data0,data1]); } +#[cfg(not(feature = "simd128"))] +pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) { + let [mut s0, mut s1] = s; + shake128_absorb_final(&mut s0, data0); + shake128_absorb_final(&mut s1, data1); +} -pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState<2,core::arch::aarch64::uint64x2_t>, out0:&mut [u8], out1:&mut [u8]) { +#[cfg(feature = "simd128")] +pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8], out1:&mut [u8]) { squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) } +#[cfg(not(feature = "simd128"))] +pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8], out1:&mut [u8]) { + let [mut s0, mut s1] = s; + shake128_squeeze_first_three_blocks(&mut s0, out0); + shake128_squeeze_first_three_blocks(&mut s1, out1); +} -pub fn shake128x2_squeeze_next_block(s: &mut KeccakState<2,core::arch::aarch64::uint64x2_t>, out0: &mut [u8], out1: &mut [u8]) { +#[cfg(feature = "simd128")] +pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) } +#[cfg(not(feature = "simd128"))] +pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { + let [mut s0, mut s1] = s; + shake128_squeeze_next_block(&mut s0, out0); + shake128_squeeze_next_block(&mut s1, out1); +} -pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { - keccak::<2,core::arch::aarch64::uint64x2_t,136, 0x1fu8>([input0, input1], [out0, out1]); + +#[cfg(feature = "simd128")] +pub fn shake128x4_init() -> KeccakState4 { + let s0 = KeccakState2::new(); + let s1 = KeccakState2::new(); + [s0,s1] +} +#[cfg(not(feature = "simd128"))] +pub fn shake128x4_init() -> KeccakState4 { + let s0 = KeccakState1::new(); + let s1 = KeccakState1::new(); + let s2 = KeccakState1::new(); + let s3 = KeccakState1::new(); + [s0,s1,s2,s3] +} + +#[cfg(feature = "simd128")] +pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { + let [mut s0, mut s1] = s; + absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(&mut s0,[data0,data1]); + absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(&mut s1,[data2,data3]); } +#[cfg(not(feature = "simd128"))] +pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { + let [mut s0, mut s1, mut s2, mut s3] = s; + shake128_absorb_final(&mut s0, data0); + shake128_absorb_final(&mut s1, data1); + shake128_absorb_final(&mut s2, data2); + shake128_absorb_final(&mut s3, data3); +} + +#[cfg(feature = "simd128")] +pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { + let [mut s0, mut s1] = s; + squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(&mut s0, [out0, out1]); + squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(&mut s1, [out2, out3]); +} +#[cfg(not(feature = "simd128"))] +pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { + let [mut s0, mut s1, mut s2, mut s3] = s; + shake128_squeeze_first_three_blocks(&mut s0, out0); + shake128_squeeze_first_three_blocks(&mut s1, out1); + shake128_squeeze_first_three_blocks(&mut s2, out2); + shake128_squeeze_first_three_blocks(&mut s3, out3); +} + +#[cfg(feature = "simd128")] +pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { + let [mut s0, mut s1] = s; + squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(&mut s0, [out0, out1]); + squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(&mut s1, [out2, out3]); +} +#[cfg(not(feature = "simd128"))] +pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { + let [mut s0, mut s1, mut s2, mut s3] = s; + shake128_squeeze_next_block(&mut s0, out0); + shake128_squeeze_next_block(&mut s1, out1); + shake128_squeeze_next_block(&mut s2, out2); + shake128_squeeze_next_block(&mut s3, out3); +} + diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs index 396304421..22e16b42e 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -105,28 +105,14 @@ pub(crate) fn store_block_full(s: &[[uint64x2_t;5];5]) -> [[u8 let mut out1 = [0u8; 200]; store_block::(s,[&mut out0, &mut out1]); [out0, out1] - - // for i in 0..RATE/16 { - // let v0 = unsafe { vtrn1q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; - // let v1 = unsafe { vtrn2q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; - // unsafe { vst1q_u64(out[0][offset+16*i..offset+16*(i+1)].as_mut_ptr() as *mut u64, v0) }; - // unsafe { vst1q_u64(out[1][offset+16*i..offset+16*(i+1)].as_mut_ptr() as *mut u64, v1) }; - // } - // if RATE%16 != 0 { - // debug_assert!(RATE % 8 == 0); - // let i = (RATE/8 - 1)/5; - // let j = (RATE/8 - 1)%5; - // let mut u = [0u8;16]; - // unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s[i][j])}; - // out[0][offset+RATE-8..offset+RATE].copy_from_slice(&u[0..8]); - // out[1][offset+RATE-8..offset+RATE].copy_from_slice(&u[8..16]); - // } } -fn slice_n(a: [&[u8];2], start:usize, len:usize) -> [&[u8];2] { +#[inline(always)] +fn slice_2(a: [&[u8];2], start:usize, len:usize) -> [&[u8];2] { [&a[0][start..start+len], &a[1][start..start+len]] } +#[inline(always)] fn split_at_mut_2(out: [&mut [u8]; 2], mid:usize) -> ([&mut [u8];2],[&mut [u8];2]) { let [out0, out1] = out; let (out00,out01) = out0.split_at_mut(mid); @@ -135,54 +121,55 @@ fn split_at_mut_2(out: [&mut [u8]; 2], mid:usize) -> ([&mut [u8];2],[&mut [u8];2 } impl KeccakItem<2> for uint64x2_t { + #[inline(always)] fn zero() -> Self { unsafe {vdupq_n_u64(0)} } - + #[inline(always)] fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { _veor5q_u64(a, b, c, d, e) } - + #[inline(always)] fn rotate_left1_and_xor(a: Self, b: Self) -> Self { _vrax1q_u64(a, b) } - + #[inline(always)] fn xor_and_rotate(a: Self, b: Self) -> Self { _vxarq_u64::(a, b) } - + #[inline(always)] fn and_not_xor(a: Self, b: Self, c: Self) -> Self { _vbcaxq_u64(a, b, c) } - + #[inline(always)] fn xor_constant(a: Self, c: u64) -> Self { _veorq_n_u64(a, c) } - + #[inline(always)] fn xor(a: Self, b: Self) -> Self { unsafe {veorq_u64(a, b)} } - + #[inline(always)] fn load_block(a:&mut [[Self;5];5], b:[&[u8];2]) { load_block::(a, b) } - + #[inline(always)] fn store_block(a:& [[Self;5];5], b:[&mut [u8];2]) { store_block::(a, b) } - + #[inline(always)] fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];2]) { load_block_full::(a, b) } - + #[inline(always)] fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];2] { store_block_full::(a) } - + #[inline(always)] fn slice_n(a:[&[u8];2],start:usize,len:usize) -> [&[u8];2] { - slice_n(a,start,len) + slice_2(a,start,len) } - + #[inline(always)] fn split_at_mut_n(a:[&mut [u8];2],mid:usize) -> ([&mut [u8];2],[&mut [u8];2]) { split_at_mut_2(a, mid) } diff --git a/libcrux-sha3/tests/sha3.rs b/libcrux-sha3/tests/sha3.rs index 342d229db..0fff1b146 100644 --- a/libcrux-sha3/tests/sha3.rs +++ b/libcrux-sha3/tests/sha3.rs @@ -9,7 +9,6 @@ fn sha3_kat_oneshot() { assert_eq!(hex::encode(&dshake), expectedshake); } -#[cfg(feature = "simd128")] #[test] fn sha3_simd_kat_oneshot() { let d256 = libcrux_sha3::rust_simd::sha3_256(b"Hello, World!"); From 2db7fb1cfbf2c679609b06c0212af83341b41734 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Mon, 13 May 2024 23:16:43 +0200 Subject: [PATCH 10/59] some bugfixes in portable --- libcrux-ml-kem/src/hash_functions.rs | 34 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index b7b3c8623..ac4d16737 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -4,22 +4,46 @@ use crate::constants::H_DIGEST_SIZE; use libcrux_sha3::rust_simd::{self, KeccakState4}; +#[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn G(input: &[u8]) -> [u8; 64] { - //rust_simd::sha3_512(input) - libcrux_sha3::sha512(input) + rust_simd::sha3_512(input) +} + +#[cfg(not(feature = "simd128"))] +#[inline(always)] +pub(crate) fn G(input: &[u8]) -> [u8; 64] { + libcrux_sha3::sha512(input) + //some bug in scalar version of rust_simd + // rust_simd::sha512(input) +} + +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + rust_simd::sha3_256(input) } +#[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - //rust_simd::sha3_256(input) - libcrux_sha3::sha256(input) + libcrux_sha3::sha256(input) + //some bug in scalar version of rust_simd + // rust_simd::sha256(input) } +#[cfg(feature = "simd128")] +#[inline(always)] +pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { + rust_simd::shake256::(input) +} + +#[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { - //rust_simd::shake256::(input) libcrux_sha3::shake256::(input) + //some bug in scalar version of rust_simd + // rust_simd::shake256::(input) } #[cfg(feature = "simd128")] From 3323b7a364b4216bf8b1b325d3d2db5eba6e48b4 Mon Sep 17 00:00:00 2001 From: xvzcf Date: Tue, 14 May 2024 02:50:38 +0200 Subject: [PATCH 11/59] Failing implementation of AVX2 rejection sampling. --- polynomials-avx2/src/lib.rs | 62 +++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 1c4cde0da..62e3046fc 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -547,12 +547,12 @@ fn serialize_1(v: SIMD256Vector) -> [u8; 2] { let mut serialized = [0u8; 2]; let bits_packed = unsafe { - let lsb_shifted_up = _mm256_slli_epi16(v.elements, 7); + let lsb_shifted_up = _mm256_slli_epi16(v.elements, 15); let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); - let msbs = _mm_packus_epi16(low_lanes, high_lanes); + let msbs = _mm_packs_epi16(low_lanes, high_lanes); _mm_movemask_epi8(msbs) }; @@ -984,8 +984,58 @@ fn deserialize_12(v: &[u8]) -> SIMD256Vector { } #[inline(always)] -fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - portable::rej_sample(a) +fn rej_sample(uniform_bytes: &[u8]) -> (usize, [i16; 16]) { + let mut sampled = [0i16; 16]; + + let count = unsafe { + let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); + let ones = _mm_set1_epi8(1); + + let potential_coefficients = deserialize_12(uniform_bytes).elements; + + let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); + let good = serialize_1(SIMD256Vector { elements: compare_with_field_modulus }); + + // Write out the indices indicated by the set bits of |good| such that + // the "good" elements can be read in sequence from |potential_coefficients| + + // Start with the first 8 bits, i.e. |good[0]| + let byte_start_indices = _pdep_u64(good[0] as u64, 0x0101010101010101) as u128; + let byte_start_indices = ((byte_start_indices << 8) - byte_start_indices) as u64; + let byte_start_indices = _pext_u64(0x0E0C0A0806040200, byte_start_indices); + + let byte_shuffle_indices_first_byte = _mm_cvtsi64_si128(byte_start_indices as i64); + let byte_shuffle_indices_second_byte = _mm_add_epi8(byte_shuffle_indices_first_byte, ones); + + let byte_shuffle_indices_low = _mm_unpacklo_epi8(byte_shuffle_indices_first_byte, byte_shuffle_indices_second_byte); + + // Then the next 8 bits, i.e. |good[1]| + let byte_start_indices = _pdep_u64(good[1] as u64, 0x0101010101010101) as u128; + let byte_start_indices = ((byte_start_indices << 8) - byte_start_indices) as u64; + let byte_start_indices = _pext_u64(0x0E0C0A0806040200, byte_start_indices); + + let byte_shuffle_indices_first_byte = _mm_cvtsi64_si128(byte_start_indices as i64); + let byte_shuffle_indices_second_byte = _mm_add_epi8(byte_shuffle_indices_first_byte, ones); + + let byte_shuffle_indices_high = _mm_unpacklo_epi8(byte_shuffle_indices_first_byte, byte_shuffle_indices_second_byte); + + // Write out the indices to an __m256 and then shuffle + let byte_shuffle_indices = _mm256_castsi128_si256(byte_shuffle_indices_low); + let byte_shuffle_indices = _mm256_inserti128_si256(byte_shuffle_indices, byte_shuffle_indices_high, 1); + + let coefficients = _mm256_shuffle_epi8(potential_coefficients, byte_shuffle_indices); + + // Write out the elements themselves + _mm256_storeu_si256(sampled.as_mut_ptr() as *mut __m256i, coefficients); + + // Count the sampled elements + let count_sampled = good[0].count_ones() + good[1].count_ones(); + + count_sampled + }; + + (count as usize, sampled) + //portable::rej_sample(uniform_bytes) } impl Operations for SIMD256Vector { @@ -1132,7 +1182,7 @@ impl Operations for SIMD256Vector { deserialize_12(a) } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rej_sample(a) + fn rej_sample(uniform_bytes: &[u8]) -> (usize, [i16; 16]) { + rej_sample(uniform_bytes) } } From 08b6022ef0f95caf61f5431f32a54ef0225ffac9 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Tue, 14 May 2024 09:18:25 +0200 Subject: [PATCH 12/59] wip avx2 --- libcrux-sha3/src/rust_simd.rs | 2 + libcrux-sha3/src/rust_simd/sha3_avx2.rs | 228 ++++++++++++++++++++ libcrux-sha3/src/rust_simd/sha3_portable.rs | 137 ++++++++++++ 3 files changed, 367 insertions(+) create mode 100644 libcrux-sha3/src/rust_simd/sha3_avx2.rs create mode 100644 libcrux-sha3/src/rust_simd/sha3_portable.rs diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index f7e36fba1..87a5b9700 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -21,6 +21,8 @@ fn keccakx2(data:[&[u8];2],out:[&mut[u8];2]) { #[cfg(feature = "simd128")] pub type KeccakState4 = [KeccakState2; 2]; +#[cfg(feature = "simd256")] +mod sha3_avx2; #[cfg(not(feature = "simd128"))] pub type KeccakState2 = [KeccakState1; 2]; diff --git a/libcrux-sha3/src/rust_simd/sha3_avx2.rs b/libcrux-sha3/src/rust_simd/sha3_avx2.rs new file mode 100644 index 000000000..6aa90d11a --- /dev/null +++ b/libcrux-sha3/src/rust_simd/sha3_avx2.rs @@ -0,0 +1,228 @@ +use core::arch::x86_64::*; +use libcrux_hacl::__m256i; + +use crate::rust_simd::sha3_trait::*; + +// This file optimizes for the stable Rust Neon Intrinsics +// If we want to use the unstable neon-sha3 instructions, we could use: +// veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 +// These instructions might speed up our code even more. + + +#[inline(always)] +fn rotate_left(x:__m256i) -> __m256i { + debug_assert!(LEFT+RIGHT == 64); + // XXX: This could be done more efficiently, if the shift values are multiples of 8. + unsafe { _mm256_xor_si256(_mm256_slli_epi64::(x), _mm256_srli_epi64::(x)) } +} + +#[inline(always)] +fn _veor5q_u64(a: __m256i, b: __m256i, c: __m256i, d: __m256i, e: __m256i) -> __m256i { + let ab = unsafe { _mm256_xor_si256(a, b) }; + let cd = unsafe { _mm256_xor_si256(c, d) }; + let abcd = unsafe { _mm256_xor_si256(ab, cd) }; + unsafe { _mm256_xor_si256(abcd, e) } +} + +#[inline(always)] +fn _vrax1q_u64(a: __m256i, b: __m256i) -> __m256i { + unsafe { _mm256_xor_si256(a, rotate_left::<1,63>(b)) } +} + +#[inline(always)] +fn _vxarq_u64(a: __m256i, b: __m256i) -> __m256i { + let ab = unsafe { _mm256_xor_si256(a, b) }; + rotate_left::(ab) +} + +#[inline(always)] +fn _vbcaxq_u64(a: __m256i, b: __m256i, c: __m256i) -> __m256i { + unsafe{ _mm256_xor_si256(a, _mm256_andnot_si256(b, c)) } +} + +#[inline(always)] +fn _veorq_n_u64(a: __m256i, c: u64) -> __m256i { + // Casting here is required, doesn't change the value. + let c = unsafe { _mm256_set1_epi64x(c as i64) }; + unsafe { _mm256_xor_si256(a, c) } +} + + +#[inline(always)] +pub(crate) fn load_block(s: &mut [[__m256i;5];5], blocks: [&[u8];4]) { + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && (RATE % 32 == 8 || RATE % 32 == 16)); + for i in 0..RATE/32 { + let v0 = unsafe { _mm256_loadu_epi64(blocks[0][start..32*(i+1)].as_ptr() as *const i64)}; + let v1 = unsafe { _mm256_loadu_epi64(blocks[1][start..32*(i+1)].as_ptr() as *const i64)}; + let v2 = unsafe { _mm256_loadu_epi64(blocks[2][start..32*(i+1)].as_ptr() as *const i64)}; + let v3 = unsafe { _mm256_loadu_epi64(blocks[3][start..32*(i+1)].as_ptr() as *const i64)}; + + let v0l = unsafe { _mm256_unpacklo_epi64(v0,v1) }; // 0 0 2 2 + let v1h = unsafe { _mm256_unpackhi_epi64(v0,v1) }; // 1 1 3 3 + let v2l = unsafe { _mm256_unpacklo_epi64(v2,v3) }; // 0 0 2 2 + let v3h = unsafe { _mm256_unpackhi_epi64(v2,v3) }; // 1 1 3 3 + + let v0 = unsafe { _mm256_permute2x128_epi256(v0l,v2l,0x20) }; // 0 0 0 0 + let v1 = unsafe { _mm256_permute2x128_epi256(v1h,v3h,0x20) }; // 1 1 1 1 + let v2 = unsafe { _mm256_permute2x128_epi256(v0l,v2l,0x31) }; // 2 2 2 2 + let v3 = unsafe { _mm256_permute2x128_epi256(v1h,v3h,0x31) }; // 3 3 3 3 + + s[(4*i)/5][(4*i)%5] = unsafe { veorq_u64(s[(4*i)/5][(4*i)%5], v0) }; + s[(4*i+1)/5][(4*i+1)%5] = unsafe { veorq_u64(s[(4*i+1)/5][(4*i+1)%5], v1) }; + s[(4*i+2)/5][(4*i+2)%5] = unsafe { veorq_u64(s[(4*i+2)/5][(4*i+2)%5], v2) }; + s[(4*i+3)/5][(4*i+3)%5] = unsafe { veorq_u64(s[(4*i+3)/5][(4*i+3)%5], v3) }; + } + + let rem = RATE%32; // has to be 8 or 16 + let start = 32 * (RATE/32); + let u8s = [0u8;32]; + u8s[0..8].copy_from_slice(&blocks[0][start..start+8]); + u8s[8..16].copy_from_slice(&blocks[1][start..start+8]); + u8s[16..24].copy_from_slice(&blocks[2][start..start+8]); + u8s[24..32].copy_from_slice(&blocks[3][start..start+8]); + let u = unsafe { _mm256_loadu_epi64(u8s.as_ptr() as *const i64)}; + let i = (4*(RATE/32))/5; + let j = (4*(RATE/32))%5; + s[i][j] = unsafe { veorq_u64(s[i][j], u)}; + if rem == 16 { + let u8s = [0u8;32]; + u8s[0..8].copy_from_slice(&blocks[0][start+8..start+16]); + u8s[8..16].copy_from_slice(&blocks[1][start+8..start+16]); + u8s[16..24].copy_from_slice(&blocks[2][start+8..start+16]); + u8s[24..32].copy_from_slice(&blocks[3][start+8..start+16]); + let u = unsafe { _mm256_loadu_epi64(u8s.as_ptr() as *const i64)}; + let i = (4*(RATE/32) + 1)/5; + let j = (4*(RATE/32) + 1)%5; + s[i][j] = unsafe { veorq_u64(s[i][j], u)}; + } +} + +#[inline(always)] +pub(crate) fn load_block_full(s: &mut [[__m256i;5];5], blocks: [[u8;200];4]) { + let [b0,b1,b2,b3] = blocks; + load_block::(s,[&b0 as &[u8], &b1 as &[u8], &b2 as &[u8], &b3 as &[u8]]); +} + +#[inline(always)] +pub(crate) fn store_block(s: &[[__m256i;5];5], out: [&mut [u8];2]) { + for i in 0..RATE/32 { + let v0l = unsafe { _mm256_permute2x128_epi256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x20) }; // 0 0 2 2 + let v1h = unsafe { _mm256_permute2x128_epi256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x20) }; // 1 1 3 3 + let v2l = unsafe { _mm256_permute2x128_epi256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x31) }; // 0 0 2 2 + let v3h = unsafe { _mm256_permute2x128_epi256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x31) }; // 1 1 3 3 + + + let v0 = unsafe { _mm256_unpacklo_epi64(v0l, v1h) }; // 0 1 2 3 + let v1 = unsafe { _mm256_unpackhi_epi64(v0l, v1h) }; // 0 1 2 3 + let v2 = unsafe { _mm256_unpacklo_epi64(v2l, v3h) }; // 0 1 2 3 + let v3 = unsafe { _mm256_unpackhi_epi64(v2l, v3h) }; // 0 1 2 3 + + unsafe { _mm256_storeu_epi64(out[0][start..32*(i+1)].as_mut_ptr() as *mut i64, v0) }; + unsafe { _mm256_storeu_epi64(out[1][start..32*(i+1)].as_mut_ptr() as *mut i64, v1) }; + unsafe { _mm256_storeu_epi64(out[2][start..32*(i+1)].as_mut_ptr() as *mut i64, v2) }; + unsafe { _mm256_storeu_epi64(out[3][start..32*(i+1)].as_mut_ptr() as *mut i64, v3) }; + } + + let rem = RATE%32; // has to be 8 or 16 + let start = 32 * (RATE/32); + let u8s = [0u8;32]; + let i = (4*(RATE/32))/5; + let j = (4*(RATE/32))%5; + unsafe { _mm256_storeu_epi64(u8s.as_mut_ptr() as *const i64, s[i][j])}; + out[0][start..start+8].copy_from_slice(&u8s[0..8]); + out[1][start..start+8].copy_from_slice(&u8s[8..16]); + out[2][start..start+8].copy_from_slice(&u8s[16..24]); + out[3][start..start+8].copy_from_slice(&u8s[24..32]); + if rem == 16 { + let u8s = [0u8;32]; + let i = (4*(RATE/32) + 1)/5; + let j = (4*(RATE/32) + 1)%5; + unsafe { _mm256_storeu_epi64(u8s.as_mut_ptr() as *const i64, s[i][j])}; + out[0][start+8..start+16].copy_from_slice(&u8s[0..8]); + out[1][start+8..start+16].copy_from_slice(&u8s[8..16]); + out[2][start+8..start+16].copy_from_slice(&u8s[16..24]); + out[3][start+8..start+16].copy_from_slice(&u8s[24..32]); + } +} + +#[inline(always)] +pub(crate) fn store_block_full(s: &[[__m256i;5];5]) -> [[u8;200];4] { + let mut out0 = [0u8; 200]; + let mut out1 = [0u8; 200]; + let mut out2 = [0u8; 200]; + let mut out3 = [0u8; 200]; + store_block::(s,[&mut out0, &mut out1, &mut out2, &mut out3]); + [out0, out1, out2, out3] +} + +#[inline(always)] +fn slice_4(a: [&[u8];4], start:usize, len:usize) -> [&[u8];4] { + [&a[0][start..start+len], &a[1][start..start+len], &a[2][start..start+len], &a[3][start..start+len]] +} + +#[inline(always)] +fn split_at_mut_4(out: [&mut [u8]; 4], mid:usize) -> ([&mut [u8];4],[&mut [u8];4]) { + let [out0, out1, out2, out3] = out; + let (out00,out01) = out0.split_at_mut(mid); + let (out10,out11) = out1.split_at_mut(mid); + let (out20,out21) = out2.split_at_mut(mid); + let (out30,out31) = out3.split_at_mut(mid); + ([out00,out10,out20,out30], + [out01,out11,out21,out31]) +} + +impl KeccakItem<4> for __m256i { + #[inline(always)] + fn zero() -> Self { + unsafe {vdupq_n_u64(0)} + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + unsafe {veorq_u64(a, b)} + } + #[inline(always)] + fn load_block(a:&mut [[Self;5];5], b:[&[u8];4]) { + load_block::(a, b) + } + #[inline(always)] + fn store_block(a:& [[Self;5];5], b:[&mut [u8];4]) { + store_block::(a, b) + } + #[inline(always)] + fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];4]) { + load_block_full::(a, b) + } + #[inline(always)] + fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];4] { + store_block_full::(a) + } + #[inline(always)] + fn slice_n(a:[&[u8];4],start:usize,len:usize) -> [&[u8];4] { + slice_4(a,start,len) + } + #[inline(always)] + fn split_at_mut_n(a:[&mut [u8];4],mid:usize) -> ([&mut [u8];2],[&mut [u8];4]) { + split_at_mut_4(a, mid) + } +} + diff --git a/libcrux-sha3/src/rust_simd/sha3_portable.rs b/libcrux-sha3/src/rust_simd/sha3_portable.rs new file mode 100644 index 000000000..f837470db --- /dev/null +++ b/libcrux-sha3/src/rust_simd/sha3_portable.rs @@ -0,0 +1,137 @@ +use crate::rust_simd::sha3_trait::*; + +// This file optimizes for the stable Rust Neon Intrinsics +// If we want to use the unstable neon-sha3 instructions, we could use: +// veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 +// These instructions might speed up our code even more. + + +#[inline(always)] +fn rotate_left(x:u64) -> u64 { + debug_assert!(LEFT+RIGHT == 64); + (x << LEFT) | (x >> RIGHT) +} + +#[inline(always)] +fn _veor5q_u64(a: u64, b: u64, c: u64, d: u64, e: u64) -> u64 { + let ab = a ^ b; + let cd = c ^ d; + let abcd = ab ^ cd; + abcd ^ e +} + +#[inline(always)] +fn _vrax1q_u64(a: u64, b: u64) -> u64 { + a ^ rotate_left::<1,63>(b) +} + +#[inline(always)] +fn _vxarq_u64(a: u64, b: u64) -> u64 { + let ab = a ^ b; + rotate_left::(ab) +} + +#[inline(always)] +fn _vbcaxq_u64(a: u64, b: u64, c: u64) -> u64 { + a ^ (b & !c) +} + +#[inline(always)] +fn _veorq_n_u64(a: u64, c: u64) -> u64 { + a ^ c +} + +#[inline(always)] +pub(crate) fn load_block(s: &mut [[u64;5];5], blocks: [&[u8];1]) { + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); + for i in 0..RATE/8 { + s[i/5][i%5] = u64::from_le_bytes(blocks[0][8*i..8*i+8].try_into().unwrap()); + } +} + +#[inline(always)] +pub(crate) fn load_block_full(s: &mut [[u64;5];5], blocks: [[u8;200];1]) { + load_block::(s,[&blocks[0] as &[u8]]); +} + +#[inline(always)] +pub(crate) fn store_block(s: &[[u64;5];5], out: [&mut [u8];1]) { + for i in 0..RATE/8 { + out[0][8*i..8*i+8].copy_from_slice(&s[i/5][i%5].to_le_bytes()); + } +} + +#[inline(always)] +pub(crate) fn store_block_full(s: &[[u64;5];5]) -> [[u8;200];1] { + let mut out = [0u8; 200]; + store_block::(s,[&mut out]); + [out] +} + +#[inline(always)] +fn slice_1(a: [&[u8];1], start:usize, len:usize) -> [&[u8];1] { + [&a[0][start..start+len]] +} + +#[inline(always)] +fn split_at_mut_1(out: [&mut [u8]; 1], mid:usize) -> ([&mut [u8];1],[&mut [u8];1]) { + let [out0] = out; + let (out00,out01) = out0.split_at_mut(mid); + ([out00], [out01]) +} + +impl KeccakItem<1> for u64 { + #[inline(always)] + fn zero() -> Self { + 0 + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + a^b + } + #[inline(always)] + fn load_block(a:&mut [[Self;5];5], b:[&[u8];1]) { + load_block::(a, b) + } + #[inline(always)] + fn store_block(a:& [[Self;5];5], b:[&mut [u8];1]) { + store_block::(a, b) + } + #[inline(always)] + fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];1]) { + load_block_full::(a, b) + } + #[inline(always)] + fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];1] { + store_block_full::(a) + } + #[inline(always)] + fn slice_n(a:[&[u8];1],start:usize,len:usize) -> [&[u8];1] { + slice_1(a,start,len) + } + #[inline(always)] + fn split_at_mut_n(a:[&mut [u8];1],mid:usize) -> ([&mut [u8];1],[&mut [u8];1]) { + split_at_mut_1(a, mid) + } +} + From bbe2c8ec9bf31ed90f382053d7daa24ed7bff81e Mon Sep 17 00:00:00 2001 From: Karthik Bhargavan Date: Tue, 14 May 2024 09:49:39 +0200 Subject: [PATCH 13/59] avx2 --- libcrux-sha3/src/rust_simd.rs | 46 ++++++++++++----- libcrux-sha3/src/rust_simd/sha3_avx2.rs | 69 ++++++++++++------------- 2 files changed, 68 insertions(+), 47 deletions(-) diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index 87a5b9700..f456567b1 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -23,10 +23,12 @@ pub type KeccakState4 = [KeccakState2; 2]; #[cfg(feature = "simd256")] mod sha3_avx2; +#[cfg(feature = "simd256")] +pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128", feature = "simd256")))] pub type KeccakState2 = [KeccakState1; 2]; -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128", feature = "simd256")))] pub type KeccakState4 = [KeccakState1; 4]; @@ -124,13 +126,18 @@ pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8] keccakx1::<136, 0x1fu8>([input1], [out1]); } +#[cfg(feature = "simd256")] +pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], + out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { + keccak::<4,core::arch::x86_64::__m256i,136, 0x1fu8>([input0, input1, input2, input3], [out0, out1, out2, out3]); +} #[cfg(feature = "simd128")] pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { keccakx1::<136, 0x1fu8>([input0], [out0]); @@ -161,7 +168,7 @@ pub fn shake128_squeeze_next_block(s: &mut KeccakState1, out0: &mut [u8]) { pub fn shake128x2_init() -> KeccakState2 { KeccakState2::new() } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x2_init() -> KeccakState2 { let s0 = KeccakState1::new(); let s1 = KeccakState1::new(); @@ -172,7 +179,7 @@ pub fn shake128x2_init() -> KeccakState2 { pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) { absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(s,[data0,data1]); } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) { let [mut s0, mut s1] = s; shake128_absorb_final(&mut s0, data0); @@ -183,7 +190,7 @@ pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8], out1:&mut [u8]) { squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8], out1:&mut [u8]) { let [mut s0, mut s1] = s; shake128_squeeze_first_three_blocks(&mut s0, out0); @@ -194,21 +201,24 @@ pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8 pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { let [mut s0, mut s1] = s; shake128_squeeze_next_block(&mut s0, out0); shake128_squeeze_next_block(&mut s1, out1); } - +#[cfg(feature = "simd256")] +pub fn shake128x4_init() -> KeccakState4 { + KeccakState4::new() +} #[cfg(feature = "simd128")] pub fn shake128x4_init() -> KeccakState4 { let s0 = KeccakState2::new(); let s1 = KeccakState2::new(); [s0,s1] } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x4_init() -> KeccakState4 { let s0 = KeccakState1::new(); let s1 = KeccakState1::new(); @@ -217,13 +227,17 @@ pub fn shake128x4_init() -> KeccakState4 { [s0,s1,s2,s3] } +#[cfg(feature = "simd128")] +pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { + absorb_final::<4,core::arch::x86_64::__m256i,168, 0x1fu8>(s,[data0,data1,data2,data3]); +} #[cfg(feature = "simd128")] pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { let [mut s0, mut s1] = s; absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(&mut s0,[data0,data1]); absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(&mut s1,[data2,data3]); } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { let [mut s0, mut s1, mut s2, mut s3] = s; shake128_absorb_final(&mut s0, data0); @@ -232,13 +246,17 @@ pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], shake128_absorb_final(&mut s3, data3); } +#[cfg(feature = "simd256")] +pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { + squeeze_first_three_blocks::<4,core::arch::x86_64::__m256i,168>(s, [out0, out1, out2, out3]); +} #[cfg(feature = "simd128")] pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { let [mut s0, mut s1] = s; squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(&mut s0, [out0, out1]); squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(&mut s1, [out2, out3]); } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { let [mut s0, mut s1, mut s2, mut s3] = s; shake128_squeeze_first_three_blocks(&mut s0, out0); @@ -247,13 +265,17 @@ pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8 shake128_squeeze_first_three_blocks(&mut s3, out3); } +#[cfg(feature = "simd128")] +pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { + squeeze_next_block::<4,core::arch::x86_64::__m256i,168>(&mut s0, [out0, out1, out2, out3]); +} #[cfg(feature = "simd128")] pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { let [mut s0, mut s1] = s; squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(&mut s0, [out0, out1]); squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(&mut s1, [out2, out3]); } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd128",feature = "simd256")))] pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { let [mut s0, mut s1, mut s2, mut s3] = s; shake128_squeeze_next_block(&mut s0, out0); diff --git a/libcrux-sha3/src/rust_simd/sha3_avx2.rs b/libcrux-sha3/src/rust_simd/sha3_avx2.rs index 6aa90d11a..ed47f93c5 100644 --- a/libcrux-sha3/src/rust_simd/sha3_avx2.rs +++ b/libcrux-sha3/src/rust_simd/sha3_avx2.rs @@ -1,5 +1,4 @@ use core::arch::x86_64::*; -use libcrux_hacl::__m256i; use crate::rust_simd::sha3_trait::*; @@ -52,48 +51,48 @@ fn _veorq_n_u64(a: __m256i, c: u64) -> __m256i { pub(crate) fn load_block(s: &mut [[__m256i;5];5], blocks: [&[u8];4]) { debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && (RATE % 32 == 8 || RATE % 32 == 16)); for i in 0..RATE/32 { - let v0 = unsafe { _mm256_loadu_epi64(blocks[0][start..32*(i+1)].as_ptr() as *const i64)}; - let v1 = unsafe { _mm256_loadu_epi64(blocks[1][start..32*(i+1)].as_ptr() as *const i64)}; - let v2 = unsafe { _mm256_loadu_epi64(blocks[2][start..32*(i+1)].as_ptr() as *const i64)}; - let v3 = unsafe { _mm256_loadu_epi64(blocks[3][start..32*(i+1)].as_ptr() as *const i64)}; + let v0 = unsafe { _mm256_loadu_si256(blocks[0][32*i..32*(i+1)].as_ptr() as *const __m256i)}; + let v1 = unsafe { _mm256_loadu_si256(blocks[1][32*i..32*(i+1)].as_ptr() as *const __m256i)}; + let v2 = unsafe { _mm256_loadu_si256(blocks[2][32*i..32*(i+1)].as_ptr() as *const __m256i)}; + let v3 = unsafe { _mm256_loadu_si256(blocks[3][32*i..32*(i+1)].as_ptr() as *const __m256i)}; let v0l = unsafe { _mm256_unpacklo_epi64(v0,v1) }; // 0 0 2 2 let v1h = unsafe { _mm256_unpackhi_epi64(v0,v1) }; // 1 1 3 3 let v2l = unsafe { _mm256_unpacklo_epi64(v2,v3) }; // 0 0 2 2 let v3h = unsafe { _mm256_unpackhi_epi64(v2,v3) }; // 1 1 3 3 - let v0 = unsafe { _mm256_permute2x128_epi256(v0l,v2l,0x20) }; // 0 0 0 0 - let v1 = unsafe { _mm256_permute2x128_epi256(v1h,v3h,0x20) }; // 1 1 1 1 - let v2 = unsafe { _mm256_permute2x128_epi256(v0l,v2l,0x31) }; // 2 2 2 2 - let v3 = unsafe { _mm256_permute2x128_epi256(v1h,v3h,0x31) }; // 3 3 3 3 + let v0 = unsafe { _mm256_permute2x128_si256(v0l,v2l,0x20) }; // 0 0 0 0 + let v1 = unsafe { _mm256_permute2x128_si256(v1h,v3h,0x20) }; // 1 1 1 1 + let v2 = unsafe { _mm256_permute2x128_si256(v0l,v2l,0x31) }; // 2 2 2 2 + let v3 = unsafe { _mm256_permute2x128_si256(v1h,v3h,0x31) }; // 3 3 3 3 - s[(4*i)/5][(4*i)%5] = unsafe { veorq_u64(s[(4*i)/5][(4*i)%5], v0) }; - s[(4*i+1)/5][(4*i+1)%5] = unsafe { veorq_u64(s[(4*i+1)/5][(4*i+1)%5], v1) }; - s[(4*i+2)/5][(4*i+2)%5] = unsafe { veorq_u64(s[(4*i+2)/5][(4*i+2)%5], v2) }; - s[(4*i+3)/5][(4*i+3)%5] = unsafe { veorq_u64(s[(4*i+3)/5][(4*i+3)%5], v3) }; + s[(4*i)/5][(4*i)%5] = unsafe { _mm256_xor_si256(s[(4*i)/5][(4*i)%5], v0) }; + s[(4*i+1)/5][(4*i+1)%5] = unsafe { _mm256_xor_si256(s[(4*i+1)/5][(4*i+1)%5], v1) }; + s[(4*i+2)/5][(4*i+2)%5] = unsafe { _mm256_xor_si256(s[(4*i+2)/5][(4*i+2)%5], v2) }; + s[(4*i+3)/5][(4*i+3)%5] = unsafe { _mm256_xor_si256(s[(4*i+3)/5][(4*i+3)%5], v3) }; } let rem = RATE%32; // has to be 8 or 16 let start = 32 * (RATE/32); - let u8s = [0u8;32]; + let mut u8s = [0u8;32]; u8s[0..8].copy_from_slice(&blocks[0][start..start+8]); u8s[8..16].copy_from_slice(&blocks[1][start..start+8]); u8s[16..24].copy_from_slice(&blocks[2][start..start+8]); u8s[24..32].copy_from_slice(&blocks[3][start..start+8]); - let u = unsafe { _mm256_loadu_epi64(u8s.as_ptr() as *const i64)}; + let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i)}; let i = (4*(RATE/32))/5; let j = (4*(RATE/32))%5; - s[i][j] = unsafe { veorq_u64(s[i][j], u)}; + s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u)}; if rem == 16 { - let u8s = [0u8;32]; + let mut u8s = [0u8;32]; u8s[0..8].copy_from_slice(&blocks[0][start+8..start+16]); u8s[8..16].copy_from_slice(&blocks[1][start+8..start+16]); u8s[16..24].copy_from_slice(&blocks[2][start+8..start+16]); u8s[24..32].copy_from_slice(&blocks[3][start+8..start+16]); - let u = unsafe { _mm256_loadu_epi64(u8s.as_ptr() as *const i64)}; + let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i)}; let i = (4*(RATE/32) + 1)/5; let j = (4*(RATE/32) + 1)%5; - s[i][j] = unsafe { veorq_u64(s[i][j], u)}; + s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u)}; } } @@ -104,12 +103,12 @@ pub(crate) fn load_block_full(s: &mut [[__m256i;5];5], blocks: } #[inline(always)] -pub(crate) fn store_block(s: &[[__m256i;5];5], out: [&mut [u8];2]) { +pub(crate) fn store_block(s: &[[__m256i;5];5], out: [&mut [u8];4]) { for i in 0..RATE/32 { - let v0l = unsafe { _mm256_permute2x128_epi256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x20) }; // 0 0 2 2 - let v1h = unsafe { _mm256_permute2x128_epi256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x20) }; // 1 1 3 3 - let v2l = unsafe { _mm256_permute2x128_epi256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x31) }; // 0 0 2 2 - let v3h = unsafe { _mm256_permute2x128_epi256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x31) }; // 1 1 3 3 + let v0l = unsafe { _mm256_permute2x128_si256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x20) }; // 0 0 2 2 + let v1h = unsafe { _mm256_permute2x128_si256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x20) }; // 1 1 3 3 + let v2l = unsafe { _mm256_permute2x128_si256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x31) }; // 0 0 2 2 + let v3h = unsafe { _mm256_permute2x128_si256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x31) }; // 1 1 3 3 let v0 = unsafe { _mm256_unpacklo_epi64(v0l, v1h) }; // 0 1 2 3 @@ -117,27 +116,27 @@ pub(crate) fn store_block(s: &[[__m256i;5];5], out: [&mut [u8] let v2 = unsafe { _mm256_unpacklo_epi64(v2l, v3h) }; // 0 1 2 3 let v3 = unsafe { _mm256_unpackhi_epi64(v2l, v3h) }; // 0 1 2 3 - unsafe { _mm256_storeu_epi64(out[0][start..32*(i+1)].as_mut_ptr() as *mut i64, v0) }; - unsafe { _mm256_storeu_epi64(out[1][start..32*(i+1)].as_mut_ptr() as *mut i64, v1) }; - unsafe { _mm256_storeu_epi64(out[2][start..32*(i+1)].as_mut_ptr() as *mut i64, v2) }; - unsafe { _mm256_storeu_epi64(out[3][start..32*(i+1)].as_mut_ptr() as *mut i64, v3) }; + unsafe { _mm256_storeu_si256(out[0][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v0) }; + unsafe { _mm256_storeu_si256(out[1][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v1) }; + unsafe { _mm256_storeu_si256(out[2][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v2) }; + unsafe { _mm256_storeu_si256(out[3][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v3) }; } let rem = RATE%32; // has to be 8 or 16 let start = 32 * (RATE/32); - let u8s = [0u8;32]; + let mut u8s = [0u8;32]; let i = (4*(RATE/32))/5; let j = (4*(RATE/32))%5; - unsafe { _mm256_storeu_epi64(u8s.as_mut_ptr() as *const i64, s[i][j])}; + unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j])}; out[0][start..start+8].copy_from_slice(&u8s[0..8]); out[1][start..start+8].copy_from_slice(&u8s[8..16]); out[2][start..start+8].copy_from_slice(&u8s[16..24]); out[3][start..start+8].copy_from_slice(&u8s[24..32]); if rem == 16 { - let u8s = [0u8;32]; + let mut u8s = [0u8;32]; let i = (4*(RATE/32) + 1)/5; let j = (4*(RATE/32) + 1)%5; - unsafe { _mm256_storeu_epi64(u8s.as_mut_ptr() as *const i64, s[i][j])}; + unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j])}; out[0][start+8..start+16].copy_from_slice(&u8s[0..8]); out[1][start+8..start+16].copy_from_slice(&u8s[8..16]); out[2][start+8..start+16].copy_from_slice(&u8s[16..24]); @@ -174,7 +173,7 @@ fn split_at_mut_4(out: [&mut [u8]; 4], mid:usize) -> ([&mut [u8];4],[&mut [u8];4 impl KeccakItem<4> for __m256i { #[inline(always)] fn zero() -> Self { - unsafe {vdupq_n_u64(0)} + unsafe { _mm256_set1_epi64x(0) } } #[inline(always)] fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { @@ -198,7 +197,7 @@ impl KeccakItem<4> for __m256i { } #[inline(always)] fn xor(a: Self, b: Self) -> Self { - unsafe {veorq_u64(a, b)} + unsafe {_mm256_xor_si256(a, b)} } #[inline(always)] fn load_block(a:&mut [[Self;5];5], b:[&[u8];4]) { @@ -221,7 +220,7 @@ impl KeccakItem<4> for __m256i { slice_4(a,start,len) } #[inline(always)] - fn split_at_mut_n(a:[&mut [u8];4],mid:usize) -> ([&mut [u8];2],[&mut [u8];4]) { + fn split_at_mut_n(a:[&mut [u8];4],mid:usize) -> ([&mut [u8];4],[&mut [u8];4]) { split_at_mut_4(a, mid) } } From 3f739cb004c5f6211a0525dc64d8b0208480064f Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 11:17:25 +0200 Subject: [PATCH 14/59] make it compile; format --- libcrux-ml-kem/src/hash_functions.rs | 155 ++++++----- libcrux-ml-kem/src/ind_cpa.rs | 15 +- libcrux-ml-kem/src/polynomial.rs | 2 +- libcrux-ml-kem/src/sampling.rs | 2 +- libcrux-sha3/benches/sha3.rs | 1 - libcrux-sha3/src/rust_simd.rs | 260 ++++++++++++++----- libcrux-sha3/src/rust_simd/sha3_arm64.rs | 148 ++++++----- libcrux-sha3/src/rust_simd/sha3_avx2.rs | 272 ++++++++++++-------- libcrux-sha3/src/rust_simd/sha3_generic.rs | 226 +++++++++------- libcrux-sha3/src/rust_simd/sha3_portable.rs | 60 +++-- libcrux-sha3/src/rust_simd/sha3_trait.rs | 16 +- libcrux-sha3/tests/sha3.rs | 6 +- polynomials-aarch64/src/lib.rs | 2 +- polynomials-aarch64/src/rejsample.rs | 2 +- polynomials-avx2/src/lib.rs | 16 +- polynomials-avx2/src/portable.rs | 5 +- polynomials/src/lib.rs | 8 +- traits/src/lib.rs | 2 +- 18 files changed, 732 insertions(+), 466 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index ac4d16737..136afc2c0 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -13,7 +13,7 @@ pub(crate) fn G(input: &[u8]) -> [u8; 64] { #[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn G(input: &[u8]) -> [u8; 64] { - libcrux_sha3::sha512(input) + libcrux_sha3::sha512(input) //some bug in scalar version of rust_simd // rust_simd::sha512(input) } @@ -27,7 +27,7 @@ pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { #[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - libcrux_sha3::sha256(input) + libcrux_sha3::sha256(input) //some bug in scalar version of rust_simd // rust_simd::sha256(input) } @@ -53,20 +53,23 @@ pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> let mut extra = [0u8; LEN]; match K { - 2 => { let (out0,out1) = out.split_at_mut(1); - rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); - } - 3 => { let (out0,out12) = out.split_at_mut(1); - let (out1,out2) = out12.split_at_mut(1); - rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); - rust_simd::shake256x2(&input[2], &input[2], &mut out2[0], &mut extra); - } - _ => { let (out0,out123) = out.split_at_mut(1); - let (out1,out23) = out123.split_at_mut(1); - let (out2,out3) = out23.split_at_mut(1); - rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); - rust_simd::shake256x2(&input[2], &input[3], &mut out2[0], &mut out3[0]); - } + 2 => { + let (out0, out1) = out.split_at_mut(1); + rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + rust_simd::shake256x2(&input[2], &input[2], &mut out2[0], &mut extra); + } + _ => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + rust_simd::shake256x2(&input[2], &input[3], &mut out2[0], &mut out3[0]); + } } out } @@ -86,16 +89,16 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { let mut states = rust_simd::shake128x4_init(); match K { 2 => { - rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); - }, + rust_simd::shake128x2_absorb_final(&mut states[0], &input[0], &input[1]); + } 3 => { - rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); - rust_simd::shake128x2_absorb_final(&mut states[1],&input[2],&input[2]); - }, + rust_simd::shake128x2_absorb_final(&mut states[0], &input[0], &input[1]); + rust_simd::shake128x2_absorb_final(&mut states[1], &input[2], &input[2]); + } _ => { - rust_simd::shake128x2_absorb_final(&mut states[0],&input[0],&input[1]); - rust_simd::shake128x2_absorb_final(&mut states[1],&input[2],&input[3]); - }, + rust_simd::shake128x2_absorb_final(&mut states[0], &input[0], &input[1]); + rust_simd::shake128x2_absorb_final(&mut states[1], &input[2], &input[3]); + } } states } @@ -107,44 +110,66 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { let mut states = rust_simd::shake128x4_init(); for i in 0..K { rust_simd::shake128_absorb_final(&mut states[i], &input[i]); - } + } states } - pub(crate) const BLOCK_SIZE: usize = 168; pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; #[cfg(feature = "simd128")] #[inline(always)] -pub(crate) fn squeeze_three_blocks ( +pub(crate) fn squeeze_three_blocks( state: &mut Shake128x4State, ) -> [[u8; THREE_BLOCKS]; K] { let mut out = [[0u8; THREE_BLOCKS]; K]; let mut extra = [0u8; THREE_BLOCKS]; match K { - 2 => { let (out0,out1) = out.split_at_mut(1); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0[0], &mut out1[0]); - } - 3 => { let (out0,out12) = out.split_at_mut(1); - let (out1,out2) = out12.split_at_mut(1); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0[0], &mut out1[0]); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2[0], &mut extra); - } - _ => { let (out0,out123) = out.split_at_mut(1); - let (out1,out23) = out123.split_at_mut(1); - let (out2,out3) = out23.split_at_mut(1); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[0], &mut out0[0], &mut out1[0]); - rust_simd::shake128x2_squeeze_first_three_blocks(&mut state[1], &mut out2[0], &mut out3[0]); - } + 2 => { + let (out0, out1) = out.split_at_mut(1); + rust_simd::shake128x2_squeeze_first_three_blocks( + &mut state[0], + &mut out0[0], + &mut out1[0], + ); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + rust_simd::shake128x2_squeeze_first_three_blocks( + &mut state[0], + &mut out0[0], + &mut out1[0], + ); + rust_simd::shake128x2_squeeze_first_three_blocks( + &mut state[1], + &mut out2[0], + &mut extra, + ); + } + _ => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + rust_simd::shake128x2_squeeze_first_three_blocks( + &mut state[0], + &mut out0[0], + &mut out1[0], + ); + rust_simd::shake128x2_squeeze_first_three_blocks( + &mut state[1], + &mut out2[0], + &mut out3[0], + ); + } } out } #[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn squeeze_three_blocks ( +pub(crate) fn squeeze_three_blocks( state: &mut Shake128x4State, ) -> [[u8; THREE_BLOCKS]; K] { let mut out = [[0u8; THREE_BLOCKS]; K]; @@ -156,9 +181,7 @@ pub(crate) fn squeeze_three_blocks ( #[cfg(feature = "simd128")] #[inline(always)] -pub(crate) fn squeeze_block( - state: &mut Shake128x4State, -) -> [[u8; BLOCK_SIZE]; K] { +pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8; BLOCK_SIZE]; K] { let mut out0 = [0u8; BLOCK_SIZE]; let mut out1 = [0u8; BLOCK_SIZE]; let mut out2 = [0u8; BLOCK_SIZE]; @@ -167,29 +190,33 @@ pub(crate) fn squeeze_block( let mut out = [[0u8; BLOCK_SIZE]; K]; match K { - 2 => { rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); - out[0] = out0; - out[1] = out1; } - 3 => { rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); - rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); - out[0] = out0; - out[1] = out1; - out[2] = out2; } - _ => { rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); - rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); - out[0] = out0; - out[1] = out1; - out[2] = out2; - out[3] = out3; } + 2 => { + rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); + out[0] = out0; + out[1] = out1; + } + 3 => { + rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); + rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); + out[0] = out0; + out[1] = out1; + out[2] = out2; + } + _ => { + rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); + rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); + out[0] = out0; + out[1] = out1; + out[2] = out2; + out[3] = out3; + } } out } #[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn squeeze_block( - state: &mut Shake128x4State, -) -> [[u8; BLOCK_SIZE]; K] { +pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8; BLOCK_SIZE]; K] { let mut out = [[0u8; BLOCK_SIZE]; K]; for i in 0..K { rust_simd::shake128_squeeze_next_block(&mut state[i], &mut out[i]); @@ -197,10 +224,8 @@ pub(crate) fn squeeze_block( out } - /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. #[inline(always)] -pub(crate) fn free_state(_xof_state: Shake128x4State) { -} +pub(crate) fn free_state(_xof_state: Shake128x4State) {} diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index ff7960476..dce14dec8 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -2,7 +2,7 @@ use libcrux_polynomials::Operations; use crate::{ constants::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT, SHARED_SECRET_SIZE}, - hash_functions::{G, PRF, PRFxN}, + hash_functions::{PRFxN, G, PRF}, helper::cloop, matrix::*, ntt::{ntt_binomially_sampled_ring_element, ntt_vector_u}, @@ -79,7 +79,7 @@ fn sample_ring_element_cbd< prf_inputs[i][32] = domain_separator; domain_separator += 1; } - let prf_outputs : [[u8; ETA2_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); + let prf_outputs: [[u8; ETA2_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); for i in 0..K { error_1[i] = sample_from_binomial_distribution::(&prf_outputs[i]); } @@ -104,7 +104,7 @@ fn sample_vector_cbd_then_ntt< prf_inputs[i][32] = domain_separator; domain_separator += 1; } - let prf_outputs : [[u8; ETA_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); + let prf_outputs: [[u8; ETA_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); for i in 0..K { let r = sample_from_binomial_distribution::(&prf_outputs[i]); re_as_ntt[i] = ntt_binomially_sampled_ring_element(r); @@ -299,10 +299,11 @@ pub(crate) fn encrypt< // e1[i] := CBD_{η2}(PRF(r,N)) // N := N + 1 // end for - let (error_1, domain_separator) = sample_ring_element_cbd::( - prf_input, - domain_separator, - ); + let (error_1, domain_separator) = + sample_ring_element_cbd::( + prf_input, + domain_separator, + ); // e_2 := CBD{η2}(PRF(r, N)) prf_input[32] = domain_separator; diff --git a/libcrux-ml-kem/src/polynomial.rs b/libcrux-ml-kem/src/polynomial.rs index 02dd071c0..66d1aff12 100644 --- a/libcrux-ml-kem/src/polynomial.rs +++ b/libcrux-ml-kem/src/polynomial.rs @@ -35,7 +35,7 @@ impl PolynomialRingElement { let mut result = PolynomialRingElement::ZERO(); for i in 0..VECTORS_IN_RING_ELEMENT { result.coefficients[i] = Vector::from_i16_array( - &a[i * FIELD_ELEMENTS_IN_VECTOR..(i + 1) * FIELD_ELEMENTS_IN_VECTOR] + &a[i * FIELD_ELEMENTS_IN_VECTOR..(i + 1) * FIELD_ELEMENTS_IN_VECTOR], ); } result diff --git a/libcrux-ml-kem/src/sampling.rs b/libcrux-ml-kem/src/sampling.rs index 95bd23dca..4b7d6c4cf 100644 --- a/libcrux-ml-kem/src/sampling.rs +++ b/libcrux-ml-kem/src/sampling.rs @@ -53,7 +53,7 @@ fn sample_from_uniform_distribution_next; #[inline(always)] -fn keccakx1(data:[&[u8];1],out:[&mut[u8];1]) { - keccak::<1, u64, RATE, DELIM>(data,out) +fn keccakx1(data: [&[u8]; 1], out: [&mut [u8]; 1]) { + keccak::<1, u64, RATE, DELIM>(data, out) } #[cfg(feature = "simd128")] @@ -15,8 +15,8 @@ mod sha3_arm64; pub type KeccakState2 = KeccakState<2, core::arch::aarch64::uint64x2_t>; #[cfg(feature = "simd128")] #[inline(always)] -fn keccakx2(data:[&[u8];2],out:[&mut[u8];2]) { - keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data,out) +fn keccakx2(data: [&[u8]; 2], out: [&mut [u8]; 2]) { + keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data, out) } #[cfg(feature = "simd128")] pub type KeccakState4 = [KeccakState2; 2]; @@ -24,6 +24,11 @@ pub type KeccakState4 = [KeccakState2; 2]; #[cfg(feature = "simd256")] mod sha3_avx2; #[cfg(feature = "simd256")] +#[inline(always)] +fn keccakx4(data: [&[u8]; 4], out: [&mut [u8]; 4]) { + keccak::<4, core::arch::x86_64::__m256i, RATE, DELIM>(data, out) +} +#[cfg(feature = "simd256")] pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; #[cfg(not(any(feature = "simd128", feature = "simd256")))] @@ -31,86 +36,111 @@ pub type KeccakState2 = [KeccakState1; 2]; #[cfg(not(any(feature = "simd128", feature = "simd256")))] pub type KeccakState4 = [KeccakState1; 4]; - #[cfg(feature = "simd128")] -pub fn sha3_224(data: &[u8]) -> [u8;28] { +pub fn sha3_224(data: &[u8]) -> [u8; 28] { let mut d0 = [0u8; 28]; let mut d1 = [0u8; 28]; keccakx2::<144, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } #[cfg(not(feature = "simd128"))] -pub fn sha3_224(data: &[u8]) -> [u8;28] { +pub fn sha3_224(data: &[u8]) -> [u8; 28] { let mut d0 = [0u8; 28]; keccakx1::<144, 0x06u8>([data], [&mut d0]); d0 } #[cfg(feature = "simd128")] -pub fn sha3_256(data: &[u8]) -> [u8;32] { +pub fn sha3_256(data: &[u8]) -> [u8; 32] { let mut d0 = [0u8; 32]; let mut d1 = [0u8; 32]; keccakx2::<136, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } -#[cfg(not(feature = "simd128"))] -pub fn sha3_256(data: &[u8]) -> [u8;32] { + +#[cfg(feature = "simd256")] +pub fn sha3_256(data: &[u8]) -> [u8; 32] { + let mut d0 = [0u8; 32]; + let mut d1 = [0u8; 32]; + let mut d2 = [0u8; 32]; + let mut d3 = [0u8; 32]; + keccakx4::<136, 0x06u8>( + [data, data, data, data], + [&mut d0, &mut d1, &mut d2, &mut d3], + ); + d0 +} + +pub fn sha3_256_portable(data: &[u8]) -> [u8; 32] { let mut d0 = [0u8; 32]; keccakx1::<136, 0x06u8>([data], [&mut d0]); d0 } #[cfg(feature = "simd128")] -pub fn sha3_384(data: &[u8]) -> [u8;48] { +pub fn sha3_384(data: &[u8]) -> [u8; 48] { let mut d0 = [0u8; 48]; let mut d1 = [0u8; 48]; keccakx2::<104, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } #[cfg(not(feature = "simd128"))] -pub fn sha3_384(data: &[u8]) -> [u8;48] { +pub fn sha3_384(data: &[u8]) -> [u8; 48] { let mut d0 = [0u8; 48]; keccakx1::<104, 0x06u8>([data], [&mut d0]); d0 } #[cfg(feature = "simd128")] -pub fn sha3_512(data: &[u8]) -> [u8;64] { +pub fn sha3_512(data: &[u8]) -> [u8; 64] { let mut d0 = [0u8; 64]; let mut d1 = [0u8; 64]; keccakx2::<72, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } -#[cfg(not(feature = "simd128"))] -pub fn sha3_512(data: &[u8]) -> [u8;64] { + +#[cfg(feature = "simd256")] +pub fn sha3_512(data: &[u8]) -> [u8; 64] { + let mut d0 = [0u8; 64]; + let mut d1 = [0u8; 64]; + let mut d2 = [0u8; 64]; + let mut d3 = [0u8; 64]; + keccakx4::<72, 0x06u8>( + [data, data, data, data], + [&mut d0, &mut d1, &mut d2, &mut d3], + ); + d0 +} + +pub fn sha3_512_portable(data: &[u8]) -> [u8; 64] { let mut d0 = [0u8; 64]; keccakx1::<72, 0x06u8>([data], [&mut d0]); d0 } #[cfg(feature = "simd128")] -pub fn shake128(data: &[u8]) -> [u8; LEN] { +pub fn shake128(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; let mut d1 = [0u8; LEN]; keccakx2::<168, 0x1fu8>([data, data], [&mut d0, &mut d1]); d0 } #[cfg(not(feature = "simd128"))] -pub fn shake128(data: &[u8]) -> [u8; LEN] { +pub fn shake128(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; keccakx1::<168, 0x1fu8>([data], [&mut d0]); d0 } #[cfg(feature = "simd128")] -pub fn shake256(data: &[u8]) -> [u8; LEN] { +pub fn shake256(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; let mut d1 = [0u8; LEN]; keccakx2::<136, 0x1fu8>([data, data], [&mut d0, &mut d1]); d0 } #[cfg(not(feature = "simd128"))] -pub fn shake256(data: &[u8]) -> [u8; LEN] { +pub fn shake256(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; keccakx1::<136, 0x1fu8>([data], [&mut d0]); d0 @@ -127,19 +157,46 @@ pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8] } #[cfg(feature = "simd256")] -pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], - out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { - keccak::<4,core::arch::x86_64::__m256i,136, 0x1fu8>([input0, input1, input2, input3], [out0, out1, out2, out3]); +pub fn shake256x4( + input0: &[u8], + input1: &[u8], + input2: &[u8], + input3: &[u8], + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { + keccak::<4, core::arch::x86_64::__m256i, 136, 0x1fu8>( + [input0, input1, input2, input3], + [out0, out1, out2, out3], + ); } #[cfg(feature = "simd128")] -pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], - out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { +pub fn shake256x4( + input0: &[u8], + input1: &[u8], + input2: &[u8], + input3: &[u8], + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] -pub fn shake256x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8], - out0: &mut [u8], out1: &mut [u8], out2: &mut [u8], out3: &mut [u8]) { +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +pub fn shake256x4( + input0: &[u8], + input1: &[u8], + input2: &[u8], + input3: &[u8], + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { keccakx1::<136, 0x1fu8>([input0], [out0]); keccakx1::<136, 0x1fu8>([input1], [out1]); keccakx1::<136, 0x1fu8>([input2], [out2]); @@ -152,46 +209,54 @@ pub fn shake128_init() -> KeccakState1 { KeccakState1::new() } -pub fn shake128_absorb_final(s:&mut KeccakState1, data0: &[u8]) { - absorb_final::<1,u64,168,0x1fu8>(s,[data0]); +pub fn shake128_absorb_final(s: &mut KeccakState1, data0: &[u8]) { + absorb_final::<1, u64, 168, 0x1fu8>(s, [data0]); } -pub fn shake128_squeeze_first_three_blocks(s: &mut KeccakState1, out0:&mut [u8]) { - squeeze_first_three_blocks::<1,u64,168>(s, [out0]) +pub fn shake128_squeeze_first_three_blocks(s: &mut KeccakState1, out0: &mut [u8]) { + squeeze_first_three_blocks::<1, u64, 168>(s, [out0]) } pub fn shake128_squeeze_next_block(s: &mut KeccakState1, out0: &mut [u8]) { - squeeze_next_block::<1,u64,168>(s, [out0]) + squeeze_next_block::<1, u64, 168>(s, [out0]) } #[cfg(feature = "simd128")] pub fn shake128x2_init() -> KeccakState2 { KeccakState2::new() } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] +#[cfg(not(any(feature = "simd128", feature = "simd256")))] pub fn shake128x2_init() -> KeccakState2 { let s0 = KeccakState1::new(); let s1 = KeccakState1::new(); - [s0,s1] + [s0, s1] } #[cfg(feature = "simd128")] -pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) { - absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(s,[data0,data1]); +pub fn shake128x2_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { + absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(s, [data0, data1]); } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] -pub fn shake128x2_absorb_final(s:&mut KeccakState2, data0: &[u8], data1: &[u8]) { +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +pub fn shake128x2_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { let [mut s0, mut s1] = s; shake128_absorb_final(&mut s0, data0); shake128_absorb_final(&mut s1, data1); } #[cfg(feature = "simd128")] -pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8], out1:&mut [u8]) { - squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) +pub fn shake128x2_squeeze_first_three_blocks( + s: &mut KeccakState2, + out0: &mut [u8], + out1: &mut [u8], +) { + squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] -pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8], out1:&mut [u8]) { +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +pub fn shake128x2_squeeze_first_three_blocks( + s: &mut KeccakState2, + out0: &mut [u8], + out1: &mut [u8], +) { let [mut s0, mut s1] = s; shake128_squeeze_first_three_blocks(&mut s0, out0); shake128_squeeze_first_three_blocks(&mut s1, out1); @@ -199,9 +264,9 @@ pub fn shake128x2_squeeze_first_three_blocks(s: &mut KeccakState2, out0:&mut [u8 #[cfg(feature = "simd128")] pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { - squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(s, [out0, out1]) + squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] +#[cfg(not(any(feature = "simd128", feature = "simd256")))] pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { let [mut s0, mut s1] = s; shake128_squeeze_next_block(&mut s0, out0); @@ -216,29 +281,47 @@ pub fn shake128x4_init() -> KeccakState4 { pub fn shake128x4_init() -> KeccakState4 { let s0 = KeccakState2::new(); let s1 = KeccakState2::new(); - [s0,s1] + [s0, s1] } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] +#[cfg(not(any(feature = "simd128", feature = "simd256")))] pub fn shake128x4_init() -> KeccakState4 { let s0 = KeccakState1::new(); let s1 = KeccakState1::new(); let s2 = KeccakState1::new(); let s3 = KeccakState1::new(); - [s0,s1,s2,s3] + [s0, s1, s2, s3] } #[cfg(feature = "simd128")] -pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { - absorb_final::<4,core::arch::x86_64::__m256i,168, 0x1fu8>(s,[data0,data1,data2,data3]); +pub fn shake128x4_absorb_final( + s: &mut KeccakState4, + data0: &[u8], + data1: &[u8], + data2: &[u8], + data3: &[u8], +) { + absorb_final::<4, core::arch::x86_64::__m256i, 168, 0x1fu8>(s, [data0, data1, data2, data3]); } #[cfg(feature = "simd128")] -pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { +pub fn shake128x4_absorb_final( + s: &mut KeccakState4, + data0: &[u8], + data1: &[u8], + data2: &[u8], + data3: &[u8], +) { let [mut s0, mut s1] = s; - absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(&mut s0,[data0,data1]); - absorb_final::<2,core::arch::aarch64::uint64x2_t,168, 0x1fu8>(&mut s1,[data2,data3]); + absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(&mut s0, [data0, data1]); + absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(&mut s1, [data2, data3]); } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] -pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], data2: &[u8], data3: &[u8]) { +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +pub fn shake128x4_absorb_final( + s: &mut KeccakState4, + data0: &[u8], + data1: &[u8], + data2: &[u8], + data3: &[u8], +) { let [mut s0, mut s1, mut s2, mut s3] = s; shake128_absorb_final(&mut s0, data0); shake128_absorb_final(&mut s1, data1); @@ -247,17 +330,35 @@ pub fn shake128x4_absorb_final(s:&mut KeccakState4, data0: &[u8], data1: &[u8], } #[cfg(feature = "simd256")] -pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { - squeeze_first_three_blocks::<4,core::arch::x86_64::__m256i,168>(s, [out0, out1, out2, out3]); +pub fn shake128x4_squeeze_first_three_blocks( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { + squeeze_first_three_blocks::<4, core::arch::x86_64::__m256i, 168>(s, [out0, out1, out2, out3]); } #[cfg(feature = "simd128")] -pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { +pub fn shake128x4_squeeze_first_three_blocks( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { let [mut s0, mut s1] = s; - squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(&mut s0, [out0, out1]); - squeeze_first_three_blocks::<2,core::arch::aarch64::uint64x2_t,168>(&mut s1, [out2, out3]); + squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s0, [out0, out1]); + squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s1, [out2, out3]); } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] -pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +pub fn shake128x4_squeeze_first_three_blocks( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { let [mut s0, mut s1, mut s2, mut s3] = s; shake128_squeeze_first_three_blocks(&mut s0, out0); shake128_squeeze_first_three_blocks(&mut s1, out1); @@ -266,21 +367,38 @@ pub fn shake128x4_squeeze_first_three_blocks(s: &mut KeccakState4, out0:&mut [u8 } #[cfg(feature = "simd128")] -pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { - squeeze_next_block::<4,core::arch::x86_64::__m256i,168>(&mut s0, [out0, out1, out2, out3]); +pub fn shake128x4_squeeze_next_block( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { + squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>(&mut s0, [out0, out1, out2, out3]); } #[cfg(feature = "simd128")] -pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { +pub fn shake128x4_squeeze_next_block( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { let [mut s0, mut s1] = s; - squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(&mut s0, [out0, out1]); - squeeze_next_block::<2,core::arch::aarch64::uint64x2_t,168>(&mut s1, [out2, out3]); + squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s0, [out0, out1]); + squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s1, [out2, out3]); } -#[cfg(not(any(feature = "simd128",feature = "simd256")))] -pub fn shake128x4_squeeze_next_block(s: &mut KeccakState4, out0:&mut [u8], out1:&mut [u8], out2:&mut [u8], out3:&mut [u8]) { +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +pub fn shake128x4_squeeze_next_block( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], +) { let [mut s0, mut s1, mut s2, mut s3] = s; shake128_squeeze_next_block(&mut s0, out0); shake128_squeeze_next_block(&mut s1, out1); shake128_squeeze_next_block(&mut s2, out2); shake128_squeeze_next_block(&mut s3, out3); } - diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/rust_simd/sha3_arm64.rs index 22e16b42e..a4b3a301b 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/rust_simd/sha3_arm64.rs @@ -1,33 +1,38 @@ -use core::arch::aarch64::*; use crate::rust_simd::sha3_trait::*; +use core::arch::aarch64::*; // This file optimizes for the stable Rust Neon Intrinsics // If we want to use the unstable neon-sha3 instructions, we could use: // veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 // These instructions might speed up our code even more. - #[inline(always)] -fn rotate_left(x:uint64x2_t) -> uint64x2_t { - debug_assert!(LEFT+RIGHT == 64); +fn rotate_left(x: uint64x2_t) -> uint64x2_t { + debug_assert!(LEFT + RIGHT == 64); // The following looks faster but is actually significantly slower //unsafe { vsriq_n_u64::(vshlq_n_u64::(x), x) } unsafe { veorq_u64(vshlq_n_u64::(x), vshrq_n_u64::(x)) } } #[inline(always)] -fn _veor5q_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t, d: uint64x2_t, e: uint64x2_t) -> uint64x2_t { - let ab = unsafe {veorq_u64(a,b)}; - let cd = unsafe {veorq_u64(c,d)}; - let abcd = unsafe {veorq_u64(ab,cd)}; - unsafe {veorq_u64(abcd,e)} +fn _veor5q_u64( + a: uint64x2_t, + b: uint64x2_t, + c: uint64x2_t, + d: uint64x2_t, + e: uint64x2_t, +) -> uint64x2_t { + let ab = unsafe { veorq_u64(a, b) }; + let cd = unsafe { veorq_u64(c, d) }; + let abcd = unsafe { veorq_u64(ab, cd) }; + unsafe { veorq_u64(abcd, e) } // Needs nightly+neon-sha3 //unsafe {veor3q_u64(veor3q_u64(a,b,c),d,e)} } #[inline(always)] fn _vrax1q_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { - unsafe { veorq_u64(a, rotate_left::<1,63>(b)) } + unsafe { veorq_u64(a, rotate_left::<1, 63>(b)) } // Needs nightly+neon-sha3 //unsafe { vrax1q_u64(a, b) } } @@ -35,14 +40,14 @@ fn _vrax1q_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { #[inline(always)] fn _vxarq_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { let ab = unsafe { veorq_u64(a, b) }; - rotate_left::(ab) + rotate_left::(ab) // Needs nightly+neon-sha3 // unsafe { vxarq_u64::(a,b) } } #[inline(always)] fn _vbcaxq_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t) -> uint64x2_t { - unsafe{ veorq_u64(a, vbicq_u64(b, c)) } + unsafe { veorq_u64(a, vbicq_u64(b, c)) } // Needs nightly+neon-sha3 // unsafe{ vbcaxq_u64(a, b, c) } } @@ -53,77 +58,91 @@ fn _veorq_n_u64(a: uint64x2_t, c: u64) -> uint64x2_t { unsafe { veorq_u64(a, c) } } - #[inline(always)] -pub(crate) fn load_block(s: &mut [[uint64x2_t;5];5], blocks: [&[u8];2]) { +pub(crate) fn load_block(s: &mut [[uint64x2_t; 5]; 5], blocks: [&[u8]; 2]) { debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); - for i in 0..RATE/16 { - let v0 = unsafe { vld1q_u64(blocks[0][16*i..16*(i+1)].as_ptr() as *const u64) }; - let v1 = unsafe { vld1q_u64(blocks[1][16*i..16*(i+1)].as_ptr() as *const u64) }; - s[(2*i)/5][(2*i)%5] = unsafe { veorq_u64(s[(2*i)/5][(2*i)%5], vtrn1q_u64(v0,v1)) }; - s[(2*i+1)/5][(2*i+1)%5] = unsafe { veorq_u64(s[(2*i+1)/5][(2*i+1)%5], vtrn2q_u64(v0,v1)) }; - } - if RATE%16 != 0 { - let i = (RATE/8 - 1)/5; - let j = (RATE/8 - 1)%5; + for i in 0..RATE / 16 { + let v0 = unsafe { vld1q_u64(blocks[0][16 * i..16 * (i + 1)].as_ptr() as *const u64) }; + let v1 = unsafe { vld1q_u64(blocks[1][16 * i..16 * (i + 1)].as_ptr() as *const u64) }; + s[(2 * i) / 5][(2 * i) % 5] = + unsafe { veorq_u64(s[(2 * i) / 5][(2 * i) % 5], vtrn1q_u64(v0, v1)) }; + s[(2 * i + 1) / 5][(2 * i + 1) % 5] = + unsafe { veorq_u64(s[(2 * i + 1) / 5][(2 * i + 1) % 5], vtrn2q_u64(v0, v1)) }; + } + if RATE % 16 != 0 { + let i = (RATE / 8 - 1) / 5; + let j = (RATE / 8 - 1) % 5; let mut u = [0u64; 2]; - u[0] = u64::from_le_bytes(blocks[0][RATE-8..RATE].try_into().unwrap()); - u[1] = u64::from_le_bytes(blocks[1][RATE-8..RATE].try_into().unwrap()); + u[0] = u64::from_le_bytes(blocks[0][RATE - 8..RATE].try_into().unwrap()); + u[1] = u64::from_le_bytes(blocks[1][RATE - 8..RATE].try_into().unwrap()); let uvec = unsafe { vld1q_u64(u.as_ptr() as *const u64) }; - s[i][j] = unsafe { veorq_u64(s[i][j], uvec)}; + s[i][j] = unsafe { veorq_u64(s[i][j], uvec) }; } } #[inline(always)] -pub(crate) fn load_block_full(s: &mut [[uint64x2_t;5];5], blocks: [[u8;200];2]) { - let [b0,b1] = blocks; - load_block::(s,[&b0 as &[u8], &b1 as &[u8]]); +pub(crate) fn load_block_full( + s: &mut [[uint64x2_t; 5]; 5], + blocks: [[u8; 200]; 2], +) { + let [b0, b1] = blocks; + load_block::(s, [&b0 as &[u8], &b1 as &[u8]]); } #[inline(always)] -pub(crate) fn store_block(s: &[[uint64x2_t;5];5], out: [&mut [u8];2]) { - for i in 0..RATE/16 { - let v0 = unsafe { vtrn1q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; - let v1 = unsafe { vtrn2q_u64(s[(2*i)/5][(2*i)%5], s[(2*i+1)/5][(2*i+1)%5]) }; - unsafe { vst1q_u64(out[0][16*i..16*(i+1)].as_mut_ptr() as *mut u64, v0) }; - unsafe { vst1q_u64(out[1][16*i..16*(i+1)].as_mut_ptr() as *mut u64, v1) }; - } - if RATE%16 != 0 { +pub(crate) fn store_block(s: &[[uint64x2_t; 5]; 5], out: [&mut [u8]; 2]) { + for i in 0..RATE / 16 { + let v0 = unsafe { + vtrn1q_u64( + s[(2 * i) / 5][(2 * i) % 5], + s[(2 * i + 1) / 5][(2 * i + 1) % 5], + ) + }; + let v1 = unsafe { + vtrn2q_u64( + s[(2 * i) / 5][(2 * i) % 5], + s[(2 * i + 1) / 5][(2 * i + 1) % 5], + ) + }; + unsafe { vst1q_u64(out[0][16 * i..16 * (i + 1)].as_mut_ptr() as *mut u64, v0) }; + unsafe { vst1q_u64(out[1][16 * i..16 * (i + 1)].as_mut_ptr() as *mut u64, v1) }; + } + if RATE % 16 != 0 { debug_assert!(RATE % 8 == 0); - let i = (RATE/8 - 1)/5; - let j = (RATE/8 - 1)%5; - let mut u = [0u8;16]; - unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s[i][j])}; - out[0][RATE-8..RATE].copy_from_slice(&u[0..8]); - out[1][RATE-8..RATE].copy_from_slice(&u[8..16]); - } -} + let i = (RATE / 8 - 1) / 5; + let j = (RATE / 8 - 1) % 5; + let mut u = [0u8; 16]; + unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s[i][j]) }; + out[0][RATE - 8..RATE].copy_from_slice(&u[0..8]); + out[1][RATE - 8..RATE].copy_from_slice(&u[8..16]); + } +} #[inline(always)] -pub(crate) fn store_block_full(s: &[[uint64x2_t;5];5]) -> [[u8;200];2] { +pub(crate) fn store_block_full(s: &[[uint64x2_t; 5]; 5]) -> [[u8; 200]; 2] { let mut out0 = [0u8; 200]; let mut out1 = [0u8; 200]; - store_block::(s,[&mut out0, &mut out1]); + store_block::(s, [&mut out0, &mut out1]); [out0, out1] -} +} #[inline(always)] -fn slice_2(a: [&[u8];2], start:usize, len:usize) -> [&[u8];2] { - [&a[0][start..start+len], &a[1][start..start+len]] +fn slice_2(a: [&[u8]; 2], start: usize, len: usize) -> [&[u8]; 2] { + [&a[0][start..start + len], &a[1][start..start + len]] } #[inline(always)] -fn split_at_mut_2(out: [&mut [u8]; 2], mid:usize) -> ([&mut [u8];2],[&mut [u8];2]) { +fn split_at_mut_2(out: [&mut [u8]; 2], mid: usize) -> ([&mut [u8]; 2], [&mut [u8]; 2]) { let [out0, out1] = out; - let (out00,out01) = out0.split_at_mut(mid); - let (out10,out11) = out1.split_at_mut(mid); - ([out00,out10], [out01,out11]) + let (out00, out01) = out0.split_at_mut(mid); + let (out10, out11) = out1.split_at_mut(mid); + ([out00, out10], [out01, out11]) } impl KeccakItem<2> for uint64x2_t { #[inline(always)] fn zero() -> Self { - unsafe {vdupq_n_u64(0)} + unsafe { vdupq_n_u64(0) } } #[inline(always)] fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { @@ -135,7 +154,7 @@ impl KeccakItem<2> for uint64x2_t { } #[inline(always)] fn xor_and_rotate(a: Self, b: Self) -> Self { - _vxarq_u64::(a, b) + _vxarq_u64::(a, b) } #[inline(always)] fn and_not_xor(a: Self, b: Self, c: Self) -> Self { @@ -147,31 +166,30 @@ impl KeccakItem<2> for uint64x2_t { } #[inline(always)] fn xor(a: Self, b: Self) -> Self { - unsafe {veorq_u64(a, b)} + unsafe { veorq_u64(a, b) } } #[inline(always)] - fn load_block(a:&mut [[Self;5];5], b:[&[u8];2]) { + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; 2]) { load_block::(a, b) } #[inline(always)] - fn store_block(a:& [[Self;5];5], b:[&mut [u8];2]) { + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; 2]) { store_block::(a, b) } #[inline(always)] - fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];2]) { + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; 2]) { load_block_full::(a, b) } #[inline(always)] - fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];2] { + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; 2] { store_block_full::(a) } #[inline(always)] - fn slice_n(a:[&[u8];2],start:usize,len:usize) -> [&[u8];2] { - slice_2(a,start,len) + fn slice_n(a: [&[u8]; 2], start: usize, len: usize) -> [&[u8]; 2] { + slice_2(a, start, len) } #[inline(always)] - fn split_at_mut_n(a:[&mut [u8];2],mid:usize) -> ([&mut [u8];2],[&mut [u8];2]) { + fn split_at_mut_n(a: [&mut [u8]; 2], mid: usize) -> ([&mut [u8]; 2], [&mut [u8]; 2]) { split_at_mut_2(a, mid) } } - diff --git a/libcrux-sha3/src/rust_simd/sha3_avx2.rs b/libcrux-sha3/src/rust_simd/sha3_avx2.rs index ed47f93c5..130f04321 100644 --- a/libcrux-sha3/src/rust_simd/sha3_avx2.rs +++ b/libcrux-sha3/src/rust_simd/sha3_avx2.rs @@ -7,10 +7,9 @@ use crate::rust_simd::sha3_trait::*; // veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 // These instructions might speed up our code even more. - #[inline(always)] -fn rotate_left(x:__m256i) -> __m256i { - debug_assert!(LEFT+RIGHT == 64); +fn rotate_left(x: __m256i) -> __m256i { + debug_assert!(LEFT + RIGHT == 64); // XXX: This could be done more efficiently, if the shift values are multiples of 8. unsafe { _mm256_xor_si256(_mm256_slli_epi64::(x), _mm256_srli_epi64::(x)) } } @@ -25,18 +24,18 @@ fn _veor5q_u64(a: __m256i, b: __m256i, c: __m256i, d: __m256i, e: __m256i) -> __ #[inline(always)] fn _vrax1q_u64(a: __m256i, b: __m256i) -> __m256i { - unsafe { _mm256_xor_si256(a, rotate_left::<1,63>(b)) } + unsafe { _mm256_xor_si256(a, rotate_left::<1, 63>(b)) } } #[inline(always)] fn _vxarq_u64(a: __m256i, b: __m256i) -> __m256i { let ab = unsafe { _mm256_xor_si256(a, b) }; - rotate_left::(ab) + rotate_left::(ab) } #[inline(always)] fn _vbcaxq_u64(a: __m256i, b: __m256i, c: __m256i) -> __m256i { - unsafe{ _mm256_xor_si256(a, _mm256_andnot_si256(b, c)) } + unsafe { _mm256_xor_si256(a, _mm256_andnot_si256(b, c)) } } #[inline(always)] @@ -46,128 +45,188 @@ fn _veorq_n_u64(a: __m256i, c: u64) -> __m256i { unsafe { _mm256_xor_si256(a, c) } } - #[inline(always)] -pub(crate) fn load_block(s: &mut [[__m256i;5];5], blocks: [&[u8];4]) { +pub(crate) fn load_block(s: &mut [[__m256i; 5]; 5], blocks: [&[u8]; 4]) { debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && (RATE % 32 == 8 || RATE % 32 == 16)); - for i in 0..RATE/32 { - let v0 = unsafe { _mm256_loadu_si256(blocks[0][32*i..32*(i+1)].as_ptr() as *const __m256i)}; - let v1 = unsafe { _mm256_loadu_si256(blocks[1][32*i..32*(i+1)].as_ptr() as *const __m256i)}; - let v2 = unsafe { _mm256_loadu_si256(blocks[2][32*i..32*(i+1)].as_ptr() as *const __m256i)}; - let v3 = unsafe { _mm256_loadu_si256(blocks[3][32*i..32*(i+1)].as_ptr() as *const __m256i)}; - - let v0l = unsafe { _mm256_unpacklo_epi64(v0,v1) }; // 0 0 2 2 - let v1h = unsafe { _mm256_unpackhi_epi64(v0,v1) }; // 1 1 3 3 - let v2l = unsafe { _mm256_unpacklo_epi64(v2,v3) }; // 0 0 2 2 - let v3h = unsafe { _mm256_unpackhi_epi64(v2,v3) }; // 1 1 3 3 - - let v0 = unsafe { _mm256_permute2x128_si256(v0l,v2l,0x20) }; // 0 0 0 0 - let v1 = unsafe { _mm256_permute2x128_si256(v1h,v3h,0x20) }; // 1 1 1 1 - let v2 = unsafe { _mm256_permute2x128_si256(v0l,v2l,0x31) }; // 2 2 2 2 - let v3 = unsafe { _mm256_permute2x128_si256(v1h,v3h,0x31) }; // 3 3 3 3 - - s[(4*i)/5][(4*i)%5] = unsafe { _mm256_xor_si256(s[(4*i)/5][(4*i)%5], v0) }; - s[(4*i+1)/5][(4*i+1)%5] = unsafe { _mm256_xor_si256(s[(4*i+1)/5][(4*i+1)%5], v1) }; - s[(4*i+2)/5][(4*i+2)%5] = unsafe { _mm256_xor_si256(s[(4*i+2)/5][(4*i+2)%5], v2) }; - s[(4*i+3)/5][(4*i+3)%5] = unsafe { _mm256_xor_si256(s[(4*i+3)/5][(4*i+3)%5], v3) }; - } - - let rem = RATE%32; // has to be 8 or 16 - let start = 32 * (RATE/32); - let mut u8s = [0u8;32]; - u8s[0..8].copy_from_slice(&blocks[0][start..start+8]); - u8s[8..16].copy_from_slice(&blocks[1][start..start+8]); - u8s[16..24].copy_from_slice(&blocks[2][start..start+8]); - u8s[24..32].copy_from_slice(&blocks[3][start..start+8]); - let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i)}; - let i = (4*(RATE/32))/5; - let j = (4*(RATE/32))%5; - s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u)}; + for i in 0..RATE / 32 { + let v0 = unsafe { + _mm256_loadu_si256(blocks[0][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + let v1 = unsafe { + _mm256_loadu_si256(blocks[1][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + let v2 = unsafe { + _mm256_loadu_si256(blocks[2][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + let v3 = unsafe { + _mm256_loadu_si256(blocks[3][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + + let v0l = unsafe { _mm256_unpacklo_epi64(v0, v1) }; // 0 0 2 2 + let v1h = unsafe { _mm256_unpackhi_epi64(v0, v1) }; // 1 1 3 3 + let v2l = unsafe { _mm256_unpacklo_epi64(v2, v3) }; // 0 0 2 2 + let v3h = unsafe { _mm256_unpackhi_epi64(v2, v3) }; // 1 1 3 3 + + let v0 = unsafe { _mm256_permute2x128_si256(v0l, v2l, 0x20) }; // 0 0 0 0 + let v1 = unsafe { _mm256_permute2x128_si256(v1h, v3h, 0x20) }; // 1 1 1 1 + let v2 = unsafe { _mm256_permute2x128_si256(v0l, v2l, 0x31) }; // 2 2 2 2 + let v3 = unsafe { _mm256_permute2x128_si256(v1h, v3h, 0x31) }; // 3 3 3 3 + + s[(4 * i) / 5][(4 * i) % 5] = unsafe { _mm256_xor_si256(s[(4 * i) / 5][(4 * i) % 5], v0) }; + s[(4 * i + 1) / 5][(4 * i + 1) % 5] = + unsafe { _mm256_xor_si256(s[(4 * i + 1) / 5][(4 * i + 1) % 5], v1) }; + s[(4 * i + 2) / 5][(4 * i + 2) % 5] = + unsafe { _mm256_xor_si256(s[(4 * i + 2) / 5][(4 * i + 2) % 5], v2) }; + s[(4 * i + 3) / 5][(4 * i + 3) % 5] = + unsafe { _mm256_xor_si256(s[(4 * i + 3) / 5][(4 * i + 3) % 5], v3) }; + } + + let rem = RATE % 32; // has to be 8 or 16 + let start = 32 * (RATE / 32); + let mut u8s = [0u8; 32]; + u8s[0..8].copy_from_slice(&blocks[0][start..start + 8]); + u8s[8..16].copy_from_slice(&blocks[1][start..start + 8]); + u8s[16..24].copy_from_slice(&blocks[2][start..start + 8]); + u8s[24..32].copy_from_slice(&blocks[3][start..start + 8]); + let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i) }; + let i = (4 * (RATE / 32)) / 5; + let j = (4 * (RATE / 32)) % 5; + s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u) }; if rem == 16 { - let mut u8s = [0u8;32]; - u8s[0..8].copy_from_slice(&blocks[0][start+8..start+16]); - u8s[8..16].copy_from_slice(&blocks[1][start+8..start+16]); - u8s[16..24].copy_from_slice(&blocks[2][start+8..start+16]); - u8s[24..32].copy_from_slice(&blocks[3][start+8..start+16]); - let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i)}; - let i = (4*(RATE/32) + 1)/5; - let j = (4*(RATE/32) + 1)%5; - s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u)}; + let mut u8s = [0u8; 32]; + u8s[0..8].copy_from_slice(&blocks[0][start + 8..start + 16]); + u8s[8..16].copy_from_slice(&blocks[1][start + 8..start + 16]); + u8s[16..24].copy_from_slice(&blocks[2][start + 8..start + 16]); + u8s[24..32].copy_from_slice(&blocks[3][start + 8..start + 16]); + let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i) }; + let i = (4 * (RATE / 32) + 1) / 5; + let j = (4 * (RATE / 32) + 1) % 5; + s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u) }; } } #[inline(always)] -pub(crate) fn load_block_full(s: &mut [[__m256i;5];5], blocks: [[u8;200];4]) { - let [b0,b1,b2,b3] = blocks; - load_block::(s,[&b0 as &[u8], &b1 as &[u8], &b2 as &[u8], &b3 as &[u8]]); +pub(crate) fn load_block_full( + s: &mut [[__m256i; 5]; 5], + blocks: [[u8; 200]; 4], +) { + let [b0, b1, b2, b3] = blocks; + load_block::(s, [&b0 as &[u8], &b1 as &[u8], &b2 as &[u8], &b3 as &[u8]]); } #[inline(always)] -pub(crate) fn store_block(s: &[[__m256i;5];5], out: [&mut [u8];4]) { - for i in 0..RATE/32 { - let v0l = unsafe { _mm256_permute2x128_si256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x20) }; // 0 0 2 2 - let v1h = unsafe { _mm256_permute2x128_si256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x20) }; // 1 1 3 3 - let v2l = unsafe { _mm256_permute2x128_si256(s[(4*i)/5][(4*i)%5],s[(4*i+2)/5][(4*i+2)%5],0x31) }; // 0 0 2 2 - let v3h = unsafe { _mm256_permute2x128_si256(s[(4*i+1)/5][(4*i+1)%5],s[(4*i+3)/5][(4*3+1)%5],0x31) }; // 1 1 3 3 - +pub(crate) fn store_block(s: &[[__m256i; 5]; 5], out: [&mut [u8]; 4]) { + for i in 0..RATE / 32 { + let v0l = unsafe { + _mm256_permute2x128_si256( + s[(4 * i) / 5][(4 * i) % 5], + s[(4 * i + 2) / 5][(4 * i + 2) % 5], + 0x20, + ) + }; // 0 0 2 2 + let v1h = unsafe { + _mm256_permute2x128_si256( + s[(4 * i + 1) / 5][(4 * i + 1) % 5], + s[(4 * i + 3) / 5][(4 * 3 + 1) % 5], + 0x20, + ) + }; // 1 1 3 3 + let v2l = unsafe { + _mm256_permute2x128_si256( + s[(4 * i) / 5][(4 * i) % 5], + s[(4 * i + 2) / 5][(4 * i + 2) % 5], + 0x31, + ) + }; // 0 0 2 2 + let v3h = unsafe { + _mm256_permute2x128_si256( + s[(4 * i + 1) / 5][(4 * i + 1) % 5], + s[(4 * i + 3) / 5][(4 * 3 + 1) % 5], + 0x31, + ) + }; // 1 1 3 3 let v0 = unsafe { _mm256_unpacklo_epi64(v0l, v1h) }; // 0 1 2 3 let v1 = unsafe { _mm256_unpackhi_epi64(v0l, v1h) }; // 0 1 2 3 let v2 = unsafe { _mm256_unpacklo_epi64(v2l, v3h) }; // 0 1 2 3 let v3 = unsafe { _mm256_unpackhi_epi64(v2l, v3h) }; // 0 1 2 3 - unsafe { _mm256_storeu_si256(out[0][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v0) }; - unsafe { _mm256_storeu_si256(out[1][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v1) }; - unsafe { _mm256_storeu_si256(out[2][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v2) }; - unsafe { _mm256_storeu_si256(out[3][32*i..32*(i+1)].as_mut_ptr() as *mut __m256i, v3) }; - } - - let rem = RATE%32; // has to be 8 or 16 - let start = 32 * (RATE/32); - let mut u8s = [0u8;32]; - let i = (4*(RATE/32))/5; - let j = (4*(RATE/32))%5; - unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j])}; - out[0][start..start+8].copy_from_slice(&u8s[0..8]); - out[1][start..start+8].copy_from_slice(&u8s[8..16]); - out[2][start..start+8].copy_from_slice(&u8s[16..24]); - out[3][start..start+8].copy_from_slice(&u8s[24..32]); + unsafe { + _mm256_storeu_si256( + out[0][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v0, + ) + }; + unsafe { + _mm256_storeu_si256( + out[1][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v1, + ) + }; + unsafe { + _mm256_storeu_si256( + out[2][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v2, + ) + }; + unsafe { + _mm256_storeu_si256( + out[3][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v3, + ) + }; + } + + let rem = RATE % 32; // has to be 8 or 16 + let start = 32 * (RATE / 32); + let mut u8s = [0u8; 32]; + let i = (4 * (RATE / 32)) / 5; + let j = (4 * (RATE / 32)) % 5; + unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j]) }; + out[0][start..start + 8].copy_from_slice(&u8s[0..8]); + out[1][start..start + 8].copy_from_slice(&u8s[8..16]); + out[2][start..start + 8].copy_from_slice(&u8s[16..24]); + out[3][start..start + 8].copy_from_slice(&u8s[24..32]); if rem == 16 { - let mut u8s = [0u8;32]; - let i = (4*(RATE/32) + 1)/5; - let j = (4*(RATE/32) + 1)%5; - unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j])}; - out[0][start+8..start+16].copy_from_slice(&u8s[0..8]); - out[1][start+8..start+16].copy_from_slice(&u8s[8..16]); - out[2][start+8..start+16].copy_from_slice(&u8s[16..24]); - out[3][start+8..start+16].copy_from_slice(&u8s[24..32]); + let mut u8s = [0u8; 32]; + let i = (4 * (RATE / 32) + 1) / 5; + let j = (4 * (RATE / 32) + 1) % 5; + unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j]) }; + out[0][start + 8..start + 16].copy_from_slice(&u8s[0..8]); + out[1][start + 8..start + 16].copy_from_slice(&u8s[8..16]); + out[2][start + 8..start + 16].copy_from_slice(&u8s[16..24]); + out[3][start + 8..start + 16].copy_from_slice(&u8s[24..32]); } -} +} #[inline(always)] -pub(crate) fn store_block_full(s: &[[__m256i;5];5]) -> [[u8;200];4] { +pub(crate) fn store_block_full(s: &[[__m256i; 5]; 5]) -> [[u8; 200]; 4] { let mut out0 = [0u8; 200]; let mut out1 = [0u8; 200]; let mut out2 = [0u8; 200]; let mut out3 = [0u8; 200]; - store_block::(s,[&mut out0, &mut out1, &mut out2, &mut out3]); + store_block::(s, [&mut out0, &mut out1, &mut out2, &mut out3]); [out0, out1, out2, out3] -} +} #[inline(always)] -fn slice_4(a: [&[u8];4], start:usize, len:usize) -> [&[u8];4] { - [&a[0][start..start+len], &a[1][start..start+len], &a[2][start..start+len], &a[3][start..start+len]] +fn slice_4(a: [&[u8]; 4], start: usize, len: usize) -> [&[u8]; 4] { + [ + &a[0][start..start + len], + &a[1][start..start + len], + &a[2][start..start + len], + &a[3][start..start + len], + ] } #[inline(always)] -fn split_at_mut_4(out: [&mut [u8]; 4], mid:usize) -> ([&mut [u8];4],[&mut [u8];4]) { +fn split_at_mut_4(out: [&mut [u8]; 4], mid: usize) -> ([&mut [u8]; 4], [&mut [u8]; 4]) { let [out0, out1, out2, out3] = out; - let (out00,out01) = out0.split_at_mut(mid); - let (out10,out11) = out1.split_at_mut(mid); - let (out20,out21) = out2.split_at_mut(mid); - let (out30,out31) = out3.split_at_mut(mid); - ([out00,out10,out20,out30], - [out01,out11,out21,out31]) + let (out00, out01) = out0.split_at_mut(mid); + let (out10, out11) = out1.split_at_mut(mid); + let (out20, out21) = out2.split_at_mut(mid); + let (out30, out31) = out3.split_at_mut(mid); + ([out00, out10, out20, out30], [out01, out11, out21, out31]) } impl KeccakItem<4> for __m256i { @@ -185,7 +244,7 @@ impl KeccakItem<4> for __m256i { } #[inline(always)] fn xor_and_rotate(a: Self, b: Self) -> Self { - _vxarq_u64::(a, b) + _vxarq_u64::(a, b) } #[inline(always)] fn and_not_xor(a: Self, b: Self, c: Self) -> Self { @@ -197,31 +256,30 @@ impl KeccakItem<4> for __m256i { } #[inline(always)] fn xor(a: Self, b: Self) -> Self { - unsafe {_mm256_xor_si256(a, b)} + unsafe { _mm256_xor_si256(a, b) } } #[inline(always)] - fn load_block(a:&mut [[Self;5];5], b:[&[u8];4]) { + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; 4]) { load_block::(a, b) } #[inline(always)] - fn store_block(a:& [[Self;5];5], b:[&mut [u8];4]) { + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; 4]) { store_block::(a, b) } #[inline(always)] - fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];4]) { + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; 4]) { load_block_full::(a, b) } #[inline(always)] - fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];4] { + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; 4] { store_block_full::(a) } #[inline(always)] - fn slice_n(a:[&[u8];4],start:usize,len:usize) -> [&[u8];4] { - slice_4(a,start,len) + fn slice_n(a: [&[u8]; 4], start: usize, len: usize) -> [&[u8]; 4] { + slice_4(a, start, len) } #[inline(always)] - fn split_at_mut_n(a:[&mut [u8];4],mid:usize) -> ([&mut [u8];4],[&mut [u8];4]) { + fn split_at_mut_n(a: [&mut [u8]; 4], mid: usize) -> ([&mut [u8]; 4], [&mut [u8]; 4]) { split_at_mut_4(a, mid) } } - diff --git a/libcrux-sha3/src/rust_simd/sha3_generic.rs b/libcrux-sha3/src/rust_simd/sha3_generic.rs index 2892be19e..d9a46718d 100644 --- a/libcrux-sha3/src/rust_simd/sha3_generic.rs +++ b/libcrux-sha3/src/rust_simd/sha3_generic.rs @@ -1,71 +1,82 @@ +use std::ops::Index; + use crate::rust_simd::sha3_trait::*; #[cfg_attr(hax, hax_lib::opaque_type)] #[derive(Clone, Copy)] -pub struct KeccakState> { +pub struct KeccakState> { pub st: [[T; 5]; 5], } -impl> KeccakState { +impl> Index for KeccakState { + type Output = [T; 5]; + + fn index(&self, index: usize) -> &Self::Output { + &self.st[index] + } +} + +impl> KeccakState { /// Create a new Shake128 x4 state. #[inline(always)] - pub(crate) fn new() -> Self { + pub(crate) fn new() -> Self { Self { st: [[T::zero(); 5]; 5], } - } } /// From here, everything is generic -/// -const _ROTC: [usize;24] = - [1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14,]; - +/// +const _ROTC: [usize; 24] = [ + 1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14, +]; #[inline(always)] -pub(crate) fn theta_rho>(s: &mut KeccakState) { - let c: [T; 5] = core::array::from_fn(|j| T::xor5(s.st[0][j],s.st[1][j],s.st[2][j],s.st[3][j],s.st[4][j])); - let t : [T; 5] = core::array::from_fn(|j| T::rotate_left1_and_xor(c[(j+4)%5], c[(j+1)%5])); - - s.st[0][0] = T::xor(s.st[0][0],t[0]); - s.st[1][0] = T::xor_and_rotate::<36,28>(s.st[1][0],t[0]); - s.st[2][0] = T::xor_and_rotate::<3,61>(s.st[2][0],t[0]); - s.st[3][0] = T::xor_and_rotate::<41,23>(s.st[3][0],t[0]); - s.st[4][0] = T::xor_and_rotate::<18,46>(s.st[4][0],t[0]); - - s.st[0][1] = T::xor_and_rotate::<1,63>(s.st[0][1],t[1]); - s.st[1][1] = T::xor_and_rotate::<44,20>(s.st[1][1],t[1]); - s.st[2][1] = T::xor_and_rotate::<10,54>(s.st[2][1],t[1]); - s.st[3][1] = T::xor_and_rotate::<45,19>(s.st[3][1],t[1]); - s.st[4][1] = T::xor_and_rotate::<2,62>(s.st[4][1],t[1]); - - s.st[0][2] = T::xor_and_rotate::<62,2>(s.st[0][2],t[2]); - s.st[1][2] = T::xor_and_rotate::<6,58>(s.st[1][2],t[2]); - s.st[2][2] = T::xor_and_rotate::<43,21>(s.st[2][2],t[2]); - s.st[3][2] = T::xor_and_rotate::<15,49>(s.st[3][2],t[2]); - s.st[4][2] = T::xor_and_rotate::<61,3>(s.st[4][2],t[2]); - - s.st[0][3] = T::xor_and_rotate::<28,36>(s.st[0][3],t[3]); - s.st[1][3] = T::xor_and_rotate::<55,9>(s.st[1][3],t[3]); - s.st[2][3] = T::xor_and_rotate::<25,39>(s.st[2][3],t[3]); - s.st[3][3] = T::xor_and_rotate::<21,43>(s.st[3][3],t[3]); - s.st[4][3] = T::xor_and_rotate::<56,8>(s.st[4][3],t[3]); - - s.st[0][4] = T::xor_and_rotate::<27,37>(s.st[0][4],t[4]); - s.st[1][4] = T::xor_and_rotate::<20,44>(s.st[1][4],t[4]); - s.st[2][4] = T::xor_and_rotate::<39,25>(s.st[2][4],t[4]); - s.st[3][4] = T::xor_and_rotate::<8,56>(s.st[3][4],t[4]); - s.st[4][4] = T::xor_and_rotate::<14,50>(s.st[4][4],t[4]); -} - - -const _PI : [usize;24] = [ +pub(crate) fn theta_rho>(s: &mut KeccakState) { + let c: [T; 5] = core::array::from_fn(|j| { + T::xor5(s.st[0][j], s.st[1][j], s.st[2][j], s.st[3][j], s.st[4][j]) + }); + let t: [T; 5] = + core::array::from_fn(|j| T::rotate_left1_and_xor(c[(j + 4) % 5], c[(j + 1) % 5])); + + s.st[0][0] = T::xor(s.st[0][0], t[0]); + s.st[1][0] = T::xor_and_rotate::<36, 28>(s.st[1][0], t[0]); + s.st[2][0] = T::xor_and_rotate::<3, 61>(s.st[2][0], t[0]); + s.st[3][0] = T::xor_and_rotate::<41, 23>(s.st[3][0], t[0]); + s.st[4][0] = T::xor_and_rotate::<18, 46>(s.st[4][0], t[0]); + + s.st[0][1] = T::xor_and_rotate::<1, 63>(s.st[0][1], t[1]); + s.st[1][1] = T::xor_and_rotate::<44, 20>(s.st[1][1], t[1]); + s.st[2][1] = T::xor_and_rotate::<10, 54>(s.st[2][1], t[1]); + s.st[3][1] = T::xor_and_rotate::<45, 19>(s.st[3][1], t[1]); + s.st[4][1] = T::xor_and_rotate::<2, 62>(s.st[4][1], t[1]); + + s.st[0][2] = T::xor_and_rotate::<62, 2>(s.st[0][2], t[2]); + s.st[1][2] = T::xor_and_rotate::<6, 58>(s.st[1][2], t[2]); + s.st[2][2] = T::xor_and_rotate::<43, 21>(s.st[2][2], t[2]); + s.st[3][2] = T::xor_and_rotate::<15, 49>(s.st[3][2], t[2]); + s.st[4][2] = T::xor_and_rotate::<61, 3>(s.st[4][2], t[2]); + + s.st[0][3] = T::xor_and_rotate::<28, 36>(s.st[0][3], t[3]); + s.st[1][3] = T::xor_and_rotate::<55, 9>(s.st[1][3], t[3]); + s.st[2][3] = T::xor_and_rotate::<25, 39>(s.st[2][3], t[3]); + s.st[3][3] = T::xor_and_rotate::<21, 43>(s.st[3][3], t[3]); + s.st[4][3] = T::xor_and_rotate::<56, 8>(s.st[4][3], t[3]); + + s.st[0][4] = T::xor_and_rotate::<27, 37>(s.st[0][4], t[4]); + s.st[1][4] = T::xor_and_rotate::<20, 44>(s.st[1][4], t[4]); + s.st[2][4] = T::xor_and_rotate::<39, 25>(s.st[2][4], t[4]); + s.st[3][4] = T::xor_and_rotate::<8, 56>(s.st[3][4], t[4]); + s.st[4][4] = T::xor_and_rotate::<14, 50>(s.st[4][4], t[4]); +} + +const _PI: [usize; 24] = [ 6, 12, 18, 24, 3, 9, 10, 16, 22, 1, 7, 13, 19, 20, 4, 5, 11, 17, 23, 2, 8, 14, 15, 21, ]; #[inline(always)] -pub(crate) fn pi>(s: &mut KeccakState) { +pub(crate) fn pi>(s: &mut KeccakState) { let old = s.st.clone(); s.st[0][1] = old[1][1]; s.st[0][2] = old[2][2]; @@ -94,7 +105,7 @@ pub(crate) fn pi>(s: &mut KeccakState) { } #[inline(always)] -pub(crate) fn chi>(s: &mut KeccakState) { +pub(crate) fn chi>(s: &mut KeccakState) { let old = s.st; for i in 0..5 { for j in 0..5 { @@ -103,25 +114,40 @@ pub(crate) fn chi>(s: &mut KeccakState) { } } -const ROUNDCONSTANTS: [u64;24] = [ - 0x0000_0000_0000_0001u64, 0x0000_0000_0000_8082u64, 0x8000_0000_0000_808au64, - 0x8000_0000_8000_8000u64, 0x0000_0000_0000_808bu64, 0x0000_0000_8000_0001u64, - 0x8000_0000_8000_8081u64, 0x8000_0000_0000_8009u64, 0x0000_0000_0000_008au64, - 0x0000_0000_0000_0088u64, 0x0000_0000_8000_8009u64, 0x0000_0000_8000_000au64, - 0x0000_0000_8000_808bu64, 0x8000_0000_0000_008bu64, 0x8000_0000_0000_8089u64, - 0x8000_0000_0000_8003u64, 0x8000_0000_0000_8002u64, 0x8000_0000_0000_0080u64, - 0x0000_0000_0000_800au64, 0x8000_0000_8000_000au64, 0x8000_0000_8000_8081u64, - 0x8000_0000_0000_8080u64, 0x0000_0000_8000_0001u64, 0x8000_0000_8000_8008u64, +const ROUNDCONSTANTS: [u64; 24] = [ + 0x0000_0000_0000_0001u64, + 0x0000_0000_0000_8082u64, + 0x8000_0000_0000_808au64, + 0x8000_0000_8000_8000u64, + 0x0000_0000_0000_808bu64, + 0x0000_0000_8000_0001u64, + 0x8000_0000_8000_8081u64, + 0x8000_0000_0000_8009u64, + 0x0000_0000_0000_008au64, + 0x0000_0000_0000_0088u64, + 0x0000_0000_8000_8009u64, + 0x0000_0000_8000_000au64, + 0x0000_0000_8000_808bu64, + 0x8000_0000_0000_008bu64, + 0x8000_0000_0000_8089u64, + 0x8000_0000_0000_8003u64, + 0x8000_0000_0000_8002u64, + 0x8000_0000_0000_0080u64, + 0x0000_0000_0000_800au64, + 0x8000_0000_8000_000au64, + 0x8000_0000_8000_8081u64, + 0x8000_0000_0000_8080u64, + 0x0000_0000_8000_0001u64, + 0x8000_0000_8000_8008u64, ]; #[inline(always)] -pub(crate) fn iota>(s: &mut KeccakState, i:usize) { +pub(crate) fn iota>(s: &mut KeccakState, i: usize) { s.st[0][0] = T::xor_constant(s.st[0][0], ROUNDCONSTANTS[i]); } - #[inline(always)] -pub(crate) fn keccakf1600>(s: &mut KeccakState) { +pub(crate) fn keccakf1600>(s: &mut KeccakState) { for i in 0..24 { theta_rho(s); pi(s); @@ -131,51 +157,65 @@ pub(crate) fn keccakf1600>(s: &mut KeccakState,const RATE:usize>(s: &mut KeccakState, blocks: [&[u8];N]) { +pub(crate) fn absorb_block, const RATE: usize>( + s: &mut KeccakState, + blocks: [&[u8]; N], +) { T::load_block::(&mut s.st, blocks); keccakf1600(s) } #[inline(always)] -pub(crate) fn absorb_final,const RATE:usize, const DELIM:u8>( - s: &mut KeccakState, last: [&[u8];N]) { +pub(crate) fn absorb_final, const RATE: usize, const DELIM: u8>( + s: &mut KeccakState, + last: [&[u8]; N], +) { debug_assert!(N > 0 && last[0].len() < RATE); let last_len = last[0].len(); let mut blocks = [[0u8; 200]; N]; for i in 0..N { blocks[i][0..last_len].copy_from_slice(&last[i]); blocks[i][last_len] = DELIM; - blocks[i][RATE-1] = blocks[i][RATE-1] | 128u8; + blocks[i][RATE - 1] = blocks[i][RATE - 1] | 128u8; } T::load_block_full::(&mut s.st, blocks); keccakf1600(s) } - #[inline(always)] -pub(crate) fn squeeze_first_block,const RATE:usize>(s: &KeccakState, out: [&mut [u8];N]) { +pub(crate) fn squeeze_first_block, const RATE: usize>( + s: &KeccakState, + out: [&mut [u8]; N], +) { T::store_block::(&s.st, out) } #[inline(always)] -pub(crate) fn squeeze_next_block,const RATE:usize>(s: &mut KeccakState, out: [&mut [u8];N]) { +pub(crate) fn squeeze_next_block, const RATE: usize>( + s: &mut KeccakState, + out: [&mut [u8]; N], +) { keccakf1600(s); T::store_block::(&s.st, out) } - #[inline(always)] -pub(crate) fn squeeze_first_three_blocks,const RATE:usize>( - s: &mut KeccakState, out: [&mut [u8];N]) { - let (o0,o1) = T::split_at_mut_n(out, RATE); - squeeze_first_block::(s, o0); - let (o1,o2) = T::split_at_mut_n(o1, RATE); - squeeze_next_block::(s, o1); - squeeze_next_block::(s, o2); +pub(crate) fn squeeze_first_three_blocks, const RATE: usize>( + s: &mut KeccakState, + out: [&mut [u8]; N], +) { + let (o0, o1) = T::split_at_mut_n(out, RATE); + squeeze_first_block::(s, o0); + let (o1, o2) = T::split_at_mut_n(o1, RATE); + squeeze_next_block::(s, o1); + squeeze_next_block::(s, o2); } #[inline(always)] -pub(crate) fn squeeze_last,const RATE:usize>(mut s: KeccakState, out: [&mut [u8];N]) { +pub(crate) fn squeeze_last, const RATE: usize>( + mut s: KeccakState, + out: [&mut [u8]; N], +) { keccakf1600(&mut s); let b = T::store_block_full::(&s.st); for i in 0..N { @@ -184,7 +224,10 @@ pub(crate) fn squeeze_last,const RATE:usize>(mut } #[inline(always)] -pub(crate) fn squeeze_first_and_last,const RATE:usize>(s: &KeccakState, out: [&mut [u8];N]) { +pub(crate) fn squeeze_first_and_last, const RATE: usize>( + s: &KeccakState, + out: [&mut [u8]; N], +) { let b = T::store_block_full::(&s.st); for i in 0..N { out[i].copy_from_slice(&b[i][0..out[i].len()]); @@ -192,28 +235,33 @@ pub(crate) fn squeeze_first_and_last,const RATE:u } #[inline(always)] -pub(crate) fn keccak,const RATE:usize, const DELIM:u8>(data: [&[u8]; N], out: [&mut [u8]; N]) { - let mut s = KeccakState::::new(); - for i in 0..data[0].len()/RATE { - absorb_block::(&mut s, T::slice_n(data,i*RATE,RATE)); +pub(crate) fn keccak, const RATE: usize, const DELIM: u8>( + data: [&[u8]; N], + out: [&mut [u8]; N], +) { + let mut s = KeccakState::::new(); + for i in 0..data[0].len() / RATE { + absorb_block::(&mut s, T::slice_n(data, i * RATE, RATE)); } let rem = data[0].len() % RATE; - absorb_final::(&mut s, T::slice_n(data,data[0].len()-rem,rem)); + absorb_final::(&mut s, T::slice_n(data, data[0].len() - rem, rem)); let outlen = out[0].len(); - let blocks = outlen/RATE; - let last = outlen - (outlen%RATE); + let blocks = outlen / RATE; + let last = outlen - (outlen % RATE); if blocks == 0 { - squeeze_first_and_last::(&s, out) + squeeze_first_and_last::(&s, out) } else { - let (o0,mut o1) = T::split_at_mut_n(out, RATE); - squeeze_first_block::(&s, o0); + let (o0, mut o1) = T::split_at_mut_n(out, RATE); + squeeze_first_block::(&s, o0); for _i in 1..blocks { - let (o,orest) = T::split_at_mut_n(o1, RATE); - squeeze_next_block::(&mut s, o); + let (o, orest) = T::split_at_mut_n(o1, RATE); + squeeze_next_block::(&mut s, o); o1 = orest; } - if last < outlen {squeeze_last::(s, o1)} + if last < outlen { + squeeze_last::(s, o1) + } } } diff --git a/libcrux-sha3/src/rust_simd/sha3_portable.rs b/libcrux-sha3/src/rust_simd/sha3_portable.rs index f837470db..cbf2af8b5 100644 --- a/libcrux-sha3/src/rust_simd/sha3_portable.rs +++ b/libcrux-sha3/src/rust_simd/sha3_portable.rs @@ -5,10 +5,9 @@ use crate::rust_simd::sha3_trait::*; // veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 // These instructions might speed up our code even more. - #[inline(always)] -fn rotate_left(x:u64) -> u64 { - debug_assert!(LEFT+RIGHT == 64); +fn rotate_left(x: u64) -> u64 { + debug_assert!(LEFT + RIGHT == 64); (x << LEFT) | (x >> RIGHT) } @@ -22,13 +21,13 @@ fn _veor5q_u64(a: u64, b: u64, c: u64, d: u64, e: u64) -> u64 { #[inline(always)] fn _vrax1q_u64(a: u64, b: u64) -> u64 { - a ^ rotate_left::<1,63>(b) + a ^ rotate_left::<1, 63>(b) } #[inline(always)] fn _vxarq_u64(a: u64, b: u64) -> u64 { let ab = a ^ b; - rotate_left::(ab) + rotate_left::(ab) } #[inline(always)] @@ -42,41 +41,41 @@ fn _veorq_n_u64(a: u64, c: u64) -> u64 { } #[inline(always)] -pub(crate) fn load_block(s: &mut [[u64;5];5], blocks: [&[u8];1]) { +pub(crate) fn load_block(s: &mut [[u64; 5]; 5], blocks: [&[u8]; 1]) { debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); - for i in 0..RATE/8 { - s[i/5][i%5] = u64::from_le_bytes(blocks[0][8*i..8*i+8].try_into().unwrap()); + for i in 0..RATE / 8 { + s[i / 5][i % 5] = u64::from_le_bytes(blocks[0][8 * i..8 * i + 8].try_into().unwrap()); } } #[inline(always)] -pub(crate) fn load_block_full(s: &mut [[u64;5];5], blocks: [[u8;200];1]) { - load_block::(s,[&blocks[0] as &[u8]]); +pub(crate) fn load_block_full(s: &mut [[u64; 5]; 5], blocks: [[u8; 200]; 1]) { + load_block::(s, [&blocks[0] as &[u8]]); } #[inline(always)] -pub(crate) fn store_block(s: &[[u64;5];5], out: [&mut [u8];1]) { - for i in 0..RATE/8 { - out[0][8*i..8*i+8].copy_from_slice(&s[i/5][i%5].to_le_bytes()); +pub(crate) fn store_block(s: &[[u64; 5]; 5], out: [&mut [u8]; 1]) { + for i in 0..RATE / 8 { + out[0][8 * i..8 * i + 8].copy_from_slice(&s[i / 5][i % 5].to_le_bytes()); } -} +} #[inline(always)] -pub(crate) fn store_block_full(s: &[[u64;5];5]) -> [[u8;200];1] { +pub(crate) fn store_block_full(s: &[[u64; 5]; 5]) -> [[u8; 200]; 1] { let mut out = [0u8; 200]; - store_block::(s,[&mut out]); + store_block::(s, [&mut out]); [out] -} +} #[inline(always)] -fn slice_1(a: [&[u8];1], start:usize, len:usize) -> [&[u8];1] { - [&a[0][start..start+len]] +fn slice_1(a: [&[u8]; 1], start: usize, len: usize) -> [&[u8]; 1] { + [&a[0][start..start + len]] } #[inline(always)] -fn split_at_mut_1(out: [&mut [u8]; 1], mid:usize) -> ([&mut [u8];1],[&mut [u8];1]) { +fn split_at_mut_1(out: [&mut [u8]; 1], mid: usize) -> ([&mut [u8]; 1], [&mut [u8]; 1]) { let [out0] = out; - let (out00,out01) = out0.split_at_mut(mid); + let (out00, out01) = out0.split_at_mut(mid); ([out00], [out01]) } @@ -95,7 +94,7 @@ impl KeccakItem<1> for u64 { } #[inline(always)] fn xor_and_rotate(a: Self, b: Self) -> Self { - _vxarq_u64::(a, b) + _vxarq_u64::(a, b) } #[inline(always)] fn and_not_xor(a: Self, b: Self, c: Self) -> Self { @@ -107,31 +106,30 @@ impl KeccakItem<1> for u64 { } #[inline(always)] fn xor(a: Self, b: Self) -> Self { - a^b + a ^ b } #[inline(always)] - fn load_block(a:&mut [[Self;5];5], b:[&[u8];1]) { + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; 1]) { load_block::(a, b) } #[inline(always)] - fn store_block(a:& [[Self;5];5], b:[&mut [u8];1]) { + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; 1]) { store_block::(a, b) } #[inline(always)] - fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];1]) { + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; 1]) { load_block_full::(a, b) } #[inline(always)] - fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];1] { + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; 1] { store_block_full::(a) } #[inline(always)] - fn slice_n(a:[&[u8];1],start:usize,len:usize) -> [&[u8];1] { - slice_1(a,start,len) + fn slice_n(a: [&[u8]; 1], start: usize, len: usize) -> [&[u8]; 1] { + slice_1(a, start, len) } #[inline(always)] - fn split_at_mut_n(a:[&mut [u8];1],mid:usize) -> ([&mut [u8];1],[&mut [u8];1]) { + fn split_at_mut_n(a: [&mut [u8]; 1], mid: usize) -> ([&mut [u8]; 1], [&mut [u8]; 1]) { split_at_mut_1(a, mid) } } - diff --git a/libcrux-sha3/src/rust_simd/sha3_trait.rs b/libcrux-sha3/src/rust_simd/sha3_trait.rs index 358dbff16..0ad85dae8 100644 --- a/libcrux-sha3/src/rust_simd/sha3_trait.rs +++ b/libcrux-sha3/src/rust_simd/sha3_trait.rs @@ -1,5 +1,4 @@ - -pub trait KeccakItem: Clone + Copy { +pub trait KeccakItem: Clone + Copy { fn zero() -> Self; fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self; fn rotate_left1_and_xor(a: Self, b: Self) -> Self; @@ -7,11 +6,10 @@ pub trait KeccakItem: Clone + Copy { fn and_not_xor(a: Self, b: Self, c: Self) -> Self; fn xor_constant(a: Self, c: u64) -> Self; fn xor(a: Self, b: Self) -> Self; - fn load_block(a:&mut [[Self;5];5], b:[&[u8];N]); - fn store_block(a:& [[Self;5];5], b:[&mut [u8];N]); - fn load_block_full(a:&mut [[Self;5];5], b:[[u8;200];N]); - fn store_block_full(a:&[[Self;5];5]) -> [[u8;200];N]; - fn slice_n(a:[&[u8];N],start:usize,len:usize) -> [&[u8];N]; - fn split_at_mut_n(a:[&mut [u8];N],mid:usize) -> ([&mut [u8];N],[&mut [u8];N]); + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; N]); + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; N]); + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; N]); + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; N]; + fn slice_n(a: [&[u8]; N], start: usize, len: usize) -> [&[u8]; N]; + fn split_at_mut_n(a: [&mut [u8]; N], mid: usize) -> ([&mut [u8]; N], [&mut [u8]; N]); } - diff --git a/libcrux-sha3/tests/sha3.rs b/libcrux-sha3/tests/sha3.rs index 0fff1b146..dfc5f35dd 100644 --- a/libcrux-sha3/tests/sha3.rs +++ b/libcrux-sha3/tests/sha3.rs @@ -5,7 +5,8 @@ fn sha3_kat_oneshot() { assert_eq!(hex::encode(&d256), expected256); let dshake = libcrux_sha3::shake128::<42>(b"Hello, World!"); - let expectedshake = "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; + let expectedshake = + "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; assert_eq!(hex::encode(&dshake), expectedshake); } @@ -16,6 +17,7 @@ fn sha3_simd_kat_oneshot() { assert_eq!(hex::encode(&d256), expected256); let dshake = libcrux_sha3::rust_simd::shake128::<42>(b"Hello, World!"); - let expectedshake = "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; + let expectedshake = + "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; assert_eq!(hex::encode(&dshake), expectedshake); } diff --git a/polynomials-aarch64/src/lib.rs b/polynomials-aarch64/src/lib.rs index bf4195788..a37953217 100644 --- a/polynomials-aarch64/src/lib.rs +++ b/polynomials-aarch64/src/lib.rs @@ -157,7 +157,7 @@ impl Operations for SIMD128Vector { deserialize_12(a) } - fn rej_sample(a: &[u8], out:&mut [i16]) -> usize { + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { rejsample::rej_sample(a, out) } } diff --git a/polynomials-aarch64/src/rejsample.rs b/polynomials-aarch64/src/rejsample.rs index 00cdb5471..b25fc0f42 100644 --- a/polynomials-aarch64/src/rejsample.rs +++ b/polynomials-aarch64/src/rejsample.rs @@ -768,7 +768,7 @@ const IDX_TABLE: [[u8; 16]; 256] = [ ]; #[inline(always)] -pub(crate) fn rej_sample(a: &[u8], out:&mut [i16]) -> usize { +pub(crate) fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { let neon_bits: [u16; 8] = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80]; let bits = _vld1q_u16(&neon_bits); let fm = _vdupq_n_s16(3328); diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 1c4cde0da..c81bfc932 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -32,7 +32,7 @@ fn to_i16_array(v: SIMD256Vector) -> [i16; 16] { out } #[inline(always)] -fn from_i16_array(array: [i16; 16]) -> SIMD256Vector { +fn from_i16_array(array: &[i16]) -> SIMD256Vector { SIMD256Vector { elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) }, } @@ -768,7 +768,7 @@ fn serialize_5(v: SIMD256Vector) -> [u8; 10] { fn deserialize_5(v: &[u8]) -> SIMD256Vector { let output = portable::deserialize_5(v); - from_i16_array(portable::to_i16_array(output)) + from_i16_array(&portable::to_i16_array(output)) } #[inline(always)] @@ -881,7 +881,7 @@ fn serialize_11(v: SIMD256Vector) -> [u8; 22] { fn deserialize_11(v: &[u8]) -> SIMD256Vector { let output = portable::deserialize_11(v); - from_i16_array(portable::to_i16_array(output)) + from_i16_array(&portable::to_i16_array(output)) } #[inline(always)] @@ -984,8 +984,8 @@ fn deserialize_12(v: &[u8]) -> SIMD256Vector { } #[inline(always)] -fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - portable::rej_sample(a) +fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { + portable::rej_sample(a, out) } impl Operations for SIMD256Vector { @@ -997,7 +997,7 @@ impl Operations for SIMD256Vector { to_i16_array(v) } - fn from_i16_array(array: [i16; 16]) -> Self { + fn from_i16_array(array: &[i16]) -> Self { from_i16_array(array) } @@ -1132,7 +1132,7 @@ impl Operations for SIMD256Vector { deserialize_12(a) } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rej_sample(a) + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { + rej_sample(a, out) } } diff --git a/polynomials-avx2/src/portable.rs b/polynomials-avx2/src/portable.rs index b18c02d19..eed92bc6b 100644 --- a/polynomials-avx2/src/portable.rs +++ b/polynomials-avx2/src/portable.rs @@ -114,8 +114,7 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> PortableVector { } #[inline(always)] -pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - let mut result = [0i16; 16]; +pub(crate) fn rej_sample(a: &[u8], result: &mut [i16]) -> usize { let mut sampled = 0; for bytes in a.chunks(3) { let b1 = bytes[0] as i16; @@ -134,5 +133,5 @@ pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { sampled += 1 } } - (sampled, result) + sampled } diff --git a/polynomials/src/lib.rs b/polynomials/src/lib.rs index cb8052616..1dd33fd91 100644 --- a/polynomials/src/lib.rs +++ b/polynomials/src/lib.rs @@ -210,7 +210,9 @@ fn to_i16_array(v: PortableVector) -> [i16; FIELD_ELEMENTS_IN_VECTOR] { #[inline(always)] fn from_i16_array(array: &[i16]) -> PortableVector { - PortableVector { elements: array[0..16].try_into().unwrap() } + PortableVector { + elements: array[0..16].try_into().unwrap(), + } } #[inline(always)] @@ -1039,7 +1041,7 @@ fn deserialize_12(bytes: &[u8]) -> PortableVector { } #[inline(always)] -fn rej_sample(a: &[u8], result: &mut[i16]) -> usize { +fn rej_sample(a: &[u8], result: &mut [i16]) -> usize { let mut sampled = 0; for bytes in a.chunks(3) { let b1 = bytes[0] as i16; @@ -1205,7 +1207,7 @@ impl Operations for PortableVector { deserialize_12(a) } - fn rej_sample(a: &[u8], out:&mut [i16]) -> usize { + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { rej_sample(a, out) } } diff --git a/traits/src/lib.rs b/traits/src/lib.rs index 4b8fc796b..ef59e381f 100644 --- a/traits/src/lib.rs +++ b/traits/src/lib.rs @@ -61,7 +61,7 @@ pub trait Operations: Copy + Clone { fn serialize_12(a: Self) -> [u8; 24]; fn deserialize_12(a: &[u8]) -> Self; - fn rej_sample(a: &[u8], out:&mut [i16]) -> usize; + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize; } // hax does not support trait with default implementations, so we use the following patter From 8488d3293e620144ce8bd12f6570ee58e83b4aed Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 11:27:43 +0200 Subject: [PATCH 15/59] fix avx2 sha3 --- libcrux-sha3/src/rust_simd/sha3_avx2.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libcrux-sha3/src/rust_simd/sha3_avx2.rs b/libcrux-sha3/src/rust_simd/sha3_avx2.rs index 130f04321..a9856a7aa 100644 --- a/libcrux-sha3/src/rust_simd/sha3_avx2.rs +++ b/libcrux-sha3/src/rust_simd/sha3_avx2.rs @@ -35,7 +35,7 @@ fn _vxarq_u64(a: __m256i, b: __m256i) -> __m2 #[inline(always)] fn _vbcaxq_u64(a: __m256i, b: __m256i, c: __m256i) -> __m256i { - unsafe { _mm256_xor_si256(a, _mm256_andnot_si256(b, c)) } + unsafe { _mm256_xor_si256(a, _mm256_andnot_si256(c, b)) } } #[inline(always)] From 6dc64714ab7be390a32693ba30255673b255fb03 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 11:34:04 +0200 Subject: [PATCH 16/59] fix build on avx2 using portable sha3 in mlkem --- libcrux-ml-kem/src/hash_functions.rs | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 136afc2c0..391836953 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -105,9 +105,11 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { #[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { +pub(crate) fn absorb( + input: [[u8; 34]; K], +) -> [libcrux_sha3::rust_simd::KeccakState1; K] { debug_assert!(K == 2 || K == 3 || K == 4); - let mut states = rust_simd::shake128x4_init(); + let mut states = [rust_simd::shake128_init(); K]; for i in 0..K { rust_simd::shake128_absorb_final(&mut states[i], &input[i]); } @@ -170,7 +172,7 @@ pub(crate) fn squeeze_three_blocks( #[cfg(not(feature = "simd128"))] #[inline(always)] pub(crate) fn squeeze_three_blocks( - state: &mut Shake128x4State, + state: &mut [libcrux_sha3::rust_simd::KeccakState1], ) -> [[u8; THREE_BLOCKS]; K] { let mut out = [[0u8; THREE_BLOCKS]; K]; for i in 0..K { @@ -216,7 +218,9 @@ pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8 #[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8; BLOCK_SIZE]; K] { +pub(crate) fn squeeze_block( + state: &mut [libcrux_sha3::rust_simd::KeccakState1; K], +) -> [[u8; BLOCK_SIZE]; K] { let mut out = [[0u8; BLOCK_SIZE]; K]; for i in 0..K { rust_simd::shake128_squeeze_next_block(&mut state[i], &mut out[i]); @@ -227,5 +231,6 @@ pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8 /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. +#[cfg(not(feature = "simd128"))] #[inline(always)] -pub(crate) fn free_state(_xof_state: Shake128x4State) {} +pub(crate) fn free_state(_xof_state: [libcrux_sha3::rust_simd::KeccakState1; K]) {} From b5bdc54a99c2c40ae2096d9b26f9c000b3ffe3af Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Tue, 14 May 2024 11:32:41 +0200 Subject: [PATCH 17/59] fixed portable --- libcrux-ml-kem/src/hash_functions.rs | 27 --------------------- libcrux-sha3/src/rust_simd/sha3_portable.rs | 2 +- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 136afc2c0..2e03404ed 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -4,48 +4,21 @@ use crate::constants::H_DIGEST_SIZE; use libcrux_sha3::rust_simd::{self, KeccakState4}; -#[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn G(input: &[u8]) -> [u8; 64] { rust_simd::sha3_512(input) } -#[cfg(not(feature = "simd128"))] -#[inline(always)] -pub(crate) fn G(input: &[u8]) -> [u8; 64] { - libcrux_sha3::sha512(input) - //some bug in scalar version of rust_simd - // rust_simd::sha512(input) -} - -#[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { rust_simd::sha3_256(input) } -#[cfg(not(feature = "simd128"))] -#[inline(always)] -pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - libcrux_sha3::sha256(input) - //some bug in scalar version of rust_simd - // rust_simd::sha256(input) -} - -#[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { rust_simd::shake256::(input) } -#[cfg(not(feature = "simd128"))] -#[inline(always)] -pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { - libcrux_sha3::shake256::(input) - //some bug in scalar version of rust_simd - // rust_simd::shake256::(input) -} - #[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { diff --git a/libcrux-sha3/src/rust_simd/sha3_portable.rs b/libcrux-sha3/src/rust_simd/sha3_portable.rs index cbf2af8b5..8d08ed7b8 100644 --- a/libcrux-sha3/src/rust_simd/sha3_portable.rs +++ b/libcrux-sha3/src/rust_simd/sha3_portable.rs @@ -44,7 +44,7 @@ fn _veorq_n_u64(a: u64, c: u64) -> u64 { pub(crate) fn load_block(s: &mut [[u64; 5]; 5], blocks: [&[u8]; 1]) { debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); for i in 0..RATE / 8 { - s[i / 5][i % 5] = u64::from_le_bytes(blocks[0][8 * i..8 * i + 8].try_into().unwrap()); + s[i / 5][i % 5] ^= u64::from_le_bytes(blocks[0][8 * i..8 * i + 8].try_into().unwrap()); } } From 4e25d054a15ee6d4c085e8bfc7c8b21b59f34b40 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Tue, 14 May 2024 11:44:57 +0200 Subject: [PATCH 18/59] fixed some feature flags --- libcrux-sha3/src/rust_simd.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index 0c00e2ee8..9aeba2a76 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -292,7 +292,7 @@ pub fn shake128x4_init() -> KeccakState4 { [s0, s1, s2, s3] } -#[cfg(feature = "simd128")] +#[cfg(feature = "simd256")] pub fn shake128x4_absorb_final( s: &mut KeccakState4, data0: &[u8], @@ -366,7 +366,7 @@ pub fn shake128x4_squeeze_first_three_blocks( shake128_squeeze_first_three_blocks(&mut s3, out3); } -#[cfg(feature = "simd128")] +#[cfg(feature = "simd256")] pub fn shake128x4_squeeze_next_block( s: &mut KeccakState4, out0: &mut [u8], From 1b66ac4b2a7b4a1d0935b0e5c631c5f6b57e9f2a Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 11:50:56 +0200 Subject: [PATCH 19/59] fix av2 build --- libcrux-sha3/src/rust_simd.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index 9aeba2a76..a32c434dc 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -374,7 +374,7 @@ pub fn shake128x4_squeeze_next_block( out2: &mut [u8], out3: &mut [u8], ) { - squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>(&mut s0, [out0, out1, out2, out3]); + squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>(s, [out0, out1, out2, out3]); } #[cfg(feature = "simd128")] pub fn shake128x4_squeeze_next_block( From cd5aa4d76b67746a841e47e71de2cf351a9b0257 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 12:00:51 +0200 Subject: [PATCH 20/59] fix avx2 and allow portable --- libcrux-sha3/src/rust_simd.rs | 6 ++++-- libcrux-sha3/src/rust_simd/sha3_avx2.rs | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index a32c434dc..fd862da26 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -71,7 +71,8 @@ pub fn sha3_256(data: &[u8]) -> [u8; 32] { d0 } -pub fn sha3_256_portable(data: &[u8]) -> [u8; 32] { +#[cfg(not(any(feature = "simd256", feature = "simd128")))] +pub fn sha3_256(data: &[u8]) -> [u8; 32] { let mut d0 = [0u8; 32]; keccakx1::<136, 0x06u8>([data], [&mut d0]); d0 @@ -112,7 +113,8 @@ pub fn sha3_512(data: &[u8]) -> [u8; 64] { d0 } -pub fn sha3_512_portable(data: &[u8]) -> [u8; 64] { +#[cfg(not(any(feature = "simd256", feature = "simd128")))] +pub fn sha3_512(data: &[u8]) -> [u8; 64] { let mut d0 = [0u8; 64]; keccakx1::<72, 0x06u8>([data], [&mut d0]); d0 diff --git a/libcrux-sha3/src/rust_simd/sha3_avx2.rs b/libcrux-sha3/src/rust_simd/sha3_avx2.rs index a9856a7aa..8146ccec6 100644 --- a/libcrux-sha3/src/rust_simd/sha3_avx2.rs +++ b/libcrux-sha3/src/rust_simd/sha3_avx2.rs @@ -127,7 +127,7 @@ pub(crate) fn store_block(s: &[[__m256i; 5]; 5], out: [&mut [ let v1h = unsafe { _mm256_permute2x128_si256( s[(4 * i + 1) / 5][(4 * i + 1) % 5], - s[(4 * i + 3) / 5][(4 * 3 + 1) % 5], + s[(4 * i + 3) / 5][(4 * i + 3) % 5], 0x20, ) }; // 1 1 3 3 @@ -141,7 +141,7 @@ pub(crate) fn store_block(s: &[[__m256i; 5]; 5], out: [&mut [ let v3h = unsafe { _mm256_permute2x128_si256( s[(4 * i + 1) / 5][(4 * i + 1) % 5], - s[(4 * i + 3) / 5][(4 * 3 + 1) % 5], + s[(4 * i + 3) / 5][(4 * i + 3) % 5], 0x31, ) }; // 1 1 3 3 From 7103d4a7cfc4dca83ae361f4098b62d083fdd632 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 12:21:25 +0200 Subject: [PATCH 21/59] simd256 hash functions --- libcrux-ml-kem/src/hash_functions.rs | 149 ++++++++++++++++++++++++++- 1 file changed, 145 insertions(+), 4 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 43293bae5..032785818 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -76,7 +76,7 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { states } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] pub(crate) fn absorb( input: [[u8; 34]; K], @@ -89,6 +89,45 @@ pub(crate) fn absorb( states } +#[cfg(feature = "simd256")] +#[inline(always)] +pub(crate) fn absorb(input: [[u8; 34]; K]) -> KeccakState4 { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut states = rust_simd::shake128x4_init(); + + match K { + 2 => { + rust_simd::shake128x4_absorb_final( + &mut states, + &input[0], + &input[1], + &input[0], + &input[0], + ); + } + 3 => { + rust_simd::shake128x4_absorb_final( + &mut states, + &input[0], + &input[1], + &input[2], + &input[0], + ); + } + 4 => { + rust_simd::shake128x4_absorb_final( + &mut states, + &input[0], + &input[1], + &input[2], + &input[3], + ); + } + _ => unreachable!(), + } + states +} + pub(crate) const BLOCK_SIZE: usize = 168; pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; @@ -142,7 +181,7 @@ pub(crate) fn squeeze_three_blocks( out } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] pub(crate) fn squeeze_three_blocks( state: &mut [libcrux_sha3::rust_simd::KeccakState1], @@ -154,6 +193,54 @@ pub(crate) fn squeeze_three_blocks( out } +#[cfg(feature = "simd256")] +#[inline(always)] +pub(crate) fn squeeze_three_blocks( + state: &mut KeccakState4, +) -> [[u8; THREE_BLOCKS]; K] { + let mut out = [[0u8; THREE_BLOCKS]; K]; + let mut dummy_out0 = [0u8; THREE_BLOCKS]; + let mut dummy_out1 = [0u8; THREE_BLOCKS]; + + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + rust_simd::shake128x4_squeeze_first_three_blocks( + state, + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + rust_simd::shake128x4_squeeze_first_three_blocks( + state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + rust_simd::shake128x4_squeeze_first_three_blocks( + state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!(), + } + out +} + #[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8; BLOCK_SIZE]; K] { @@ -189,7 +276,7 @@ pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8 out } -#[cfg(not(feature = "simd128"))] +#[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] pub(crate) fn squeeze_block( state: &mut [libcrux_sha3::rust_simd::KeccakState1; K], @@ -201,9 +288,63 @@ pub(crate) fn squeeze_block( out } +#[cfg(feature = "simd256")] +#[inline(always)] +pub(crate) fn squeeze_block(state: &mut KeccakState4) -> [[u8; BLOCK_SIZE]; K] { + let mut dummy_out0 = [0u8; BLOCK_SIZE]; + let mut dummy_out1 = [0u8; BLOCK_SIZE]; + + let mut out = [[0u8; BLOCK_SIZE]; K]; + + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + rust_simd::shake128x4_squeeze_next_block( + state, + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + rust_simd::shake128x4_squeeze_next_block( + state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + rust_simd::shake128x4_squeeze_next_block( + state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!(), + } + out +} + /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. -#[cfg(not(feature = "simd128"))] +#[cfg(feature = "simd256")] +#[inline(always)] +pub(crate) fn free_state(_xof_state: KeccakState4) {} + +/// Free the memory of the state. +/// +/// **NOTE:** That this needs to be done manually for now. +#[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] pub(crate) fn free_state(_xof_state: [libcrux_sha3::rust_simd::KeccakState1; K]) {} From 0376b125c82fe699729ae1494e1e0ca808bebe71 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 12:31:29 +0200 Subject: [PATCH 22/59] drop other sha3 code --- libcrux-sha3/src/hacl.rs | 3 - libcrux-sha3/src/hacl/hash_sha3.rs | 644 ----------------------- libcrux-sha3/src/hacl/streaming_types.rs | 39 -- libcrux-sha3/src/lib.rs | 94 ---- libcrux-sha3/src/lowstar.rs | 2 - libcrux-sha3/src/lowstar/endianness.rs | 53 -- libcrux-sha3/src/lowstar/ignore.rs | 1 - libcrux-sha3/src/x4.rs | 61 --- libcrux-sha3/src/x4/internal.rs | 337 ------------ 9 files changed, 1234 deletions(-) delete mode 100644 libcrux-sha3/src/hacl.rs delete mode 100644 libcrux-sha3/src/hacl/hash_sha3.rs delete mode 100644 libcrux-sha3/src/hacl/streaming_types.rs delete mode 100644 libcrux-sha3/src/lowstar.rs delete mode 100644 libcrux-sha3/src/lowstar/endianness.rs delete mode 100644 libcrux-sha3/src/lowstar/ignore.rs delete mode 100644 libcrux-sha3/src/x4.rs delete mode 100644 libcrux-sha3/src/x4/internal.rs diff --git a/libcrux-sha3/src/hacl.rs b/libcrux-sha3/src/hacl.rs deleted file mode 100644 index 7fa89d0d1..000000000 --- a/libcrux-sha3/src/hacl.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod streaming_types; - -pub(crate) mod hash_sha3; diff --git a/libcrux-sha3/src/hacl/hash_sha3.rs b/libcrux-sha3/src/hacl/hash_sha3.rs deleted file mode 100644 index c54d0458b..000000000 --- a/libcrux-sha3/src/hacl/hash_sha3.rs +++ /dev/null @@ -1,644 +0,0 @@ -#![allow(non_snake_case)] -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(unused_assignments)] -#![allow(unused_mut)] -#![allow(unreachable_patterns)] -#![allow(const_item_mutation)] - -#[inline(always)] -fn block_len(a: crate::hacl::streaming_types::hash_alg) -> u32 { - match a { - crate::hacl::streaming_types::hash_alg::SHA3_224 => 144u32, - crate::hacl::streaming_types::hash_alg::SHA3_256 => 136u32, - crate::hacl::streaming_types::hash_alg::SHA3_384 => 104u32, - crate::hacl::streaming_types::hash_alg::SHA3_512 => 72u32, - crate::hacl::streaming_types::hash_alg::Shake128 => 168u32, - crate::hacl::streaming_types::hash_alg::Shake256 => 136u32, - _ => panic!("Precondition of the function most likely violated"), - } -} - -#[inline(always)] -fn hash_len(a: crate::hacl::streaming_types::hash_alg) -> u32 { - match a { - crate::hacl::streaming_types::hash_alg::SHA3_224 => 28u32, - crate::hacl::streaming_types::hash_alg::SHA3_256 => 32u32, - crate::hacl::streaming_types::hash_alg::SHA3_384 => 48u32, - crate::hacl::streaming_types::hash_alg::SHA3_512 => 64u32, - _ => panic!("Precondition of the function most likely violated"), - } -} - -#[inline(always)] -pub fn update_multi_sha3( - a: crate::hacl::streaming_types::hash_alg, - s: &mut [u64], - blocks: &mut [u8], - n_blocks: u32, -) -> () { - for i in 0u32..n_blocks { - let block: (&mut [u8], &mut [u8]) = - blocks.split_at_mut(i.wrapping_mul(block_len(a)) as usize); - absorb_inner(block_len(a), block.1, s) - } -} - -#[inline(always)] -pub fn update_last_sha3( - a: crate::hacl::streaming_types::hash_alg, - s: &mut [u64], - input: &mut [u8], - input_len: u32, -) -> () { - let suffix: u8 = if a == crate::hacl::streaming_types::hash_alg::Shake128 - || a == crate::hacl::streaming_types::hash_alg::Shake256 - { - 0x1fu8 - } else { - 0x06u8 - }; - let len: u32 = block_len(a); - if input_len == len { - absorb_inner(len, input, s); - let mut lastBlock_: [u8; 200] = [0u8; 200usize]; - let lastBlock: (&mut [u8], &mut [u8]) = (&mut lastBlock_).split_at_mut(0usize); - (lastBlock.1[0usize..0usize]) - .copy_from_slice(&(&mut input[input_len as usize..])[0usize..0usize]); - lastBlock.1[0usize] = suffix; - loadState(len, lastBlock.1, s); - if !(suffix & 0x80u8 == 0u8) && 0u32 == len.wrapping_sub(1u32) { - state_permute(s) - }; - let mut nextBlock_: [u8; 200] = [0u8; 200usize]; - let nextBlock: (&mut [u8], &mut [u8]) = (&mut nextBlock_).split_at_mut(0usize); - nextBlock.1[len.wrapping_sub(1u32) as usize] = 0x80u8; - loadState(len, nextBlock.1, s); - state_permute(s) - } else { - let mut lastBlock_: [u8; 200] = [0u8; 200usize]; - let lastBlock: (&mut [u8], &mut [u8]) = (&mut lastBlock_).split_at_mut(0usize); - (lastBlock.1[0usize..input_len as usize]) - .copy_from_slice(&input[0usize..input_len as usize]); - lastBlock.1[input_len as usize] = suffix; - loadState(len, lastBlock.1, s); - if !(suffix & 0x80u8 == 0u8) && input_len == len.wrapping_sub(1u32) { - state_permute(s) - }; - let mut nextBlock_: [u8; 200] = [0u8; 200usize]; - let nextBlock: (&mut [u8], &mut [u8]) = (&mut nextBlock_).split_at_mut(0usize); - nextBlock.1[len.wrapping_sub(1u32) as usize] = 0x80u8; - loadState(len, nextBlock.1, s); - state_permute(s) - } -} - -pub struct hash_buf { - pub fst: crate::hacl::streaming_types::hash_alg, - pub snd: Vec, -} - -pub struct state_t { - pub block_state: hash_buf, - pub buf: Vec, - pub total_len: u64, -} - -#[inline(always)] -pub fn get_alg(s: &mut [state_t]) -> crate::hacl::streaming_types::hash_alg { - let mut block_state: &mut hash_buf = &mut (s[0usize]).block_state; - (*block_state).fst -} - -#[inline(always)] -pub fn malloc(a: crate::hacl::streaming_types::hash_alg) -> Vec { - let mut buf: Vec = vec![0u8; block_len(a) as usize]; - let mut buf0: Vec = vec![0u64; 25usize]; - let mut block_state: hash_buf = hash_buf { fst: a, snd: buf0 }; - let s: &mut [u64] = &mut block_state.snd; - (s[0usize..25usize]).copy_from_slice(&[0u64; 25usize]); - let mut s0: state_t = state_t { - block_state: block_state, - buf: buf, - total_len: 0u32 as u64, - }; - let mut p: Vec = { - let mut tmp: Vec = Vec::new(); - tmp.push(s0); - tmp - }; - p -} - -#[inline(always)] -pub fn copy(state: &mut [state_t]) -> Vec { - let mut block_state0: &mut hash_buf = &mut (state[0usize]).block_state; - let buf0: &mut [u8] = &mut (state[0usize]).buf; - let total_len0: u64 = (state[0usize]).total_len; - let i: crate::hacl::streaming_types::hash_alg = (*block_state0).fst; - let mut buf: Vec = vec![0u8; block_len(i) as usize]; - ((&mut buf)[0usize..block_len(i) as usize]) - .copy_from_slice(&buf0[0usize..block_len(i) as usize]); - let mut buf1: Vec = vec![0u64; 25usize]; - let mut block_state: hash_buf = hash_buf { fst: i, snd: buf1 }; - let s_src: &mut [u64] = &mut (*block_state0).snd; - let s_dst: &mut [u64] = &mut block_state.snd; - (s_dst[0usize..25usize]).copy_from_slice(&s_src[0usize..25usize]); - let mut s: state_t = state_t { - block_state: block_state, - buf: buf, - total_len: total_len0, - }; - let mut p: Vec = { - let mut tmp: Vec = Vec::new(); - tmp.push(s); - tmp - }; - p -} - -#[inline(always)] -pub fn reset(state: &mut [state_t]) -> () { - let mut block_state: &mut hash_buf = &mut (state[0usize]).block_state; - let i: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - crate::lowstar::ignore::ignore::(i); - let s: &mut [u64] = &mut (*block_state).snd; - (s[0usize..25usize]).copy_from_slice(&[0u64; 25usize]); - (state[0usize]).total_len = 0u32 as u64 -} - -#[inline(always)] -pub fn update( - state: &mut [state_t], - chunk: &mut [u8], - chunk_len: u32, -) -> crate::hacl::streaming_types::error_code { - let mut block_state: &mut hash_buf = &mut (state[0usize]).block_state; - let total_len: u64 = (state[0usize]).total_len; - let i: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - if chunk_len as u64 > 0xFFFFFFFFFFFFFFFFu64.wrapping_sub(total_len) { - crate::hacl::streaming_types::error_code::MaximumLengthExceeded - } else { - let sz: u32 = if total_len.wrapping_rem(block_len(i) as u64) == 0u64 && total_len > 0u64 { - block_len(i) - } else { - total_len.wrapping_rem(block_len(i) as u64) as u32 - }; - if chunk_len <= (block_len(i)).wrapping_sub(sz) { - let buf: &mut [u8] = &mut (state[0usize]).buf; - let total_len1: u64 = (state[0usize]).total_len; - let sz1: u32 = - if total_len1.wrapping_rem(block_len(i) as u64) == 0u64 && total_len1 > 0u64 { - block_len(i) - } else { - total_len1.wrapping_rem(block_len(i) as u64) as u32 - }; - let buf2: (&mut [u8], &mut [u8]) = buf.split_at_mut(sz1 as usize); - (buf2.1[0usize..chunk_len as usize]) - .copy_from_slice(&chunk[0usize..chunk_len as usize]); - let total_len2: u64 = total_len1.wrapping_add(chunk_len as u64); - (state[0usize]).total_len = total_len2 - } else if sz == 0u32 { - let buf: &mut [u8] = &mut (state[0usize]).buf; - let total_len1: u64 = (state[0usize]).total_len; - let sz1: u32 = - if total_len1.wrapping_rem(block_len(i) as u64) == 0u64 && total_len1 > 0u64 { - block_len(i) - } else { - total_len1.wrapping_rem(block_len(i) as u64) as u32 - }; - if !(sz1 == 0u32) { - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, buf, (block_len(i)).wrapping_div(block_len(a1))) - }; - let ite: u32 = if (chunk_len as u64).wrapping_rem(block_len(i) as u64) == 0u64 - && chunk_len as u64 > 0u64 - { - block_len(i) - } else { - (chunk_len as u64).wrapping_rem(block_len(i) as u64) as u32 - }; - let n_blocks: u32 = chunk_len.wrapping_sub(ite).wrapping_div(block_len(i)); - let data1_len: u32 = n_blocks.wrapping_mul(block_len(i)); - let data2_len: u32 = chunk_len.wrapping_sub(data1_len); - let data1: (&mut [u8], &mut [u8]) = chunk.split_at_mut(0usize); - let data2: (&mut [u8], &mut [u8]) = data1.1.split_at_mut(data1_len as usize); - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, data2.0, data1_len.wrapping_div(block_len(a1))); - let dst: (&mut [u8], &mut [u8]) = buf.split_at_mut(0usize); - (dst.1[0usize..data2_len as usize]) - .copy_from_slice(&data2.1[0usize..data2_len as usize]); - (state[0usize]).total_len = total_len1.wrapping_add(chunk_len as u64) - } else { - let diff: u32 = (block_len(i)).wrapping_sub(sz); - let chunk1: (&mut [u8], &mut [u8]) = chunk.split_at_mut(0usize); - let chunk2: (&mut [u8], &mut [u8]) = chunk1.1.split_at_mut(diff as usize); - let buf: &mut [u8] = &mut (state[0usize]).buf; - let total_len1: u64 = (state[0usize]).total_len; - let sz1: u32 = - if total_len1.wrapping_rem(block_len(i) as u64) == 0u64 && total_len1 > 0u64 { - block_len(i) - } else { - total_len1.wrapping_rem(block_len(i) as u64) as u32 - }; - let buf2: (&mut [u8], &mut [u8]) = buf.split_at_mut(sz1 as usize); - (buf2.1[0usize..diff as usize]).copy_from_slice(&chunk2.0[0usize..diff as usize]); - let total_len2: u64 = total_len1.wrapping_add(diff as u64); - (state[0usize]).total_len = total_len2; - let buf0: &mut [u8] = &mut (state[0usize]).buf; - let total_len10: u64 = (state[0usize]).total_len; - let sz10: u32 = - if total_len10.wrapping_rem(block_len(i) as u64) == 0u64 && total_len10 > 0u64 { - block_len(i) - } else { - total_len10.wrapping_rem(block_len(i) as u64) as u32 - }; - if !(sz10 == 0u32) { - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, buf0, (block_len(i)).wrapping_div(block_len(a1))) - }; - let ite: u32 = - if (chunk_len.wrapping_sub(diff) as u64).wrapping_rem(block_len(i) as u64) == 0u64 - && chunk_len.wrapping_sub(diff) as u64 > 0u64 - { - block_len(i) - } else { - (chunk_len.wrapping_sub(diff) as u64).wrapping_rem(block_len(i) as u64) as u32 - }; - let n_blocks: u32 = chunk_len - .wrapping_sub(diff) - .wrapping_sub(ite) - .wrapping_div(block_len(i)); - let data1_len: u32 = n_blocks.wrapping_mul(block_len(i)); - let data2_len: u32 = chunk_len.wrapping_sub(diff).wrapping_sub(data1_len); - let data1: (&mut [u8], &mut [u8]) = chunk2.1.split_at_mut(0usize); - let data2: (&mut [u8], &mut [u8]) = data1.1.split_at_mut(data1_len as usize); - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, data2.0, data1_len.wrapping_div(block_len(a1))); - let dst: (&mut [u8], &mut [u8]) = buf0.split_at_mut(0usize); - (dst.1[0usize..data2_len as usize]) - .copy_from_slice(&data2.1[0usize..data2_len as usize]); - (state[0usize]).total_len = - total_len10.wrapping_add(chunk_len.wrapping_sub(diff) as u64) - }; - crate::hacl::streaming_types::error_code::Success - } -} - -#[inline(always)] -fn digest_( - a: crate::hacl::streaming_types::hash_alg, - state: &mut [state_t], - output: &mut [u8], - l: u32, -) -> () { - let mut block_state: &mut hash_buf = &mut (state[0usize]).block_state; - let buf_: &mut [u8] = &mut (state[0usize]).buf; - let total_len: u64 = (state[0usize]).total_len; - let r: u32 = if total_len.wrapping_rem(block_len(a) as u64) == 0u64 && total_len > 0u64 { - block_len(a) - } else { - total_len.wrapping_rem(block_len(a) as u64) as u32 - }; - let buf_1: (&mut [u8], &mut [u8]) = buf_.split_at_mut(0usize); - let mut buf: [u64; 25] = [0u64; 25usize]; - let mut tmp_block_state: hash_buf = hash_buf { - fst: a, - snd: Vec::from(buf), - }; - let s_src: &mut [u64] = &mut (*block_state).snd; - let s_dst: &mut [u64] = &mut tmp_block_state.snd; - (s_dst[0usize..25usize]).copy_from_slice(&s_src[0usize..25usize]); - let buf_multi: (&mut [u8], &mut [u8]) = buf_1.1.split_at_mut(0usize); - let ite: u32 = if r.wrapping_rem(block_len(a)) == 0u32 && r > 0u32 { - block_len(a) - } else { - r.wrapping_rem(block_len(a)) - }; - let buf_last: (&mut [u8], &mut [u8]) = buf_multi.1.split_at_mut(r.wrapping_sub(ite) as usize); - let a1: crate::hacl::streaming_types::hash_alg = tmp_block_state.fst; - let s: &mut [u64] = &mut tmp_block_state.snd; - update_multi_sha3(a1, s, buf_last.0, 0u32.wrapping_div(block_len(a1))); - let a10: crate::hacl::streaming_types::hash_alg = tmp_block_state.fst; - let s0: &mut [u64] = &mut tmp_block_state.snd; - update_last_sha3(a10, s0, buf_last.1, r); - let a11: crate::hacl::streaming_types::hash_alg = tmp_block_state.fst; - let s1: &mut [u64] = &mut tmp_block_state.snd; - if a11 == crate::hacl::streaming_types::hash_alg::Shake128 - || a11 == crate::hacl::streaming_types::hash_alg::Shake256 - { - squeeze0(s1, block_len(a11), l, output) - } else { - squeeze0(s1, block_len(a11), hash_len(a11), output) - } -} - -#[inline(always)] -pub fn digest( - state: &mut [state_t], - output: &mut [u8], -) -> crate::hacl::streaming_types::error_code { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(state); - if a1 == crate::hacl::streaming_types::hash_alg::Shake128 - || a1 == crate::hacl::streaming_types::hash_alg::Shake256 - { - crate::hacl::streaming_types::error_code::InvalidAlgorithm - } else { - digest_(a1, state, output, hash_len(a1)); - crate::hacl::streaming_types::error_code::Success - } -} - -#[inline(always)] -pub fn squeeze( - s: &mut [state_t], - dst: &mut [u8], - l: u32, -) -> crate::hacl::streaming_types::error_code { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(s); - if !(a1 == crate::hacl::streaming_types::hash_alg::Shake128 - || a1 == crate::hacl::streaming_types::hash_alg::Shake256) - { - crate::hacl::streaming_types::error_code::InvalidAlgorithm - } else if l == 0u32 { - crate::hacl::streaming_types::error_code::InvalidLength - } else { - digest_(a1, s, dst, l); - crate::hacl::streaming_types::error_code::Success - } -} - -#[inline(always)] -pub fn block_len0(s: &mut [state_t]) -> u32 { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(s); - block_len(a1) -} - -#[inline(always)] -pub fn hash_len0(s: &mut [state_t]) -> u32 { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(s); - hash_len(a1) -} - -#[inline(always)] -pub fn is_shake(s: &mut [state_t]) -> bool { - let uu____0: crate::hacl::streaming_types::hash_alg = get_alg(s); - uu____0 == crate::hacl::streaming_types::hash_alg::Shake128 - || uu____0 == crate::hacl::streaming_types::hash_alg::Shake256 -} - -#[inline(always)] -pub fn shake128_hacl( - inputByteLen: u32, - input: &mut [u8], - outputByteLen: u32, - output: &mut [u8], -) -> () { - keccak( - 1344u32, - 256u32, - inputByteLen, - input, - 0x1Fu8, - outputByteLen, - output, - ) -} - -#[inline(always)] -pub fn shake256_hacl( - inputByteLen: u32, - input: &mut [u8], - outputByteLen: u32, - output: &mut [u8], -) -> () { - keccak( - 1088u32, - 512u32, - inputByteLen, - input, - 0x1Fu8, - outputByteLen, - output, - ) -} - -#[inline(always)] -pub fn sha3_224(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(1152u32, 448u32, input_len, input, 0x06u8, 28u32, output) -} - -#[inline(always)] -pub fn sha3_256(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(1088u32, 512u32, input_len, input, 0x06u8, 32u32, output) -} - -#[inline(always)] -pub fn sha3_384(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(832u32, 768u32, input_len, input, 0x06u8, 48u32, output) -} - -#[inline(always)] -pub fn sha3_512(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(576u32, 1024u32, input_len, input, 0x06u8, 64u32, output) -} - -const keccak_rotc: [u32; 24] = [ - 1u32, 3u32, 6u32, 10u32, 15u32, 21u32, 28u32, 36u32, 45u32, 55u32, 2u32, 14u32, 27u32, 41u32, - 56u32, 8u32, 25u32, 43u32, 62u32, 18u32, 39u32, 61u32, 20u32, 44u32, -]; - -const keccak_piln: [u32; 24] = [ - 10u32, 7u32, 11u32, 17u32, 18u32, 3u32, 5u32, 16u32, 8u32, 21u32, 24u32, 4u32, 15u32, 23u32, - 19u32, 13u32, 12u32, 2u32, 20u32, 14u32, 22u32, 9u32, 6u32, 1u32, -]; - -const keccak_rndc: [u64; 24] = [ - 0x0000000000000001u64, - 0x0000000000008082u64, - 0x800000000000808au64, - 0x8000000080008000u64, - 0x000000000000808bu64, - 0x0000000080000001u64, - 0x8000000080008081u64, - 0x8000000000008009u64, - 0x000000000000008au64, - 0x0000000000000088u64, - 0x0000000080008009u64, - 0x000000008000000au64, - 0x000000008000808bu64, - 0x800000000000008bu64, - 0x8000000000008089u64, - 0x8000000000008003u64, - 0x8000000000008002u64, - 0x8000000000000080u64, - 0x000000000000800au64, - 0x800000008000000au64, - 0x8000000080008081u64, - 0x8000000000008080u64, - 0x0000000080000001u64, - 0x8000000080008008u64, -]; - -#[inline(always)] -pub fn state_permute(s: &mut [u64]) -> () { - for i in 0u32..24u32 { - let mut _C: [u64; 5] = [0u64; 5usize]; - for i0 in 0u32..5u32 { - (&mut _C)[i0 as usize] = s[i0.wrapping_add(0u32) as usize] - ^ (s[i0.wrapping_add(5u32) as usize] - ^ (s[i0.wrapping_add(10u32) as usize] - ^ (s[i0.wrapping_add(15u32) as usize] - ^ s[i0.wrapping_add(20u32) as usize]))) - } - for i0 in 0u32..5u32 { - let uu____0: u64 = (&mut _C)[i0.wrapping_add(1u32).wrapping_rem(5u32) as usize]; - let _D: u64 = (&mut _C)[i0.wrapping_add(4u32).wrapping_rem(5u32) as usize] - ^ (uu____0.wrapping_shl(1u32) | uu____0.wrapping_shr(63u32)); - for i1 in 0u32..5u32 { - s[i0.wrapping_add(5u32.wrapping_mul(i1)) as usize] = - s[i0.wrapping_add(5u32.wrapping_mul(i1)) as usize] ^ _D - } - } - let x: u64 = s[1usize]; - let mut current: [u64; 1] = [x; 1usize]; - for i0 in 0u32..24u32 { - let _Y: u32 = (&keccak_piln)[i0 as usize]; - let r: u32 = (&keccak_rotc)[i0 as usize]; - let temp: u64 = s[_Y as usize]; - let uu____1: u64 = (&mut current)[0usize]; - s[_Y as usize] = uu____1.wrapping_shl(r) | uu____1.wrapping_shr(64u32.wrapping_sub(r)); - (&mut current)[0usize] = temp - } - for i0 in 0u32..5u32 { - let v0: u64 = s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v1: u64 = s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v2: u64 = s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v3: u64 = s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v4: u64 = s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v0; - s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v1; - s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v2; - s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v3; - s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v4 - } - let c: u64 = (&keccak_rndc)[i as usize]; - s[0usize] = s[0usize] ^ c - } -} - -#[inline(always)] -pub fn loadState(rateInBytes: u32, input: &mut [u8], s: &mut [u64]) -> () { - let mut block: [u8; 200] = [0u8; 200usize]; - ((&mut block)[0usize..rateInBytes as usize]) - .copy_from_slice(&input[0usize..rateInBytes as usize]); - for i in 0u32..25u32 { - let u: u64 = crate::lowstar::endianness::load64_le( - &mut (&mut block)[i.wrapping_mul(8u32) as usize..], - ); - let x: u64 = u; - s[i as usize] = s[i as usize] ^ x - } -} - -#[inline(always)] -fn storeState(rateInBytes: u32, s: &mut [u64], res: &mut [u8]) -> () { - let mut block: [u8; 200] = [0u8; 200usize]; - for i in 0u32..25u32 { - let sj: u64 = s[i as usize]; - crate::lowstar::endianness::store64_le( - &mut (&mut block)[i.wrapping_mul(8u32) as usize..], - sj, - ) - } - (res[0usize..rateInBytes as usize]) - .copy_from_slice(&(&mut (&mut block)[0usize..])[0usize..rateInBytes as usize]) -} - -#[inline(always)] -pub fn absorb_inner(rateInBytes: u32, block: &mut [u8], s: &mut [u64]) -> () { - loadState(rateInBytes, block, s); - state_permute(s) -} - -#[inline(always)] -fn absorb( - s: &mut [u64], - rateInBytes: u32, - inputByteLen: u32, - input: &mut [u8], - delimitedSuffix: u8, -) -> () { - let n_blocks: u32 = inputByteLen.wrapping_div(rateInBytes); - let rem: u32 = inputByteLen.wrapping_rem(rateInBytes); - for i in 0u32..n_blocks { - let block: (&mut [u8], &mut [u8]) = - input.split_at_mut(i.wrapping_mul(rateInBytes) as usize); - absorb_inner(rateInBytes, block.1, s) - } - let last: (&mut [u8], &mut [u8]) = - input.split_at_mut(n_blocks.wrapping_mul(rateInBytes) as usize); - let mut lastBlock_: [u8; 200] = [0u8; 200usize]; - let lastBlock: (&mut [u8], &mut [u8]) = (&mut lastBlock_).split_at_mut(0usize); - (lastBlock.1[0usize..rem as usize]).copy_from_slice(&last.1[0usize..rem as usize]); - lastBlock.1[rem as usize] = delimitedSuffix; - loadState(rateInBytes, lastBlock.1, s); - if !(delimitedSuffix & 0x80u8 == 0u8) && rem == rateInBytes.wrapping_sub(1u32) { - state_permute(s) - }; - let mut nextBlock_: [u8; 200] = [0u8; 200usize]; - let nextBlock: (&mut [u8], &mut [u8]) = (&mut nextBlock_).split_at_mut(0usize); - nextBlock.1[rateInBytes.wrapping_sub(1u32) as usize] = 0x80u8; - loadState(rateInBytes, nextBlock.1, s); - state_permute(s) -} - -#[inline(always)] -pub fn squeeze0(s: &mut [u64], rateInBytes: u32, outputByteLen: u32, output: &mut [u8]) -> () { - let outBlocks: u32 = outputByteLen.wrapping_div(rateInBytes); - let remOut: u32 = outputByteLen.wrapping_rem(rateInBytes); - let blocks: (&mut [u8], &mut [u8]) = output.split_at_mut(0usize); - let last: (&mut [u8], &mut [u8]) = blocks - .1 - .split_at_mut(outputByteLen.wrapping_sub(remOut) as usize); - for i in 0u32..outBlocks { - storeState( - rateInBytes, - s, - &mut last.0[i.wrapping_mul(rateInBytes) as usize..], - ); - state_permute(s) - } - storeState(remOut, s, last.1) -} - -#[inline(always)] -pub fn keccak( - rate: u32, - capacity: u32, - inputByteLen: u32, - input: &mut [u8], - delimitedSuffix: u8, - outputByteLen: u32, - output: &mut [u8], -) -> () { - crate::lowstar::ignore::ignore::(capacity); - let rateInBytes: u32 = rate.wrapping_div(8u32); - let mut s: [u64; 25] = [0u64; 25usize]; - absorb(&mut s, rateInBytes, inputByteLen, input, delimitedSuffix); - squeeze0(&mut s, rateInBytes, outputByteLen, output) -} diff --git a/libcrux-sha3/src/hacl/streaming_types.rs b/libcrux-sha3/src/hacl/streaming_types.rs deleted file mode 100644 index 559af9cdb..000000000 --- a/libcrux-sha3/src/hacl/streaming_types.rs +++ /dev/null @@ -1,39 +0,0 @@ -#![allow(non_snake_case)] -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(unused_assignments)] -#![allow(unused_mut)] -#![allow(unreachable_patterns)] -#![allow(const_item_mutation)] - -#[derive(PartialEq, Clone, Copy)] -pub enum hash_alg -{ - SHA2_224, - SHA2_256, - SHA2_384, - SHA2_512, - SHA1, - MD5, - Blake2S, - Blake2B, - SHA3_256, - SHA3_224, - SHA3_384, - SHA3_512, - Shake128, - Shake256 -} - -#[derive(PartialEq, Clone, Copy)] -pub enum error_code -{ - Success, - InvalidAlgorithm, - InvalidLength, - MaximumLengthExceeded -} - -pub struct state_32 { pub block_state: Vec, pub buf: Vec, pub total_len: u64 } - -pub struct state_64 { pub block_state: Vec, pub buf: Vec, pub total_len: u64 } diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 4f17cf9ac..545acc9e1 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -1,17 +1,6 @@ // XXX: Can't do no_std // #![no_std] -// // Low* library code -// mod lowstar; - -// // SHA3 plus helpers -// mod hacl; -// use crate::hacl::hash_sha3::{self, shake128_hacl, shake256_hacl}; - -/// A Sha3x4 API -pub mod x4; - -//#[cfg(feature = "simd128")] pub mod rust_simd; pub type Sha3_224Digest = [u8; 28]; @@ -205,86 +194,3 @@ pub fn shake256(data: &[u8]) -> [u8; BYTES] { } out } - -// mod pure { - -// /// SHA3 224 -// pub fn sha3_224(payload: &[u8]) -> Sha3_224Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 28]; - -// hash_sha3::sha3_224(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHA3 256 -// pub fn sha3_256(payload: &[u8]) -> Sha3_256Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 32]; - -// hash_sha3::sha3_256(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHA3 384 -// pub fn sha3_384(payload: &[u8]) -> Sha3_384Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 48]; - -// hash_sha3::sha3_384(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHA3 512 -// pub fn sha3_512(payload: &[u8]) -> Sha3_512Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 64]; - -// hash_sha3::sha3_512(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHAKE 128 -// /// -// /// The caller must define the size of the output in the return type. -// pub fn shake128(data: &[u8]) -> [u8; LEN] { -// debug_assert!(LEN <= u32::MAX as usize && data.len() <= u32::MAX as usize); -// let data = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(data.as_ptr() as *mut u8, data.len())) -// }; -// let mut out = [0u8; LEN]; -// shake128_hacl(data.len() as u32, data, LEN as u32, &mut out); - -// out -// } - -// /// SHAKE 256 -// /// -// /// The caller must define the size of the output in the return type. -// pub fn shake256(data: &[u8]) -> [u8; LEN] { -// debug_assert!(LEN <= u32::MAX as usize && data.len() <= u32::MAX as usize); -// let data = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(data.as_ptr() as *mut u8, data.len())) -// }; -// let mut out = [0u8; LEN]; -// shake256_hacl(data.len() as u32, data, LEN as u32, &mut out); - -// out -// } -// } diff --git a/libcrux-sha3/src/lowstar.rs b/libcrux-sha3/src/lowstar.rs deleted file mode 100644 index f63af5cbe..000000000 --- a/libcrux-sha3/src/lowstar.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod endianness; -pub mod ignore; diff --git a/libcrux-sha3/src/lowstar/endianness.rs b/libcrux-sha3/src/lowstar/endianness.rs deleted file mode 100644 index 00d3ea9c5..000000000 --- a/libcrux-sha3/src/lowstar/endianness.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::convert::TryInto; - -// Little Endian - -pub fn load16_le(bytes: &[u8]) -> u16 { - u16::from_le_bytes(bytes[0..2].try_into().unwrap()) -} - -pub fn store16_le(bytes: &mut[u8], x: u16) { - bytes[0..2].copy_from_slice(&u16::to_le_bytes(x)) -} - -pub fn load32_le(bytes: &[u8]) -> u32 { - u32::from_le_bytes(bytes[0..4].try_into().unwrap()) -} - -pub fn store32_le(bytes: &mut[u8], x: u32) { - bytes[0..4].copy_from_slice(&u32::to_le_bytes(x)) -} - -pub fn load64_le(bytes: &[u8]) -> u64 { - u64::from_le_bytes(bytes[0..8].try_into().unwrap()) -} - -pub fn store64_le(bytes: &mut[u8], x: u64) { - bytes[0..8].copy_from_slice(&u64::to_le_bytes(x)) -} - -// Big Endian - -pub fn load32_be(bytes: &[u8]) -> u32 { - u32::from_be_bytes(bytes[0..4].try_into().unwrap()) -} - -pub fn store32_be(bytes: &mut[u8], x: u32) { - bytes[0..4].copy_from_slice(&u32::to_be_bytes(x)) -} - -pub fn load64_be(bytes: &[u8]) -> u64 { - u64::from_be_bytes(bytes[0..8].try_into().unwrap()) -} - -pub fn store64_be(bytes: &mut[u8], x: u64) { - bytes[0..8].copy_from_slice(&u64::to_be_bytes(x)) -} - -pub fn load128_be(bytes: &[u8]) -> u128 { - u128::from_be_bytes(bytes[0..16].try_into().unwrap()) -} - -pub fn store128_be(bytes: &mut[u8], x: u128) { - bytes[0..16].copy_from_slice(&u128::to_be_bytes(x)) -} diff --git a/libcrux-sha3/src/lowstar/ignore.rs b/libcrux-sha3/src/lowstar/ignore.rs deleted file mode 100644 index 919eb52f9..000000000 --- a/libcrux-sha3/src/lowstar/ignore.rs +++ /dev/null @@ -1 +0,0 @@ -pub fn ignore(_: T) {} diff --git a/libcrux-sha3/src/x4.rs b/libcrux-sha3/src/x4.rs deleted file mode 100644 index ca501b87b..000000000 --- a/libcrux-sha3/src/x4.rs +++ /dev/null @@ -1,61 +0,0 @@ -/// An incremental eXtendable Output Function API for SHA3 (shake). -/// -/// This x4 variant of the incremental API always processes 4 inputs at a time. -/// This uses AVX2 when available to run the 4 operations in parallel. -/// -/// More generic APIs will be added later. -mod internal; - -/// Incremental state -#[cfg_attr(hax, hax_lib::opaque_type)] -pub struct Shake128StateX4 { - state: internal::Shake128StateX4, -} - -impl Shake128StateX4 { - /// Create a new Shake128 x4 state. - #[inline(always)] - pub fn new() -> Self { - Self { - state: internal::Shake128StateX4::new(), - } - } - - /// This is only used internally to work around Eurydice bugs. - #[inline(always)] - pub fn free_memory(self) { - self.state.free(); - } - - /// Absorb 4 blocks. - /// - /// A blocks MUST all be the same length. - /// Each slice MUST be a multiple of the block length 168. - #[inline(always)] - pub fn absorb_4blocks(&mut self, input: [&[u8]; 4]) { - self.state.absorb_blocks(input) - } - - /// Absorb up to 4 blocks. - /// - /// The `input` must be of length 1 to 4. - /// A blocks MUST all be the same length. - /// Each slice MUST be a multiple of the block length 168. - #[inline(always)] - pub fn absorb_final(&mut self, input: [&[u8]; N]) { - // Pad the input to the length of 4 - let data = [ - input[0], - if N > 1 { input[1] } else { &[] }, - if N > 2 { input[2] } else { &[] }, - if N > 3 { input[3] } else { &[] }, - ]; - self.state.absorb_final(data); - } - - /// Squeeze `M` blocks of length `N` - #[inline(always)] - pub fn squeeze_blocks(&mut self) -> [[u8; N]; M] { - self.state.squeeze_blocks() - } -} diff --git a/libcrux-sha3/src/x4/internal.rs b/libcrux-sha3/src/x4/internal.rs deleted file mode 100644 index 12a75821f..000000000 --- a/libcrux-sha3/src/x4/internal.rs +++ /dev/null @@ -1,337 +0,0 @@ -use core::ptr::null_mut; - -use libcrux_hacl::{ - Hacl_Hash_SHA3_Scalar_shake128_absorb_final, Hacl_Hash_SHA3_Scalar_shake128_absorb_nblocks, - Hacl_Hash_SHA3_Scalar_shake128_squeeze_nblocks, Hacl_Hash_SHA3_Scalar_state_free, - Hacl_Hash_SHA3_Scalar_state_malloc, -}; -#[cfg(feature = "simd256")] -use libcrux_hacl::{ - Hacl_Hash_SHA3_Simd256_shake128_absorb_final, Hacl_Hash_SHA3_Simd256_shake128_absorb_nblocks, - Hacl_Hash_SHA3_Simd256_shake128_squeeze_nblocks, Hacl_Hash_SHA3_Simd256_state_free, - Hacl_Hash_SHA3_Simd256_state_malloc, Lib_IntVector_Intrinsics_vec256, -}; -#[cfg(feature = "simd256")] -use libcrux_platform::simd256_support; - -/// SHAKE 128 -/// -/// Handle to internal SHAKE 128 state -#[cfg(feature = "simd256")] -pub struct Shake128StateX4 { - statex4: *mut Lib_IntVector_Intrinsics_vec256, - state: [*mut u64; 4], -} - -#[cfg(not(feature = "simd256"))] -pub struct Shake128StateX4 { - state: [*mut u64; 4], -} - -impl Shake128StateX4 { - #[cfg(feature = "simd256")] - pub fn new() -> Self { - if simd256_support() { - Self { - statex4: unsafe { Hacl_Hash_SHA3_Simd256_state_malloc() }, - state: [null_mut(), null_mut(), null_mut(), null_mut()], - } - } else { - Self { - statex4: null_mut(), - state: unsafe { - [ - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - ] - }, - } - } - } - - #[cfg(not(feature = "simd256"))] - pub fn new() -> Self { - Self { - state: unsafe { - [ - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - ] - }, - } - } - - /// Free and consume the state. - /// - /// **NOTE:** This consumes the value. It is not usable after this call! - #[cfg(feature = "simd256")] - pub fn free(mut self) { - if simd256_support() { - unsafe { - Hacl_Hash_SHA3_Simd256_state_free(self.statex4); - // null the pointer (hacl isn't doing that unfortunately) - // This way we can check whether the memory was freed already or not. - self.statex4 = null_mut(); - } - } else { - for i in 0..4 { - unsafe { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]); - // null the pointer (hacl isn't doing that unfortunately) - // This way we can check whether the memory was freed already or not. - self.state[i] = null_mut(); - } - } - } - } - - /// Free and consume the state. - /// - /// **NOTE:** This consumes the value. It is not usable after this call! - #[cfg(not(feature = "simd256"))] - pub fn free(mut self) { - for i in 0..4 { - unsafe { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]); - // null the pointer (hacl isn't doing that unfortunately) - // This way we can check whether the memory was freed already or not. - self.state[i] = null_mut(); - } - } - } - - /// Absorb up to 4 blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(feature = "simd256")] - pub fn absorb_blocks(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() % 168 == 0); - - if simd256_support() { - unsafe { - Hacl_Hash_SHA3_Simd256_shake128_absorb_nblocks( - self.statex4, - input[0].as_ptr() as _, - input[1].as_ptr() as _, - input[2].as_ptr() as _, - input[3].as_ptr() as _, - input[0].len() as u32, - ) - }; - } else { - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_nblocks( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - } - - /// Absorb up to 4 blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(not(feature = "simd256"))] - pub fn absorb_blocks(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() % 168 == 0); - - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_nblocks( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - - /// Absorb up to 4 final blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(feature = "simd256")] - pub fn absorb_final(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() < 168); - - if simd256_support() { - unsafe { - Hacl_Hash_SHA3_Simd256_shake128_absorb_final( - self.statex4, - input[0].as_ptr() as _, - input[1].as_ptr() as _, - input[2].as_ptr() as _, - input[3].as_ptr() as _, - input[0].len() as u32, - ) - }; - } else { - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_final( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - } - - /// Absorb up to 4 final blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(not(feature = "simd256"))] - pub fn absorb_final(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() < 168); - - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_final( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - - #[cfg(feature = "simd256")] - pub fn squeeze_blocks( - &mut self, - ) -> [[u8; OUTPUT_BYTES]; M] { - debug_assert!(OUTPUT_BYTES % 168 == 0); - debug_assert!(M <= self.state.len() && (M == 2 || M == 3 || M == 4)); - - if simd256_support() { - let mut output = [[0u8; OUTPUT_BYTES]; 4]; - unsafe { - Hacl_Hash_SHA3_Simd256_shake128_squeeze_nblocks( - self.statex4, - output[0].as_mut_ptr(), - output[1].as_mut_ptr(), - output[2].as_mut_ptr(), - output[3].as_mut_ptr(), - OUTPUT_BYTES as u32, - ); - }; - core::array::from_fn(|i| output[i]) - } else { - let mut output = [[0u8; OUTPUT_BYTES]; M]; - for i in 0..M { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_squeeze_nblocks( - self.state[i], - output[i].as_mut_ptr(), - OUTPUT_BYTES as u32, - ); - }; - } - output - } - } - - #[cfg(not(feature = "simd256"))] - pub fn squeeze_blocks( - &mut self, - ) -> [[u8; OUTPUT_BYTES]; M] { - debug_assert!(OUTPUT_BYTES % 168 == 0); - debug_assert!(M <= self.state.len()); - - let mut output = [[0u8; OUTPUT_BYTES]; M]; - - for i in 0..M { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_squeeze_nblocks( - self.state[i], - output[i].as_mut_ptr(), - OUTPUT_BYTES as u32, - ); - }; - } - - output - } -} - -/// **NOTE:** When generating C code with Eurydice, the state needs to be freed -/// manually for now due to a bug in Eurydice. -impl Drop for Shake128StateX4 { - #[cfg(feature = "simd256")] - fn drop(&mut self) { - if simd256_support() { - // A manual free may have occurred already. - // Avoid double free. - unsafe { - if !self.statex4.is_null() { - Hacl_Hash_SHA3_Simd256_state_free(self.statex4); - } - } - } else { - // A manual free may have occurred already. - // Avoid double free. - for i in 0..4 { - unsafe { - if !self.state[i].is_null() { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]) - } - } - } - } - } - - #[cfg(not(feature = "simd256"))] - fn drop(&mut self) { - // A manual free may have occurred already. - // Avoid double free. - for i in 0..4 { - unsafe { - if !self.state[i].is_null() { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]) - } - } - } - } -} From 302736f9bd0b76cdbe698c740e1c52c5827c5b85 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Tue, 14 May 2024 14:27:09 +0200 Subject: [PATCH 23/59] shak256 simd256 --- libcrux-sha3/src/rust_simd.rs | 15 ++++++++++++++- sys/pqclean/src/bindings.rs | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index fd862da26..685fdd195 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -141,7 +141,20 @@ pub fn shake256(data: &[u8]) -> [u8; LEN] { keccakx2::<136, 0x1fu8>([data, data], [&mut d0, &mut d1]); d0 } -#[cfg(not(feature = "simd128"))] +#[cfg(feature = "simd256")] +pub fn shake256(data: &[u8]) -> [u8; LEN] { + let mut d0 = [0u8; LEN]; + let mut d1 = [0u8; LEN]; + let mut d2 = [0u8; LEN]; + let mut d3 = [0u8; LEN]; + keccakx4::<136, 0x1fu8>( + [data, data, data, data], + [&mut d0, &mut d1, &mut d2, &mut d3], + ); + d0 +} + +#[cfg(not(any(feature = "simd256", feature = "simd128")))] pub fn shake256(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; keccakx1::<136, 0x1fu8>([data], [&mut d0]); diff --git a/sys/pqclean/src/bindings.rs b/sys/pqclean/src/bindings.rs index 59a2d73d9..5f6602af9 100644 --- a/sys/pqclean/src/bindings.rs +++ b/sys/pqclean/src/bindings.rs @@ -1,4 +1,4 @@ -/* automatically generated by rust-bindgen 0.69.1 */ +/* automatically generated by rust-bindgen 0.69.4 */ pub const SHAKE128_RATE: u32 = 168; pub const SHAKE256_RATE: u32 = 136; From 1389350524c03dce3f05fce35e1a725125b82487 Mon Sep 17 00:00:00 2001 From: xvzcf Date: Tue, 14 May 2024 15:55:52 +0200 Subject: [PATCH 24/59] AVX2 implementation of Kyber rejection sampling. --- polynomials-avx2/src/debug.rs | 8 ++++++++ polynomials-avx2/src/lib.rs | 35 +++++++++++++++++++++++--------- polynomials-avx2/src/portable.rs | 26 +----------------------- 3 files changed, 34 insertions(+), 35 deletions(-) diff --git a/polynomials-avx2/src/debug.rs b/polynomials-avx2/src/debug.rs index c49a167b8..7277a76f6 100644 --- a/polynomials-avx2/src/debug.rs +++ b/polynomials-avx2/src/debug.rs @@ -9,6 +9,14 @@ pub(crate) fn print_m256i_as_i16s(a: __m256i, prefix: &'static str) { unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; println!("{}: {:?}", prefix, a_bytes); } + +#[allow(dead_code)] +pub(crate) fn print_m256i_as_i8s(a: __m256i, prefix: &'static str) { + let mut a_bytes = [0i8; 32]; + unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; + println!("{}: {:?}", prefix, a_bytes); +} + #[allow(dead_code)] pub(crate) fn print_m256i_as_i32s(a: __m256i, prefix: &'static str) { let mut a_bytes = [0i32; 8]; diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 62e3046fc..1d6b45852 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -994,10 +994,13 @@ fn rej_sample(uniform_bytes: &[u8]) -> (usize, [i16; 16]) { let potential_coefficients = deserialize_12(uniform_bytes).elements; let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); - let good = serialize_1(SIMD256Vector { elements: compare_with_field_modulus }); + let good = serialize_1(SIMD256Vector { + elements: compare_with_field_modulus, + }); // Write out the indices indicated by the set bits of |good| such that - // the "good" elements can be read in sequence from |potential_coefficients| + // the "good" elements can be read in sequence from the beginning of + // |potential_coefficients| // Start with the first 8 bits, i.e. |good[0]| let byte_start_indices = _pdep_u64(good[0] as u64, 0x0101010101010101) as u128; @@ -1007,7 +1010,10 @@ fn rej_sample(uniform_bytes: &[u8]) -> (usize, [i16; 16]) { let byte_shuffle_indices_first_byte = _mm_cvtsi64_si128(byte_start_indices as i64); let byte_shuffle_indices_second_byte = _mm_add_epi8(byte_shuffle_indices_first_byte, ones); - let byte_shuffle_indices_low = _mm_unpacklo_epi8(byte_shuffle_indices_first_byte, byte_shuffle_indices_second_byte); + let byte_shuffle_indices_low = _mm_unpacklo_epi8( + byte_shuffle_indices_first_byte, + byte_shuffle_indices_second_byte, + ); // Then the next 8 bits, i.e. |good[1]| let byte_start_indices = _pdep_u64(good[1] as u64, 0x0101010101010101) as u128; @@ -1017,25 +1023,34 @@ fn rej_sample(uniform_bytes: &[u8]) -> (usize, [i16; 16]) { let byte_shuffle_indices_first_byte = _mm_cvtsi64_si128(byte_start_indices as i64); let byte_shuffle_indices_second_byte = _mm_add_epi8(byte_shuffle_indices_first_byte, ones); - let byte_shuffle_indices_high = _mm_unpacklo_epi8(byte_shuffle_indices_first_byte, byte_shuffle_indices_second_byte); + let byte_shuffle_indices_high = _mm_unpacklo_epi8( + byte_shuffle_indices_first_byte, + byte_shuffle_indices_second_byte, + ); // Write out the indices to an __m256 and then shuffle let byte_shuffle_indices = _mm256_castsi128_si256(byte_shuffle_indices_low); - let byte_shuffle_indices = _mm256_inserti128_si256(byte_shuffle_indices, byte_shuffle_indices_high, 1); + let byte_shuffle_indices = + _mm256_inserti128_si256(byte_shuffle_indices, byte_shuffle_indices_high, 1); let coefficients = _mm256_shuffle_epi8(potential_coefficients, byte_shuffle_indices); // Write out the elements themselves - _mm256_storeu_si256(sampled.as_mut_ptr() as *mut __m256i, coefficients); - - // Count the sampled elements - let count_sampled = good[0].count_ones() + good[1].count_ones(); + let low_coefficients = _mm256_castsi256_si128(coefficients); + _mm_storeu_si128(sampled.as_mut_ptr() as *mut __m128i, low_coefficients); + let count_sampled = good[0].count_ones(); + + let high_coefficients = _mm256_extracti128_si256(coefficients, 1); + _mm_storeu_si128( + sampled.as_mut_ptr().offset(count_sampled as isize) as *mut __m128i, + high_coefficients, + ); + let count_sampled = count_sampled + good[1].count_ones(); count_sampled }; (count as usize, sampled) - //portable::rej_sample(uniform_bytes) } impl Operations for SIMD256Vector { diff --git a/polynomials-avx2/src/portable.rs b/polynomials-avx2/src/portable.rs index b18c02d19..844e6a3c2 100644 --- a/polynomials-avx2/src/portable.rs +++ b/polynomials-avx2/src/portable.rs @@ -1,4 +1,4 @@ -pub use libcrux_traits::{FIELD_ELEMENTS_IN_VECTOR, FIELD_MODULUS}; +pub use libcrux_traits::FIELD_ELEMENTS_IN_VECTOR; type FieldElement = i16; @@ -112,27 +112,3 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> PortableVector { result } - -#[inline(always)] -pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - let mut result = [0i16; 16]; - let mut sampled = 0; - for bytes in a.chunks(3) { - let b1 = bytes[0] as i16; - let b2 = bytes[1] as i16; - let b3 = bytes[2] as i16; - - let d1 = ((b2 & 0xF) << 8) | b1; - let d2 = (b3 << 4) | (b2 >> 4); - - if d1 < FIELD_MODULUS && sampled < 16 { - result[sampled] = d1; - sampled += 1 - } - if d2 < FIELD_MODULUS && sampled < 16 { - result[sampled] = d2; - sampled += 1 - } - } - (sampled, result) -} From 84b22fbd06ef1005bb2bbe85a4e6815a63ffea94 Mon Sep 17 00:00:00 2001 From: Karthik Bhargavan Date: Tue, 14 May 2024 16:05:35 +0200 Subject: [PATCH 25/59] prfxn simd --- libcrux-ml-kem/src/hash_functions.rs | 31 +++++++++++++++++-- libcrux-sha3/src/rust_simd.rs | 46 ++-------------------------- 2 files changed, 32 insertions(+), 45 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 032785818..1c231e69e 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -19,6 +19,33 @@ pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { rust_simd::shake256::(input) } + +#[cfg(feature = "simd256")] +#[inline(always)] +pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + let mut out = [[0u8; LEN]; K]; + let mut dummy_out0 = [0u8; LEN]; + let mut dummy_out1 = [0u8; LEN]; + + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + rust_simd::shake256x4(&input[0], &input[1], &input[0], &input[0], &mut out0[0], &mut out1[0], &mut dummy_out0, &mut dummy_out1); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + rust_simd::shake256x4(&input[0], &input[1], &input[2], &input[0], &mut out0[0], &mut out1[0], &mut out2[0], &mut dummy_out0); + } + _ => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + rust_simd::shake256x4(&input[0], &input[1], &input[2], &input[3], &mut out0[0], &mut out1[0], &mut out2[0], &mut out3[0]); + } + } + out +} #[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { @@ -46,8 +73,8 @@ pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> } out } -#[cfg(not(feature = "simd128"))] -#[inline(always)] +#[cfg(not(any(feature = "simd128", feature = "simd256")))] +//#[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { core::array::from_fn(|i| rust_simd::shake256::(&input[i])) } diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index 685fdd195..b2c457774 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -58,20 +58,7 @@ pub fn sha3_256(data: &[u8]) -> [u8; 32] { d0 } -#[cfg(feature = "simd256")] -pub fn sha3_256(data: &[u8]) -> [u8; 32] { - let mut d0 = [0u8; 32]; - let mut d1 = [0u8; 32]; - let mut d2 = [0u8; 32]; - let mut d3 = [0u8; 32]; - keccakx4::<136, 0x06u8>( - [data, data, data, data], - [&mut d0, &mut d1, &mut d2, &mut d3], - ); - d0 -} - -#[cfg(not(any(feature = "simd256", feature = "simd128")))] +#[cfg(not(feature = "simd128"))] pub fn sha3_256(data: &[u8]) -> [u8; 32] { let mut d0 = [0u8; 32]; keccakx1::<136, 0x06u8>([data], [&mut d0]); @@ -99,21 +86,7 @@ pub fn sha3_512(data: &[u8]) -> [u8; 64] { keccakx2::<72, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } - -#[cfg(feature = "simd256")] -pub fn sha3_512(data: &[u8]) -> [u8; 64] { - let mut d0 = [0u8; 64]; - let mut d1 = [0u8; 64]; - let mut d2 = [0u8; 64]; - let mut d3 = [0u8; 64]; - keccakx4::<72, 0x06u8>( - [data, data, data, data], - [&mut d0, &mut d1, &mut d2, &mut d3], - ); - d0 -} - -#[cfg(not(any(feature = "simd256", feature = "simd128")))] +#[cfg(not(feature = "simd128"))] pub fn sha3_512(data: &[u8]) -> [u8; 64] { let mut d0 = [0u8; 64]; keccakx1::<72, 0x06u8>([data], [&mut d0]); @@ -141,20 +114,7 @@ pub fn shake256(data: &[u8]) -> [u8; LEN] { keccakx2::<136, 0x1fu8>([data, data], [&mut d0, &mut d1]); d0 } -#[cfg(feature = "simd256")] -pub fn shake256(data: &[u8]) -> [u8; LEN] { - let mut d0 = [0u8; LEN]; - let mut d1 = [0u8; LEN]; - let mut d2 = [0u8; LEN]; - let mut d3 = [0u8; LEN]; - keccakx4::<136, 0x1fu8>( - [data, data, data, data], - [&mut d0, &mut d1, &mut d2, &mut d3], - ); - d0 -} - -#[cfg(not(any(feature = "simd256", feature = "simd128")))] +#[cfg(not(feature = "simd128"))] pub fn shake256(data: &[u8]) -> [u8; LEN] { let mut d0 = [0u8; LEN]; keccakx1::<136, 0x1fu8>([data], [&mut d0]); From e059f93b1fc88ee9037d00fe12ae586a2271c2cb Mon Sep 17 00:00:00 2001 From: Karthik Bhargavan Date: Tue, 14 May 2024 16:30:48 +0200 Subject: [PATCH 26/59] merged --- polynomials-avx2/src/lib.rs | 2 +- polynomials-avx2/src/portable.rs | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 3c9bd5fc9..e09655aab 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -984,7 +984,7 @@ fn deserialize_12(v: &[u8]) -> SIMD256Vector { } #[inline(always)] -pub(crate) fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { +pub(crate) fn rej_sample(uniform_bytes: &[u8], out: &mut [i16]) -> usize { let count = unsafe { let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); let ones = _mm_set1_epi8(1); diff --git a/polynomials-avx2/src/portable.rs b/polynomials-avx2/src/portable.rs index a1dda7f4c..844e6a3c2 100644 --- a/polynomials-avx2/src/portable.rs +++ b/polynomials-avx2/src/portable.rs @@ -112,4 +112,3 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> PortableVector { result } - From acaa2927f2110a6ffb14abb8452219f370e2a88b Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Tue, 14 May 2024 18:33:36 +0200 Subject: [PATCH 27/59] bugfix for arm --- libcrux-ml-kem/src/hash_functions.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 1c231e69e..eebed1c50 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -365,13 +365,6 @@ pub(crate) fn squeeze_block(state: &mut KeccakState4) -> [[u8; B /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. -#[cfg(feature = "simd256")] #[inline(always)] pub(crate) fn free_state(_xof_state: KeccakState4) {} -/// Free the memory of the state. -/// -/// **NOTE:** That this needs to be done manually for now. -#[cfg(not(any(feature = "simd256", feature = "simd128")))] -#[inline(always)] -pub(crate) fn free_state(_xof_state: [libcrux_sha3::rust_simd::KeccakState1; K]) {} From f156bf4038a1adafe67643114d359200b433466c Mon Sep 17 00:00:00 2001 From: xvzcf Date: Tue, 14 May 2024 21:11:15 +0200 Subject: [PATCH 28/59] Updates to avx2 rejection sampling. --- polynomials-avx2/src/lib.rs | 73 +-- polynomials-avx2/src/sampling.rs | 787 +++++++++++++++++++++++++++++++ 2 files changed, 790 insertions(+), 70 deletions(-) create mode 100644 polynomials-avx2/src/sampling.rs diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index e09655aab..51cac80a7 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -6,6 +6,7 @@ use libcrux_traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMER mod debug; mod portable; +mod sampling; const BARRETT_MULTIPLIER: i16 = 20159; @@ -983,74 +984,6 @@ fn deserialize_12(v: &[u8]) -> SIMD256Vector { } } -#[inline(always)] -pub(crate) fn rej_sample(uniform_bytes: &[u8], out: &mut [i16]) -> usize { - let count = unsafe { - let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); - let ones = _mm_set1_epi8(1); - - let potential_coefficients = deserialize_12(uniform_bytes).elements; - - let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); - let good = serialize_1(SIMD256Vector { - elements: compare_with_field_modulus, - }); - - // Write out the indices indicated by the set bits of |good| such that - // the "good" elements can be read in sequence from the beginning of - // |potential_coefficients| - - // Start with the first 8 bits, i.e. |good[0]| - let byte_start_indices = _pdep_u64(good[0] as u64, 0x0101010101010101) as u128; - let byte_start_indices = ((byte_start_indices << 8) - byte_start_indices) as u64; - let byte_start_indices = _pext_u64(0x0E0C0A0806040200, byte_start_indices); - - let byte_shuffle_indices_first_byte = _mm_cvtsi64_si128(byte_start_indices as i64); - let byte_shuffle_indices_second_byte = _mm_add_epi8(byte_shuffle_indices_first_byte, ones); - - let byte_shuffle_indices_low = _mm_unpacklo_epi8( - byte_shuffle_indices_first_byte, - byte_shuffle_indices_second_byte, - ); - - // Then the next 8 bits, i.e. |good[1]| - let byte_start_indices = _pdep_u64(good[1] as u64, 0x0101010101010101) as u128; - let byte_start_indices = ((byte_start_indices << 8) - byte_start_indices) as u64; - let byte_start_indices = _pext_u64(0x0E0C0A0806040200, byte_start_indices); - - let byte_shuffle_indices_first_byte = _mm_cvtsi64_si128(byte_start_indices as i64); - let byte_shuffle_indices_second_byte = _mm_add_epi8(byte_shuffle_indices_first_byte, ones); - - let byte_shuffle_indices_high = _mm_unpacklo_epi8( - byte_shuffle_indices_first_byte, - byte_shuffle_indices_second_byte, - ); - - // Write out the indices to an __m256 and then shuffle - let byte_shuffle_indices = _mm256_castsi128_si256(byte_shuffle_indices_low); - let byte_shuffle_indices = - _mm256_inserti128_si256(byte_shuffle_indices, byte_shuffle_indices_high, 1); - - let coefficients = _mm256_shuffle_epi8(potential_coefficients, byte_shuffle_indices); - - // Write out the elements themselves - let low_coefficients = _mm256_castsi256_si128(coefficients); - _mm_storeu_si128(out.as_mut_ptr() as *mut __m128i, low_coefficients); - let count_sampled = good[0].count_ones(); - - let high_coefficients = _mm256_extracti128_si256(coefficients, 1); - _mm_storeu_si128( - out.as_mut_ptr().offset(count_sampled as isize) as *mut __m128i, - high_coefficients, - ); - let count_sampled = count_sampled + good[1].count_ones(); - - count_sampled - }; - - count as usize -} - impl Operations for SIMD256Vector { fn ZERO() -> Self { zero() @@ -1195,7 +1128,7 @@ impl Operations for SIMD256Vector { deserialize_12(a) } - fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { - rej_sample(a, out) + fn rej_sample(input: &[u8], output: &mut [i16]) -> usize { + sampling::rejection_sample(input, output) } } diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs new file mode 100644 index 000000000..c9b759782 --- /dev/null +++ b/polynomials-avx2/src/sampling.rs @@ -0,0 +1,787 @@ +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use crate::{SIMD256Vector, serialize_1, deserialize_12, FIELD_MODULUS}; + +const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ + [ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, + ], // 0 + [ + 0, 1, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 1 + [ + 2, 3, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 2 + [ + 0, 1, 2, 3, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 3 + [ + 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 4 + [ + 0, 1, 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 5 + [ + 2, 3, 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 6 + [ + 0, 1, 2, 3, 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 7 + [ + 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 8 + [ + 0, 1, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 9 + [ + 2, 3, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 10 + [ + 0, 1, 2, 3, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 11 + [ + 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 12 + [ + 0, 1, 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 13 + [ + 2, 3, 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 14 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 15 + [ + 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 16 + [ + 0, 1, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 17 + [ + 2, 3, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 18 + [ + 0, 1, 2, 3, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 19 + [ + 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 20 + [ + 0, 1, 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 21 + [ + 2, 3, 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 22 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 23 + [ + 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 24 + [ + 0, 1, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 25 + [ + 2, 3, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 26 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 27 + [ + 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 28 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 29 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 30 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 31 + [ + 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 32 + [ + 0, 1, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 33 + [ + 2, 3, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 34 + [ + 0, 1, 2, 3, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 35 + [ + 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 36 + [ + 0, 1, 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 37 + [ + 2, 3, 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 38 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 39 + [ + 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 40 + [ + 0, 1, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 41 + [ + 2, 3, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 42 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 43 + [ + 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 44 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 45 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 46 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 47 + [ + 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 48 + [ + 0, 1, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 49 + [ + 2, 3, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 50 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 51 + [ + 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 52 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 53 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 54 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 55 + [ + 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 56 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 57 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 58 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 59 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 60 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 61 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 62 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff], // 63 + [ + 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 64 + [ + 0, 1, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 65 + [ + 2, 3, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 66 + [ + 0, 1, 2, 3, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 67 + [ + 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 68 + [ + 0, 1, 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 69 + [ + 2, 3, 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 70 + [ + 0, 1, 2, 3, 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 71 + [ + 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 72 + [ + 0, 1, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 73 + [ + 2, 3, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 74 + [ + 0, 1, 2, 3, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 75 + [ + 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 76 + [ + 0, 1, 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 77 + [ + 2, 3, 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 78 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 79 + [ + 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 80 + [ + 0, 1, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 81 + [ + 2, 3, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 82 + [ + 0, 1, 2, 3, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 83 + [ + 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 84 + [ + 0, 1, 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 85 + [ + 2, 3, 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 86 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 87 + [ + 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 88 + [ + 0, 1, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 89 + [ + 2, 3, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 90 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 91 + [ + 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 92 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 93 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 94 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff], // 95 + [ + 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 96 + [ + 0, 1, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 97 + [ + 2, 3, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 98 + [ + 0, 1, 2, 3, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 99 + [ + 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 100 + [ + 0, 1, 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 101 + [ + 2, 3, 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 102 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 103 + [ + 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 104 + [ + 0, 1, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 105 + [ + 2, 3, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 106 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 107 + [ + 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 108 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 109 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 110 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 111 + [ + 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 112 + [ + 0, 1, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 113 + [ + 2, 3, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 114 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 115 + [ + 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 116 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 117 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 118 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 119 + [ + 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 120 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 121 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 122 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 123 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 124 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 125 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 126 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff], // 127 + [ + 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 128 + [ + 0, 1, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 129 + [ + 2, 3, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 130 + [ + 0, 1, 2, 3, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 131 + [ + 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 132 + [ + 0, 1, 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 133 + [ + 2, 3, 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 134 + [ + 0, 1, 2, 3, 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 135 + [ + 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 136 + [ + 0, 1, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 137 + [ + 2, 3, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 138 + [ + 0, 1, 2, 3, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 139 + [ + 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 140 + [ + 0, 1, 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 141 + [ + 2, 3, 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 142 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 143 + [ + 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 144 + [ + 0, 1, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 145 + [ + 2, 3, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 146 + [ + 0, 1, 2, 3, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 147 + [ + 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 148 + [ + 0, 1, 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 149 + [ + 2, 3, 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 150 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 151 + [ + 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 152 + [ + 0, 1, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 153 + [ + 2, 3, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 154 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 155 + [ + 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 156 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 157 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 158 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff], // 159 + [ + 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 160 + [ + 0, 1, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 161 + [ + 2, 3, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 162 + [ + 0, 1, 2, 3, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 163 + [ + 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 164 + [ + 0, 1, 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 165 + [ + 2, 3, 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 166 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 167 + [ + 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 168 + [ + 0, 1, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 169 + [ + 2, 3, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 170 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 171 + [ + 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 172 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 173 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 174 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 175 + [ + 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 176 + [ + 0, 1, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 177 + [ + 2, 3, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 178 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 179 + [ + 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 180 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 181 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 182 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 183 + [ + 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 184 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 185 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 186 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 187 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 188 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 189 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 190 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff], // 191 + [ + 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 192 + [ + 0, 1, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 193 + [ + 2, 3, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 194 + [ + 0, 1, 2, 3, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 195 + [ + 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 196 + [ + 0, 1, 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 197 + [ + 2, 3, 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 198 + [ + 0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 199 + [ + 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 200 + [ + 0, 1, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 201 + [ + 2, 3, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 202 + [ + 0, 1, 2, 3, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 203 + [ + 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 204 + [ + 0, 1, 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 205 + [ + 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 206 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 207 + [ + 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 208 + [ + 0, 1, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 209 + [ + 2, 3, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 210 + [ + 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 211 + [ + 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 212 + [ + 0, 1, 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 213 + [ + 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 214 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 215 + [ + 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 216 + [ + 0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 217 + [ + 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 218 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 219 + [ + 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 220 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 221 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 222 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff], // 223 + [ + 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 224 + [ + 0, 1, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 225 + [ + 2, 3, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 226 + [ + 0, 1, 2, 3, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 227 + [ + 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 228 + [ + 0, 1, 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 229 + [ + 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 230 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 231 + [ + 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 232 + [ + 0, 1, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 233 + [ + 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 234 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 235 + [ + 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 236 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 237 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 238 + [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 239 + [ + 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 240 + [ + 0, 1, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 241 + [ + 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 242 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 243 + [ + 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 244 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 245 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 246 + [0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 247 + [ + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 248 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 249 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 250 + [0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 251 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 252 + [0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 253 + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 254 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 255 +]; + +#[inline(always)] +pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { + let count = unsafe { + let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); + + let potential_coefficients = deserialize_12(input).elements; + + let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); + let good = serialize_1(SIMD256Vector { + elements: compare_with_field_modulus, + }); + + let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; + let lower_shuffles = _mm_load_si128(lower_shuffles.as_ptr() as *const __m128i); + let lower_coefficients = _mm256_castsi256_si128(potential_coefficients); + let lower_coefficients = _mm_shuffle_epi8(lower_coefficients, lower_shuffles); + + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, lower_coefficients); + let sampled_count = good[0].count_ones(); + + let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; + let upper_shuffles = _mm_load_si128(upper_shuffles.as_ptr() as *const __m128i); + let upper_coefficients = _mm256_extractf128_si256(potential_coefficients, 1); + let upper_coefficients = _mm_shuffle_epi8(upper_coefficients, upper_shuffles); + + _mm_storeu_si128(output.as_mut_ptr().offset(sampled_count as isize) as *mut __m128i, upper_coefficients); + + sampled_count + good[1].count_ones() + }; + + count as usize +} From 78da7eca10e2e402498f1cac0c72716e86e5bff9 Mon Sep 17 00:00:00 2001 From: xvzcf Date: Tue, 14 May 2024 21:13:17 +0200 Subject: [PATCH 29/59] load -> loadu --- polynomials-avx2/src/sampling.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs index c9b759782..09612525a 100644 --- a/polynomials-avx2/src/sampling.rs +++ b/polynomials-avx2/src/sampling.rs @@ -766,7 +766,7 @@ pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { }); let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; - let lower_shuffles = _mm_load_si128(lower_shuffles.as_ptr() as *const __m128i); + let lower_shuffles = _mm_loadu_si128(lower_shuffles.as_ptr() as *const __m128i); let lower_coefficients = _mm256_castsi256_si128(potential_coefficients); let lower_coefficients = _mm_shuffle_epi8(lower_coefficients, lower_shuffles); @@ -774,7 +774,7 @@ pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { let sampled_count = good[0].count_ones(); let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; - let upper_shuffles = _mm_load_si128(upper_shuffles.as_ptr() as *const __m128i); + let upper_shuffles = _mm_loadu_si128(upper_shuffles.as_ptr() as *const __m128i); let upper_coefficients = _mm256_extractf128_si256(potential_coefficients, 1); let upper_coefficients = _mm_shuffle_epi8(upper_coefficients, upper_shuffles); From ef3d4d71142b26769da31dfdefd79f721e6f7485 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Wed, 15 May 2024 06:16:45 +0200 Subject: [PATCH 30/59] mutable inputs to lessen memmoves --- libcrux-ml-kem/src/helper.rs | 2 +- libcrux-ml-kem/src/ind_cpa.rs | 16 ++++---- libcrux-ml-kem/src/invert_ntt.rs | 38 +++++++++---------- libcrux-ml-kem/src/matrix.rs | 30 +++++++-------- libcrux-ml-kem/src/ntt.rs | 63 ++++++++++++++------------------ libcrux-ml-kem/src/polynomial.rs | 26 ++++++------- libcrux-ml-kem/src/serialize.rs | 10 ++--- 7 files changed, 85 insertions(+), 100 deletions(-) diff --git a/libcrux-ml-kem/src/helper.rs b/libcrux-ml-kem/src/helper.rs index 3c9b77dfc..c3fc25a82 100644 --- a/libcrux-ml-kem/src/helper.rs +++ b/libcrux-ml-kem/src/helper.rs @@ -23,7 +23,7 @@ macro_rules! cloop { }; (for ($i:ident, $item:ident) in $val:ident.into_iter().enumerate() $body:block) => { for $i in 0..$val.len() { - let $item = $val[$i]; + let $item = &$val[$i]; $body } }; diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index dce14dec8..a7c9db163 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -73,7 +73,7 @@ fn sample_ring_element_cbd< prf_input: [u8; 33], mut domain_separator: u8, ) -> ([PolynomialRingElement; K], u8) { - let mut error_1 = [PolynomialRingElement::::ZERO(); K]; + let mut error_1 = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); let mut prf_inputs = [prf_input; K]; for i in 0..K { prf_inputs[i][32] = domain_separator; @@ -98,7 +98,7 @@ fn sample_vector_cbd_then_ntt< prf_input: [u8; 33], mut domain_separator: u8, ) -> ([PolynomialRingElement; K], u8) { - let mut re_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut re_as_ntt = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); let mut prf_inputs = [prf_input; K]; for i in 0..K { prf_inputs[i][32] = domain_separator; @@ -106,8 +106,8 @@ fn sample_vector_cbd_then_ntt< } let prf_outputs: [[u8; ETA_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); for i in 0..K { - let r = sample_from_binomial_distribution::(&prf_outputs[i]); - re_as_ntt[i] = ntt_binomially_sampled_ring_element(r); + re_as_ntt[i] = sample_from_binomial_distribution::(&prf_outputs[i]); + ntt_binomially_sampled_ring_element(&mut re_as_ntt[i]); } (re_as_ntt, domain_separator) } @@ -340,14 +340,14 @@ fn deserialize_then_decompress_u< >( ciphertext: &[u8; CIPHERTEXT_SIZE], ) -> [PolynomialRingElement; K] { - let mut u_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut u_as_ntt = core::array::from_fn(|_| PolynomialRingElement::::ZERO()); cloop! { for (i, u_bytes) in ciphertext .chunks_exact((COEFFICIENTS_IN_RING_ELEMENT * U_COMPRESSION_FACTOR) / 8) .enumerate() { - let u = deserialize_then_decompress_ring_element_u::(u_bytes); - u_as_ntt[i] = ntt_vector_u::(u); + u_as_ntt[i] = deserialize_then_decompress_ring_element_u::(u_bytes); + ntt_vector_u::(&mut u_as_ntt[i]); } } u_as_ntt @@ -358,7 +358,7 @@ fn deserialize_then_decompress_u< fn deserialize_secret_key( secret_key: &[u8], ) -> [PolynomialRingElement; K] { - let mut secret_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut secret_as_ntt = core::array::from_fn(|_| PolynomialRingElement::::ZERO()); cloop! { for (i, secret_bytes) in secret_key.chunks_exact(BYTES_PER_RING_ELEMENT).enumerate() { secret_as_ntt[i] = deserialize_to_uncompressed_ring_element(secret_bytes); diff --git a/libcrux-ml-kem/src/invert_ntt.rs b/libcrux-ml-kem/src/invert_ntt.rs index 89ca5cad3..5b5e01a43 100644 --- a/libcrux-ml-kem/src/invert_ntt.rs +++ b/libcrux-ml-kem/src/invert_ntt.rs @@ -7,9 +7,9 @@ use libcrux_polynomials::{GenericOperations, Operations, FIELD_ELEMENTS_IN_VECTO #[inline(always)] pub(crate) fn invert_ntt_at_layer_1( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_1_step( @@ -21,15 +21,14 @@ pub(crate) fn invert_ntt_at_layer_1( ); *zeta_i -= 3; } - re } #[inline(always)] pub(crate) fn invert_ntt_at_layer_2( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_2_step( @@ -39,21 +38,19 @@ pub(crate) fn invert_ntt_at_layer_2( ); *zeta_i -= 1; } - re } #[inline(always)] pub(crate) fn invert_ntt_at_layer_3( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_3_step(re.coefficients[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]); } - re } #[inline(always)] @@ -70,9 +67,9 @@ pub(crate) fn inv_ntt_layer_int_vec_step_reduce( #[inline(always)] pub(crate) fn invert_ntt_at_layer_4_plus( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, layer: usize, -) -> PolynomialRingElement { +) { let step = 1 << layer; for round in 0..(128 >> layer) { @@ -92,13 +89,12 @@ pub(crate) fn invert_ntt_at_layer_4_plus( re.coefficients[j + step_vec] = y; } } - re } #[inline(always)] pub(crate) fn invert_ntt_montgomery( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { // We only ever call this function after matrix/vector multiplication hax_debug_assert!(to_i16_array(re) .into_iter() @@ -106,13 +102,13 @@ pub(crate) fn invert_ntt_montgomery( let mut zeta_i = super::constants::COEFFICIENTS_IN_RING_ELEMENT / 2; - re = invert_ntt_at_layer_1(&mut zeta_i, re, 1); - re = invert_ntt_at_layer_2(&mut zeta_i, re, 2); - re = invert_ntt_at_layer_3(&mut zeta_i, re, 3); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 4); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 5); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 6); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 7); + invert_ntt_at_layer_1(&mut zeta_i, re, 1); + invert_ntt_at_layer_2(&mut zeta_i, re, 2); + invert_ntt_at_layer_3(&mut zeta_i, re, 3); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 4); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 5); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 6); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 7); hax_debug_assert!( to_i16_array(re)[0].abs() < 128 * (K as i16) * FIELD_MODULUS diff --git a/libcrux-ml-kem/src/matrix.rs b/libcrux-ml-kem/src/matrix.rs index f10565f82..f482b7730 100644 --- a/libcrux-ml-kem/src/matrix.rs +++ b/libcrux-ml-kem/src/matrix.rs @@ -11,7 +11,7 @@ pub(crate) fn sample_matrix_A( seed: [u8; 34], transpose: bool, ) -> [[PolynomialRingElement; K]; K] { - let mut A_transpose = [[PolynomialRingElement::::ZERO(); K]; K]; + let mut A_transpose = core::array::from_fn(|_i| core::array::from_fn (|_j| PolynomialRingElement::::ZERO())); for i in 0..K { let mut seeds = [seed; K]; @@ -20,12 +20,12 @@ pub(crate) fn sample_matrix_A( seeds[j][33] = j as u8; } let sampled = sample_from_xof(seeds); - for j in 0..K { + for (j,sample) in sampled.into_iter().enumerate() { // A[i][j] = A_transpose[j][i] if transpose { - A_transpose[j][i] = sampled[j]; + A_transpose[j][i] = sample; } else { - A_transpose[i][j] = sampled[j]; + A_transpose[i][j] = sample; } } } @@ -48,10 +48,10 @@ pub(crate) fn compute_message( for i in 0..K { let product = secret_as_ntt[i].ntt_multiply(&u_as_ntt[i]); - result = result.add_to_ring_element::(&product); + result.add_to_ring_element::(&product); } - result = invert_ntt_montgomery::(result); + invert_ntt_montgomery::(&mut result); result = v.subtract_reduce(result); result @@ -69,10 +69,10 @@ pub(crate) fn compute_ring_element_v( for i in 0..K { let product = t_as_ntt[i].ntt_multiply(&r_as_ntt[i]); - result = result.add_to_ring_element::(&product); + result.add_to_ring_element::(&product); } - result = invert_ntt_montgomery::(result); + invert_ntt_montgomery::(&mut result); result = error_2.add_message_error_reduce(message, result); result @@ -85,19 +85,19 @@ pub(crate) fn compute_vector_u( r_as_ntt: &[PolynomialRingElement; K], error_1: &[PolynomialRingElement; K], ) -> [PolynomialRingElement; K] { - let mut result = [PolynomialRingElement::::ZERO(); K]; + let mut result = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); cloop! { for (i, row) in a_as_ntt.iter().enumerate() { cloop! { for (j, a_element) in row.iter().enumerate() { let product = a_element.ntt_multiply(&r_as_ntt[j]); - result[i] = result[i].add_to_ring_element::(&product); + result[i].add_to_ring_element::(&product); } } - result[i] = invert_ntt_montgomery::(result[i]); - result[i] = error_1[i].add_error_reduce(result[i]); + invert_ntt_montgomery::(&mut result[i]); + result[i].add_error_reduce(&error_1[i]); } } @@ -112,17 +112,17 @@ pub(crate) fn compute_As_plus_e( s_as_ntt: &[PolynomialRingElement; K], error_as_ntt: &[PolynomialRingElement; K], ) -> [PolynomialRingElement; K] { - let mut result = [PolynomialRingElement::::ZERO(); K]; + let mut result = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); cloop! { for (i, row) in matrix_A.iter().enumerate() { cloop! { for (j, matrix_element) in row.iter().enumerate() { let product = matrix_element.ntt_multiply(&s_as_ntt[j]); - result[i] = result[i].add_to_ring_element::(&product); + result[i].add_to_ring_element::(&product); } } - result[i] = error_as_ntt[i].add_standard_error_reduce(result[i]); + result[i].add_standard_error_reduce(&error_as_ntt[i]); } } diff --git a/libcrux-ml-kem/src/ntt.rs b/libcrux-ml-kem/src/ntt.rs index 9d164beaa..fa45c0b74 100644 --- a/libcrux-ml-kem/src/ntt.rs +++ b/libcrux-ml-kem/src/ntt.rs @@ -7,10 +7,10 @@ use libcrux_polynomials::{GenericOperations, Operations, FIELD_ELEMENTS_IN_VECTO #[inline(always)] pub(crate) fn ntt_at_layer_1( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_1_step( @@ -22,16 +22,15 @@ pub(crate) fn ntt_at_layer_1( ); *zeta_i += 3; } - re } #[inline(always)] pub(crate) fn ntt_at_layer_2( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_2_step( @@ -41,23 +40,20 @@ pub(crate) fn ntt_at_layer_2( ); *zeta_i += 1; } - re } #[inline(always)] pub(crate) fn ntt_at_layer_3( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_3_step(re.coefficients[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]); } - - re } #[inline(always)] @@ -74,10 +70,10 @@ fn ntt_layer_int_vec_step( #[inline(always)] pub(crate) fn ntt_at_layer_4_plus( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { debug_assert!(layer >= 4); let step = 1 << layer; @@ -98,59 +94,56 @@ pub(crate) fn ntt_at_layer_4_plus( re.coefficients[j + step_vec] = y; } } - re } #[inline(always)] pub(crate) fn ntt_at_layer_7( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { let step = VECTORS_IN_RING_ELEMENT / 2; for j in 0..step { let t = Vector::multiply_by_constant(re.coefficients[j + step], -1600); re.coefficients[j + step] = Vector::sub(re.coefficients[j], &t); re.coefficients[j] = Vector::add(re.coefficients[j], &t); } - - re } #[inline(always)] pub(crate) fn ntt_binomially_sampled_ring_element( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { // Due to the small coefficient bound, we can skip the first round of // Montgomery reductions. - re = ntt_at_layer_7(re); + ntt_at_layer_7(re); let mut zeta_i = 1; - re = ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3); - re = ntt_at_layer_3(&mut zeta_i, re, 3, 3); - re = ntt_at_layer_2(&mut zeta_i, re, 2, 3); - re = ntt_at_layer_1(&mut zeta_i, re, 1, 3); + ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3); + ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3); + ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3); + ntt_at_layer_3(&mut zeta_i, re, 3, 3); + ntt_at_layer_2(&mut zeta_i, re, 2, 3); + ntt_at_layer_1(&mut zeta_i, re, 1, 3); re.poly_barrett_reduce() } #[inline(always)] pub(crate) fn ntt_vector_u( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { hax_debug_assert!(to_i16_array(re) .into_iter() .all(|coefficient| coefficient.abs() <= 3328)); let mut zeta_i = 0; - re = ntt_at_layer_4_plus(&mut zeta_i, re, 7, 3328); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3328); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3328); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3328); - re = ntt_at_layer_3(&mut zeta_i, re, 3, 3328); - re = ntt_at_layer_2(&mut zeta_i, re, 2, 3328); - re = ntt_at_layer_1(&mut zeta_i, re, 1, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 7, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3328); + ntt_at_layer_3(&mut zeta_i, re, 3, 3328); + ntt_at_layer_2(&mut zeta_i, re, 2, 3328); + ntt_at_layer_1(&mut zeta_i, re, 1, 3328); re.poly_barrett_reduce() } diff --git a/libcrux-ml-kem/src/polynomial.rs b/libcrux-ml-kem/src/polynomial.rs index 66d1aff12..111a64d2f 100644 --- a/libcrux-ml-kem/src/polynomial.rs +++ b/libcrux-ml-kem/src/polynomial.rs @@ -16,7 +16,7 @@ pub(crate) const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 128] = pub(crate) const VECTORS_IN_RING_ELEMENT: usize = super::constants::COEFFICIENTS_IN_RING_ELEMENT / FIELD_ELEMENTS_IN_VECTOR; -#[derive(Clone, Copy)] +//#[derive(Clone, Copy)] pub(crate) struct PolynomialRingElement { pub(crate) coefficients: [Vector; VECTORS_IN_RING_ELEMENT], } @@ -44,19 +44,17 @@ impl PolynomialRingElement { /// Given two polynomial ring elements `lhs` and `rhs`, compute the pointwise /// sum of their constituent coefficients. #[inline(always)] - pub(crate) fn add_to_ring_element(mut self, rhs: &Self) -> Self { + pub(crate) fn add_to_ring_element(&mut self, rhs: &Self) { for i in 0..self.coefficients.len() { self.coefficients[i] = Vector::add(self.coefficients[i], &rhs.coefficients[i]); } - self } #[inline(always)] - pub fn poly_barrett_reduce(mut self) -> Self { + pub fn poly_barrett_reduce(&mut self) { for i in 0..VECTORS_IN_RING_ELEMENT { self.coefficients[i] = Vector::barrett_reduce(self.coefficients[i]); } - self } #[inline(always)] @@ -84,28 +82,26 @@ impl PolynomialRingElement { } #[inline(always)] - pub(crate) fn add_error_reduce(&self, mut result: Self) -> Self { + pub(crate) fn add_error_reduce(&mut self, error: &Self) { for j in 0..VECTORS_IN_RING_ELEMENT { let coefficient_normal_form = - Vector::montgomery_multiply_by_constant(result.coefficients[j], 1441); + Vector::montgomery_multiply_by_constant(self.coefficients[j], 1441); - result.coefficients[j] = - Vector::barrett_reduce(Vector::add(coefficient_normal_form, &self.coefficients[j])); + self.coefficients[j] = + Vector::barrett_reduce(Vector::add(coefficient_normal_form, &error.coefficients[j])); } - result } #[inline(always)] - pub(crate) fn add_standard_error_reduce(&self, mut result: Self) -> Self { + pub(crate) fn add_standard_error_reduce(&mut self, error: &Self) { for j in 0..VECTORS_IN_RING_ELEMENT { // The coefficients are of the form aR^{-1} mod q, which means // calling to_montgomery_domain() on them should return a mod q. - let coefficient_normal_form = Vector::to_standard_domain(result.coefficients[j]); + let coefficient_normal_form = Vector::to_standard_domain(self.coefficients[j]); - result.coefficients[j] = - Vector::barrett_reduce(Vector::add(coefficient_normal_form, &self.coefficients[j])); + self.coefficients[j] = + Vector::barrett_reduce(Vector::add(coefficient_normal_form, &error.coefficients[j])); } - result } /// Given two `KyberPolynomialRingElement`s in their NTT representations, diff --git a/libcrux-ml-kem/src/serialize.rs b/libcrux-ml-kem/src/serialize.rs index 0d8db5d69..ecf92ddc9 100644 --- a/libcrux-ml-kem/src/serialize.rs +++ b/libcrux-ml-kem/src/serialize.rs @@ -38,7 +38,7 @@ pub(super) fn deserialize_then_decompress_message( #[inline(always)] pub(super) fn serialize_uncompressed_ring_element( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; BYTES_PER_RING_ELEMENT] { let mut serialized = [0u8; BYTES_PER_RING_ELEMENT]; for i in 0..VECTORS_IN_RING_ELEMENT { @@ -99,7 +99,7 @@ pub(super) fn deserialize_ring_elements_reduced< >( public_key: &[u8], ) -> [PolynomialRingElement; K] { - let mut deserialized_pk = [PolynomialRingElement::::ZERO(); K]; + let mut deserialized_pk = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); cloop! { for (i, ring_element) in public_key .chunks_exact(BYTES_PER_RING_ELEMENT) @@ -113,7 +113,7 @@ pub(super) fn deserialize_ring_elements_reduced< #[inline(always)] fn compress_then_serialize_10( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; OUT_LEN] { let mut serialized = [0u8; OUT_LEN]; for i in 0..VECTORS_IN_RING_ELEMENT { @@ -128,7 +128,7 @@ fn compress_then_serialize_10( #[inline(always)] fn compress_then_serialize_11( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; OUT_LEN] { let mut serialized = [0u8; OUT_LEN]; for i in 0..VECTORS_IN_RING_ELEMENT { @@ -147,7 +147,7 @@ pub(super) fn compress_then_serialize_ring_element_u< const OUT_LEN: usize, Vector: Operations, >( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; OUT_LEN] { hax_debug_assert!((COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN); From a885b69b7b9e8b9383fe7d6ba9d1ad6342ec36b5 Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Wed, 15 May 2024 06:36:00 +0200 Subject: [PATCH 31/59] more mutability --- libcrux-ml-kem/src/ind_cpa.rs | 19 +++++++++---------- libcrux-ml-kem/src/serialize.rs | 22 ++++++++++------------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index a7c9db163..0fb0e7f08 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -200,8 +200,8 @@ fn compress_then_serialize_u< Vector: Operations, >( input: [PolynomialRingElement; K], -) -> [u8; OUT_LEN] { - let mut out = [0u8; OUT_LEN]; + out: &mut [u8] +) { cloop! { for (i, re) in input.into_iter().enumerate() { out[i * (OUT_LEN / K)..(i + 1) * (OUT_LEN / K)].copy_from_slice( @@ -209,8 +209,6 @@ fn compress_then_serialize_u< ); } } - - out } /// This function implements Algorithm 13 of the @@ -271,7 +269,7 @@ pub(crate) fn encrypt< public_key: &[u8], message: [u8; SHARED_SECRET_SIZE], randomness: &[u8], -) -> [u8; CIPHERTEXT_SIZE] { +) -> [u8; CIPHERTEXT_SIZE] { // tˆ := Decode_12(pk) let t_as_ntt = deserialize_ring_elements_reduced::( &public_key[..T_AS_NTT_ENCODED_SIZE], @@ -317,14 +315,15 @@ pub(crate) fn encrypt< let message_as_ring_element = deserialize_then_decompress_message(message); let v = compute_ring_element_v(&t_as_ntt, &r_as_ntt, &error_2, &message_as_ring_element); + let mut ciphertext = [0u8; CIPHERTEXT_SIZE]; + + let (c1,c2) = ciphertext.split_at_mut(C1_LEN); + // c_1 := Encode_{du}(Compress_q(u,d_u)) - let c1 = compress_then_serialize_u::(u); + compress_then_serialize_u::(u, c1); // c_2 := Encode_{dv}(Compress_q(v,d_v)) - let c2 = compress_then_serialize_ring_element_v::(v); - - let mut ciphertext: [u8; CIPHERTEXT_SIZE] = into_padded_array(&c1); - ciphertext[C1_LEN..].copy_from_slice(c2.as_slice()); + compress_then_serialize_ring_element_v::(v, c2); ciphertext } diff --git a/libcrux-ml-kem/src/serialize.rs b/libcrux-ml-kem/src/serialize.rs index ecf92ddc9..fd77d30ac 100644 --- a/libcrux-ml-kem/src/serialize.rs +++ b/libcrux-ml-kem/src/serialize.rs @@ -159,10 +159,10 @@ pub(super) fn compress_then_serialize_ring_element_u< } #[inline(always)] -fn compress_then_serialize_4( +fn compress_then_serialize_4( re: PolynomialRingElement, -) -> [u8; OUT_LEN] { - let mut serialized = [0u8; OUT_LEN]; + serialized: &mut [u8] +) { for i in 0..VECTORS_IN_RING_ELEMENT { let coefficient = Vector::compress::<4>(Vector::to_unsigned_representative(re.coefficients[i])); @@ -170,15 +170,13 @@ fn compress_then_serialize_4( let bytes = Vector::serialize_4(coefficient); serialized[8 * i..8 * i + 8].copy_from_slice(&bytes); } - serialized } #[inline(always)] -fn compress_then_serialize_5( +fn compress_then_serialize_5( re: PolynomialRingElement, -) -> [u8; OUT_LEN] { - let mut serialized = [0u8; OUT_LEN]; - + serialized: &mut [u8] +) { for i in 0..VECTORS_IN_RING_ELEMENT { let coefficients = Vector::compress::<5>(Vector::to_unsigned_representative(re.coefficients[i])); @@ -186,7 +184,6 @@ fn compress_then_serialize_5( let bytes = Vector::serialize_5(coefficients); serialized[10 * i..10 * i + 10].copy_from_slice(&bytes); } - serialized } #[inline(always)] @@ -196,12 +193,13 @@ pub(super) fn compress_then_serialize_ring_element_v< Vector: Operations, >( re: PolynomialRingElement, -) -> [u8; OUT_LEN] { + out: &mut [u8] +) { hax_debug_assert!((COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN); match COMPRESSION_FACTOR as u32 { - 4 => compress_then_serialize_4(re), - 5 => compress_then_serialize_5(re), + 4 => compress_then_serialize_4(re, out), + 5 => compress_then_serialize_5(re, out), _ => unreachable!(), } } From 393efa2f9d3c6f3643b16717a359153ad1333bdd Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Wed, 15 May 2024 06:58:42 +0200 Subject: [PATCH 32/59] made polynomial non-copy --- libcrux-ml-kem/src/polynomial.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/libcrux-ml-kem/src/polynomial.rs b/libcrux-ml-kem/src/polynomial.rs index 111a64d2f..052df8a06 100644 --- a/libcrux-ml-kem/src/polynomial.rs +++ b/libcrux-ml-kem/src/polynomial.rs @@ -16,7 +16,6 @@ pub(crate) const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 128] = pub(crate) const VECTORS_IN_RING_ELEMENT: usize = super::constants::COEFFICIENTS_IN_RING_ELEMENT / FIELD_ELEMENTS_IN_VECTOR; -//#[derive(Clone, Copy)] pub(crate) struct PolynomialRingElement { pub(crate) coefficients: [Vector; VECTORS_IN_RING_ELEMENT], } From ed7c67bf73b70536eccb56c7d8022118f17e55fc Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Wed, 15 May 2024 13:14:40 +0200 Subject: [PATCH 33/59] misc fixes - rustfmt - some feature fixes --- libcrux-ml-kem/src/hash_functions.rs | 60 +++++++-- libcrux-ml-kem/src/helper.rs | 4 +- libcrux-ml-kem/src/ind_cpa.rs | 11 +- libcrux-ml-kem/src/invert_ntt.rs | 6 +- libcrux-ml-kem/src/matrix.rs | 6 +- libcrux-ml-kem/src/ntt.rs | 12 +- libcrux-ml-kem/src/polynomial.rs | 12 +- libcrux-ml-kem/src/serialize.rs | 12 +- libcrux-sha3/Cargo.toml | 5 +- libcrux-sha3/src/lib.rs | 189 +++++++++++---------------- libcrux-sha3/src/rust_simd.rs | 12 +- polynomials-avx2/src/lib.rs | 2 + polynomials-avx2/src/sampling.rs | 7 +- 13 files changed, 175 insertions(+), 163 deletions(-) diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index eebed1c50..9a3427fad 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -2,7 +2,9 @@ use crate::constants::H_DIGEST_SIZE; -use libcrux_sha3::rust_simd::{self, KeccakState4}; +#[cfg(feature = "simd256")] +use libcrux_sha3::rust_simd::KeccakState4; +use libcrux_sha3::*; #[inline(always)] pub(crate) fn G(input: &[u8]) -> [u8; 64] { @@ -19,7 +21,6 @@ pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { rust_simd::shake256::(input) } - #[cfg(feature = "simd256")] #[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { @@ -30,22 +31,50 @@ pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> match K { 2 => { let (out0, out1) = out.split_at_mut(1); - rust_simd::shake256x4(&input[0], &input[1], &input[0], &input[0], &mut out0[0], &mut out1[0], &mut dummy_out0, &mut dummy_out1); + rust_simd::shake256x4( + &input[0], + &input[1], + &input[0], + &input[0], + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); } 3 => { let (out0, out12) = out.split_at_mut(1); let (out1, out2) = out12.split_at_mut(1); - rust_simd::shake256x4(&input[0], &input[1], &input[2], &input[0], &mut out0[0], &mut out1[0], &mut out2[0], &mut dummy_out0); + rust_simd::shake256x4( + &input[0], + &input[1], + &input[2], + &input[0], + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); } _ => { let (out0, out123) = out.split_at_mut(1); let (out1, out23) = out123.split_at_mut(1); let (out2, out3) = out23.split_at_mut(1); - rust_simd::shake256x4(&input[0], &input[1], &input[2], &input[3], &mut out0[0], &mut out1[0], &mut out2[0], &mut out3[0]); + rust_simd::shake256x4( + &input[0], + &input[1], + &input[2], + &input[3], + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); } } out } + #[cfg(feature = "simd128")] #[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { @@ -73,12 +102,14 @@ pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> } out } + #[cfg(not(any(feature = "simd128", feature = "simd256")))] -//#[inline(always)] +#[inline(always)] pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { core::array::from_fn(|i| rust_simd::shake256::(&input[i])) } +#[cfg(feature = "simd128")] pub(crate) type Shake128x4State = KeccakState4; #[cfg(feature = "simd128")] @@ -105,9 +136,7 @@ pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { #[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] -pub(crate) fn absorb( - input: [[u8; 34]; K], -) -> [libcrux_sha3::rust_simd::KeccakState1; K] { +pub(crate) fn absorb(input: [[u8; 34]; K]) -> [rust_simd::KeccakState1; K] { debug_assert!(K == 2 || K == 3 || K == 4); let mut states = [rust_simd::shake128_init(); K]; for i in 0..K { @@ -211,7 +240,7 @@ pub(crate) fn squeeze_three_blocks( #[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] pub(crate) fn squeeze_three_blocks( - state: &mut [libcrux_sha3::rust_simd::KeccakState1], + state: &mut [rust_simd::KeccakState1], ) -> [[u8; THREE_BLOCKS]; K] { let mut out = [[0u8; THREE_BLOCKS]; K]; for i in 0..K { @@ -306,7 +335,7 @@ pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8 #[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] pub(crate) fn squeeze_block( - state: &mut [libcrux_sha3::rust_simd::KeccakState1; K], + state: &mut [rust_simd::KeccakState1; K], ) -> [[u8; BLOCK_SIZE]; K] { let mut out = [[0u8; BLOCK_SIZE]; K]; for i in 0..K { @@ -365,6 +394,13 @@ pub(crate) fn squeeze_block(state: &mut KeccakState4) -> [[u8; B /// Free the memory of the state. /// /// **NOTE:** That this needs to be done manually for now. +#[cfg(not(any(feature = "simd256", feature = "simd128")))] #[inline(always)] -pub(crate) fn free_state(_xof_state: KeccakState4) {} +pub(crate) fn free_state(_xof_state: [rust_simd::KeccakState1; K]) {} +/// Free the memory of the state. +/// +/// **NOTE:** That this needs to be done manually for now. +#[cfg(any(feature = "simd256", feature = "simd128"))] +#[inline(always)] +pub(crate) fn free_state(_xof_state: KeccakState4) {} diff --git a/libcrux-ml-kem/src/helper.rs b/libcrux-ml-kem/src/helper.rs index c3fc25a82..20e850b0d 100644 --- a/libcrux-ml-kem/src/helper.rs +++ b/libcrux-ml-kem/src/helper.rs @@ -1,7 +1,7 @@ /// The following macros are defined so that the extraction from Rust to C code /// can go through. -#[cfg(not(hax))] +#[cfg(eurydice)] macro_rules! cloop { (for ($i:ident, $chunk:ident) in $val:ident.$values:ident.chunks_exact($($chunk_size:expr),*).enumerate() $body:block) => { for $i in 0..$val.$values.len() / ($($chunk_size)*) { @@ -35,7 +35,7 @@ macro_rules! cloop { }; } -#[cfg(hax)] +#[cfg(not(eurydice))] macro_rules! cloop { (for ($i:ident, $chunk:ident) in $val:ident.$values:ident.chunks_exact($($chunk_size:expr),*).enumerate() $body:block) => { for ($i, $chunk) in $val.$values.chunks_exact($($chunk_size),*).enumerate() $body diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index 0fb0e7f08..67ac0f557 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -55,7 +55,7 @@ fn serialize_secret_key( input: [PolynomialRingElement; K], - out: &mut [u8] + out: &mut [u8], ) { cloop! { for (i, re) in input.into_iter().enumerate() { out[i * (OUT_LEN / K)..(i + 1) * (OUT_LEN / K)].copy_from_slice( - &compress_then_serialize_ring_element_u::(re), + &compress_then_serialize_ring_element_u::(&re), ); } } @@ -269,7 +269,7 @@ pub(crate) fn encrypt< public_key: &[u8], message: [u8; SHARED_SECRET_SIZE], randomness: &[u8], -) -> [u8; CIPHERTEXT_SIZE] { +) -> [u8; CIPHERTEXT_SIZE] { // tˆ := Decode_12(pk) let t_as_ntt = deserialize_ring_elements_reduced::( &public_key[..T_AS_NTT_ENCODED_SIZE], @@ -316,8 +316,7 @@ pub(crate) fn encrypt< let v = compute_ring_element_v(&t_as_ntt, &r_as_ntt, &error_2, &message_as_ring_element); let mut ciphertext = [0u8; CIPHERTEXT_SIZE]; - - let (c1,c2) = ciphertext.split_at_mut(C1_LEN); + let (c1, c2) = ciphertext.split_at_mut(C1_LEN); // c_1 := Encode_{du}(Compress_q(u,d_u)) compress_then_serialize_u::(u, c1); diff --git a/libcrux-ml-kem/src/invert_ntt.rs b/libcrux-ml-kem/src/invert_ntt.rs index 5b5e01a43..589633ac1 100644 --- a/libcrux-ml-kem/src/invert_ntt.rs +++ b/libcrux-ml-kem/src/invert_ntt.rs @@ -28,7 +28,7 @@ pub(crate) fn invert_ntt_at_layer_2( zeta_i: &mut usize, re: &mut PolynomialRingElement, _layer: usize, -) { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_2_step( @@ -45,7 +45,7 @@ pub(crate) fn invert_ntt_at_layer_3( zeta_i: &mut usize, re: &mut PolynomialRingElement, _layer: usize, -) { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = @@ -94,7 +94,7 @@ pub(crate) fn invert_ntt_at_layer_4_plus( #[inline(always)] pub(crate) fn invert_ntt_montgomery( re: &mut PolynomialRingElement, -) { +) { // We only ever call this function after matrix/vector multiplication hax_debug_assert!(to_i16_array(re) .into_iter() diff --git a/libcrux-ml-kem/src/matrix.rs b/libcrux-ml-kem/src/matrix.rs index f482b7730..e8e03253b 100644 --- a/libcrux-ml-kem/src/matrix.rs +++ b/libcrux-ml-kem/src/matrix.rs @@ -11,7 +11,9 @@ pub(crate) fn sample_matrix_A( seed: [u8; 34], transpose: bool, ) -> [[PolynomialRingElement; K]; K] { - let mut A_transpose = core::array::from_fn(|_i| core::array::from_fn (|_j| PolynomialRingElement::::ZERO())); + let mut A_transpose = core::array::from_fn(|_i| { + core::array::from_fn(|_j| PolynomialRingElement::::ZERO()) + }); for i in 0..K { let mut seeds = [seed; K]; @@ -20,7 +22,7 @@ pub(crate) fn sample_matrix_A( seeds[j][33] = j as u8; } let sampled = sample_from_xof(seeds); - for (j,sample) in sampled.into_iter().enumerate() { + for (j, sample) in sampled.into_iter().enumerate() { // A[i][j] = A_transpose[j][i] if transpose { A_transpose[j][i] = sample; diff --git a/libcrux-ml-kem/src/ntt.rs b/libcrux-ml-kem/src/ntt.rs index fa45c0b74..afa17cf9e 100644 --- a/libcrux-ml-kem/src/ntt.rs +++ b/libcrux-ml-kem/src/ntt.rs @@ -10,7 +10,7 @@ pub(crate) fn ntt_at_layer_1( re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_1_step( @@ -73,7 +73,7 @@ pub(crate) fn ntt_at_layer_4_plus( re: &mut PolynomialRingElement, layer: usize, _initial_coefficient_bound: usize, -) { +) { debug_assert!(layer >= 4); let step = 1 << layer; @@ -97,9 +97,7 @@ pub(crate) fn ntt_at_layer_4_plus( } #[inline(always)] -pub(crate) fn ntt_at_layer_7( - re: &mut PolynomialRingElement, -) { +pub(crate) fn ntt_at_layer_7(re: &mut PolynomialRingElement) { let step = VECTORS_IN_RING_ELEMENT / 2; for j in 0..step { let t = Vector::multiply_by_constant(re.coefficients[j + step], -1600); @@ -111,7 +109,7 @@ pub(crate) fn ntt_at_layer_7( #[inline(always)] pub(crate) fn ntt_binomially_sampled_ring_element( re: &mut PolynomialRingElement, -) { +) { // Due to the small coefficient bound, we can skip the first round of // Montgomery reductions. ntt_at_layer_7(re); @@ -130,7 +128,7 @@ pub(crate) fn ntt_binomially_sampled_ring_element( #[inline(always)] pub(crate) fn ntt_vector_u( re: &mut PolynomialRingElement, -) { +) { hax_debug_assert!(to_i16_array(re) .into_iter() .all(|coefficient| coefficient.abs() <= 3328)); diff --git a/libcrux-ml-kem/src/polynomial.rs b/libcrux-ml-kem/src/polynomial.rs index 052df8a06..b78c0e491 100644 --- a/libcrux-ml-kem/src/polynomial.rs +++ b/libcrux-ml-kem/src/polynomial.rs @@ -86,8 +86,10 @@ impl PolynomialRingElement { let coefficient_normal_form = Vector::montgomery_multiply_by_constant(self.coefficients[j], 1441); - self.coefficients[j] = - Vector::barrett_reduce(Vector::add(coefficient_normal_form, &error.coefficients[j])); + self.coefficients[j] = Vector::barrett_reduce(Vector::add( + coefficient_normal_form, + &error.coefficients[j], + )); } } @@ -98,8 +100,10 @@ impl PolynomialRingElement { // calling to_montgomery_domain() on them should return a mod q. let coefficient_normal_form = Vector::to_standard_domain(self.coefficients[j]); - self.coefficients[j] = - Vector::barrett_reduce(Vector::add(coefficient_normal_form, &error.coefficients[j])); + self.coefficients[j] = Vector::barrett_reduce(Vector::add( + coefficient_normal_form, + &error.coefficients[j], + )); } } diff --git a/libcrux-ml-kem/src/serialize.rs b/libcrux-ml-kem/src/serialize.rs index fd77d30ac..9cc9fc9c3 100644 --- a/libcrux-ml-kem/src/serialize.rs +++ b/libcrux-ml-kem/src/serialize.rs @@ -161,8 +161,8 @@ pub(super) fn compress_then_serialize_ring_element_u< #[inline(always)] fn compress_then_serialize_4( re: PolynomialRingElement, - serialized: &mut [u8] -) { + serialized: &mut [u8], +) { for i in 0..VECTORS_IN_RING_ELEMENT { let coefficient = Vector::compress::<4>(Vector::to_unsigned_representative(re.coefficients[i])); @@ -175,8 +175,8 @@ fn compress_then_serialize_4( #[inline(always)] fn compress_then_serialize_5( re: PolynomialRingElement, - serialized: &mut [u8] -) { + serialized: &mut [u8], +) { for i in 0..VECTORS_IN_RING_ELEMENT { let coefficients = Vector::compress::<5>(Vector::to_unsigned_representative(re.coefficients[i])); @@ -193,8 +193,8 @@ pub(super) fn compress_then_serialize_ring_element_v< Vector: Operations, >( re: PolynomialRingElement, - out: &mut [u8] -) { + out: &mut [u8], +) { hax_debug_assert!((COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN); match COMPRESSION_FACTOR as u32 { diff --git a/libcrux-sha3/Cargo.toml b/libcrux-sha3/Cargo.toml index 6326ac64f..b8bb92749 100644 --- a/libcrux-sha3/Cargo.toml +++ b/libcrux-sha3/Cargo.toml @@ -9,11 +9,7 @@ repository.workspace = true readme.workspace = true [dependencies] -libcrux-hacl = { version = "0.0.2-pre.2", path = "../sys/hacl", features = [ - "sha3", -] } libcrux-platform = { version = "0.0.2-pre.2", path = "../sys/platform" } -hex = { version = "0.4.3", features = ["serde"] } # This is only required for verification. # The hax config is set by the hax toolchain. @@ -30,4 +26,5 @@ harness = false [dev-dependencies] criterion = "0.5.1" +hex = "0.4.3" rand = "0.8.5" diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 545acc9e1..4064e9e6f 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -51,130 +51,108 @@ pub const fn digest_size(mode: Algorithm) -> usize { } } -/// SHA3 -pub fn hash(algorithm: Algorithm, payload: &[u8]) -> [u8; LEN] { - debug_assert!(payload.len() <= u32::MAX as usize); - - let mut out = [0u8; LEN]; - match algorithm { - Algorithm::Sha3_224 => sha224_ema(&mut out, payload), - Algorithm::Sha3_256 => sha256_ema(&mut out, payload), - Algorithm::Sha3_384 => sha384_ema(&mut out, payload), - Algorithm::Sha3_512 => sha512_ema(&mut out, payload), - } - out -} - -use libcrux_hacl::{ - Hacl_Hash_SHA3_sha3_224, Hacl_Hash_SHA3_sha3_256, Hacl_Hash_SHA3_sha3_384, - Hacl_Hash_SHA3_sha3_512, Hacl_Hash_SHA3_shake128_hacl, Hacl_Hash_SHA3_shake256_hacl, -}; +// /// SHA3 +// pub fn hash(algorithm: Algorithm, payload: &[u8]) -> [u8; LEN] { +// debug_assert!(payload.len() <= u32::MAX as usize); + +// let mut out = [0u8; LEN]; +// match algorithm { +// Algorithm::Sha3_224 => sha224_ema(&mut out, payload), +// Algorithm::Sha3_256 => sha256_ema(&mut out, payload), +// Algorithm::Sha3_384 => sha384_ema(&mut out, payload), +// Algorithm::Sha3_512 => sha512_ema(&mut out, payload), +// } +// out +// } /// SHA3 224 #[inline(always)] -pub fn sha224(payload: &[u8]) -> [u8; 28] { - let mut digest = [0u8; 28]; - sha224_ema(&mut digest, payload); - digest +pub fn sha224(data: &[u8]) -> [u8; 28] { + rust_simd::sha3_224(data) } -/// SHA3 224 -#[inline(always)] -pub fn sha224_ema(digest: &mut [u8], payload: &[u8]) { - debug_assert!(payload.len() <= u32::MAX as usize); - debug_assert!(digest.len() == 28); - - unsafe { - Hacl_Hash_SHA3_sha3_224( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } -} +// /// SHA3 224 +// #[inline(always)] +// pub fn sha224_ema(digest: &mut [u8], payload: &[u8]) { +// debug_assert!(payload.len() <= u32::MAX as usize); +// debug_assert!(digest.len() == 28); + +// unsafe { +// Hacl_Hash_SHA3_sha3_224( +// digest.as_mut_ptr(), +// payload.as_ptr() as _, +// payload.len().try_into().unwrap(), +// ); +// } +// } /// SHA3 256 #[inline(always)] -pub fn sha256(payload: &[u8]) -> [u8; 32] { - let mut digest = [0u8; 32]; - sha256_ema(&mut digest, payload); - digest +pub fn sha256(data: &[u8]) -> [u8; 32] { + rust_simd::sha3_256(data) } -/// SHA3 256 -#[inline(always)] -pub fn sha256_ema(digest: &mut [u8], payload: &[u8]) { - debug_assert!(payload.len() <= u32::MAX as usize); - debug_assert!(digest.len() == 32); - - unsafe { - Hacl_Hash_SHA3_sha3_256( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } -} +// /// SHA3 256 +// #[inline(always)] +// pub fn sha256_ema(digest: &mut [u8], payload: &[u8]) { +// debug_assert!(payload.len() <= u32::MAX as usize); +// debug_assert!(digest.len() == 32); + +// unsafe { +// Hacl_Hash_SHA3_sha3_256( +// digest.as_mut_ptr(), +// payload.as_ptr() as _, +// payload.len().try_into().unwrap(), +// ); +// } +// } /// SHA3 384 #[inline(always)] -pub fn sha384(payload: &[u8]) -> [u8; 48] { - let mut digest = [0u8; 48]; - sha384_ema(&mut digest, payload); - digest +pub fn sha384(data: &[u8]) -> [u8; 48] { + rust_simd::sha3_384(data) } -/// SHA3 384 -#[inline(always)] -pub fn sha384_ema(digest: &mut [u8], payload: &[u8]) { - debug_assert!(payload.len() <= u32::MAX as usize); - debug_assert!(digest.len() == 48); - - unsafe { - Hacl_Hash_SHA3_sha3_384( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } -} +// /// SHA3 384 +// #[inline(always)] +// pub fn sha384_ema(digest: &mut [u8], payload: &[u8]) { +// debug_assert!(payload.len() <= u32::MAX as usize); +// debug_assert!(digest.len() == 48); + +// unsafe { +// Hacl_Hash_SHA3_sha3_384( +// digest.as_mut_ptr(), +// payload.as_ptr() as _, +// payload.len().try_into().unwrap(), +// ); +// } +// } /// SHA3 512 #[inline(always)] -pub fn sha512(payload: &[u8]) -> [u8; 64] { - let mut digest = [0u8; 64]; - sha512_ema(&mut digest, payload); - digest +pub fn sha512(data: &[u8]) -> [u8; 64] { + rust_simd::sha3_512(data) } -/// SHA3 512 -#[inline(always)] -pub fn sha512_ema(digest: &mut [u8], payload: &[u8]) { - debug_assert!(payload.len() <= u32::MAX as usize); - debug_assert!(digest.len() == 64); - - unsafe { - Hacl_Hash_SHA3_sha3_512( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } -} +// /// SHA3 512 +// #[inline(always)] +// pub fn sha512_ema(digest: &mut [u8], payload: &[u8]) { +// debug_assert!(payload.len() <= u32::MAX as usize); +// debug_assert!(digest.len() == 64); + +// unsafe { +// Hacl_Hash_SHA3_sha3_512( +// digest.as_mut_ptr(), +// payload.as_ptr() as _, +// payload.len().try_into().unwrap(), +// ); +// } +// } /// SHAKE 128 #[inline(always)] pub fn shake128(data: &[u8]) -> [u8; BYTES] { - let mut out = [0u8; BYTES]; - unsafe { - Hacl_Hash_SHA3_shake128_hacl( - data.len() as u32, - data.as_ptr() as _, - BYTES as u32, - out.as_mut_ptr(), - ); - } - out + rust_simd::shake128(data) } /// SHAKE 256 @@ -183,14 +161,5 @@ pub fn shake128(data: &[u8]) -> [u8; BYTES] { /// the output will only return `u32::MAX` bytes. #[inline(always)] pub fn shake256(data: &[u8]) -> [u8; BYTES] { - let mut out = [0u8; BYTES]; - unsafe { - Hacl_Hash_SHA3_shake256_hacl( - data.len() as u32, - data.as_ptr() as _, - BYTES as u32, - out.as_mut_ptr(), - ); - } - out + rust_simd::shake256(data) } diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs index b2c457774..9d6bf05d2 100644 --- a/libcrux-sha3/src/rust_simd.rs +++ b/libcrux-sha3/src/rust_simd.rs @@ -23,11 +23,12 @@ pub type KeccakState4 = [KeccakState2; 2]; #[cfg(feature = "simd256")] mod sha3_avx2; -#[cfg(feature = "simd256")] -#[inline(always)] -fn keccakx4(data: [&[u8]; 4], out: [&mut [u8]; 4]) { - keccak::<4, core::arch::x86_64::__m256i, RATE, DELIM>(data, out) -} +// #[cfg(feature = "simd256")] +// #[inline(always)] +// fn keccakx4(data: [&[u8]; 4], out: [&mut [u8]; 4]) { +// keccak::<4, core::arch::x86_64::__m256i, RATE, DELIM>(data, out) +// } + #[cfg(feature = "simd256")] pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; @@ -43,6 +44,7 @@ pub fn sha3_224(data: &[u8]) -> [u8; 28] { keccakx2::<144, 0x06u8>([data, data], [&mut d0, &mut d1]); d0 } + #[cfg(not(feature = "simd128"))] pub fn sha3_224(data: &[u8]) -> [u8; 28] { let mut d0 = [0u8; 28]; diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 51cac80a7..b766b6cb7 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -4,7 +4,9 @@ use core::arch::x86::*; use core::arch::x86_64::*; use libcrux_traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; +#[cfg(test)] mod debug; + mod portable; mod sampling; diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs index 09612525a..6bd7b3168 100644 --- a/polynomials-avx2/src/sampling.rs +++ b/polynomials-avx2/src/sampling.rs @@ -3,7 +3,7 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use crate::{SIMD256Vector, serialize_1, deserialize_12, FIELD_MODULUS}; +use crate::{deserialize_12, serialize_1, SIMD256Vector, FIELD_MODULUS}; const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ [ @@ -778,7 +778,10 @@ pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { let upper_coefficients = _mm256_extractf128_si256(potential_coefficients, 1); let upper_coefficients = _mm_shuffle_epi8(upper_coefficients, upper_shuffles); - _mm_storeu_si128(output.as_mut_ptr().offset(sampled_count as isize) as *mut __m128i, upper_coefficients); + _mm_storeu_si128( + output.as_mut_ptr().offset(sampled_count as isize) as *mut __m128i, + upper_coefficients, + ); sampled_count + good[1].count_ones() }; From 2079973f005d80c342cd459abbdd38163cf57746 Mon Sep 17 00:00:00 2001 From: xvzcf Date: Wed, 15 May 2024 17:19:13 +0200 Subject: [PATCH 34/59] Break out implementations into their own modules. --- polynomials-avx2/src/arithmetic.rs | 112 +++ polynomials-avx2/src/compress.rs | 131 ++++ polynomials-avx2/src/lib.rs | 1015 +--------------------------- polynomials-avx2/src/ntt.rs | 286 ++++++++ polynomials-avx2/src/sampling.rs | 6 +- polynomials-avx2/src/serialize.rs | 448 ++++++++++++ 6 files changed, 1018 insertions(+), 980 deletions(-) create mode 100644 polynomials-avx2/src/arithmetic.rs create mode 100644 polynomials-avx2/src/compress.rs create mode 100644 polynomials-avx2/src/ntt.rs create mode 100644 polynomials-avx2/src/serialize.rs diff --git a/polynomials-avx2/src/arithmetic.rs b/polynomials-avx2/src/arithmetic.rs new file mode 100644 index 000000000..0262b96a7 --- /dev/null +++ b/polynomials-avx2/src/arithmetic.rs @@ -0,0 +1,112 @@ +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use crate::SIMD256Vector; +use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; + +#[inline(always)] +pub(crate) fn add(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { + lhs.elements = unsafe { _mm256_add_epi16(lhs.elements, rhs.elements) }; + + lhs +} + +#[inline(always)] +pub(crate) fn sub(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { + lhs.elements = unsafe { _mm256_sub_epi16(lhs.elements, rhs.elements) }; + + lhs +} + +#[inline(always)] +pub(crate) fn multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { + v.elements = unsafe { + let c = _mm256_set1_epi16(c); + + _mm256_mullo_epi16(v.elements, c) + }; + + v +} + +#[inline(always)] +pub(crate) fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { + v.elements = unsafe { + let c = _mm256_set1_epi16(c); + + _mm256_and_si256(v.elements, c) + }; + + v +} + +#[inline(always)] +pub(crate) fn shift_right(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { _mm256_srai_epi16(v.elements, SHIFT_BY) }; + + v +} + +#[inline(always)] +pub(crate) fn shift_left(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { _mm256_slli_epi16(v.elements, SHIFT_BY) }; + + v +} + +#[inline(always)] +pub(crate) fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { + let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); + + let v_minus_field_modulus = _mm256_sub_epi16(v.elements, field_modulus); + + let sign_mask = _mm256_srai_epi16(v_minus_field_modulus, 15); + let conditional_add_field_modulus = _mm256_and_si256(sign_mask, field_modulus); + + _mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) + }; + + v +} + +const BARRETT_MULTIPLIER: i16 = 20159; + +#[inline(always)] +pub(crate) fn barrett_reduce(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { + let t = _mm256_mulhi_epi16(v.elements, _mm256_set1_epi16(BARRETT_MULTIPLIER)); + let t = _mm256_add_epi16(t, _mm256_set1_epi16(512)); + + let quotient = _mm256_srai_epi16(t, 10); + + let quotient_times_field_modulus = + _mm256_mullo_epi16(quotient, _mm256_set1_epi16(FIELD_MODULUS)); + + _mm256_sub_epi16(v.elements, quotient_times_field_modulus) + }; + + v +} + +#[inline(always)] +pub(crate) fn montgomery_multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { + v.elements = unsafe { + let c = _mm256_set1_epi16(c); + let value_low = _mm256_mullo_epi16(v.elements, c); + + let k = _mm256_mullo_epi16( + value_low, + _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); + + let value_high = _mm256_mulhi_epi16(v.elements, c); + + _mm256_sub_epi16(value_high, k_times_modulus) + }; + + v +} diff --git a/polynomials-avx2/src/compress.rs b/polynomials-avx2/src/compress.rs new file mode 100644 index 000000000..197a3049c --- /dev/null +++ b/polynomials-avx2/src/compress.rs @@ -0,0 +1,131 @@ +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use crate::SIMD256Vector; +use libcrux_traits::FIELD_MODULUS; + +// This implementation was taken from: +// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html +// +// TODO: Optimize this implementation if performance numbers suggest doing so. +#[inline(always)] +fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + let result = unsafe { + let prod02 = _mm256_mul_epu32(lhs, rhs); + let prod13 = _mm256_mul_epu32( + _mm256_shuffle_epi32(lhs, 0b11_11_01_01), + _mm256_shuffle_epi32(rhs, 0b11_11_01_01), + ); + + _mm256_unpackhi_epi64( + _mm256_unpacklo_epi32(prod02, prod13), + _mm256_unpackhi_epi32(prod02, prod13), + ) + }; + + result +} + +#[inline(always)] +pub(crate) fn compress_message_coefficient(mut v: SIMD256Vector) -> SIMD256Vector { + v.elements = unsafe { + let field_modulus_halved = _mm256_set1_epi16((FIELD_MODULUS - 1) / 2); + let field_modulus_quartered = _mm256_set1_epi16((FIELD_MODULUS - 1) / 4); + + let shifted = _mm256_sub_epi16(field_modulus_halved, v.elements); + let mask = _mm256_srai_epi16(shifted, 15); + + let shifted_to_positive = _mm256_xor_si256(mask, shifted); + let shifted_to_positive_in_range = + _mm256_sub_epi16(shifted_to_positive, field_modulus_quartered); + + _mm256_srli_epi16(shifted_to_positive_in_range, 15) + }; + + v +} + +#[inline(always)] +pub(crate) fn compress_ciphertext_coefficient( + mut v: SIMD256Vector, +) -> SIMD256Vector { + v.elements = unsafe { + let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); + let compression_factor = _mm256_set1_epi32(10_321_340); + let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); + + // Compress the first 8 coefficients + let coefficients_low = _mm256_castsi256_si128(v.elements); + let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); + + let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS); + let compressed_low = _mm256_add_epi32(compressed_low, field_modulus_halved); + + let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); + let compressed_low = _mm256_srli_epi32(compressed_low, 35 - 32); + let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask); + + // Compress the next 8 coefficients + let coefficients_high = _mm256_extracti128_si256(v.elements, 1); + let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); + + let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS); + let compressed_high = _mm256_add_epi32(compressed_high, field_modulus_halved); + + let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); + let compressed_high = _mm256_srli_epi32(compressed_high, 35 - 32); + let compressed_high = _mm256_and_si256(compressed_high, coefficient_bits_mask); + + // Combine them + let compressed = _mm256_packs_epi32(compressed_low, compressed_high); + + _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) + }; + + v +} + +#[inline(always)] +pub(crate) fn decompress_ciphertext_coefficient( + mut v: SIMD256Vector, +) -> SIMD256Vector { + v.elements = unsafe { + let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32); + let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS); + + // Compress the first 8 coefficients + let coefficients_low = _mm256_castsi256_si128(v.elements); + let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); + + let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus); + let decompressed_low = _mm256_slli_epi32(decompressed_low, 1); + let decompressed_low = _mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); + + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_low = _mm256_srli_epi32(decompressed_low, COEFFICIENT_BITS); + let decompressed_low = _mm256_srli_epi32(decompressed_low, 1); + + // Compress the next 8 coefficients + let coefficients_high = _mm256_extracti128_si256(v.elements, 1); + let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); + + let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus); + let decompressed_high = _mm256_slli_epi32(decompressed_high, 1); + let decompressed_high = _mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); + + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_high = _mm256_srli_epi32(decompressed_high, COEFFICIENT_BITS); + let decompressed_high = _mm256_srli_epi32(decompressed_high, 1); + + // Combine them + let compressed = _mm256_packs_epi32(decompressed_low, decompressed_high); + + _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) + }; + + v +} diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index b766b6cb7..9b80e8d99 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -2,15 +2,17 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use libcrux_traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; +use libcrux_traits::Operations; #[cfg(test)] mod debug; +mod arithmetic; +mod compress; +mod ntt; mod portable; mod sampling; - -const BARRETT_MULTIPLIER: i16 = 20159; +mod serialize; #[derive(Clone, Copy)] pub struct SIMD256Vector { @@ -41,951 +43,6 @@ fn from_i16_array(array: &[i16]) -> SIMD256Vector { } } -#[inline(always)] -fn add(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { - lhs.elements = unsafe { _mm256_add_epi16(lhs.elements, rhs.elements) }; - - lhs -} - -#[inline(always)] -fn sub(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { - lhs.elements = unsafe { _mm256_sub_epi16(lhs.elements, rhs.elements) }; - - lhs -} - -#[inline(always)] -fn multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - - _mm256_mullo_epi16(v.elements, c) - }; - - v -} - -#[inline(always)] -fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - - _mm256_and_si256(v.elements, c) - }; - - v -} - -#[inline(always)] -fn shift_right(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { _mm256_srai_epi16(v.elements, SHIFT_BY) }; - - v -} - -#[inline(always)] -fn shift_left(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { _mm256_slli_epi16(v.elements, SHIFT_BY) }; - - v -} - -#[inline(always)] -fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); - - let v_minus_field_modulus = _mm256_sub_epi16(v.elements, field_modulus); - - let sign_mask = _mm256_srai_epi16(v_minus_field_modulus, 15); - let conditional_add_field_modulus = _mm256_and_si256(sign_mask, field_modulus); - - _mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) - }; - - v -} - -#[inline(always)] -fn barrett_reduce(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let t = _mm256_mulhi_epi16(v.elements, _mm256_set1_epi16(BARRETT_MULTIPLIER)); - let t = _mm256_add_epi16(t, _mm256_set1_epi16(512)); - - let quotient = _mm256_srai_epi16(t, 10); - - let quotient_times_field_modulus = - _mm256_mullo_epi16(quotient, _mm256_set1_epi16(FIELD_MODULUS)); - - _mm256_sub_epi16(v.elements, quotient_times_field_modulus) - }; - - v -} - -#[inline(always)] -fn montgomery_multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - let value_low = _mm256_mullo_epi16(v.elements, c); - - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); - - let value_high = _mm256_mulhi_epi16(v.elements, c); - - _mm256_sub_epi16(value_high, k_times_modulus) - }; - - v -} - -#[inline(always)] -fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { - v = unsafe { - let value_low = _mm256_mullo_epi16(v, c); - - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); - - let value_high = _mm256_mulhi_epi16(v, c); - - _mm256_sub_epi16(value_high, k_times_modulus) - }; - - v -} - -#[inline(always)] -fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { - v = unsafe { - let k = _mm256_mullo_epi16( - v, - _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi32(FIELD_MODULUS as i32)); - - let value_high = _mm256_srli_epi32(v, 16); - - let result = _mm256_sub_epi16(value_high, k_times_modulus); - - let result = _mm256_slli_epi32(result, 16); - _mm256_srai_epi32(result, 16) - }; - - v -} - -#[inline(always)] -fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { - v = unsafe { - let value_low = _mm_mullo_epi16(v, c); - - let k = _mm_mullo_epi16( - value_low, - _mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm_mulhi_epi16(k, _mm_set1_epi16(FIELD_MODULUS)); - - let value_high = _mm_mulhi_epi16(v, c); - - _mm_sub_epi16(value_high, k_times_modulus) - }; - - v -} - -#[inline(always)] -fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus_halved = _mm256_set1_epi16((FIELD_MODULUS - 1) / 2); - let field_modulus_quartered = _mm256_set1_epi16((FIELD_MODULUS - 1) / 4); - - let shifted = _mm256_sub_epi16(field_modulus_halved, v.elements); - let mask = _mm256_srai_epi16(shifted, 15); - - let shifted_to_positive = _mm256_xor_si256(mask, shifted); - let shifted_to_positive_in_range = - _mm256_sub_epi16(shifted_to_positive, field_modulus_quartered); - - _mm256_srli_epi16(shifted_to_positive_in_range, 15) - }; - - v -} - -// This implementation was taken from: -// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html -// -// TODO: Optimize this implementation if performance numbers suggest doing so. -#[inline(always)] -fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { - let result = unsafe { - let prod02 = _mm256_mul_epu32(lhs, rhs); - let prod13 = _mm256_mul_epu32( - _mm256_shuffle_epi32(lhs, 0b11_11_01_01), - _mm256_shuffle_epi32(rhs, 0b11_11_01_01), - ); - - _mm256_unpackhi_epi64( - _mm256_unpacklo_epi32(prod02, prod13), - _mm256_unpackhi_epi32(prod02, prod13), - ) - }; - - result -} - -#[inline(always)] -fn compress(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); - let compression_factor = _mm256_set1_epi32(10_321_340); - let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); - - // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(v.elements); - let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); - - let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS); - let compressed_low = _mm256_add_epi32(compressed_low, field_modulus_halved); - - let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); - let compressed_low = _mm256_srli_epi32(compressed_low, 35 - 32); - let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask); - - // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(v.elements, 1); - let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); - - let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS); - let compressed_high = _mm256_add_epi32(compressed_high, field_modulus_halved); - - let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); - let compressed_high = _mm256_srli_epi32(compressed_high, 35 - 32); - let compressed_high = _mm256_and_si256(compressed_high, coefficient_bits_mask); - - // Combine them - let compressed = _mm256_packs_epi32(compressed_low, compressed_high); - - _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) - }; - - v -} - -#[inline(always)] -fn decompress_ciphertext_coefficient( - mut v: SIMD256Vector, -) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32); - let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS); - - // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(v.elements); - let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); - - let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus); - let decompressed_low = _mm256_slli_epi32(decompressed_low, 1); - let decompressed_low = _mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); - - // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack - // of support for const generic expressions. - let decompressed_low = _mm256_srli_epi32(decompressed_low, COEFFICIENT_BITS); - let decompressed_low = _mm256_srli_epi32(decompressed_low, 1); - - // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(v.elements, 1); - let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); - - let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus); - let decompressed_high = _mm256_slli_epi32(decompressed_high, 1); - let decompressed_high = _mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); - - // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack - // of support for const generic expressions. - let decompressed_high = _mm256_srli_epi32(decompressed_high, COEFFICIENT_BITS); - let decompressed_high = _mm256_srli_epi32(decompressed_high, 1); - - // Combine them - let compressed = _mm256_packs_epi32(decompressed_low, decompressed_high); - - _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) - }; - - v -} - -#[inline(always)] -fn ntt_layer_1_step( - mut v: SIMD256Vector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> SIMD256Vector { - v.elements = unsafe { - let zetas = _mm256_set_epi16( - -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, - zeta1, -zeta0, -zeta0, zeta0, zeta0, - ); - - let rhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); - let rhs = montgomery_multiply_by_constants(rhs, zetas); - - let lhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); - - _mm256_add_epi16(lhs, rhs) - }; - - v -} - -#[inline(always)] -fn ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - v.elements = unsafe { - let zetas = _mm256_set_epi16( - -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, - -zeta0, zeta0, zeta0, zeta0, zeta0, - ); - - let rhs = _mm256_shuffle_epi32(v.elements, 0b11_10_11_10); - let rhs = montgomery_multiply_by_constants(rhs, zetas); - - let lhs = _mm256_shuffle_epi32(v.elements, 0b01_00_01_00); - - _mm256_add_epi16(lhs, rhs) - }; - - v -} - -#[inline(always)] -fn ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - v.elements = unsafe { - let rhs = _mm256_extracti128_si256(v.elements, 1); - let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); - - let lhs = _mm256_castsi256_si128(v.elements); - - let lower_coefficients = _mm_add_epi16(lhs, rhs); - let upper_coefficients = _mm_sub_epi16(lhs, rhs); - - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); - - combined - }; - - v -} - -#[inline(always)] -fn inv_ntt_layer_1_step( - mut v: SIMD256Vector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); - - let rhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), - ); - - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, - ), - ); - - let sum = barrett_reduce(SIMD256Vector { elements: sum }).elements; - - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) - }; - - v -} - -#[inline(always)] -fn inv_ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_permute4x64_epi64(v.elements, 0b11_11_01_01); - - let rhs = _mm256_permute4x64_epi64(v.elements, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), - ); - - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, - ), - ); - - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) - }; - - v -} - -#[inline(always)] -fn inv_ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_extracti128_si256(v.elements, 1); - let rhs = _mm256_castsi256_si128(v.elements); - - let lower_coefficients = _mm_add_epi16(lhs, rhs); - - let upper_coefficients = _mm_sub_epi16(lhs, rhs); - let upper_coefficients = - montgomery_multiply_m128i_by_constants(upper_coefficients, _mm_set1_epi16(zeta)); - - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); - - combined - }; - - v -} - -#[inline(always)] -fn ntt_multiply( - lhs: &SIMD256Vector, - rhs: &SIMD256Vector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> SIMD256Vector { - let products = unsafe { - // Compute the first term of the product - let shuffle_with = _mm256_set_epi8( - 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, - 12, 9, 8, 5, 4, 1, 0, - ); - const PERMUTE_WITH: i32 = 0b11_01_10_00; - - // Prepare the left hand side - let lhs_shuffled = _mm256_shuffle_epi8(lhs.elements, shuffle_with); - let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); - - let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); - let lhs_evens = _mm256_cvtepi16_epi32(lhs_evens); - - let lhs_odds = _mm256_extracti128_si256(lhs_shuffled, 1); - let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); - - // Prepare the right hand side - let rhs_shuffled = _mm256_shuffle_epi8(rhs.elements, shuffle_with); - let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); - - let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); - let rhs_evens = _mm256_cvtepi16_epi32(rhs_evens); - - let rhs_odds = _mm256_extracti128_si256(rhs_shuffled, 1); - let rhs_odds = _mm256_cvtepi16_epi32(rhs_odds); - - // Start operating with them - let left = _mm256_mullo_epi32(lhs_evens, rhs_evens); - - let right = _mm256_mullo_epi32(lhs_odds, rhs_odds); - let right = montgomery_reduce_i32s(right); - let right = _mm256_mullo_epi32( - right, - _mm256_set_epi32( - -(zeta3 as i32), - zeta3 as i32, - -(zeta2 as i32), - zeta2 as i32, - -(zeta1 as i32), - zeta1 as i32, - -(zeta0 as i32), - zeta0 as i32, - ), - ); - - let products_left = _mm256_add_epi32(left, right); - let products_left = montgomery_reduce_i32s(products_left); - - // Compute the second term of the product - let rhs_adjacent_swapped = _mm256_shuffle_epi8( - rhs.elements, - _mm256_set_epi8( - 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, - 5, 4, 7, 6, 1, 0, 3, 2, - ), - ); - let products_right = _mm256_madd_epi16(lhs.elements, rhs_adjacent_swapped); - let products_right = montgomery_reduce_i32s(products_right); - let products_right = _mm256_slli_epi32(products_right, 16); - - // Combine them into one vector - _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) - }; - - SIMD256Vector { elements: products } -} - -#[inline(always)] -fn serialize_1(v: SIMD256Vector) -> [u8; 2] { - let mut serialized = [0u8; 2]; - - let bits_packed = unsafe { - let lsb_shifted_up = _mm256_slli_epi16(v.elements, 15); - - let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); - let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); - - let msbs = _mm_packs_epi16(low_lanes, high_lanes); - - _mm_movemask_epi8(msbs) - }; - - serialized[0] = bits_packed as u8; - serialized[1] = (bits_packed >> 8) as u8; - - serialized -} - -#[inline(always)] -fn deserialize_1(bytes: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsb_to_msb = _mm256_set_epi16( - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - ); - - let coefficients = _mm256_set_epi16( - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsb_to_msb); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 7); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn serialize_4(v: SIMD256Vector) -> [u8; 8] { - let mut serialized = [0u8; 16]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - ), - ); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_2_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, - ), - ); - - let combined = _mm256_permutevar8x32_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), - ); - let combined = _mm256_castsi256_si128(combined); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, combined); - } - - serialized[0..8].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_4(bytes: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - - let coefficients = _mm256_set_epi16( - bytes[7] as i16, - bytes[7] as i16, - bytes[6] as i16, - bytes[6] as i16, - bytes[5] as i16, - bytes[5] as i16, - bytes[4] as i16, - bytes[4] as i16, - bytes[3] as i16, - bytes[3] as i16, - bytes[2] as i16, - bytes[2] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 4); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn serialize_5(v: SIMD256Vector) -> [u8; 10] { - let mut serialized = [0u8; 32]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 22); - - let adjacent_8_combined = _mm256_shuffle_epi32(adjacent_4_combined, 0b00_00_10_00); - let adjacent_8_combined = _mm256_sllv_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_8_combined = _mm256_srli_epi64(adjacent_8_combined, 12); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(5) as *mut __m128i, upper_8); - } - - serialized[0..10].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_5(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_5(v); - - from_i16_array(&portable::to_i16_array(output)) -} - -#[inline(always)] -fn serialize_10(v: SIMD256Vector) -> [u8; 20] { - let mut serialized = [0u8; 32]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 12); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, - 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(10) as *mut __m128i, upper_8); - } - - serialized[0..20].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_10(v: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - ); - - let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(4) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 6); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 10) - 1)); - - coefficients - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn serialize_11(v: SIMD256Vector) -> [u8; 22] { - let input = portable::from_i16_array(to_i16_array(v)); - - portable::serialize_11(input) -} - -#[inline(always)] -fn deserialize_11(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_11(v); - - from_i16_array(&portable::to_i16_array(output)) -} - -#[inline(always)] -fn serialize_12(v: SIMD256Vector) -> [u8; 24] { - let mut serialized = [0u8; 32]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 8); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, - 10, 9, 8, 5, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(12) as *mut __m128i, upper_8); - } - - serialized[0..24].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_12(v: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - - let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(8) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 4); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 12) - 1)); - - coefficients - }; - - SIMD256Vector { - elements: deserialized, - } -} - impl Operations for SIMD256Vector { fn ZERO() -> Self { zero() @@ -1000,75 +57,75 @@ impl Operations for SIMD256Vector { } fn add(lhs: Self, rhs: &Self) -> Self { - add(lhs, rhs) + arithmetic::add(lhs, rhs) } fn sub(lhs: Self, rhs: &Self) -> Self { - sub(lhs, rhs) + arithmetic::sub(lhs, rhs) } fn multiply_by_constant(v: Self, c: i16) -> Self { - multiply_by_constant(v, c) + arithmetic::multiply_by_constant(v, c) } fn bitwise_and_with_constant(v: Self, c: i16) -> Self { - bitwise_and_with_constant(v, c) + arithmetic::bitwise_and_with_constant(v, c) } fn shift_right(v: Self) -> Self { - shift_right::<{ SHIFT_BY }>(v) + arithmetic::shift_right::<{ SHIFT_BY }>(v) } fn shift_left(v: Self) -> Self { - shift_left::<{ SHIFT_BY }>(v) + arithmetic::shift_left::<{ SHIFT_BY }>(v) } fn cond_subtract_3329(v: Self) -> Self { - cond_subtract_3329(v) + arithmetic::cond_subtract_3329(v) } fn barrett_reduce(v: Self) -> Self { - barrett_reduce(v) + arithmetic::barrett_reduce(v) } fn montgomery_multiply_by_constant(v: Self, r: i16) -> Self { - montgomery_multiply_by_constant(v, r) + arithmetic::montgomery_multiply_by_constant(v, r) } fn compress_1(v: Self) -> Self { - compress_1(v) + compress::compress_message_coefficient(v) } fn compress(v: Self) -> Self { - compress::(v) + compress::compress_ciphertext_coefficient::(v) } fn decompress_ciphertext_coefficient(v: Self) -> Self { - decompress_ciphertext_coefficient::(v) + compress::decompress_ciphertext_coefficient::(v) } fn ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { - ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) + ntt::ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) } fn ntt_layer_2_step(a: Self, zeta0: i16, zeta1: i16) -> Self { - ntt_layer_2_step(a, zeta0, zeta1) + ntt::ntt_layer_2_step(a, zeta0, zeta1) } fn ntt_layer_3_step(a: Self, zeta: i16) -> Self { - ntt_layer_3_step(a, zeta) + ntt::ntt_layer_3_step(a, zeta) } fn inv_ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { - inv_ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) + ntt::inv_ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) } fn inv_ntt_layer_2_step(a: Self, zeta0: i16, zeta1: i16) -> Self { - inv_ntt_layer_2_step(a, zeta0, zeta1) + ntt::inv_ntt_layer_2_step(a, zeta0, zeta1) } fn inv_ntt_layer_3_step(a: Self, zeta: i16) -> Self { - inv_ntt_layer_3_step(a, zeta) + ntt::inv_ntt_layer_3_step(a, zeta) } fn ntt_multiply( @@ -1079,55 +136,55 @@ impl Operations for SIMD256Vector { zeta2: i16, zeta3: i16, ) -> Self { - ntt_multiply(lhs, rhs, zeta0, zeta1, zeta2, zeta3) + ntt::ntt_multiply(lhs, rhs, zeta0, zeta1, zeta2, zeta3) } fn serialize_1(a: Self) -> [u8; 2] { - serialize_1(a) + serialize::serialize_1(a) } fn deserialize_1(a: &[u8]) -> Self { - deserialize_1(a) + serialize::deserialize_1(a) } fn serialize_4(a: Self) -> [u8; 8] { - serialize_4(a) + serialize::serialize_4(a) } fn deserialize_4(a: &[u8]) -> Self { - deserialize_4(a) + serialize::deserialize_4(a) } fn serialize_5(a: Self) -> [u8; 10] { - serialize_5(a) + serialize::serialize_5(a) } fn deserialize_5(a: &[u8]) -> Self { - deserialize_5(a) + serialize::deserialize_5(a) } fn serialize_10(a: Self) -> [u8; 20] { - serialize_10(a) + serialize::serialize_10(a) } fn deserialize_10(a: &[u8]) -> Self { - deserialize_10(a) + serialize::deserialize_10(a) } fn serialize_11(a: Self) -> [u8; 22] { - serialize_11(a) + serialize::serialize_11(a) } fn deserialize_11(a: &[u8]) -> Self { - deserialize_11(a) + serialize::deserialize_11(a) } fn serialize_12(a: Self) -> [u8; 24] { - serialize_12(a) + serialize::serialize_12(a) } fn deserialize_12(a: &[u8]) -> Self { - deserialize_12(a) + serialize::deserialize_12(a) } fn rej_sample(input: &[u8], output: &mut [i16]) -> usize { diff --git a/polynomials-avx2/src/ntt.rs b/polynomials-avx2/src/ntt.rs new file mode 100644 index 000000000..cca8f1447 --- /dev/null +++ b/polynomials-avx2/src/ntt.rs @@ -0,0 +1,286 @@ +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use crate::{arithmetic, SIMD256Vector}; +use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; + +#[inline(always)] +pub(crate) fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { + v = unsafe { + let value_low = _mm256_mullo_epi16(v, c); + + let k = _mm256_mullo_epi16( + value_low, + _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); + + let value_high = _mm256_mulhi_epi16(v, c); + + _mm256_sub_epi16(value_high, k_times_modulus) + }; + + v +} + +#[inline(always)] +pub(crate) fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { + v = unsafe { + let k = _mm256_mullo_epi16( + v, + _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), + ); + let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi32(FIELD_MODULUS as i32)); + + let value_high = _mm256_srli_epi32(v, 16); + + let result = _mm256_sub_epi16(value_high, k_times_modulus); + + let result = _mm256_slli_epi32(result, 16); + _mm256_srai_epi32(result, 16) + }; + + v +} + +#[inline(always)] +pub(crate) fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { + v = unsafe { + let value_low = _mm_mullo_epi16(v, c); + + let k = _mm_mullo_epi16( + value_low, + _mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = _mm_mulhi_epi16(k, _mm_set1_epi16(FIELD_MODULUS)); + + let value_high = _mm_mulhi_epi16(v, c); + + _mm_sub_epi16(value_high, k_times_modulus) + }; + + v +} + +#[inline(always)] +pub(crate) fn ntt_layer_1_step( + mut v: SIMD256Vector, + zeta0: i16, + zeta1: i16, + zeta2: i16, + zeta3: i16, +) -> SIMD256Vector { + v.elements = unsafe { + let zetas = _mm256_set_epi16( + -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, + zeta1, -zeta0, -zeta0, zeta0, zeta0, + ); + + let rhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); + let rhs = montgomery_multiply_by_constants(rhs, zetas); + + let lhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); + + _mm256_add_epi16(lhs, rhs) + }; + + v +} + +#[inline(always)] +pub(crate) fn ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { + v.elements = unsafe { + let zetas = _mm256_set_epi16( + -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, + -zeta0, zeta0, zeta0, zeta0, zeta0, + ); + + let rhs = _mm256_shuffle_epi32(v.elements, 0b11_10_11_10); + let rhs = montgomery_multiply_by_constants(rhs, zetas); + + let lhs = _mm256_shuffle_epi32(v.elements, 0b01_00_01_00); + + _mm256_add_epi16(lhs, rhs) + }; + + v +} + +#[inline(always)] +pub(crate) fn ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { + v.elements = unsafe { + let rhs = _mm256_extracti128_si256(v.elements, 1); + let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); + + let lhs = _mm256_castsi256_si128(v.elements); + + let lower_coefficients = _mm_add_epi16(lhs, rhs); + let upper_coefficients = _mm_sub_epi16(lhs, rhs); + + let combined = _mm256_castsi128_si256(lower_coefficients); + let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + + combined + }; + + v +} + +#[inline(always)] +pub(crate) fn inv_ntt_layer_1_step( + mut v: SIMD256Vector, + zeta0: i16, + zeta1: i16, + zeta2: i16, + zeta3: i16, +) -> SIMD256Vector { + v.elements = unsafe { + let lhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); + + let rhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); + let rhs = _mm256_mullo_epi16( + rhs, + _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), + ); + + let sum = _mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + _mm256_set_epi16( + zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, + ), + ); + + let sum = arithmetic::barrett_reduce(SIMD256Vector { elements: sum }).elements; + + _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) + }; + + v +} + +#[inline(always)] +pub(crate) fn inv_ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { + v.elements = unsafe { + let lhs = _mm256_permute4x64_epi64(v.elements, 0b11_11_01_01); + + let rhs = _mm256_permute4x64_epi64(v.elements, 0b10_10_00_00); + let rhs = _mm256_mullo_epi16( + rhs, + _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), + ); + + let sum = _mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + _mm256_set_epi16( + zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, + ), + ); + + _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) + }; + + v +} + +#[inline(always)] +pub(crate) fn inv_ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { + v.elements = unsafe { + let lhs = _mm256_extracti128_si256(v.elements, 1); + let rhs = _mm256_castsi256_si128(v.elements); + + let lower_coefficients = _mm_add_epi16(lhs, rhs); + + let upper_coefficients = _mm_sub_epi16(lhs, rhs); + let upper_coefficients = + montgomery_multiply_m128i_by_constants(upper_coefficients, _mm_set1_epi16(zeta)); + + let combined = _mm256_castsi128_si256(lower_coefficients); + let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + + combined + }; + + v +} + +#[inline(always)] +pub(crate) fn ntt_multiply( + lhs: &SIMD256Vector, + rhs: &SIMD256Vector, + zeta0: i16, + zeta1: i16, + zeta2: i16, + zeta3: i16, +) -> SIMD256Vector { + let products = unsafe { + // Compute the first term of the product + let shuffle_with = _mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, + 12, 9, 8, 5, 4, 1, 0, + ); + const PERMUTE_WITH: i32 = 0b11_01_10_00; + + // Prepare the left hand side + let lhs_shuffled = _mm256_shuffle_epi8(lhs.elements, shuffle_with); + let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); + + let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); + let lhs_evens = _mm256_cvtepi16_epi32(lhs_evens); + + let lhs_odds = _mm256_extracti128_si256(lhs_shuffled, 1); + let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); + + // Prepare the right hand side + let rhs_shuffled = _mm256_shuffle_epi8(rhs.elements, shuffle_with); + let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); + + let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); + let rhs_evens = _mm256_cvtepi16_epi32(rhs_evens); + + let rhs_odds = _mm256_extracti128_si256(rhs_shuffled, 1); + let rhs_odds = _mm256_cvtepi16_epi32(rhs_odds); + + // Start operating with them + let left = _mm256_mullo_epi32(lhs_evens, rhs_evens); + + let right = _mm256_mullo_epi32(lhs_odds, rhs_odds); + let right = montgomery_reduce_i32s(right); + let right = _mm256_mullo_epi32( + right, + _mm256_set_epi32( + -(zeta3 as i32), + zeta3 as i32, + -(zeta2 as i32), + zeta2 as i32, + -(zeta1 as i32), + zeta1 as i32, + -(zeta0 as i32), + zeta0 as i32, + ), + ); + + let products_left = _mm256_add_epi32(left, right); + let products_left = montgomery_reduce_i32s(products_left); + + // Compute the second term of the product + let rhs_adjacent_swapped = _mm256_shuffle_epi8( + rhs.elements, + _mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, + 5, 4, 7, 6, 1, 0, 3, 2, + ), + ); + let products_right = _mm256_madd_epi16(lhs.elements, rhs_adjacent_swapped); + let products_right = montgomery_reduce_i32s(products_right); + let products_right = _mm256_slli_epi32(products_right, 16); + + // Combine them into one vector + _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) + }; + + SIMD256Vector { elements: products } +} diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs index 6bd7b3168..5ded20880 100644 --- a/polynomials-avx2/src/sampling.rs +++ b/polynomials-avx2/src/sampling.rs @@ -3,7 +3,11 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use crate::{deserialize_12, serialize_1, SIMD256Vector, FIELD_MODULUS}; +use crate::{ + serialize::{deserialize_12, serialize_1}, + SIMD256Vector, +}; +use libcrux_traits::FIELD_MODULUS; const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ [ diff --git a/polynomials-avx2/src/serialize.rs b/polynomials-avx2/src/serialize.rs new file mode 100644 index 000000000..b88988a2d --- /dev/null +++ b/polynomials-avx2/src/serialize.rs @@ -0,0 +1,448 @@ +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use crate::portable; +use crate::SIMD256Vector; + +#[inline(always)] +pub(crate) fn serialize_1(v: SIMD256Vector) -> [u8; 2] { + let mut serialized = [0u8; 2]; + + let bits_packed = unsafe { + let lsb_shifted_up = _mm256_slli_epi16(v.elements, 15); + + let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); + let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); + + let msbs = _mm_packs_epi16(low_lanes, high_lanes); + + _mm_movemask_epi8(msbs) + }; + + serialized[0] = bits_packed as u8; + serialized[1] = (bits_packed >> 8) as u8; + + serialized +} + +#[inline(always)] +pub(crate) fn deserialize_1(bytes: &[u8]) -> SIMD256Vector { + let deserialized = unsafe { + let shift_lsb_to_msb = _mm256_set_epi16( + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + ); + + let coefficients = _mm256_set_epi16( + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsb_to_msb); + let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 7); + + _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) + }; + + SIMD256Vector { + elements: deserialized, + } +} + +#[inline(always)] +pub(crate) fn serialize_4(v: SIMD256Vector) -> [u8; 8] { + let mut serialized = [0u8; 16]; + + unsafe { + let adjacent_2_combined = _mm256_madd_epi16( + v.elements, + _mm256_set_epi16( + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + ), + ); + + let adjacent_8_combined = _mm256_shuffle_epi8( + adjacent_2_combined, + _mm256_set_epi8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, + ), + ); + + let combined = _mm256_permutevar8x32_epi32( + adjacent_8_combined, + _mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), + ); + let combined = _mm256_castsi256_si128(combined); + + _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, combined); + } + + serialized[0..8].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_4(bytes: &[u8]) -> SIMD256Vector { + let deserialized = unsafe { + let shift_lsbs_to_msbs = _mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let coefficients = _mm256_set_epi16( + bytes[7] as i16, + bytes[7] as i16, + bytes[6] as i16, + bytes[6] as i16, + bytes[5] as i16, + bytes[5] as i16, + bytes[4] as i16, + bytes[4] as i16, + bytes[3] as i16, + bytes[3] as i16, + bytes[2] as i16, + bytes[2] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 4); + + _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) + }; + + SIMD256Vector { + elements: deserialized, + } +} + +#[inline(always)] +pub(crate) fn serialize_5(v: SIMD256Vector) -> [u8; 10] { + let mut serialized = [0u8; 32]; + + unsafe { + let adjacent_2_combined = _mm256_madd_epi16( + v.elements, + _mm256_set_epi16( + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + ), + ); + + let adjacent_4_combined = _mm256_sllv_epi32( + adjacent_2_combined, + _mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), + ); + let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 22); + + let adjacent_8_combined = _mm256_shuffle_epi32(adjacent_4_combined, 0b00_00_10_00); + let adjacent_8_combined = _mm256_sllv_epi32( + adjacent_8_combined, + _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_8_combined = _mm256_srli_epi64(adjacent_8_combined, 12); + + let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); + + _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); + _mm_storeu_si128(serialized.as_mut_ptr().offset(5) as *mut __m128i, upper_8); + } + + serialized[0..10].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_5(v: &[u8]) -> SIMD256Vector { + let output = portable::deserialize_5(v); + + crate::from_i16_array(&portable::to_i16_array(output)) +} + +#[inline(always)] +pub(crate) fn serialize_10(v: SIMD256Vector) -> [u8; 20] { + let mut serialized = [0u8; 32]; + + unsafe { + let adjacent_2_combined = _mm256_madd_epi16( + v.elements, + _mm256_set_epi16( + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + ), + ); + + let adjacent_4_combined = _mm256_sllv_epi32( + adjacent_2_combined, + _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 12); + + let adjacent_8_combined = _mm256_shuffle_epi8( + adjacent_4_combined, + _mm256_set_epi8( + -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, + 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); + + _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); + _mm_storeu_si128(serialized.as_mut_ptr().offset(10) as *mut __m128i, upper_8); + } + + serialized[0..20].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_10(v: &[u8]) -> SIMD256Vector { + let deserialized = unsafe { + let shift_lsbs_to_msbs = _mm256_set_epi16( + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + ); + + let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); + let lower_coefficients = _mm_shuffle_epi8( + lower_coefficients, + _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), + ); + let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(4) as *const __m128i); + let upper_coefficients = _mm_shuffle_epi8( + upper_coefficients, + _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), + ); + + let coefficients = _mm256_castsi128_si256(lower_coefficients); + let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); + + let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = _mm256_srli_epi16(coefficients, 6); + let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 10) - 1)); + + coefficients + }; + + SIMD256Vector { + elements: deserialized, + } +} + +#[inline(always)] +pub(crate) fn serialize_11(v: SIMD256Vector) -> [u8; 22] { + let input = portable::from_i16_array(crate::to_i16_array(v)); + + portable::serialize_11(input) +} + +#[inline(always)] +pub(crate) fn deserialize_11(v: &[u8]) -> SIMD256Vector { + let output = portable::deserialize_11(v); + + crate::from_i16_array(&portable::to_i16_array(output)) +} + +#[inline(always)] +pub(crate) fn serialize_12(v: SIMD256Vector) -> [u8; 24] { + let mut serialized = [0u8; 32]; + + unsafe { + let adjacent_2_combined = _mm256_madd_epi16( + v.elements, + _mm256_set_epi16( + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + ), + ); + + let adjacent_4_combined = _mm256_sllv_epi32( + adjacent_2_combined, + _mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), + ); + let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 8); + + let adjacent_8_combined = _mm256_shuffle_epi8( + adjacent_4_combined, + _mm256_set_epi8( + -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, + 10, 9, 8, 5, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); + + _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); + _mm_storeu_si128(serialized.as_mut_ptr().offset(12) as *mut __m128i, upper_8); + } + + serialized[0..24].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_12(v: &[u8]) -> SIMD256Vector { + let deserialized = unsafe { + let shift_lsbs_to_msbs = _mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); + let lower_coefficients = _mm_shuffle_epi8( + lower_coefficients, + _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), + ); + let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(8) as *const __m128i); + let upper_coefficients = _mm_shuffle_epi8( + upper_coefficients, + _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), + ); + + let coefficients = _mm256_castsi128_si256(lower_coefficients); + let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); + + let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = _mm256_srli_epi16(coefficients, 4); + let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 12) - 1)); + + coefficients + }; + + SIMD256Vector { + elements: deserialized, + } +} From 7846cd25bf04f2c552beeccdb3f7ed306b9426ae Mon Sep 17 00:00:00 2001 From: xvzcf Date: Wed, 15 May 2024 19:14:19 +0200 Subject: [PATCH 35/59] Sugar and desugar the SIMD256Vector struct only in lib.rs --- polynomials-avx2/src/arithmetic.rs | 71 ++++++------ polynomials-avx2/src/compress.rs | 33 +++--- polynomials-avx2/src/lib.rs | 170 +++++++++++++++++++---------- polynomials-avx2/src/ntt.rs | 92 ++++++++-------- polynomials-avx2/src/sampling.rs | 11 +- polynomials-avx2/src/serialize.rs | 79 ++++++-------- 6 files changed, 238 insertions(+), 218 deletions(-) diff --git a/polynomials-avx2/src/arithmetic.rs b/polynomials-avx2/src/arithmetic.rs index 0262b96a7..3115867bf 100644 --- a/polynomials-avx2/src/arithmetic.rs +++ b/polynomials-avx2/src/arithmetic.rs @@ -3,65 +3,56 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use crate::SIMD256Vector; use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; #[inline(always)] -pub(crate) fn add(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { - lhs.elements = unsafe { _mm256_add_epi16(lhs.elements, rhs.elements) }; +pub(crate) fn add(mut lhs: __m256i, rhs: __m256i) -> __m256i { + lhs = unsafe { _mm256_add_epi16(lhs, rhs) }; lhs } #[inline(always)] -pub(crate) fn sub(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { - lhs.elements = unsafe { _mm256_sub_epi16(lhs.elements, rhs.elements) }; +pub(crate) fn sub(mut lhs: __m256i, rhs: __m256i) -> __m256i { + lhs = unsafe { _mm256_sub_epi16(lhs, rhs) }; lhs } #[inline(always)] -pub(crate) fn multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); +pub(crate) fn multiply_by_constant(mut vector: __m256i, constant: i16) -> __m256i { + vector = unsafe { _mm256_mullo_epi16(vector, _mm256_set1_epi16(constant)) }; - _mm256_mullo_epi16(v.elements, c) - }; - - v + vector } #[inline(always)] -pub(crate) fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - - _mm256_and_si256(v.elements, c) - }; +pub(crate) fn bitwise_and_with_constant(mut vector: __m256i, constant: i16) -> __m256i { + vector = unsafe { _mm256_and_si256(vector, _mm256_set1_epi16(constant)) }; - v + vector } #[inline(always)] -pub(crate) fn shift_right(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { _mm256_srai_epi16(v.elements, SHIFT_BY) }; +pub(crate) fn shift_right(mut vector: __m256i) -> __m256i { + vector = unsafe { _mm256_srai_epi16(vector, SHIFT_BY) }; - v + vector } #[inline(always)] -pub(crate) fn shift_left(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { _mm256_slli_epi16(v.elements, SHIFT_BY) }; +pub(crate) fn shift_left(mut vector: __m256i) -> __m256i { + vector = unsafe { _mm256_slli_epi16(vector, SHIFT_BY) }; - v + vector } #[inline(always)] -pub(crate) fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { +pub(crate) fn cond_subtract_3329(mut vector: __m256i) -> __m256i { + vector = unsafe { let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); - let v_minus_field_modulus = _mm256_sub_epi16(v.elements, field_modulus); + let v_minus_field_modulus = _mm256_sub_epi16(vector, field_modulus); let sign_mask = _mm256_srai_epi16(v_minus_field_modulus, 15); let conditional_add_field_modulus = _mm256_and_si256(sign_mask, field_modulus); @@ -69,15 +60,15 @@ pub(crate) fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { _mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) }; - v + vector } const BARRETT_MULTIPLIER: i16 = 20159; #[inline(always)] -pub(crate) fn barrett_reduce(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let t = _mm256_mulhi_epi16(v.elements, _mm256_set1_epi16(BARRETT_MULTIPLIER)); +pub(crate) fn barrett_reduce(mut vector: __m256i) -> __m256i { + vector = unsafe { + let t = _mm256_mulhi_epi16(vector, _mm256_set1_epi16(BARRETT_MULTIPLIER)); let t = _mm256_add_epi16(t, _mm256_set1_epi16(512)); let quotient = _mm256_srai_epi16(t, 10); @@ -85,17 +76,17 @@ pub(crate) fn barrett_reduce(mut v: SIMD256Vector) -> SIMD256Vector { let quotient_times_field_modulus = _mm256_mullo_epi16(quotient, _mm256_set1_epi16(FIELD_MODULUS)); - _mm256_sub_epi16(v.elements, quotient_times_field_modulus) + _mm256_sub_epi16(vector, quotient_times_field_modulus) }; - v + vector } #[inline(always)] -pub(crate) fn montgomery_multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - let value_low = _mm256_mullo_epi16(v.elements, c); +pub(crate) fn montgomery_multiply_by_constant(mut vector: __m256i, constant: i16) -> __m256i { + vector = unsafe { + let constant = _mm256_set1_epi16(constant); + let value_low = _mm256_mullo_epi16(vector, constant); let k = _mm256_mullo_epi16( value_low, @@ -103,10 +94,10 @@ pub(crate) fn montgomery_multiply_by_constant(mut v: SIMD256Vector, c: i16) -> S ); let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); - let value_high = _mm256_mulhi_epi16(v.elements, c); + let value_high = _mm256_mulhi_epi16(vector, constant); _mm256_sub_epi16(value_high, k_times_modulus) }; - v + vector } diff --git a/polynomials-avx2/src/compress.rs b/polynomials-avx2/src/compress.rs index 197a3049c..8f4e239fe 100644 --- a/polynomials-avx2/src/compress.rs +++ b/polynomials-avx2/src/compress.rs @@ -3,7 +3,6 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use crate::SIMD256Vector; use libcrux_traits::FIELD_MODULUS; // This implementation was taken from: @@ -29,12 +28,12 @@ fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { } #[inline(always)] -pub(crate) fn compress_message_coefficient(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { +pub(crate) fn compress_message_coefficient(mut vector: __m256i) -> __m256i { + vector = unsafe { let field_modulus_halved = _mm256_set1_epi16((FIELD_MODULUS - 1) / 2); let field_modulus_quartered = _mm256_set1_epi16((FIELD_MODULUS - 1) / 4); - let shifted = _mm256_sub_epi16(field_modulus_halved, v.elements); + let shifted = _mm256_sub_epi16(field_modulus_halved, vector); let mask = _mm256_srai_epi16(shifted, 15); let shifted_to_positive = _mm256_xor_si256(mask, shifted); @@ -44,20 +43,20 @@ pub(crate) fn compress_message_coefficient(mut v: SIMD256Vector) -> SIMD256Vecto _mm256_srli_epi16(shifted_to_positive_in_range, 15) }; - v + vector } #[inline(always)] pub(crate) fn compress_ciphertext_coefficient( - mut v: SIMD256Vector, -) -> SIMD256Vector { - v.elements = unsafe { + mut vector: __m256i, +) -> __m256i { + vector = unsafe { let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); let compression_factor = _mm256_set1_epi32(10_321_340); let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(v.elements); + let coefficients_low = _mm256_castsi256_si128(vector); let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS); @@ -68,7 +67,7 @@ pub(crate) fn compress_ciphertext_coefficient( let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask); // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(v.elements, 1); + let coefficients_high = _mm256_extracti128_si256(vector, 1); let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS); @@ -84,19 +83,19 @@ pub(crate) fn compress_ciphertext_coefficient( _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) }; - v + vector } #[inline(always)] pub(crate) fn decompress_ciphertext_coefficient( - mut v: SIMD256Vector, -) -> SIMD256Vector { - v.elements = unsafe { + mut vector: __m256i, +) -> __m256i { + vector = unsafe { let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32); let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS); // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(v.elements); + let coefficients_low = _mm256_castsi256_si128(vector); let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus); @@ -109,7 +108,7 @@ pub(crate) fn decompress_ciphertext_coefficient( let decompressed_low = _mm256_srli_epi32(decompressed_low, 1); // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(v.elements, 1); + let coefficients_high = _mm256_extracti128_si256(vector, 1); let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus); @@ -127,5 +126,5 @@ pub(crate) fn decompress_ciphertext_coefficient( _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) }; - v + vector } diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 9b80e8d99..207f9b6f4 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -57,75 +57,115 @@ impl Operations for SIMD256Vector { } fn add(lhs: Self, rhs: &Self) -> Self { - arithmetic::add(lhs, rhs) + Self { + elements: arithmetic::add(lhs.elements, rhs.elements), + } } fn sub(lhs: Self, rhs: &Self) -> Self { - arithmetic::sub(lhs, rhs) + Self { + elements: arithmetic::sub(lhs.elements, rhs.elements), + } } fn multiply_by_constant(v: Self, c: i16) -> Self { - arithmetic::multiply_by_constant(v, c) + Self { + elements: arithmetic::multiply_by_constant(v.elements, c), + } } - fn bitwise_and_with_constant(v: Self, c: i16) -> Self { - arithmetic::bitwise_and_with_constant(v, c) + fn bitwise_and_with_constant(vector: Self, constant: i16) -> Self { + Self { + elements: arithmetic::bitwise_and_with_constant(vector.elements, constant), + } } - fn shift_right(v: Self) -> Self { - arithmetic::shift_right::<{ SHIFT_BY }>(v) + fn shift_right(vector: Self) -> Self { + Self { + elements: arithmetic::shift_right::<{ SHIFT_BY }>(vector.elements), + } } - fn shift_left(v: Self) -> Self { - arithmetic::shift_left::<{ SHIFT_BY }>(v) + fn shift_left(vector: Self) -> Self { + Self { + elements: arithmetic::shift_left::<{ SHIFT_BY }>(vector.elements), + } } - fn cond_subtract_3329(v: Self) -> Self { - arithmetic::cond_subtract_3329(v) + fn cond_subtract_3329(vector: Self) -> Self { + Self { + elements: arithmetic::cond_subtract_3329(vector.elements), + } } - fn barrett_reduce(v: Self) -> Self { - arithmetic::barrett_reduce(v) + fn barrett_reduce(vector: Self) -> Self { + Self { + elements: arithmetic::barrett_reduce(vector.elements), + } } - fn montgomery_multiply_by_constant(v: Self, r: i16) -> Self { - arithmetic::montgomery_multiply_by_constant(v, r) + fn montgomery_multiply_by_constant(vector: Self, constant: i16) -> Self { + Self { + elements: arithmetic::montgomery_multiply_by_constant(vector.elements, constant), + } } - fn compress_1(v: Self) -> Self { - compress::compress_message_coefficient(v) + fn compress_1(vector: Self) -> Self { + Self { + elements: compress::compress_message_coefficient(vector.elements), + } } - fn compress(v: Self) -> Self { - compress::compress_ciphertext_coefficient::(v) + fn compress(vector: Self) -> Self { + Self { + elements: compress::compress_ciphertext_coefficient::( + vector.elements, + ), + } } - fn decompress_ciphertext_coefficient(v: Self) -> Self { - compress::decompress_ciphertext_coefficient::(v) + fn decompress_ciphertext_coefficient(vector: Self) -> Self { + Self { + elements: compress::decompress_ciphertext_coefficient::( + vector.elements, + ), + } } - fn ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { - ntt::ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) + fn ntt_layer_1_step(vector: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { + Self { + elements: ntt::ntt_layer_1_step(vector.elements, zeta0, zeta1, zeta2, zeta3), + } } - fn ntt_layer_2_step(a: Self, zeta0: i16, zeta1: i16) -> Self { - ntt::ntt_layer_2_step(a, zeta0, zeta1) + fn ntt_layer_2_step(vector: Self, zeta0: i16, zeta1: i16) -> Self { + Self { + elements: ntt::ntt_layer_2_step(vector.elements, zeta0, zeta1), + } } - fn ntt_layer_3_step(a: Self, zeta: i16) -> Self { - ntt::ntt_layer_3_step(a, zeta) + fn ntt_layer_3_step(vector: Self, zeta: i16) -> Self { + Self { + elements: ntt::ntt_layer_3_step(vector.elements, zeta), + } } - fn inv_ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { - ntt::inv_ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) + fn inv_ntt_layer_1_step(vector: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { + Self { + elements: ntt::inv_ntt_layer_1_step(vector.elements, zeta0, zeta1, zeta2, zeta3), + } } - fn inv_ntt_layer_2_step(a: Self, zeta0: i16, zeta1: i16) -> Self { - ntt::inv_ntt_layer_2_step(a, zeta0, zeta1) + fn inv_ntt_layer_2_step(vector: Self, zeta0: i16, zeta1: i16) -> Self { + Self { + elements: ntt::inv_ntt_layer_2_step(vector.elements, zeta0, zeta1), + } } - fn inv_ntt_layer_3_step(a: Self, zeta: i16) -> Self { - ntt::inv_ntt_layer_3_step(a, zeta) + fn inv_ntt_layer_3_step(vector: Self, zeta: i16) -> Self { + Self { + elements: ntt::inv_ntt_layer_3_step(vector.elements, zeta), + } } fn ntt_multiply( @@ -136,55 +176,69 @@ impl Operations for SIMD256Vector { zeta2: i16, zeta3: i16, ) -> Self { - ntt::ntt_multiply(lhs, rhs, zeta0, zeta1, zeta2, zeta3) + Self { + elements: ntt::ntt_multiply(lhs.elements, rhs.elements, zeta0, zeta1, zeta2, zeta3), + } } - fn serialize_1(a: Self) -> [u8; 2] { - serialize::serialize_1(a) + fn serialize_1(vector: Self) -> [u8; 2] { + serialize::serialize_1(vector.elements) } - fn deserialize_1(a: &[u8]) -> Self { - serialize::deserialize_1(a) + fn deserialize_1(input: &[u8]) -> Self { + Self { + elements: serialize::deserialize_1(input), + } } - fn serialize_4(a: Self) -> [u8; 8] { - serialize::serialize_4(a) + fn serialize_4(vector: Self) -> [u8; 8] { + serialize::serialize_4(vector.elements) } - fn deserialize_4(a: &[u8]) -> Self { - serialize::deserialize_4(a) + fn deserialize_4(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_4(bytes), + } } - fn serialize_5(a: Self) -> [u8; 10] { - serialize::serialize_5(a) + fn serialize_5(vector: Self) -> [u8; 10] { + serialize::serialize_5(vector.elements) } - fn deserialize_5(a: &[u8]) -> Self { - serialize::deserialize_5(a) + fn deserialize_5(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_5(bytes), + } } - fn serialize_10(a: Self) -> [u8; 20] { - serialize::serialize_10(a) + fn serialize_10(vector: Self) -> [u8; 20] { + serialize::serialize_10(vector.elements) } - fn deserialize_10(a: &[u8]) -> Self { - serialize::deserialize_10(a) + fn deserialize_10(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_10(bytes), + } } - fn serialize_11(a: Self) -> [u8; 22] { - serialize::serialize_11(a) + fn serialize_11(vector: Self) -> [u8; 22] { + serialize::serialize_11(vector.elements) } - fn deserialize_11(a: &[u8]) -> Self { - serialize::deserialize_11(a) + fn deserialize_11(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_11(bytes), + } } - fn serialize_12(a: Self) -> [u8; 24] { - serialize::serialize_12(a) + fn serialize_12(vector: Self) -> [u8; 24] { + serialize::serialize_12(vector.elements) } - fn deserialize_12(a: &[u8]) -> Self { - serialize::deserialize_12(a) + fn deserialize_12(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_12(bytes), + } } fn rej_sample(input: &[u8], output: &mut [i16]) -> usize { diff --git a/polynomials-avx2/src/ntt.rs b/polynomials-avx2/src/ntt.rs index cca8f1447..28377dbfc 100644 --- a/polynomials-avx2/src/ntt.rs +++ b/polynomials-avx2/src/ntt.rs @@ -3,11 +3,11 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use crate::{arithmetic, SIMD256Vector}; +use crate::arithmetic; use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; #[inline(always)] -pub(crate) fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { +fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { v = unsafe { let value_low = _mm256_mullo_epi16(v, c); @@ -26,7 +26,7 @@ pub(crate) fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __ } #[inline(always)] -pub(crate) fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { +fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { v = unsafe { let k = _mm256_mullo_epi16( v, @@ -46,7 +46,7 @@ pub(crate) fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { } #[inline(always)] -pub(crate) fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { +fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { v = unsafe { let value_low = _mm_mullo_epi16(v, c); @@ -66,55 +66,55 @@ pub(crate) fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) #[inline(always)] pub(crate) fn ntt_layer_1_step( - mut v: SIMD256Vector, + mut vector: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, -) -> SIMD256Vector { - v.elements = unsafe { +) -> __m256i { + vector = unsafe { let zetas = _mm256_set_epi16( -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, zeta1, -zeta0, -zeta0, zeta0, zeta0, ); - let rhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); + let rhs = _mm256_shuffle_epi32(vector, 0b11_11_01_01); let rhs = montgomery_multiply_by_constants(rhs, zetas); - let lhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); + let lhs = _mm256_shuffle_epi32(vector, 0b10_10_00_00); _mm256_add_epi16(lhs, rhs) }; - v + vector } #[inline(always)] -pub(crate) fn ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - v.elements = unsafe { +pub(crate) fn ntt_layer_2_step(mut vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + vector = unsafe { let zetas = _mm256_set_epi16( -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, -zeta0, zeta0, zeta0, zeta0, zeta0, ); - let rhs = _mm256_shuffle_epi32(v.elements, 0b11_10_11_10); + let rhs = _mm256_shuffle_epi32(vector, 0b11_10_11_10); let rhs = montgomery_multiply_by_constants(rhs, zetas); - let lhs = _mm256_shuffle_epi32(v.elements, 0b01_00_01_00); + let lhs = _mm256_shuffle_epi32(vector, 0b01_00_01_00); _mm256_add_epi16(lhs, rhs) }; - v + vector } #[inline(always)] -pub(crate) fn ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - v.elements = unsafe { - let rhs = _mm256_extracti128_si256(v.elements, 1); +pub(crate) fn ntt_layer_3_step(mut vector: __m256i, zeta: i16) -> __m256i { + vector = unsafe { + let rhs = _mm256_extracti128_si256(vector, 1); let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); - let lhs = _mm256_castsi256_si128(v.elements); + let lhs = _mm256_castsi256_si128(vector); let lower_coefficients = _mm_add_epi16(lhs, rhs); let upper_coefficients = _mm_sub_epi16(lhs, rhs); @@ -125,21 +125,21 @@ pub(crate) fn ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector combined }; - v + vector } #[inline(always)] pub(crate) fn inv_ntt_layer_1_step( - mut v: SIMD256Vector, + mut vector: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, -) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); +) -> __m256i { + vector = unsafe { + let lhs = _mm256_shuffle_epi32(vector, 0b11_11_01_01); - let rhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); + let rhs = _mm256_shuffle_epi32(vector, 0b10_10_00_00); let rhs = _mm256_mullo_epi16( rhs, _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), @@ -153,20 +153,20 @@ pub(crate) fn inv_ntt_layer_1_step( ), ); - let sum = arithmetic::barrett_reduce(SIMD256Vector { elements: sum }).elements; + let sum = arithmetic::barrett_reduce(sum); _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) }; - v + vector } #[inline(always)] -pub(crate) fn inv_ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_permute4x64_epi64(v.elements, 0b11_11_01_01); +pub(crate) fn inv_ntt_layer_2_step(mut vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + vector = unsafe { + let lhs = _mm256_permute4x64_epi64(vector, 0b11_11_01_01); - let rhs = _mm256_permute4x64_epi64(v.elements, 0b10_10_00_00); + let rhs = _mm256_permute4x64_epi64(vector, 0b10_10_00_00); let rhs = _mm256_mullo_epi16( rhs, _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), @@ -183,14 +183,14 @@ pub(crate) fn inv_ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) }; - v + vector } #[inline(always)] -pub(crate) fn inv_ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_extracti128_si256(v.elements, 1); - let rhs = _mm256_castsi256_si128(v.elements); +pub(crate) fn inv_ntt_layer_3_step(mut vector: __m256i, zeta: i16) -> __m256i { + vector = unsafe { + let lhs = _mm256_extracti128_si256(vector, 1); + let rhs = _mm256_castsi256_si128(vector); let lower_coefficients = _mm_add_epi16(lhs, rhs); @@ -204,19 +204,19 @@ pub(crate) fn inv_ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Ve combined }; - v + vector } #[inline(always)] pub(crate) fn ntt_multiply( - lhs: &SIMD256Vector, - rhs: &SIMD256Vector, + lhs: __m256i, + rhs: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, -) -> SIMD256Vector { - let products = unsafe { +) -> __m256i { + return unsafe { // Compute the first term of the product let shuffle_with = _mm256_set_epi8( 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, @@ -225,7 +225,7 @@ pub(crate) fn ntt_multiply( const PERMUTE_WITH: i32 = 0b11_01_10_00; // Prepare the left hand side - let lhs_shuffled = _mm256_shuffle_epi8(lhs.elements, shuffle_with); + let lhs_shuffled = _mm256_shuffle_epi8(lhs, shuffle_with); let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); @@ -235,7 +235,7 @@ pub(crate) fn ntt_multiply( let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); // Prepare the right hand side - let rhs_shuffled = _mm256_shuffle_epi8(rhs.elements, shuffle_with); + let rhs_shuffled = _mm256_shuffle_epi8(rhs, shuffle_with); let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); @@ -268,19 +268,17 @@ pub(crate) fn ntt_multiply( // Compute the second term of the product let rhs_adjacent_swapped = _mm256_shuffle_epi8( - rhs.elements, + rhs, _mm256_set_epi8( 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, ), ); - let products_right = _mm256_madd_epi16(lhs.elements, rhs_adjacent_swapped); + let products_right = _mm256_madd_epi16(lhs, rhs_adjacent_swapped); let products_right = montgomery_reduce_i32s(products_right); let products_right = _mm256_slli_epi32(products_right, 16); // Combine them into one vector _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) }; - - SIMD256Vector { elements: products } } diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs index 5ded20880..aa6cc6ca6 100644 --- a/polynomials-avx2/src/sampling.rs +++ b/polynomials-avx2/src/sampling.rs @@ -3,10 +3,7 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -use crate::{ - serialize::{deserialize_12, serialize_1}, - SIMD256Vector, -}; +use crate::serialize::{deserialize_12, serialize_1}; use libcrux_traits::FIELD_MODULUS; const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ @@ -762,12 +759,10 @@ pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { let count = unsafe { let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); - let potential_coefficients = deserialize_12(input).elements; + let potential_coefficients = deserialize_12(input); let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); - let good = serialize_1(SIMD256Vector { - elements: compare_with_field_modulus, - }); + let good = serialize_1(compare_with_field_modulus); let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; let lower_shuffles = _mm_loadu_si128(lower_shuffles.as_ptr() as *const __m128i); diff --git a/polynomials-avx2/src/serialize.rs b/polynomials-avx2/src/serialize.rs index b88988a2d..39b75ea2c 100644 --- a/polynomials-avx2/src/serialize.rs +++ b/polynomials-avx2/src/serialize.rs @@ -7,11 +7,9 @@ use crate::portable; use crate::SIMD256Vector; #[inline(always)] -pub(crate) fn serialize_1(v: SIMD256Vector) -> [u8; 2] { - let mut serialized = [0u8; 2]; - +pub(crate) fn serialize_1(vector: __m256i) -> [u8; 2] { let bits_packed = unsafe { - let lsb_shifted_up = _mm256_slli_epi16(v.elements, 15); + let lsb_shifted_up = _mm256_slli_epi16(vector, 15); let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); @@ -21,6 +19,7 @@ pub(crate) fn serialize_1(v: SIMD256Vector) -> [u8; 2] { _mm_movemask_epi8(msbs) }; + let mut serialized = [0u8; 2]; serialized[0] = bits_packed as u8; serialized[1] = (bits_packed >> 8) as u8; @@ -28,8 +27,8 @@ pub(crate) fn serialize_1(v: SIMD256Vector) -> [u8; 2] { } #[inline(always)] -pub(crate) fn deserialize_1(bytes: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { +pub(crate) fn deserialize_1(bytes: &[u8]) -> __m256i { + return unsafe { let shift_lsb_to_msb = _mm256_set_epi16( 1 << 0, 1 << 1, @@ -73,19 +72,15 @@ pub(crate) fn deserialize_1(bytes: &[u8]) -> SIMD256Vector { _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) }; - - SIMD256Vector { - elements: deserialized, - } } #[inline(always)] -pub(crate) fn serialize_4(v: SIMD256Vector) -> [u8; 8] { +pub(crate) fn serialize_4(vector: __m256i) -> [u8; 8] { let mut serialized = [0u8; 16]; unsafe { let adjacent_2_combined = _mm256_madd_epi16( - v.elements, + vector, _mm256_set_epi16( 1 << 4, 1, @@ -127,8 +122,8 @@ pub(crate) fn serialize_4(v: SIMD256Vector) -> [u8; 8] { } #[inline(always)] -pub(crate) fn deserialize_4(bytes: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { +pub(crate) fn deserialize_4(bytes: &[u8]) -> __m256i { + return unsafe { let shift_lsbs_to_msbs = _mm256_set_epi16( 1 << 0, 1 << 4, @@ -172,19 +167,15 @@ pub(crate) fn deserialize_4(bytes: &[u8]) -> SIMD256Vector { _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) }; - - SIMD256Vector { - elements: deserialized, - } } #[inline(always)] -pub(crate) fn serialize_5(v: SIMD256Vector) -> [u8; 10] { +pub(crate) fn serialize_5(vector: __m256i) -> [u8; 10] { let mut serialized = [0u8; 32]; unsafe { let adjacent_2_combined = _mm256_madd_epi16( - v.elements, + vector, _mm256_set_epi16( 1 << 5, 1, @@ -229,19 +220,19 @@ pub(crate) fn serialize_5(v: SIMD256Vector) -> [u8; 10] { } #[inline(always)] -pub(crate) fn deserialize_5(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_5(v); +pub(crate) fn deserialize_5(bytes: &[u8]) -> __m256i { + let output = portable::deserialize_5(bytes); - crate::from_i16_array(&portable::to_i16_array(output)) + crate::from_i16_array(&portable::to_i16_array(output)).elements } #[inline(always)] -pub(crate) fn serialize_10(v: SIMD256Vector) -> [u8; 20] { +pub(crate) fn serialize_10(vector: __m256i) -> [u8; 20] { let mut serialized = [0u8; 32]; unsafe { let adjacent_2_combined = _mm256_madd_epi16( - v.elements, + vector, _mm256_set_epi16( 1 << 10, 1, @@ -287,8 +278,8 @@ pub(crate) fn serialize_10(v: SIMD256Vector) -> [u8; 20] { } #[inline(always)] -pub(crate) fn deserialize_10(v: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { +pub(crate) fn deserialize_10(bytes: &[u8]) -> __m256i { + return unsafe { let shift_lsbs_to_msbs = _mm256_set_epi16( 1 << 0, 1 << 2, @@ -308,12 +299,12 @@ pub(crate) fn deserialize_10(v: &[u8]) -> SIMD256Vector { 1 << 6, ); - let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); + let lower_coefficients = _mm_loadu_si128(bytes.as_ptr() as *const __m128i); let lower_coefficients = _mm_shuffle_epi8( lower_coefficients, _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), ); - let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(4) as *const __m128i); + let upper_coefficients = _mm_loadu_si128(bytes.as_ptr().offset(4) as *const __m128i); let upper_coefficients = _mm_shuffle_epi8( upper_coefficients, _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), @@ -328,33 +319,29 @@ pub(crate) fn deserialize_10(v: &[u8]) -> SIMD256Vector { coefficients }; - - SIMD256Vector { - elements: deserialized, - } } #[inline(always)] -pub(crate) fn serialize_11(v: SIMD256Vector) -> [u8; 22] { - let input = portable::from_i16_array(crate::to_i16_array(v)); +pub(crate) fn serialize_11(vector: __m256i) -> [u8; 22] { + let input = portable::from_i16_array(crate::to_i16_array(SIMD256Vector { elements: vector })); portable::serialize_11(input) } #[inline(always)] -pub(crate) fn deserialize_11(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_11(v); +pub(crate) fn deserialize_11(bytes: &[u8]) -> __m256i { + let output = portable::deserialize_11(bytes); - crate::from_i16_array(&portable::to_i16_array(output)) + crate::from_i16_array(&portable::to_i16_array(output)).elements } #[inline(always)] -pub(crate) fn serialize_12(v: SIMD256Vector) -> [u8; 24] { +pub(crate) fn serialize_12(vector: __m256i) -> [u8; 24] { let mut serialized = [0u8; 32]; unsafe { let adjacent_2_combined = _mm256_madd_epi16( - v.elements, + vector, _mm256_set_epi16( 1 << 12, 1, @@ -400,8 +387,8 @@ pub(crate) fn serialize_12(v: SIMD256Vector) -> [u8; 24] { } #[inline(always)] -pub(crate) fn deserialize_12(v: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { +pub(crate) fn deserialize_12(bytes: &[u8]) -> __m256i { + return unsafe { let shift_lsbs_to_msbs = _mm256_set_epi16( 1 << 0, 1 << 4, @@ -421,12 +408,12 @@ pub(crate) fn deserialize_12(v: &[u8]) -> SIMD256Vector { 1 << 4, ); - let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); + let lower_coefficients = _mm_loadu_si128(bytes.as_ptr() as *const __m128i); let lower_coefficients = _mm_shuffle_epi8( lower_coefficients, _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), ); - let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(8) as *const __m128i); + let upper_coefficients = _mm_loadu_si128(bytes.as_ptr().offset(8) as *const __m128i); let upper_coefficients = _mm_shuffle_epi8( upper_coefficients, _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), @@ -441,8 +428,4 @@ pub(crate) fn deserialize_12(v: &[u8]) -> SIMD256Vector { coefficients }; - - SIMD256Vector { - elements: deserialized, - } } From 8cb3fd519f9d5caed0d32e5bcec54b3afc1ad9d3 Mon Sep 17 00:00:00 2001 From: xvzcf Date: Wed, 15 May 2024 21:56:59 +0200 Subject: [PATCH 36/59] Wrapping avx2 intrinsics in safe wrappers. --- polynomials-avx2/src/arithmetic.rs | 100 +++++++----------- polynomials-avx2/src/compress.rs | 158 +++++++++++++---------------- polynomials-avx2/src/intrinsics.rs | 107 +++++++++++++++++++ polynomials-avx2/src/lib.rs | 2 + 4 files changed, 214 insertions(+), 153 deletions(-) create mode 100644 polynomials-avx2/src/intrinsics.rs diff --git a/polynomials-avx2/src/arithmetic.rs b/polynomials-avx2/src/arithmetic.rs index 3115867bf..701d76ad6 100644 --- a/polynomials-avx2/src/arithmetic.rs +++ b/polynomials-avx2/src/arithmetic.rs @@ -1,103 +1,75 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; - +use crate::intrinsics::*; use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; #[inline(always)] -pub(crate) fn add(mut lhs: __m256i, rhs: __m256i) -> __m256i { - lhs = unsafe { _mm256_add_epi16(lhs, rhs) }; - - lhs +pub(crate) fn add(lhs: __m256i, rhs: __m256i) -> __m256i { + mm256_add_epi16(lhs, rhs) } #[inline(always)] -pub(crate) fn sub(mut lhs: __m256i, rhs: __m256i) -> __m256i { - lhs = unsafe { _mm256_sub_epi16(lhs, rhs) }; - - lhs +pub(crate) fn sub(lhs: __m256i, rhs: __m256i) -> __m256i { + mm256_sub_epi16(lhs, rhs) } #[inline(always)] -pub(crate) fn multiply_by_constant(mut vector: __m256i, constant: i16) -> __m256i { - vector = unsafe { _mm256_mullo_epi16(vector, _mm256_set1_epi16(constant)) }; - - vector +pub(crate) fn multiply_by_constant(vector: __m256i, constant: i16) -> __m256i { + mm256_mullo_epi16(vector, mm256_set1_epi16(constant)) } #[inline(always)] -pub(crate) fn bitwise_and_with_constant(mut vector: __m256i, constant: i16) -> __m256i { - vector = unsafe { _mm256_and_si256(vector, _mm256_set1_epi16(constant)) }; - - vector +pub(crate) fn bitwise_and_with_constant(vector: __m256i, constant: i16) -> __m256i { + mm256_and_si256(vector, mm256_set1_epi16(constant)) } #[inline(always)] -pub(crate) fn shift_right(mut vector: __m256i) -> __m256i { - vector = unsafe { _mm256_srai_epi16(vector, SHIFT_BY) }; - - vector +pub(crate) fn shift_right(vector: __m256i) -> __m256i { + mm256_srai_epi16::<{ SHIFT_BY }>(vector) } #[inline(always)] -pub(crate) fn shift_left(mut vector: __m256i) -> __m256i { - vector = unsafe { _mm256_slli_epi16(vector, SHIFT_BY) }; - - vector +pub(crate) fn shift_left(vector: __m256i) -> __m256i { + mm256_slli_epi16::<{SHIFT_BY}>(vector) } #[inline(always)] -pub(crate) fn cond_subtract_3329(mut vector: __m256i) -> __m256i { - vector = unsafe { - let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); - - let v_minus_field_modulus = _mm256_sub_epi16(vector, field_modulus); +pub(crate) fn cond_subtract_3329(vector: __m256i) -> __m256i { + let field_modulus = mm256_set1_epi16(FIELD_MODULUS); - let sign_mask = _mm256_srai_epi16(v_minus_field_modulus, 15); - let conditional_add_field_modulus = _mm256_and_si256(sign_mask, field_modulus); + let v_minus_field_modulus = mm256_sub_epi16(vector, field_modulus); - _mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) - }; + let sign_mask = mm256_srai_epi16::<15>(v_minus_field_modulus); + let conditional_add_field_modulus = mm256_and_si256(sign_mask, field_modulus); - vector + mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) } const BARRETT_MULTIPLIER: i16 = 20159; #[inline(always)] -pub(crate) fn barrett_reduce(mut vector: __m256i) -> __m256i { - vector = unsafe { - let t = _mm256_mulhi_epi16(vector, _mm256_set1_epi16(BARRETT_MULTIPLIER)); - let t = _mm256_add_epi16(t, _mm256_set1_epi16(512)); +pub(crate) fn barrett_reduce(vector: __m256i) -> __m256i { + let t = mm256_mulhi_epi16(vector, mm256_set1_epi16(BARRETT_MULTIPLIER)); + let t = mm256_add_epi16(t, mm256_set1_epi16(512)); - let quotient = _mm256_srai_epi16(t, 10); + let quotient = mm256_srai_epi16::<10>(t); - let quotient_times_field_modulus = - _mm256_mullo_epi16(quotient, _mm256_set1_epi16(FIELD_MODULUS)); + let quotient_times_field_modulus = + mm256_mullo_epi16(quotient, mm256_set1_epi16(FIELD_MODULUS)); - _mm256_sub_epi16(vector, quotient_times_field_modulus) - }; - - vector + mm256_sub_epi16(vector, quotient_times_field_modulus) } #[inline(always)] -pub(crate) fn montgomery_multiply_by_constant(mut vector: __m256i, constant: i16) -> __m256i { - vector = unsafe { - let constant = _mm256_set1_epi16(constant); - let value_low = _mm256_mullo_epi16(vector, constant); - - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); +pub(crate) fn montgomery_multiply_by_constant(vector: __m256i, constant: i16) -> __m256i { + let constant = mm256_set1_epi16(constant); + let value_low = mm256_mullo_epi16(vector, constant); - let value_high = _mm256_mulhi_epi16(vector, constant); + let k = mm256_mullo_epi16( + value_low, + mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS)); - _mm256_sub_epi16(value_high, k_times_modulus) - }; + let value_high = mm256_mulhi_epi16(vector, constant); - vector + mm256_sub_epi16(value_high, k_times_modulus) } diff --git a/polynomials-avx2/src/compress.rs b/polynomials-avx2/src/compress.rs index 8f4e239fe..57f0f98d2 100644 --- a/polynomials-avx2/src/compress.rs +++ b/polynomials-avx2/src/compress.rs @@ -1,8 +1,4 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; - +use crate::intrinsics::*; use libcrux_traits::FIELD_MODULUS; // This implementation was taken from: @@ -11,120 +7,104 @@ use libcrux_traits::FIELD_MODULUS; // TODO: Optimize this implementation if performance numbers suggest doing so. #[inline(always)] fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { - let result = unsafe { - let prod02 = _mm256_mul_epu32(lhs, rhs); - let prod13 = _mm256_mul_epu32( - _mm256_shuffle_epi32(lhs, 0b11_11_01_01), - _mm256_shuffle_epi32(rhs, 0b11_11_01_01), - ); - - _mm256_unpackhi_epi64( - _mm256_unpacklo_epi32(prod02, prod13), - _mm256_unpackhi_epi32(prod02, prod13), - ) - }; - - result + let prod02 = mm256_mul_epu32(lhs, rhs); + let prod13 = mm256_mul_epu32( + mm256_shuffle_epi32::<0b11_11_01_01>(lhs), + mm256_shuffle_epi32::<0b11_11_01_01>(rhs), + ); + + mm256_unpackhi_epi64( + mm256_unpacklo_epi32(prod02, prod13), + mm256_unpackhi_epi32(prod02, prod13), + ) } #[inline(always)] -pub(crate) fn compress_message_coefficient(mut vector: __m256i) -> __m256i { - vector = unsafe { - let field_modulus_halved = _mm256_set1_epi16((FIELD_MODULUS - 1) / 2); - let field_modulus_quartered = _mm256_set1_epi16((FIELD_MODULUS - 1) / 4); - - let shifted = _mm256_sub_epi16(field_modulus_halved, vector); - let mask = _mm256_srai_epi16(shifted, 15); +pub(crate) fn compress_message_coefficient(vector: __m256i) -> __m256i { + let field_modulus_halved = mm256_set1_epi16((FIELD_MODULUS - 1) / 2); + let field_modulus_quartered = mm256_set1_epi16((FIELD_MODULUS - 1) / 4); - let shifted_to_positive = _mm256_xor_si256(mask, shifted); - let shifted_to_positive_in_range = - _mm256_sub_epi16(shifted_to_positive, field_modulus_quartered); + let shifted = mm256_sub_epi16(field_modulus_halved, vector); + let mask = mm256_srai_epi16::<15>(shifted); - _mm256_srli_epi16(shifted_to_positive_in_range, 15) - }; + let shifted_to_positive = mm256_xor_si256(mask, shifted); + let shifted_to_positive_in_range = + mm256_sub_epi16(shifted_to_positive, field_modulus_quartered); - vector + mm256_srli_epi16::<15>(shifted_to_positive_in_range) } #[inline(always)] pub(crate) fn compress_ciphertext_coefficient( - mut vector: __m256i, + vector: __m256i, ) -> __m256i { - vector = unsafe { - let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); - let compression_factor = _mm256_set1_epi32(10_321_340); - let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); + let field_modulus_halved = mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); + let compression_factor = mm256_set1_epi32(10_321_340); + let coefficient_bits_mask = mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); - // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(vector); - let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); + // Compress the first 8 coefficients + let coefficients_low = mm256_castsi256_si128(vector); + let coefficients_low = mm256_cvtepi16_epi32(coefficients_low); - let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS); - let compressed_low = _mm256_add_epi32(compressed_low, field_modulus_halved); + let compressed_low = mm256_slli_epi32::<{COEFFICIENT_BITS}>(coefficients_low); + let compressed_low = mm256_add_epi32(compressed_low, field_modulus_halved); - let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); - let compressed_low = _mm256_srli_epi32(compressed_low, 35 - 32); - let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask); + let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); + let compressed_low = mm256_srli_epi32::<3>(compressed_low); + let compressed_low = mm256_and_si256(compressed_low, coefficient_bits_mask); - // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(vector, 1); - let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); + // Compress the next 8 coefficients + let coefficients_high = mm256_extracti128_si256::<1>(vector); + let coefficients_high = mm256_cvtepi16_epi32(coefficients_high); - let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS); - let compressed_high = _mm256_add_epi32(compressed_high, field_modulus_halved); + let compressed_high = mm256_slli_epi32::<{COEFFICIENT_BITS}>(coefficients_high); + let compressed_high = mm256_add_epi32(compressed_high, field_modulus_halved); - let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); - let compressed_high = _mm256_srli_epi32(compressed_high, 35 - 32); - let compressed_high = _mm256_and_si256(compressed_high, coefficient_bits_mask); + let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); + let compressed_high = mm256_srli_epi32::<3>(compressed_high); + let compressed_high = mm256_and_si256(compressed_high, coefficient_bits_mask); - // Combine them - let compressed = _mm256_packs_epi32(compressed_low, compressed_high); + // Combine them + let compressed = mm256_packs_epi32(compressed_low, compressed_high); - _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) - }; - - vector + mm256_permute4x64_epi64::<0b11_01_10_00>(compressed) } #[inline(always)] pub(crate) fn decompress_ciphertext_coefficient( - mut vector: __m256i, + vector: __m256i, ) -> __m256i { - vector = unsafe { - let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32); - let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS); - - // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(vector); - let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); + let field_modulus = mm256_set1_epi32(FIELD_MODULUS as i32); + let two_pow_coefficient_bits = mm256_set1_epi32(1 << COEFFICIENT_BITS); - let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus); - let decompressed_low = _mm256_slli_epi32(decompressed_low, 1); - let decompressed_low = _mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); + // Compress the first 8 coefficients + let coefficients_low = mm256_castsi256_si128(vector); + let coefficients_low = mm256_cvtepi16_epi32(coefficients_low); - // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack - // of support for const generic expressions. - let decompressed_low = _mm256_srli_epi32(decompressed_low, COEFFICIENT_BITS); - let decompressed_low = _mm256_srli_epi32(decompressed_low, 1); + let decompressed_low = mm256_mullo_epi32(coefficients_low, field_modulus); + let decompressed_low = mm256_slli_epi32::<1>(decompressed_low); + let decompressed_low = mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); - // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(vector, 1); - let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_low = mm256_srli_epi32::<{COEFFICIENT_BITS}>(decompressed_low); + let decompressed_low = mm256_srli_epi32::<1>(decompressed_low); - let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus); - let decompressed_high = _mm256_slli_epi32(decompressed_high, 1); - let decompressed_high = _mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); + // Compress the next 8 coefficients + let coefficients_high = mm256_extracti128_si256::<1>(vector); + let coefficients_high = mm256_cvtepi16_epi32(coefficients_high); - // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack - // of support for const generic expressions. - let decompressed_high = _mm256_srli_epi32(decompressed_high, COEFFICIENT_BITS); - let decompressed_high = _mm256_srli_epi32(decompressed_high, 1); + let decompressed_high = mm256_mullo_epi32(coefficients_high, field_modulus); + let decompressed_high = mm256_slli_epi32::<1>(decompressed_high); + let decompressed_high = mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); - // Combine them - let compressed = _mm256_packs_epi32(decompressed_low, decompressed_high); + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_high = mm256_srli_epi32::<{COEFFICIENT_BITS}>(decompressed_high); + let decompressed_high = mm256_srli_epi32::<1>(decompressed_high); - _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) - }; + // Combine them + let compressed = mm256_packs_epi32(decompressed_low, decompressed_high); - vector + mm256_permute4x64_epi64::<0b11_01_10_00>(compressed) } diff --git a/polynomials-avx2/src/intrinsics.rs b/polynomials-avx2/src/intrinsics.rs new file mode 100644 index 000000000..d28b227c7 --- /dev/null +++ b/polynomials-avx2/src/intrinsics.rs @@ -0,0 +1,107 @@ +#[cfg(target_arch = "x86")] +pub(crate) use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +pub(crate) use core::arch::x86_64::*; + +pub(crate) fn mm256_set1_epi16(constant: i16) -> __m256i { + unsafe { _mm256_set1_epi16(constant) } +} +pub(crate) fn mm256_set1_epi32(constant: i32) -> __m256i { + unsafe { _mm256_set1_epi32(constant) } +} + +pub(crate) fn mm256_add_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_add_epi16(lhs, rhs) } +} +pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_add_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_sub_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_sub_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_mullo_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mullo_epi16(lhs, rhs) } +} +pub(crate) fn mm256_mullo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mullo_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_mulhi_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mulhi_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_mul_epu32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mul_epu32(lhs, rhs) } +} + +pub(crate) fn mm256_and_si256(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_and_si256(lhs, rhs) } +} + +pub(crate) fn mm256_xor_si256(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_xor_si256(lhs, rhs) } +} + +pub(crate) fn mm256_srai_epi16(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); + unsafe { _mm256_srai_epi16(vector, SHIFT_BY) } +} +pub(crate) fn mm256_srli_epi16(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); + unsafe { _mm256_srli_epi16(vector, SHIFT_BY) } +} +pub(crate) fn mm256_srli_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_srli_epi32(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_slli_epi16(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); + unsafe { _mm256_slli_epi16(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_slli_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_slli_epi32(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_shuffle_epi32(vector: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_shuffle_epi32(vector, CONTROL) } +} + +pub(crate) fn mm256_permute4x64_epi64(vector: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_permute4x64_epi64(vector, CONTROL) } +} + +pub(crate) fn mm256_unpackhi_epi64(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_unpackhi_epi64(lhs, rhs) } +} + +pub(crate) fn mm256_unpacklo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_unpacklo_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_unpackhi_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_unpackhi_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_castsi256_si128(vector: __m256i) -> __m128i { + unsafe { _mm256_castsi256_si128(vector) } +} + +pub(crate) fn mm256_cvtepi16_epi32(vector: __m128i) -> __m256i { + unsafe { _mm256_cvtepi16_epi32(vector) } +} + +pub(crate) fn mm256_packs_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_packs_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_extracti128_si256(vector: __m256i) -> __m128i { + debug_assert!(CONTROL == 0 || CONTROL == 1); + unsafe { _mm256_extracti128_si256(vector, CONTROL) } +} diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 207f9b6f4..52de02fd0 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -7,6 +7,8 @@ use libcrux_traits::Operations; #[cfg(test)] mod debug; +mod intrinsics; + mod arithmetic; mod compress; mod ntt; From 689a3ecce988d6ebc290394e29eaa781f0ceced6 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Thu, 16 May 2024 15:57:02 +0200 Subject: [PATCH 37/59] Feature detection and cleanup (Ml-KEM/SHA3) (#280) --- .github/workflows/mlkem.yml | 4 +- libcrux-ml-kem/build.rs | 6 +- libcrux-ml-kem/src/constants.rs | 3 + libcrux-ml-kem/src/hash_functions.rs | 836 ++++++++++-------- libcrux-ml-kem/src/ind_cca.rs | 42 +- libcrux-ml-kem/src/ind_cpa.rs | 33 +- libcrux-ml-kem/src/matrix.rs | 8 +- libcrux-ml-kem/src/sampling.rs | 10 +- libcrux-sha3/benches/sha3.rs | 100 +-- libcrux-sha3/build.rs | 4 +- .../sha3_generic.rs => generic_keccak.rs} | 9 +- libcrux-sha3/src/lib.rs | 656 ++++++++++++-- .../sha3_portable.rs => portable_keccak.rs} | 7 +- libcrux-sha3/src/rust_simd.rs | 381 -------- libcrux-sha3/src/simd.rs | 11 + .../sha3_arm64.rs => simd/arm64.rs} | 3 +- .../{rust_simd/sha3_avx2.rs => simd/avx2.rs} | 7 +- .../{rust_simd/sha3_trait.rs => traits.rs} | 3 +- libcrux-sha3/tests/sha3.rs | 4 +- polynomials/build.rs | 6 +- 20 files changed, 1155 insertions(+), 978 deletions(-) rename libcrux-sha3/src/{rust_simd/sha3_generic.rs => generic_keccak.rs} (97%) rename libcrux-sha3/src/{rust_simd/sha3_portable.rs => portable_keccak.rs} (92%) delete mode 100644 libcrux-sha3/src/rust_simd.rs create mode 100644 libcrux-sha3/src/simd.rs rename libcrux-sha3/src/{rust_simd/sha3_arm64.rs => simd/arm64.rs} (99%) rename libcrux-sha3/src/{rust_simd/sha3_avx2.rs => simd/avx2.rs} (97%) rename libcrux-sha3/src/{rust_simd/sha3_trait.rs => traits.rs} (89%) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 7abb24116..f0b7040f7 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -4,7 +4,7 @@ on: push: branches: ["main", "dev"] pull_request: - branches: ["main", "dev"] + branches: ["main", "dev", "*"] workflow_dispatch: merge_group: @@ -188,7 +188,7 @@ jobs: - name: 🏃🏻‍♀️ Benchmarks Windows if: ${{ matrix.os == 'windows-latest' }} run: cargo bench --verbose $RUST_TARGET_FLAG - + - name: 🏃🏻‍♀️ Benchmarks Clang if: ${{ matrix.os != 'windows-latest' }} run: CC=clang cargo bench --verbose $RUST_TARGET_FLAG diff --git a/libcrux-ml-kem/build.rs b/libcrux-ml-kem/build.rs index f15f3d581..ef1138666 100644 --- a/libcrux-ml-kem/build.rs +++ b/libcrux-ml-kem/build.rs @@ -16,11 +16,13 @@ fn main() { // We enable simd128 on all aarch64 builds. println!("cargo:rustc-cfg=feature=\"simd128\""); } - if (target_arch == "x86" || target_arch == "x86_64") && !disable_simd256 { - // We enable simd256 on all x86 and x86_64 builds. + if target_arch == "x86_64" && !disable_simd256 { + // We enable simd256 on all x86_64 builds. // Note that this doesn't mean the required CPU features are available. // But the compiler will support them and the runtime checks ensure that // it's only used when available. + // + // We don't enable this on x86 because it seems to generate invalid code. println!("cargo:rustc-cfg=feature=\"simd256\""); } } diff --git a/libcrux-ml-kem/src/constants.rs b/libcrux-ml-kem/src/constants.rs index b3903e9db..cf89e9348 100644 --- a/libcrux-ml-kem/src/constants.rs +++ b/libcrux-ml-kem/src/constants.rs @@ -32,4 +32,7 @@ pub(crate) const CPA_PKE_KEY_GENERATION_SEED_SIZE: usize = 32; // XXX: Eurydice can't handle this. // digest_size(Algorithm::Sha3_256); +/// SHA3 256 digest size pub(crate) const H_DIGEST_SIZE: usize = 32; +/// SHA3 512 digest size +pub(crate) const G_DIGEST_SIZE: usize = 64; diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 9a3427fad..cc00df6fd 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -1,406 +1,530 @@ #![allow(non_snake_case)] -use crate::constants::H_DIGEST_SIZE; +use crate::constants::{G_DIGEST_SIZE, H_DIGEST_SIZE}; -#[cfg(feature = "simd256")] -use libcrux_sha3::rust_simd::KeccakState4; -use libcrux_sha3::*; +/// The SHA3 block size. +pub(crate) const BLOCK_SIZE: usize = 168; -#[inline(always)] -pub(crate) fn G(input: &[u8]) -> [u8; 64] { - rust_simd::sha3_512(input) -} +/// The size of 3 SHA3 blocks. +pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; -#[inline(always)] -pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - rust_simd::sha3_256(input) -} +/// Abstraction for the hashing, to pick the fastest version depending on the +/// platform features available. +/// +/// There are 3 instantiations of this trait right now, using the libcrux-sha3 crate. +/// - AVX2 +/// - NEON +/// - Portable +pub(crate) trait Hash { + /// G aka SHA3 512 + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE]; -#[inline(always)] -pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { - rust_simd::shake256::(input) -} + /// H aka SHA3 256 + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE]; -#[cfg(feature = "simd256")] -#[inline(always)] -pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { - let mut out = [[0u8; LEN]; K]; - let mut dummy_out0 = [0u8; LEN]; - let mut dummy_out1 = [0u8; LEN]; - - match K { - 2 => { - let (out0, out1) = out.split_at_mut(1); - rust_simd::shake256x4( - &input[0], - &input[1], - &input[0], - &input[0], - &mut out0[0], - &mut out1[0], - &mut dummy_out0, - &mut dummy_out1, - ); - } - 3 => { - let (out0, out12) = out.split_at_mut(1); - let (out1, out2) = out12.split_at_mut(1); - rust_simd::shake256x4( - &input[0], - &input[1], - &input[2], - &input[0], - &mut out0[0], - &mut out1[0], - &mut out2[0], - &mut dummy_out0, - ); - } - _ => { - let (out0, out123) = out.split_at_mut(1); - let (out1, out23) = out123.split_at_mut(1); - let (out2, out3) = out23.split_at_mut(1); - rust_simd::shake256x4( - &input[0], - &input[1], - &input[2], - &input[3], - &mut out0[0], - &mut out1[0], - &mut out2[0], - &mut out3[0], - ); - } - } - out + /// PRF aka SHAKE256 + fn PRF(input: &[u8]) -> [u8; LEN]; + + /// PRFxN aka N SHAKE256 + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K]; + + /// Create a SHAKE128 state and absorb the input. + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self; + + /// Squeeze 3 blocks out of the SHAKE128 state. + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K]; + + /// Squeeze 1 block out of the SHAKE128 state. + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K]; } -#[cfg(feature = "simd128")] -#[inline(always)] -pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { - let mut out = [[0u8; LEN]; K]; - let mut extra = [0u8; LEN]; +/// A portable implementation of [`Hash`] +pub(crate) mod portable { + use super::*; + use libcrux_sha3::portable::{ + self, + incremental::{ + shake128_absorb_final, shake128_init, shake128_squeeze_first_three_blocks, + shake128_squeeze_next_block, + }, + KeccakState1, + }; + + /// The state. + /// + /// It's only used for SHAKE128. + /// All other functions don't actually use any members. + pub(crate) struct PortableHash { + shake128_state: [KeccakState1; K], + } - match K { - 2 => { - let (out0, out1) = out.split_at_mut(1); - rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); + impl Hash for PortableHash { + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { + let mut digest = [0u8; G_DIGEST_SIZE]; + portable::sha512(&mut digest, input); + digest } - 3 => { - let (out0, out12) = out.split_at_mut(1); - let (out1, out2) = out12.split_at_mut(1); - rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); - rust_simd::shake256x2(&input[2], &input[2], &mut out2[0], &mut extra); + + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + let mut digest = [0u8; H_DIGEST_SIZE]; + portable::sha256(&mut digest, input); + digest } - _ => { - let (out0, out123) = out.split_at_mut(1); - let (out1, out23) = out123.split_at_mut(1); - let (out2, out3) = out23.split_at_mut(1); - rust_simd::shake256x2(&input[0], &input[1], &mut out0[0], &mut out1[0]); - rust_simd::shake256x2(&input[2], &input[3], &mut out2[0], &mut out3[0]); + + fn PRF(input: &[u8]) -> [u8; LEN] { + let mut digest = [0u8; LEN]; + portable::shake256(&mut digest, input); + digest } - } - out -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -#[inline(always)] -pub(crate) fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { - core::array::from_fn(|i| rust_simd::shake256::(&input[i])) -} + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); -#[cfg(feature = "simd128")] -pub(crate) type Shake128x4State = KeccakState4; + let mut out = [[0u8; LEN]; K]; + for i in 0..K { + portable::shake256::(&mut out[i], &input[i]); + } + out + } -#[cfg(feature = "simd128")] -#[inline(always)] -pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128x4State { - debug_assert!(K == 2 || K == 3 || K == 4); + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self { + debug_assert!(K == 2 || K == 3 || K == 4); - let mut states = rust_simd::shake128x4_init(); - match K { - 2 => { - rust_simd::shake128x2_absorb_final(&mut states[0], &input[0], &input[1]); + let mut state = [shake128_init(); K]; + for i in 0..K { + shake128_absorb_final(&mut state[i], &input[i]); + } + Self { + shake128_state: state, + } } - 3 => { - rust_simd::shake128x2_absorb_final(&mut states[0], &input[0], &input[1]); - rust_simd::shake128x2_absorb_final(&mut states[1], &input[2], &input[2]); + + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; THREE_BLOCKS]; K]; + for i in 0..K { + shake128_squeeze_first_three_blocks(&mut self.shake128_state[i], &mut out[i]); + } + out } - _ => { - rust_simd::shake128x2_absorb_final(&mut states[0], &input[0], &input[1]); - rust_simd::shake128x2_absorb_final(&mut states[1], &input[2], &input[3]); + + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; BLOCK_SIZE]; K]; + for i in 0..K { + shake128_squeeze_next_block(&mut self.shake128_state[i], &mut out[i]); + } + out } } - states } -#[cfg(not(any(feature = "simd256", feature = "simd128")))] -#[inline(always)] -pub(crate) fn absorb(input: [[u8; 34]; K]) -> [rust_simd::KeccakState1; K] { - debug_assert!(K == 2 || K == 3 || K == 4); - let mut states = [rust_simd::shake128_init(); K]; - for i in 0..K { - rust_simd::shake128_absorb_final(&mut states[i], &input[i]); +/// A SIMD256 implementation of [`Hash`] for AVX2 +pub(crate) mod avx2 { + use super::*; + use libcrux_sha3::{ + avx2::x4::{self, incremental::KeccakState4}, + portable, + }; + + /// The state. + /// + /// It's only used for SHAKE128. + /// All other functions don't actually use any members. + pub(crate) struct Simd256Hash { + shake128_state: KeccakState4, } - states -} -#[cfg(feature = "simd256")] -#[inline(always)] -pub(crate) fn absorb(input: [[u8; 34]; K]) -> KeccakState4 { - debug_assert!(K == 2 || K == 3 || K == 4); - let mut states = rust_simd::shake128x4_init(); - - match K { - 2 => { - rust_simd::shake128x4_absorb_final( - &mut states, - &input[0], - &input[1], - &input[0], - &input[0], - ); + impl Hash for Simd256Hash { + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { + let mut digest = [0u8; G_DIGEST_SIZE]; + portable::sha512(&mut digest, input); + digest } - 3 => { - rust_simd::shake128x4_absorb_final( - &mut states, - &input[0], - &input[1], - &input[2], - &input[0], - ); + + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + let mut digest = [0u8; H_DIGEST_SIZE]; + portable::sha256(&mut digest, input); + digest } - 4 => { - rust_simd::shake128x4_absorb_final( - &mut states, - &input[0], - &input[1], - &input[2], - &input[3], - ); + + fn PRF(input: &[u8]) -> [u8; LEN] { + let mut digest = [0u8; LEN]; + portable::shake256(&mut digest, input); + digest } - _ => unreachable!(), - } - states -} -pub(crate) const BLOCK_SIZE: usize = 168; -pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut out = [[0u8; LEN]; K]; + + match K { + 2 => { + let mut dummy_out0 = [0u8; LEN]; + let mut dummy_out1 = [0u8; LEN]; + let (out0, out1) = out.split_at_mut(1); + x4::shake256( + &input[0], + &input[1], + &input[0], + &input[0], + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let mut dummy_out0 = [0u8; LEN]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x4::shake256( + &input[0], + &input[1], + &input[2], + &input[0], + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x4::shake256( + &input[0], + &input[1], + &input[2], + &input[3], + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out + } -#[cfg(feature = "simd128")] -#[inline(always)] -pub(crate) fn squeeze_three_blocks( - state: &mut Shake128x4State, -) -> [[u8; THREE_BLOCKS]; K] { - let mut out = [[0u8; THREE_BLOCKS]; K]; - let mut extra = [0u8; THREE_BLOCKS]; - - match K { - 2 => { - let (out0, out1) = out.split_at_mut(1); - rust_simd::shake128x2_squeeze_first_three_blocks( - &mut state[0], - &mut out0[0], - &mut out1[0], - ); + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut state = x4::incremental::shake128_init(); + + match K { + 2 => { + x4::incremental::shake128_absorb_final( + &mut state, &input[0], &input[1], &input[0], &input[0], + ); + } + 3 => { + x4::incremental::shake128_absorb_final( + &mut state, &input[0], &input[1], &input[2], &input[0], + ); + } + 4 => { + x4::incremental::shake128_absorb_final( + &mut state, &input[0], &input[1], &input[2], &input[3], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + Self { + shake128_state: state, + } } - 3 => { - let (out0, out12) = out.split_at_mut(1); - let (out1, out2) = out12.split_at_mut(1); - rust_simd::shake128x2_squeeze_first_three_blocks( - &mut state[0], - &mut out0[0], - &mut out1[0], - ); - rust_simd::shake128x2_squeeze_first_three_blocks( - &mut state[1], - &mut out2[0], - &mut extra, - ); + + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; THREE_BLOCKS]; K]; + match K { + 2 => { + let mut dummy_out0 = [0u8; THREE_BLOCKS]; + let mut dummy_out1 = [0u8; THREE_BLOCKS]; + let (out0, out1) = out.split_at_mut(1); + x4::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let mut dummy_out0 = [0u8; THREE_BLOCKS]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x4::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x4::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out } - _ => { - let (out0, out123) = out.split_at_mut(1); - let (out1, out23) = out123.split_at_mut(1); - let (out2, out3) = out23.split_at_mut(1); - rust_simd::shake128x2_squeeze_first_three_blocks( - &mut state[0], - &mut out0[0], - &mut out1[0], - ); - rust_simd::shake128x2_squeeze_first_three_blocks( - &mut state[1], - &mut out2[0], - &mut out3[0], - ); + + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut dummy_out0 = [0u8; BLOCK_SIZE]; + let mut dummy_out1 = [0u8; BLOCK_SIZE]; + + let mut out = [[0u8; BLOCK_SIZE]; K]; + + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + x4::incremental::shake128_squeeze_next_block( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x4::incremental::shake128_squeeze_next_block( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x4::incremental::shake128_squeeze_next_block( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out } } - out } -#[cfg(not(any(feature = "simd256", feature = "simd128")))] -#[inline(always)] -pub(crate) fn squeeze_three_blocks( - state: &mut [rust_simd::KeccakState1], -) -> [[u8; THREE_BLOCKS]; K] { - let mut out = [[0u8; THREE_BLOCKS]; K]; - for i in 0..K { - rust_simd::shake128_squeeze_first_three_blocks(&mut state[i], &mut out[i]); +/// A SIMD128 implementation of [`Hash`] for NEON +pub(crate) mod neon { + use super::*; + use libcrux_sha3::neon::x2::{self, incremental::KeccakState2}; + + /// The state. + /// + /// It's only used for SHAKE128. + /// All other functions don't actually use any members. + pub(crate) struct Simd128Hash { + shake128_state: [KeccakState2; 2], } - out -} -#[cfg(feature = "simd256")] -#[inline(always)] -pub(crate) fn squeeze_three_blocks( - state: &mut KeccakState4, -) -> [[u8; THREE_BLOCKS]; K] { - let mut out = [[0u8; THREE_BLOCKS]; K]; - let mut dummy_out0 = [0u8; THREE_BLOCKS]; - let mut dummy_out1 = [0u8; THREE_BLOCKS]; - - match K { - 2 => { - let (out0, out1) = out.split_at_mut(1); - rust_simd::shake128x4_squeeze_first_three_blocks( - state, - &mut out0[0], - &mut out1[0], - &mut dummy_out0, - &mut dummy_out1, - ); - } - 3 => { - let (out0, out12) = out.split_at_mut(1); - let (out1, out2) = out12.split_at_mut(1); - rust_simd::shake128x4_squeeze_first_three_blocks( - state, - &mut out0[0], - &mut out1[0], - &mut out2[0], - &mut dummy_out0, - ); + impl Hash for Simd128Hash { + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { + let mut digest = [0u8; G_DIGEST_SIZE]; + libcrux_sha3::neon::sha512(&mut digest, input); + digest } - 4 => { - let (out0, out123) = out.split_at_mut(1); - let (out1, out23) = out123.split_at_mut(1); - let (out2, out3) = out23.split_at_mut(1); - rust_simd::shake128x4_squeeze_first_three_blocks( - state, - &mut out0[0], - &mut out1[0], - &mut out2[0], - &mut out3[0], - ); - } - _ => unreachable!(), - } - out -} -#[cfg(feature = "simd128")] -#[inline(always)] -pub(crate) fn squeeze_block(state: &mut Shake128x4State) -> [[u8; BLOCK_SIZE]; K] { - let mut out0 = [0u8; BLOCK_SIZE]; - let mut out1 = [0u8; BLOCK_SIZE]; - let mut out2 = [0u8; BLOCK_SIZE]; - let mut out3 = [0u8; BLOCK_SIZE]; - - let mut out = [[0u8; BLOCK_SIZE]; K]; - - match K { - 2 => { - rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); - out[0] = out0; - out[1] = out1; + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + let mut digest = [0u8; H_DIGEST_SIZE]; + libcrux_sha3::neon::sha256(&mut digest, input); + digest } - 3 => { - rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); - rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); - out[0] = out0; - out[1] = out1; - out[2] = out2; + + fn PRF(input: &[u8]) -> [u8; LEN] { + let mut digest = [0u8; LEN]; + libcrux_sha3::neon::shake256(&mut digest, input); + digest } - _ => { - rust_simd::shake128x2_squeeze_next_block(&mut state[0], &mut out0, &mut out1); - rust_simd::shake128x2_squeeze_next_block(&mut state[1], &mut out2, &mut out3); - out[0] = out0; - out[1] = out1; - out[2] = out2; - out[3] = out3; + + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; LEN]; K]; + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]); + } + 3 => { + let mut extra = [0u8; LEN]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]); + x2::shake256(&input[2], &input[2], &mut out2[0], &mut extra); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]); + x2::shake256(&input[2], &input[3], &mut out2[0], &mut out3[0]); + } + _ => unreachable!(), + } + out } - } - out -} -#[cfg(not(any(feature = "simd256", feature = "simd128")))] -#[inline(always)] -pub(crate) fn squeeze_block( - state: &mut [rust_simd::KeccakState1; K], -) -> [[u8; BLOCK_SIZE]; K] { - let mut out = [[0u8; BLOCK_SIZE]; K]; - for i in 0..K { - rust_simd::shake128_squeeze_next_block(&mut state[i], &mut out[i]); - } - out -} + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut state = [ + x2::incremental::shake128_init(), + x2::incremental::shake128_init(), + ]; -#[cfg(feature = "simd256")] -#[inline(always)] -pub(crate) fn squeeze_block(state: &mut KeccakState4) -> [[u8; BLOCK_SIZE]; K] { - let mut dummy_out0 = [0u8; BLOCK_SIZE]; - let mut dummy_out1 = [0u8; BLOCK_SIZE]; - - let mut out = [[0u8; BLOCK_SIZE]; K]; - - match K { - 2 => { - let (out0, out1) = out.split_at_mut(1); - rust_simd::shake128x4_squeeze_next_block( - state, - &mut out0[0], - &mut out1[0], - &mut dummy_out0, - &mut dummy_out1, - ); + match K { + 2 => { + x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + } + 3 => { + x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[2]); + } + _ => { + x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[3]); + } + } + // _ => unreachable!("This function is only called with 2, 3, 4"), + Self { + shake128_state: state, + } } - 3 => { - let (out0, out12) = out.split_at_mut(1); - let (out1, out2) = out12.split_at_mut(1); - rust_simd::shake128x4_squeeze_next_block( - state, - &mut out0[0], - &mut out1[0], - &mut out2[0], - &mut dummy_out0, - ); + + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; THREE_BLOCKS]; K]; + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[0], + &mut out0[0], + &mut out1[0], + ); + } + 3 => { + let mut extra = [0u8; THREE_BLOCKS]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[0], + &mut out0[0], + &mut out1[0], + ); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[1], + &mut out2[0], + &mut extra, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[0], + &mut out0[0], + &mut out1[0], + ); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[1], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out } - 4 => { - let (out0, out123) = out.split_at_mut(1); - let (out1, out23) = out123.split_at_mut(1); - let (out2, out3) = out23.split_at_mut(1); - rust_simd::shake128x4_squeeze_next_block( - state, - &mut out0[0], - &mut out1[0], - &mut out2[0], - &mut out3[0], - ); + + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; BLOCK_SIZE]; K]; + match K { + 2 => { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[0], + &mut out0, + &mut out1, + ); + out[0] = out0; + out[1] = out1; + } + 3 => { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + let mut out2 = [0u8; BLOCK_SIZE]; + let mut out3 = [0u8; BLOCK_SIZE]; + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[0], + &mut out0, + &mut out1, + ); + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[1], + &mut out2, + &mut out3, + ); + out[0] = out0; + out[1] = out1; + out[2] = out2; + } + 4 => { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + let mut out2 = [0u8; BLOCK_SIZE]; + let mut out3 = [0u8; BLOCK_SIZE]; + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[0], + &mut out0, + &mut out1, + ); + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[1], + &mut out2, + &mut out3, + ); + out[0] = out0; + out[1] = out1; + out[2] = out2; + out[3] = out3; + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out } - _ => unreachable!(), } - out } - -/// Free the memory of the state. -/// -/// **NOTE:** That this needs to be done manually for now. -#[cfg(not(any(feature = "simd256", feature = "simd128")))] -#[inline(always)] -pub(crate) fn free_state(_xof_state: [rust_simd::KeccakState1; K]) {} - -/// Free the memory of the state. -/// -/// **NOTE:** That this needs to be done manually for now. -#[cfg(any(feature = "simd256", feature = "simd128"))] -#[inline(always)] -pub(crate) fn free_state(_xof_state: KeccakState4) {} diff --git a/libcrux-ml-kem/src/ind_cca.rs b/libcrux-ml-kem/src/ind_cca.rs index e7b6343db..ac0cb8b98 100644 --- a/libcrux-ml-kem/src/ind_cca.rs +++ b/libcrux-ml-kem/src/ind_cca.rs @@ -5,7 +5,7 @@ use crate::{ compare_ciphertexts_in_constant_time, select_shared_secret_in_constant_time, }, constants::{CPA_PKE_KEY_GENERATION_SEED_SIZE, H_DIGEST_SIZE, SHARED_SECRET_SIZE}, - hash_functions::{G, H, PRF}, + hash_functions::{self, Hash}, ind_cpa::{into_padded_array, serialize_public_key}, serialize::deserialize_ring_elements_reduced, types::{MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey}, @@ -27,7 +27,7 @@ pub type MlKemSharedSecret = [u8; SHARED_SECRET_SIZE]; /// Serialize the secret key. #[inline(always)] -fn serialize_kem_secret_key( +fn serialize_kem_secret_key>( private_key: &[u8], public_key: &[u8], implicit_rejection_value: &[u8], @@ -38,7 +38,7 @@ fn serialize_kem_secret_key( pointer += private_key.len(); out[pointer..pointer + public_key.len()].copy_from_slice(public_key); pointer += public_key.len(); - out[pointer..pointer + H_DIGEST_SIZE].copy_from_slice(&H(public_key)); + out[pointer..pointer + H_DIGEST_SIZE].copy_from_slice(&Hasher::H(public_key)); pointer += H_DIGEST_SIZE; out[pointer..pointer + implicit_rejection_value.len()] .copy_from_slice(implicit_rejection_value); @@ -138,6 +138,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, libcrux_polynomials::SIMD256Vector, + hash_functions::avx2::Simd256Hash, >(ind_cpa_keypair_randomness, implicit_rejection_value); #[cfg(not(feature = "simd256"))] generate_keypair_generic::< @@ -149,6 +150,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { #[cfg(feature = "simd128")] @@ -161,6 +163,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, libcrux_polynomials::SIMD128Vector, + hash_functions::neon::Simd128Hash, >(ind_cpa_keypair_randomness, implicit_rejection_value); #[cfg(not(feature = "simd128"))] generate_keypair_generic::< @@ -172,6 +175,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) } else { generate_keypair_generic::< @@ -183,6 +187,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) } } @@ -196,6 +201,7 @@ fn generate_keypair_generic< const ETA1: usize, const ETA1_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( ind_cpa_keypair_randomness: &[u8], implicit_rejection_value: &[u8], @@ -208,10 +214,14 @@ fn generate_keypair_generic< ETA1, ETA1_RANDOMNESS_SIZE, Vector, + Hasher, >(ind_cpa_keypair_randomness); - let secret_key_serialized = - serialize_kem_secret_key(&ind_cpa_private_key, &public_key, implicit_rejection_value); + let secret_key_serialized = serialize_kem_secret_key::( + &ind_cpa_private_key, + &public_key, + implicit_rejection_value, + ); let private_key: MlKemPrivateKey = MlKemPrivateKey::from(secret_key_serialized); @@ -253,6 +263,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, libcrux_polynomials::SIMD256Vector, + hash_functions::avx2::Simd256Hash, >(public_key, randomness); #[cfg(not(feature = "simd256"))] encapsulate_generic::< @@ -270,6 +281,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(public_key, randomness) } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { #[cfg(not(feature = "simd128"))] @@ -288,6 +300,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(public_key, randomness); #[cfg(feature = "simd128")] encapsulate_generic::< @@ -305,6 +318,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, libcrux_polynomials::SIMD128Vector, + hash_functions::neon::Simd128Hash, >(public_key, randomness) } else { encapsulate_generic::< @@ -322,6 +336,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(public_key, randomness) } } @@ -341,14 +356,15 @@ fn encapsulate_generic< const ETA2: usize, const ETA2_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( public_key: &MlKemPublicKey, randomness: [u8; SHARED_SECRET_SIZE], ) -> (MlKemCiphertext, MlKemSharedSecret) { let mut to_hash: [u8; 2 * H_DIGEST_SIZE] = into_padded_array(&randomness); - to_hash[H_DIGEST_SIZE..].copy_from_slice(&H(public_key.as_slice())); + to_hash[H_DIGEST_SIZE..].copy_from_slice(&Hasher::H(public_key.as_slice())); - let hashed = G(&to_hash); + let hashed = Hasher::G(&to_hash); let (shared_secret, pseudorandomness) = hashed.split_at(SHARED_SECRET_SIZE); let ciphertext = crate::ind_cpa::encrypt::< @@ -365,6 +381,7 @@ fn encapsulate_generic< ETA2, ETA2_RANDOMNESS_SIZE, Vector, + Hasher, >(public_key.as_slice(), randomness, pseudorandomness); let mut shared_secret_array = [0u8; SHARED_SECRET_SIZE]; shared_secret_array.copy_from_slice(shared_secret); @@ -412,6 +429,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, libcrux_polynomials::SIMD256Vector, + hash_functions::avx2::Simd256Hash, >(private_key, ciphertext); #[cfg(not(feature = "simd256"))] return decapsulate_generic::< @@ -432,6 +450,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(private_key, ciphertext); } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { #[cfg(feature = "simd128")] @@ -453,6 +472,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, libcrux_polynomials::SIMD128Vector, + hash_functions::neon::Simd128Hash, >(private_key, ciphertext); #[cfg(not(feature = "simd128"))] return decapsulate_generic::< @@ -473,6 +493,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(private_key, ciphertext); } else { decapsulate_generic::< @@ -493,6 +514,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(private_key, ciphertext) } } @@ -515,6 +537,7 @@ pub(crate) fn decapsulate_generic< const ETA2_RANDOMNESS_SIZE: usize, const IMPLICIT_REJECTION_HASH_INPUT_SIZE: usize, Vector: Operations, + Hasher: Hash, >( private_key: &MlKemPrivateKey, ciphertext: &MlKemCiphertext, @@ -535,13 +558,13 @@ pub(crate) fn decapsulate_generic< let mut to_hash: [u8; SHARED_SECRET_SIZE + H_DIGEST_SIZE] = into_padded_array(&decrypted); to_hash[SHARED_SECRET_SIZE..].copy_from_slice(ind_cpa_public_key_hash); - let hashed = G(&to_hash); + let hashed = Hasher::G(&to_hash); let (shared_secret, pseudorandomness) = hashed.split_at(SHARED_SECRET_SIZE); let mut to_hash: [u8; IMPLICIT_REJECTION_HASH_INPUT_SIZE] = into_padded_array(implicit_rejection_value); to_hash[SHARED_SECRET_SIZE..].copy_from_slice(ciphertext.as_ref()); - let implicit_rejection_shared_secret: [u8; SHARED_SECRET_SIZE] = PRF(&to_hash); + let implicit_rejection_shared_secret: [u8; SHARED_SECRET_SIZE] = Hasher::PRF(&to_hash); let expected_ciphertext = crate::ind_cpa::encrypt::< K, @@ -557,6 +580,7 @@ pub(crate) fn decapsulate_generic< ETA2, ETA2_RANDOMNESS_SIZE, Vector, + Hasher, >(ind_cpa_public_key, decrypted, pseudorandomness); let selector = compare_ciphertexts_in_constant_time::( diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index 67ac0f557..2271b9ed9 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -2,7 +2,7 @@ use libcrux_polynomials::Operations; use crate::{ constants::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT, SHARED_SECRET_SIZE}, - hash_functions::{PRFxN, G, PRF}, + hash_functions::Hash, helper::cloop, matrix::*, ntt::{ntt_binomially_sampled_ring_element, ntt_vector_u}, @@ -69,6 +69,7 @@ fn sample_ring_element_cbd< const ETA2_RANDOMNESS_SIZE: usize, const ETA2: usize, Vector: Operations, + Hasher: Hash, >( prf_input: [u8; 33], mut domain_separator: u8, @@ -79,7 +80,7 @@ fn sample_ring_element_cbd< prf_inputs[i][32] = domain_separator; domain_separator += 1; } - let prf_outputs: [[u8; ETA2_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); + let prf_outputs: [[u8; ETA2_RANDOMNESS_SIZE]; K] = Hasher::PRFxN(&prf_inputs); for i in 0..K { error_1[i] = sample_from_binomial_distribution::(&prf_outputs[i]); } @@ -94,6 +95,7 @@ fn sample_vector_cbd_then_ntt< const ETA: usize, const ETA_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( prf_input: [u8; 33], mut domain_separator: u8, @@ -104,7 +106,7 @@ fn sample_vector_cbd_then_ntt< prf_inputs[i][32] = domain_separator; domain_separator += 1; } - let prf_outputs: [[u8; ETA_RANDOMNESS_SIZE]; K] = PRFxN(&prf_inputs); + let prf_outputs: [[u8; ETA_RANDOMNESS_SIZE]; K] = Hasher::PRFxN(&prf_inputs); for i in 0..K { re_as_ntt[i] = sample_from_binomial_distribution::(&prf_outputs[i]); ntt_binomially_sampled_ring_element(&mut re_as_ntt[i]); @@ -159,22 +161,24 @@ pub(crate) fn generate_keypair< const ETA1: usize, const ETA1_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( key_generation_seed: &[u8], ) -> ([u8; PRIVATE_KEY_SIZE], [u8; PUBLIC_KEY_SIZE]) { // (ρ,σ) := G(d) - let hashed = G(key_generation_seed); + let hashed = Hasher::G(key_generation_seed); let (seed_for_A, seed_for_secret_and_error) = hashed.split_at(32); - let A_transpose = sample_matrix_A(into_padded_array(seed_for_A), true); + let A_transpose = sample_matrix_A::(into_padded_array(seed_for_A), true); let prf_input: [u8; 33] = into_padded_array(seed_for_secret_and_error); let (secret_as_ntt, domain_separator) = - sample_vector_cbd_then_ntt::(prf_input, 0); - let (error_as_ntt, _) = sample_vector_cbd_then_ntt::( - prf_input, - domain_separator, - ); + sample_vector_cbd_then_ntt::(prf_input, 0); + let (error_as_ntt, _) = + sample_vector_cbd_then_ntt::( + prf_input, + domain_separator, + ); // tˆ := Aˆ ◦ sˆ + eˆ let t_as_ntt = compute_As_plus_e(&A_transpose, &secret_as_ntt, &error_as_ntt); @@ -265,6 +269,7 @@ pub(crate) fn encrypt< const ETA2: usize, const ETA2_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( public_key: &[u8], message: [u8; SHARED_SECRET_SIZE], @@ -282,7 +287,7 @@ pub(crate) fn encrypt< // end for // end for let seed = &public_key[T_AS_NTT_ENCODED_SIZE..]; - let A_transpose = sample_matrix_A(into_padded_array(seed), false); + let A_transpose = sample_matrix_A::(into_padded_array(seed), false); // for i from 0 to k−1 do // r[i] := CBD{η1}(PRF(r, N)) @@ -291,21 +296,21 @@ pub(crate) fn encrypt< // rˆ := NTT(r) let mut prf_input: [u8; 33] = into_padded_array(randomness); let (r_as_ntt, domain_separator) = - sample_vector_cbd_then_ntt::(prf_input, 0); + sample_vector_cbd_then_ntt::(prf_input, 0); // for i from 0 to k−1 do // e1[i] := CBD_{η2}(PRF(r,N)) // N := N + 1 // end for let (error_1, domain_separator) = - sample_ring_element_cbd::( + sample_ring_element_cbd::( prf_input, domain_separator, ); // e_2 := CBD{η2}(PRF(r, N)) prf_input[32] = domain_separator; - let prf_output: [u8; ETA2_RANDOMNESS_SIZE] = PRF(&prf_input); + let prf_output: [u8; ETA2_RANDOMNESS_SIZE] = Hasher::PRF(&prf_input); let error_2 = sample_from_binomial_distribution::(&prf_output); // u := NTT^{-1}(AˆT ◦ rˆ) + e_1 diff --git a/libcrux-ml-kem/src/matrix.rs b/libcrux-ml-kem/src/matrix.rs index e8e03253b..2b4d4ed85 100644 --- a/libcrux-ml-kem/src/matrix.rs +++ b/libcrux-ml-kem/src/matrix.rs @@ -1,13 +1,13 @@ use libcrux_polynomials::Operations; use crate::{ - helper::cloop, invert_ntt::invert_ntt_montgomery, polynomial::PolynomialRingElement, - sampling::sample_from_xof, + hash_functions::Hash, helper::cloop, invert_ntt::invert_ntt_montgomery, + polynomial::PolynomialRingElement, sampling::sample_from_xof, }; #[inline(always)] #[allow(non_snake_case)] -pub(crate) fn sample_matrix_A( +pub(crate) fn sample_matrix_A>( seed: [u8; 34], transpose: bool, ) -> [[PolynomialRingElement; K]; K] { @@ -21,7 +21,7 @@ pub(crate) fn sample_matrix_A( seeds[j][32] = i as u8; seeds[j][33] = j as u8; } - let sampled = sample_from_xof(seeds); + let sampled = sample_from_xof::(seeds); for (j, sample) in sampled.into_iter().enumerate() { // A[i][j] = A_transpose[j][i] if transpose { diff --git a/libcrux-ml-kem/src/sampling.rs b/libcrux-ml-kem/src/sampling.rs index 4b7d6c4cf..26e6ff216 100644 --- a/libcrux-ml-kem/src/sampling.rs +++ b/libcrux-ml-kem/src/sampling.rs @@ -71,14 +71,14 @@ fn sample_from_uniform_distribution_next( +pub(super) fn sample_from_xof>( seeds: [[u8; 34]; K], ) -> [PolynomialRingElement; K] { let mut sampled_coefficients: [usize; K] = [0; K]; let mut out: [[i16; 272]; K] = [[0; 272]; K]; - let mut xof_state = absorb(seeds); - let randomness = squeeze_three_blocks(&mut xof_state); + let mut xof_state = Hasher::shake128_init_absorb(seeds); + let randomness = xof_state.shake128_squeeze_three_blocks(); let mut done = sample_from_uniform_distribution_next::( randomness, @@ -92,15 +92,13 @@ pub(super) fn sample_from_xof( // To avoid failing here, we squeeze more blocks out of the state until // we have enough. while !done { - let randomness = squeeze_block(&mut xof_state); + let randomness = xof_state.shake128_squeeze_block(); done = sample_from_uniform_distribution_next::( randomness, &mut sampled_coefficients, &mut out, ); } - // XXX: We have to manually free the state here due to a Eurydice issue. - free_state(xof_state); out.map(|s| PolynomialRingElement::::from_i16_array(&s[0..256])) } diff --git a/libcrux-sha3/benches/sha3.rs b/libcrux-sha3/benches/sha3.rs index ce7ca2a58..0195560aa 100644 --- a/libcrux-sha3/benches/sha3.rs +++ b/libcrux-sha3/benches/sha3.rs @@ -19,10 +19,10 @@ pub fn fmt(x: usize) -> String { } macro_rules! impl_comp { - ($fun:ident, $libcrux:expr, $arm64:ident, $rust_crypto:ty, $openssl:expr) => { + ($fun:ident, $libcrux:expr) => { // Comparing libcrux performance for different payload sizes and other implementations. fn $fun(c: &mut Criterion) { - const PAYLOAD_SIZES: [usize; 1] = [1024 * 1024 * 10]; + const PAYLOAD_SIZES: [usize; 3] = [128, 1024, 1024 * 1024 * 10]; let mut group = c.benchmark_group(stringify!($fun).replace("_", " ")); @@ -43,111 +43,29 @@ macro_rules! impl_comp { }, ); + #[cfg(feature = "simd128")] group.bench_with_input( - BenchmarkId::new("rust version (simd)", fmt(*payload_size)), + BenchmarkId::new("rust version (simd128)", fmt(*payload_size)), payload_size, |b, payload_size| { b.iter_batched( || randombytes(*payload_size), |payload| { - let _d: [u8; digest_size($libcrux)] = rust_simd::$arm64(&payload); + let _d: [u8; digest_size($libcrux)] = neon::$fun(&payload); }, BatchSize::SmallInput, ) }, ); - - // group.bench_with_input( - // BenchmarkId::new("RustCrypto", fmt(*payload_size)), - // payload_size, - // |b, payload_size| { - // use sha3::Digest; - - // b.iter_batched( - // || randombytes(*payload_size), - // |payload| { - // let mut hasher = <$rust_crypto>::new(); - // hasher.update(&payload); - // let _result = hasher.finalize(); - // }, - // BatchSize::SmallInput, - // ) - // }, - // ); - - // #[cfg(all(not(windows), not(target_arch = "wasm32"), not(target_arch = "x86")))] - // group.bench_with_input( - // BenchmarkId::new("OpenSSL", fmt(*payload_size)), - // payload_size, - // |b, payload_size| { - // use openssl::hash::*; - - // b.iter_batched( - // || randombytes(*payload_size), - // |payload| { - // let _result = hash($openssl, &payload); - // }, - // BatchSize::SmallInput, - // ) - // }, - // ); - - // #[cfg(not(target_arch = "wasm32"))] - // if stringify!($fun) != "Sha3_224" { - // group.bench_with_input( - // BenchmarkId::new("PQClean", fmt(*payload_size)), - // payload_size, - // |b, payload_size| { - // b.iter_batched( - // || randombytes(*payload_size), - // |payload| { - // let mut digest = [0; libcrux::digest::digest_size($libcrux)]; - // unsafe { - // $pqclean( - // digest.as_mut_ptr(), - // payload.as_ptr() as _, - // payload.len(), - // ) - // }; - // }, - // BatchSize::SmallInput, - // ) - // }, - // ); - // } } } }; } -impl_comp!( - Sha3_224, - Algorithm::Sha3_224, - sha3_224, - sha3::Sha3_224, - MessageDigest::sha3_224() // libcrux_pqclean::sha3_256 // This is wrong, but it's not actually used. -); -impl_comp!( - Sha3_256, - Algorithm::Sha3_256, - sha3_256, - sha3::Sha3_256, - MessageDigest::sha3_256() // libcrux_pqclean::sha3_256 -); -impl_comp!( - Sha3_384, - Algorithm::Sha3_384, - sha3_384, - sha3::Sha3_384, - MessageDigest::sha3_384() // libcrux_pqclean::sha3_384 -); -impl_comp!( - Sha3_512, - Algorithm::Sha3_512, - sha3_512, - sha3::Sha3_512, - MessageDigest::sha3_512() // libcrux_pqclean::sha3_512 -); +impl_comp!(Sha3_224, Algorithm::Sha224); +impl_comp!(Sha3_256, Algorithm::Sha256); +impl_comp!(Sha3_384, Algorithm::Sha384); +impl_comp!(Sha3_512, Algorithm::Sha512); fn benchmarks(c: &mut Criterion) { Sha3_224(c); diff --git a/libcrux-sha3/build.rs b/libcrux-sha3/build.rs index a5dce81b2..ef1138666 100644 --- a/libcrux-sha3/build.rs +++ b/libcrux-sha3/build.rs @@ -17,10 +17,12 @@ fn main() { println!("cargo:rustc-cfg=feature=\"simd128\""); } if target_arch == "x86_64" && !disable_simd256 { - // We enable simd256 on all x86 and x86_64 builds. + // We enable simd256 on all x86_64 builds. // Note that this doesn't mean the required CPU features are available. // But the compiler will support them and the runtime checks ensure that // it's only used when available. + // + // We don't enable this on x86 because it seems to generate invalid code. println!("cargo:rustc-cfg=feature=\"simd256\""); } } diff --git a/libcrux-sha3/src/rust_simd/sha3_generic.rs b/libcrux-sha3/src/generic_keccak.rs similarity index 97% rename from libcrux-sha3/src/rust_simd/sha3_generic.rs rename to libcrux-sha3/src/generic_keccak.rs index d9a46718d..e5f6ea9f1 100644 --- a/libcrux-sha3/src/rust_simd/sha3_generic.rs +++ b/libcrux-sha3/src/generic_keccak.rs @@ -1,8 +1,12 @@ -use std::ops::Index; +//! The generic SHA3 implementation that uses portable or platform specific +//! sub-routines. -use crate::rust_simd::sha3_trait::*; +use core::ops::Index; + +use crate::traits::*; #[cfg_attr(hax, hax_lib::opaque_type)] +#[allow(private_bounds)] // TODO: figure out visibility #[derive(Clone, Copy)] pub struct KeccakState> { pub st: [[T; 5]; 5], @@ -16,6 +20,7 @@ impl> Index for KeccakState { } } +#[allow(private_bounds)] // TODO: figure out visibility impl> KeccakState { /// Create a new Shake128 x4 state. #[inline(always)] diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 4064e9e6f..0e5e63ddc 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -1,30 +1,44 @@ -// XXX: Can't do no_std -// #![no_std] +//! # SHA3 +//! +//! A SHA3 implementation with optional simd optimisations. -pub mod rust_simd; +#![no_std] +pub mod simd; + +mod generic_keccak; +mod portable_keccak; +mod traits; + +/// A SHA3 224 Digest pub type Sha3_224Digest = [u8; 28]; + +/// A SHA3 256 Digest pub type Sha3_256Digest = [u8; 32]; + +/// A SHA3 384 Digest pub type Sha3_384Digest = [u8; 48]; + +/// A SHA3 512 Digest pub type Sha3_512Digest = [u8; 64]; /// The Digest Algorithm. #[derive(Copy, Clone, Debug, PartialEq)] #[repr(u32)] pub enum Algorithm { - Sha3_224 = 1, - Sha3_256 = 2, - Sha3_384 = 3, - Sha3_512 = 4, + Sha224 = 1, + Sha256 = 2, + Sha384 = 3, + Sha512 = 4, } impl From for Algorithm { fn from(v: u32) -> Algorithm { match v { - 1 => Algorithm::Sha3_224, - 2 => Algorithm::Sha3_256, - 3 => Algorithm::Sha3_384, - 4 => Algorithm::Sha3_512, + 1 => Algorithm::Sha224, + 2 => Algorithm::Sha256, + 3 => Algorithm::Sha384, + 4 => Algorithm::Sha512, _ => panic!("Unknown Digest mode {}", v), } } @@ -33,10 +47,10 @@ impl From for Algorithm { impl From for u32 { fn from(v: Algorithm) -> u32 { match v { - Algorithm::Sha3_224 => 1, - Algorithm::Sha3_256 => 2, - Algorithm::Sha3_384 => 3, - Algorithm::Sha3_512 => 4, + Algorithm::Sha224 => 1, + Algorithm::Sha256 => 2, + Algorithm::Sha384 => 3, + Algorithm::Sha512 => 4, } } } @@ -44,115 +58,104 @@ impl From for u32 { /// Returns the output size of a digest. pub const fn digest_size(mode: Algorithm) -> usize { match mode { - Algorithm::Sha3_224 => 28, - Algorithm::Sha3_256 => 32, - Algorithm::Sha3_384 => 48, - Algorithm::Sha3_512 => 64, + Algorithm::Sha224 => 28, + Algorithm::Sha256 => 32, + Algorithm::Sha384 => 48, + Algorithm::Sha512 => 64, } } -// /// SHA3 -// pub fn hash(algorithm: Algorithm, payload: &[u8]) -> [u8; LEN] { -// debug_assert!(payload.len() <= u32::MAX as usize); +/// SHA3 +pub fn hash(algorithm: Algorithm, payload: &[u8]) -> [u8; LEN] { + debug_assert!(payload.len() <= u32::MAX as usize); -// let mut out = [0u8; LEN]; -// match algorithm { -// Algorithm::Sha3_224 => sha224_ema(&mut out, payload), -// Algorithm::Sha3_256 => sha256_ema(&mut out, payload), -// Algorithm::Sha3_384 => sha384_ema(&mut out, payload), -// Algorithm::Sha3_512 => sha512_ema(&mut out, payload), -// } -// out -// } + let mut out = [0u8; LEN]; + match algorithm { + Algorithm::Sha224 => portable::sha224(&mut out, payload), + Algorithm::Sha256 => portable::sha256(&mut out, payload), + Algorithm::Sha384 => portable::sha384(&mut out, payload), + Algorithm::Sha512 => portable::sha512(&mut out, payload), + } + out +} /// SHA3 224 #[inline(always)] -pub fn sha224(data: &[u8]) -> [u8; 28] { - rust_simd::sha3_224(data) +pub fn sha224(data: &[u8]) -> Sha3_224Digest { + let mut out = [0u8; 28]; + sha224_ema(&mut out, data); + out } -// /// SHA3 224 -// #[inline(always)] -// pub fn sha224_ema(digest: &mut [u8], payload: &[u8]) { -// debug_assert!(payload.len() <= u32::MAX as usize); -// debug_assert!(digest.len() == 28); +/// SHA3 224 +/// +/// Preconditions: +/// - `digest.len() == 28` +#[inline(always)] +pub fn sha224_ema(digest: &mut [u8], payload: &[u8]) { + debug_assert!(payload.len() <= u32::MAX as usize); + debug_assert!(digest.len() == 28); -// unsafe { -// Hacl_Hash_SHA3_sha3_224( -// digest.as_mut_ptr(), -// payload.as_ptr() as _, -// payload.len().try_into().unwrap(), -// ); -// } -// } + portable::sha224(digest, payload) +} /// SHA3 256 #[inline(always)] -pub fn sha256(data: &[u8]) -> [u8; 32] { - rust_simd::sha3_256(data) +pub fn sha256(data: &[u8]) -> Sha3_256Digest { + let mut out = [0u8; 32]; + sha256_ema(&mut out, data); + out } -// /// SHA3 256 -// #[inline(always)] -// pub fn sha256_ema(digest: &mut [u8], payload: &[u8]) { -// debug_assert!(payload.len() <= u32::MAX as usize); -// debug_assert!(digest.len() == 32); +/// SHA3 256 +#[inline(always)] +pub fn sha256_ema(digest: &mut [u8], payload: &[u8]) { + debug_assert!(payload.len() <= u32::MAX as usize); + debug_assert!(digest.len() == 32); -// unsafe { -// Hacl_Hash_SHA3_sha3_256( -// digest.as_mut_ptr(), -// payload.as_ptr() as _, -// payload.len().try_into().unwrap(), -// ); -// } -// } + portable::sha256(digest, payload) +} /// SHA3 384 #[inline(always)] -pub fn sha384(data: &[u8]) -> [u8; 48] { - rust_simd::sha3_384(data) +pub fn sha384(data: &[u8]) -> Sha3_384Digest { + let mut out = [0u8; 48]; + sha384_ema(&mut out, data); + out } -// /// SHA3 384 -// #[inline(always)] -// pub fn sha384_ema(digest: &mut [u8], payload: &[u8]) { -// debug_assert!(payload.len() <= u32::MAX as usize); -// debug_assert!(digest.len() == 48); +/// SHA3 384 +#[inline(always)] +pub fn sha384_ema(digest: &mut [u8], payload: &[u8]) { + debug_assert!(payload.len() <= u32::MAX as usize); + debug_assert!(digest.len() == 48); -// unsafe { -// Hacl_Hash_SHA3_sha3_384( -// digest.as_mut_ptr(), -// payload.as_ptr() as _, -// payload.len().try_into().unwrap(), -// ); -// } -// } + portable::sha384(digest, payload) +} /// SHA3 512 #[inline(always)] -pub fn sha512(data: &[u8]) -> [u8; 64] { - rust_simd::sha3_512(data) +pub fn sha512(data: &[u8]) -> Sha3_512Digest { + let mut out = [0u8; 64]; + sha512_ema(&mut out, data); + out } -// /// SHA3 512 -// #[inline(always)] -// pub fn sha512_ema(digest: &mut [u8], payload: &[u8]) { -// debug_assert!(payload.len() <= u32::MAX as usize); -// debug_assert!(digest.len() == 64); +/// SHA3 512 +#[inline(always)] +pub fn sha512_ema(digest: &mut [u8], payload: &[u8]) { + debug_assert!(payload.len() <= u32::MAX as usize); + debug_assert!(digest.len() == 64); -// unsafe { -// Hacl_Hash_SHA3_sha3_512( -// digest.as_mut_ptr(), -// payload.as_ptr() as _, -// payload.len().try_into().unwrap(), -// ); -// } -// } + portable::sha512(digest, payload) +} /// SHAKE 128 #[inline(always)] pub fn shake128(data: &[u8]) -> [u8; BYTES] { - rust_simd::shake128(data) + let mut out = [0u8; BYTES]; + portable::shake128(&mut out, data); + out } /// SHAKE 256 @@ -161,5 +164,472 @@ pub fn shake128(data: &[u8]) -> [u8; BYTES] { /// the output will only return `u32::MAX` bytes. #[inline(always)] pub fn shake256(data: &[u8]) -> [u8; BYTES] { - rust_simd::shake256(data) + let mut out = [0u8; BYTES]; + portable::shake256(&mut out, data); + out +} + +mod incremental {} + +// === The portable instantiation === // + +/// A portable SHA3 implementations without platform dependent optimisations. +pub mod portable { + use super::*; + use generic_keccak::{keccak, KeccakState}; + + pub type KeccakState1 = KeccakState<1, u64>; + + #[inline(always)] + fn keccakx1(data: [&[u8]; 1], out: [&mut [u8]; 1]) { + keccak::<1, u64, RATE, DELIM>(data, out) + } + + /// A portable SHA3 224 implementation. + pub fn sha224(digest: &mut [u8], data: &[u8]) { + keccakx1::<144, 0x06u8>([data], [digest]); + } + + /// A portable SHA3 256 implementation. + pub fn sha256(digest: &mut [u8], data: &[u8]) { + keccakx1::<136, 0x06u8>([data], [digest]); + } + + /// A portable SHA3 384 implementation. + pub fn sha384(digest: &mut [u8], data: &[u8]) { + keccakx1::<104, 0x06u8>([data], [digest]); + } + + /// A portable SHA3 512 implementation. + pub fn sha512(digest: &mut [u8], data: &[u8]) { + keccakx1::<72, 0x06u8>([data], [digest]); + } + + /// A portable SHAKE128 implementation. + pub fn shake128(digest: &mut [u8; LEN], data: &[u8]) { + keccakx1::<168, 0x1fu8>([data], [digest]); + } + + /// A portable SHAKE256 implementation. + pub fn shake256(digest: &mut [u8; LEN], data: &[u8]) { + keccakx1::<136, 0x1fu8>([data], [digest]); + } + + /// An incremental API for SHAKE + pub mod incremental { + use generic_keccak::{absorb_final, squeeze_first_three_blocks, squeeze_next_block}; + + use super::*; + + /// Initialise the SHAKE state. + pub fn shake128_init() -> KeccakState1 { + KeccakState1::new() + } + + /// Absorb + pub fn shake128_absorb_final(s: &mut KeccakState1, data0: &[u8]) { + absorb_final::<1, u64, 168, 0x1fu8>(s, [data0]); + } + + /// Squeeze three blocks + pub fn shake128_squeeze_first_three_blocks(s: &mut KeccakState1, out0: &mut [u8]) { + squeeze_first_three_blocks::<1, u64, 168>(s, [out0]) + } + + /// Squeeze another block + pub fn shake128_squeeze_next_block(s: &mut KeccakState1, out0: &mut [u8]) { + squeeze_next_block::<1, u64, 168>(s, [out0]) + } + } +} + +/// A neon optimised implementation. +/// +/// When this is compiled for non-neon architectures, the functions panic. +/// The caller must make sure to check for hardware feature before calling these +/// functions and compile them in. +/// +/// Feature `simd128` enables the implementations in this module. +pub mod neon { + use crate::generic_keccak::keccak; + + #[cfg(feature = "simd128")] + #[inline(always)] + fn keccakx2(data: [&[u8]; 2], out: [&mut [u8]; 2]) { + keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data, out) + } + + /// A portable SHA3 224 implementation. + #[allow(unused_variables)] + pub fn sha224(digest: &mut [u8], data: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + { + let mut dummy = [0u8; 28]; + keccakx2::<144, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHA3 256 implementation. + #[allow(unused_variables)] + pub fn sha256(digest: &mut [u8], data: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + { + let mut dummy = [0u8; 32]; + keccakx2::<136, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHA3 384 implementation. + #[allow(unused_variables)] + pub fn sha384(digest: &mut [u8], data: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + { + let mut dummy = [0u8; 48]; + keccakx2::<104, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHA3 512 implementation. + #[allow(unused_variables)] + pub fn sha512(digest: &mut [u8], data: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + { + let mut dummy = [0u8; 64]; + keccakx2::<72, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHAKE128 implementation. + #[allow(unused_variables)] + pub fn shake128(digest: &mut [u8; LEN], data: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + { + let mut dummy = [0u8; LEN]; + keccakx2::<168, 0x1fu8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHAKE256 implementation. + #[allow(unused_variables)] + pub fn shake256(digest: &mut [u8; LEN], data: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + { + let mut dummy = [0u8; LEN]; + keccakx2::<136, 0x1fu8>([data, data], [digest, &mut dummy]); + } + } + + /// Performing 2 operations in parallel + pub mod x2 { + use super::*; + + /// Run SHAKE256 on both inputs in parallel. + /// + /// Writes the two results into `out0` and `out1` + #[allow(unused_variables)] + pub fn shake256(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { + // TODO: make argument ordering consistent + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(feature = "simd128")] + keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); + } + + /// An incremental API to perform 2 operations in parallel + pub mod incremental { + use crate::generic_keccak::{ + absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, + }; + + #[cfg(feature = "simd128")] + pub type KeccakState2 = KeccakState<2, core::arch::aarch64::uint64x2_t>; + #[cfg(not(feature = "simd128"))] + pub type KeccakState2 = [crate::portable::KeccakState1; 2]; + + pub fn shake128_init() -> KeccakState2 { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let s0 = KeccakState1::new(); + // let s1 = KeccakState1::new(); + // [s0, s1] + // } + #[cfg(feature = "simd128")] + KeccakState2::new() + } + + #[allow(unused_variables)] + pub fn shake128_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let [mut s0, mut s1] = s; + // shake128_absorb_final(&mut s0, data0); + // shake128_absorb_final(&mut s1, data1); + // } + #[cfg(feature = "simd128")] + absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(s, [data0, data1]); + } + + #[allow(unused_variables)] + pub fn shake128_squeeze_first_three_blocks( + s: &mut KeccakState2, + out0: &mut [u8], + out1: &mut [u8], + ) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let [mut s0, mut s1] = s; + // shake128_squeeze_first_three_blocks(&mut s0, out0); + // shake128_squeeze_first_three_blocks(&mut s1, out1); + // } + #[cfg(feature = "simd128")] + squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( + s, + [out0, out1], + ) + } + + #[allow(unused_variables)] + pub fn shake128_squeeze_next_block( + s: &mut KeccakState2, + out0: &mut [u8], + out1: &mut [u8], + ) { + #[cfg(not(feature = "simd128"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let [mut s0, mut s1] = s; + // shake128_squeeze_next_block(&mut s0, out0); + // shake128_squeeze_next_block(&mut s1, out1); + // } + #[cfg(feature = "simd128")] + squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) + } + } + } +} + +/// An AVX2 optimised implementation. +/// +/// When this is compiled for non-neon architectures, the functions panic. +/// The caller must make sure to check for hardware feature before calling these +/// functions and compile them in. +/// +/// Feature `simd256` enables the implementations in this module. +pub mod avx2 { + + /// Performing 4 operations in parallel + pub mod x4 { + + /// Perform 4 SHAKE256 operations in parallel + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake256( + input0: &[u8], + input1: &[u8], + input2: &[u8], + input3: &[u8], + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + ) { + #[cfg(not(feature = "simd256"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(feature = "simd128")] + // { + // keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); + // keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); + // } + // { + // keccakx1::<136, 0x1fu8>([input0], [out0]); + // keccakx1::<136, 0x1fu8>([input1], [out1]); + // keccakx1::<136, 0x1fu8>([input2], [out2]); + // keccakx1::<136, 0x1fu8>([input3], [out3]); + // } + #[cfg(feature = "simd256")] + keccak::<4, core::arch::x86_64::__m256i, 136, 0x1fu8>( + [input0, input1, input2, input3], + [out0, out1, out2, out3], + ); + } + + /// An incremental API to perform 4 operations in parallel + pub mod incremental { + #[cfg(feature = "simd256")] + use crate::generic_keccak::{ + absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, + }; + + #[cfg(feature = "simd256")] + pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; + #[cfg(feature = "simd128")] + pub type KeccakState4 = [crate::neon::x2::incremental::KeccakState2; 2]; + #[cfg(not(any(feature = "simd256", feature = "simd128")))] + pub type KeccakState4 = [crate::portable::KeccakState1; 4]; + + pub fn shake128_init() -> KeccakState4 { + #[cfg(not(feature = "simd256"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(feature = "simd128")] + // { + // let s0 = KeccakState2::new(); + // let s1 = KeccakState2::new(); + // [s0, s1] + // } + // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // { + // let s0 = KeccakState1::new(); + // let s1 = KeccakState1::new(); + // let s2 = KeccakState1::new(); + // let s3 = KeccakState1::new(); + // [s0, s1, s2, s3] + // } + #[cfg(feature = "simd256")] + KeccakState4::new() + } + + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake128_absorb_final( + s: &mut KeccakState4, + data0: &[u8], + data1: &[u8], + data2: &[u8], + data3: &[u8], + ) { + #[cfg(not(feature = "simd256"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(feature = "simd128")] + // { + // let [mut s0, mut s1] = s; + // absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>( + // &mut s0, + // [data0, data1], + // ); + // absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>( + // &mut s1, + // [data2, data3], + // ); + // } + // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // { + // let [mut s0, mut s1, mut s2, mut s3] = s; + // shake128_absorb_final(&mut s0, data0); + // shake128_absorb_final(&mut s1, data1); + // shake128_absorb_final(&mut s2, data2); + // shake128_absorb_final(&mut s3, data3); + // } + #[cfg(feature = "simd256")] + absorb_final::<4, core::arch::x86_64::__m256i, 168, 0x1fu8>( + s, + [data0, data1, data2, data3], + ); + } + + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake128_squeeze_first_three_blocks( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + ) { + #[cfg(not(feature = "simd256"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(feature = "simd128")] + // { + // let [mut s0, mut s1] = s; + // squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s0, + // [out0, out1], + // ); + // squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s1, + // [out2, out3], + // ); + // } + // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // { + // let [mut s0, mut s1, mut s2, mut s3] = s; + // shake128_squeeze_first_three_blocks(&mut s0, out0); + // shake128_squeeze_first_three_blocks(&mut s1, out1); + // shake128_squeeze_first_three_blocks(&mut s2, out2); + // shake128_squeeze_first_three_blocks(&mut s3, out3); + // } + #[cfg(feature = "simd256")] + squeeze_first_three_blocks::<4, core::arch::x86_64::__m256i, 168>( + s, + [out0, out1, out2, out3], + ); + } + + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake128_squeeze_next_block( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + ) { + #[cfg(not(feature = "simd256"))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(feature = "simd128")] + // { + // let [mut s0, mut s1] = s; + // squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s0, + // [out0, out1], + // ); + // squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s1, + // [out2, out3], + // ); + // } + // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // { + // let [mut s0, mut s1, mut s2, mut s3] = s; + // shake128_squeeze_next_block(&mut s0, out0); + // shake128_squeeze_next_block(&mut s1, out1); + // shake128_squeeze_next_block(&mut s2, out2); + // shake128_squeeze_next_block(&mut s3, out3); + // } + #[cfg(feature = "simd256")] + squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>( + s, + [out0, out1, out2, out3], + ); + } + } + } } diff --git a/libcrux-sha3/src/rust_simd/sha3_portable.rs b/libcrux-sha3/src/portable_keccak.rs similarity index 92% rename from libcrux-sha3/src/rust_simd/sha3_portable.rs rename to libcrux-sha3/src/portable_keccak.rs index 8d08ed7b8..341399985 100644 --- a/libcrux-sha3/src/rust_simd/sha3_portable.rs +++ b/libcrux-sha3/src/portable_keccak.rs @@ -1,9 +1,6 @@ -use crate::rust_simd::sha3_trait::*; +//! A portable SHA3 implementation using the generic implementation. -// This file optimizes for the stable Rust Neon Intrinsics -// If we want to use the unstable neon-sha3 instructions, we could use: -// veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 -// These instructions might speed up our code even more. +use crate::traits::*; #[inline(always)] fn rotate_left(x: u64) -> u64 { diff --git a/libcrux-sha3/src/rust_simd.rs b/libcrux-sha3/src/rust_simd.rs deleted file mode 100644 index 9d6bf05d2..000000000 --- a/libcrux-sha3/src/rust_simd.rs +++ /dev/null @@ -1,381 +0,0 @@ -mod sha3_generic; -mod sha3_portable; -mod sha3_trait; -pub use sha3_generic::*; - -pub type KeccakState1 = KeccakState<1, u64>; -#[inline(always)] -fn keccakx1(data: [&[u8]; 1], out: [&mut [u8]; 1]) { - keccak::<1, u64, RATE, DELIM>(data, out) -} - -#[cfg(feature = "simd128")] -mod sha3_arm64; -#[cfg(feature = "simd128")] -pub type KeccakState2 = KeccakState<2, core::arch::aarch64::uint64x2_t>; -#[cfg(feature = "simd128")] -#[inline(always)] -fn keccakx2(data: [&[u8]; 2], out: [&mut [u8]; 2]) { - keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data, out) -} -#[cfg(feature = "simd128")] -pub type KeccakState4 = [KeccakState2; 2]; - -#[cfg(feature = "simd256")] -mod sha3_avx2; -// #[cfg(feature = "simd256")] -// #[inline(always)] -// fn keccakx4(data: [&[u8]; 4], out: [&mut [u8]; 4]) { -// keccak::<4, core::arch::x86_64::__m256i, RATE, DELIM>(data, out) -// } - -#[cfg(feature = "simd256")] -pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; - -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub type KeccakState2 = [KeccakState1; 2]; -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub type KeccakState4 = [KeccakState1; 4]; - -#[cfg(feature = "simd128")] -pub fn sha3_224(data: &[u8]) -> [u8; 28] { - let mut d0 = [0u8; 28]; - let mut d1 = [0u8; 28]; - keccakx2::<144, 0x06u8>([data, data], [&mut d0, &mut d1]); - d0 -} - -#[cfg(not(feature = "simd128"))] -pub fn sha3_224(data: &[u8]) -> [u8; 28] { - let mut d0 = [0u8; 28]; - keccakx1::<144, 0x06u8>([data], [&mut d0]); - d0 -} - -#[cfg(feature = "simd128")] -pub fn sha3_256(data: &[u8]) -> [u8; 32] { - let mut d0 = [0u8; 32]; - let mut d1 = [0u8; 32]; - keccakx2::<136, 0x06u8>([data, data], [&mut d0, &mut d1]); - d0 -} - -#[cfg(not(feature = "simd128"))] -pub fn sha3_256(data: &[u8]) -> [u8; 32] { - let mut d0 = [0u8; 32]; - keccakx1::<136, 0x06u8>([data], [&mut d0]); - d0 -} - -#[cfg(feature = "simd128")] -pub fn sha3_384(data: &[u8]) -> [u8; 48] { - let mut d0 = [0u8; 48]; - let mut d1 = [0u8; 48]; - keccakx2::<104, 0x06u8>([data, data], [&mut d0, &mut d1]); - d0 -} -#[cfg(not(feature = "simd128"))] -pub fn sha3_384(data: &[u8]) -> [u8; 48] { - let mut d0 = [0u8; 48]; - keccakx1::<104, 0x06u8>([data], [&mut d0]); - d0 -} - -#[cfg(feature = "simd128")] -pub fn sha3_512(data: &[u8]) -> [u8; 64] { - let mut d0 = [0u8; 64]; - let mut d1 = [0u8; 64]; - keccakx2::<72, 0x06u8>([data, data], [&mut d0, &mut d1]); - d0 -} -#[cfg(not(feature = "simd128"))] -pub fn sha3_512(data: &[u8]) -> [u8; 64] { - let mut d0 = [0u8; 64]; - keccakx1::<72, 0x06u8>([data], [&mut d0]); - d0 -} - -#[cfg(feature = "simd128")] -pub fn shake128(data: &[u8]) -> [u8; LEN] { - let mut d0 = [0u8; LEN]; - let mut d1 = [0u8; LEN]; - keccakx2::<168, 0x1fu8>([data, data], [&mut d0, &mut d1]); - d0 -} -#[cfg(not(feature = "simd128"))] -pub fn shake128(data: &[u8]) -> [u8; LEN] { - let mut d0 = [0u8; LEN]; - keccakx1::<168, 0x1fu8>([data], [&mut d0]); - d0 -} - -#[cfg(feature = "simd128")] -pub fn shake256(data: &[u8]) -> [u8; LEN] { - let mut d0 = [0u8; LEN]; - let mut d1 = [0u8; LEN]; - keccakx2::<136, 0x1fu8>([data, data], [&mut d0, &mut d1]); - d0 -} -#[cfg(not(feature = "simd128"))] -pub fn shake256(data: &[u8]) -> [u8; LEN] { - let mut d0 = [0u8; LEN]; - keccakx1::<136, 0x1fu8>([data], [&mut d0]); - d0 -} - -#[cfg(feature = "simd128")] -pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { - keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); -} -#[cfg(not(feature = "simd128"))] -pub fn shake256x2(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { - keccakx1::<136, 0x1fu8>([input0], [out0]); - keccakx1::<136, 0x1fu8>([input1], [out1]); -} - -#[cfg(feature = "simd256")] -pub fn shake256x4( - input0: &[u8], - input1: &[u8], - input2: &[u8], - input3: &[u8], - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - keccak::<4, core::arch::x86_64::__m256i, 136, 0x1fu8>( - [input0, input1, input2, input3], - [out0, out1, out2, out3], - ); -} -#[cfg(feature = "simd128")] -pub fn shake256x4( - input0: &[u8], - input1: &[u8], - input2: &[u8], - input3: &[u8], - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); - keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake256x4( - input0: &[u8], - input1: &[u8], - input2: &[u8], - input3: &[u8], - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - keccakx1::<136, 0x1fu8>([input0], [out0]); - keccakx1::<136, 0x1fu8>([input1], [out1]); - keccakx1::<136, 0x1fu8>([input2], [out2]); - keccakx1::<136, 0x1fu8>([input3], [out3]); -} - -/// Incremental API - -pub fn shake128_init() -> KeccakState1 { - KeccakState1::new() -} - -pub fn shake128_absorb_final(s: &mut KeccakState1, data0: &[u8]) { - absorb_final::<1, u64, 168, 0x1fu8>(s, [data0]); -} - -pub fn shake128_squeeze_first_three_blocks(s: &mut KeccakState1, out0: &mut [u8]) { - squeeze_first_three_blocks::<1, u64, 168>(s, [out0]) -} - -pub fn shake128_squeeze_next_block(s: &mut KeccakState1, out0: &mut [u8]) { - squeeze_next_block::<1, u64, 168>(s, [out0]) -} - -#[cfg(feature = "simd128")] -pub fn shake128x2_init() -> KeccakState2 { - KeccakState2::new() -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x2_init() -> KeccakState2 { - let s0 = KeccakState1::new(); - let s1 = KeccakState1::new(); - [s0, s1] -} - -#[cfg(feature = "simd128")] -pub fn shake128x2_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { - absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(s, [data0, data1]); -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x2_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { - let [mut s0, mut s1] = s; - shake128_absorb_final(&mut s0, data0); - shake128_absorb_final(&mut s1, data1); -} - -#[cfg(feature = "simd128")] -pub fn shake128x2_squeeze_first_three_blocks( - s: &mut KeccakState2, - out0: &mut [u8], - out1: &mut [u8], -) { - squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x2_squeeze_first_three_blocks( - s: &mut KeccakState2, - out0: &mut [u8], - out1: &mut [u8], -) { - let [mut s0, mut s1] = s; - shake128_squeeze_first_three_blocks(&mut s0, out0); - shake128_squeeze_first_three_blocks(&mut s1, out1); -} - -#[cfg(feature = "simd128")] -pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { - squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x2_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) { - let [mut s0, mut s1] = s; - shake128_squeeze_next_block(&mut s0, out0); - shake128_squeeze_next_block(&mut s1, out1); -} - -#[cfg(feature = "simd256")] -pub fn shake128x4_init() -> KeccakState4 { - KeccakState4::new() -} -#[cfg(feature = "simd128")] -pub fn shake128x4_init() -> KeccakState4 { - let s0 = KeccakState2::new(); - let s1 = KeccakState2::new(); - [s0, s1] -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x4_init() -> KeccakState4 { - let s0 = KeccakState1::new(); - let s1 = KeccakState1::new(); - let s2 = KeccakState1::new(); - let s3 = KeccakState1::new(); - [s0, s1, s2, s3] -} - -#[cfg(feature = "simd256")] -pub fn shake128x4_absorb_final( - s: &mut KeccakState4, - data0: &[u8], - data1: &[u8], - data2: &[u8], - data3: &[u8], -) { - absorb_final::<4, core::arch::x86_64::__m256i, 168, 0x1fu8>(s, [data0, data1, data2, data3]); -} -#[cfg(feature = "simd128")] -pub fn shake128x4_absorb_final( - s: &mut KeccakState4, - data0: &[u8], - data1: &[u8], - data2: &[u8], - data3: &[u8], -) { - let [mut s0, mut s1] = s; - absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(&mut s0, [data0, data1]); - absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(&mut s1, [data2, data3]); -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x4_absorb_final( - s: &mut KeccakState4, - data0: &[u8], - data1: &[u8], - data2: &[u8], - data3: &[u8], -) { - let [mut s0, mut s1, mut s2, mut s3] = s; - shake128_absorb_final(&mut s0, data0); - shake128_absorb_final(&mut s1, data1); - shake128_absorb_final(&mut s2, data2); - shake128_absorb_final(&mut s3, data3); -} - -#[cfg(feature = "simd256")] -pub fn shake128x4_squeeze_first_three_blocks( - s: &mut KeccakState4, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - squeeze_first_three_blocks::<4, core::arch::x86_64::__m256i, 168>(s, [out0, out1, out2, out3]); -} -#[cfg(feature = "simd128")] -pub fn shake128x4_squeeze_first_three_blocks( - s: &mut KeccakState4, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - let [mut s0, mut s1] = s; - squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s0, [out0, out1]); - squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s1, [out2, out3]); -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x4_squeeze_first_three_blocks( - s: &mut KeccakState4, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - let [mut s0, mut s1, mut s2, mut s3] = s; - shake128_squeeze_first_three_blocks(&mut s0, out0); - shake128_squeeze_first_three_blocks(&mut s1, out1); - shake128_squeeze_first_three_blocks(&mut s2, out2); - shake128_squeeze_first_three_blocks(&mut s3, out3); -} - -#[cfg(feature = "simd256")] -pub fn shake128x4_squeeze_next_block( - s: &mut KeccakState4, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>(s, [out0, out1, out2, out3]); -} -#[cfg(feature = "simd128")] -pub fn shake128x4_squeeze_next_block( - s: &mut KeccakState4, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - let [mut s0, mut s1] = s; - squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s0, [out0, out1]); - squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(&mut s1, [out2, out3]); -} -#[cfg(not(any(feature = "simd128", feature = "simd256")))] -pub fn shake128x4_squeeze_next_block( - s: &mut KeccakState4, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], -) { - let [mut s0, mut s1, mut s2, mut s3] = s; - shake128_squeeze_next_block(&mut s0, out0); - shake128_squeeze_next_block(&mut s1, out1); - shake128_squeeze_next_block(&mut s2, out2); - shake128_squeeze_next_block(&mut s3, out3); -} diff --git a/libcrux-sha3/src/simd.rs b/libcrux-sha3/src/simd.rs new file mode 100644 index 000000000..7dc25011a --- /dev/null +++ b/libcrux-sha3/src/simd.rs @@ -0,0 +1,11 @@ +//! SIMD implementations of SHA3 +//! +//! Runtime feature detection must be performed before calling these functions. +//! +//! If the caller does not perform feature detection, the top level functions +//! must be used. + +#[cfg(feature = "simd128")] +mod arm64; +#[cfg(feature = "simd256")] +mod avx2; diff --git a/libcrux-sha3/src/rust_simd/sha3_arm64.rs b/libcrux-sha3/src/simd/arm64.rs similarity index 99% rename from libcrux-sha3/src/rust_simd/sha3_arm64.rs rename to libcrux-sha3/src/simd/arm64.rs index a4b3a301b..5d847ca7c 100644 --- a/libcrux-sha3/src/rust_simd/sha3_arm64.rs +++ b/libcrux-sha3/src/simd/arm64.rs @@ -1,6 +1,7 @@ -use crate::rust_simd::sha3_trait::*; use core::arch::aarch64::*; +use crate::traits::KeccakItem; + // This file optimizes for the stable Rust Neon Intrinsics // If we want to use the unstable neon-sha3 instructions, we could use: // veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 diff --git a/libcrux-sha3/src/rust_simd/sha3_avx2.rs b/libcrux-sha3/src/simd/avx2.rs similarity index 97% rename from libcrux-sha3/src/rust_simd/sha3_avx2.rs rename to libcrux-sha3/src/simd/avx2.rs index 8146ccec6..153270906 100644 --- a/libcrux-sha3/src/rust_simd/sha3_avx2.rs +++ b/libcrux-sha3/src/simd/avx2.rs @@ -1,11 +1,6 @@ use core::arch::x86_64::*; -use crate::rust_simd::sha3_trait::*; - -// This file optimizes for the stable Rust Neon Intrinsics -// If we want to use the unstable neon-sha3 instructions, we could use: -// veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 -// These instructions might speed up our code even more. +use crate::traits::*; #[inline(always)] fn rotate_left(x: __m256i) -> __m256i { diff --git a/libcrux-sha3/src/rust_simd/sha3_trait.rs b/libcrux-sha3/src/traits.rs similarity index 89% rename from libcrux-sha3/src/rust_simd/sha3_trait.rs rename to libcrux-sha3/src/traits.rs index 0ad85dae8..a499305e3 100644 --- a/libcrux-sha3/src/rust_simd/sha3_trait.rs +++ b/libcrux-sha3/src/traits.rs @@ -1,4 +1,5 @@ -pub trait KeccakItem: Clone + Copy { +/// A trait for multiplexing implementations. +pub(crate) trait KeccakItem: Clone + Copy { fn zero() -> Self; fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self; fn rotate_left1_and_xor(a: Self, b: Self) -> Self; diff --git a/libcrux-sha3/tests/sha3.rs b/libcrux-sha3/tests/sha3.rs index dfc5f35dd..a4b7e3248 100644 --- a/libcrux-sha3/tests/sha3.rs +++ b/libcrux-sha3/tests/sha3.rs @@ -12,11 +12,11 @@ fn sha3_kat_oneshot() { #[test] fn sha3_simd_kat_oneshot() { - let d256 = libcrux_sha3::rust_simd::sha3_256(b"Hello, World!"); + let d256 = libcrux_sha3::sha256(b"Hello, World!"); let expected256 = "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"; assert_eq!(hex::encode(&d256), expected256); - let dshake = libcrux_sha3::rust_simd::shake128::<42>(b"Hello, World!"); + let dshake = libcrux_sha3::shake128::<42>(b"Hello, World!"); let expectedshake = "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; assert_eq!(hex::encode(&dshake), expectedshake); diff --git a/polynomials/build.rs b/polynomials/build.rs index f15f3d581..ef1138666 100644 --- a/polynomials/build.rs +++ b/polynomials/build.rs @@ -16,11 +16,13 @@ fn main() { // We enable simd128 on all aarch64 builds. println!("cargo:rustc-cfg=feature=\"simd128\""); } - if (target_arch == "x86" || target_arch == "x86_64") && !disable_simd256 { - // We enable simd256 on all x86 and x86_64 builds. + if target_arch == "x86_64" && !disable_simd256 { + // We enable simd256 on all x86_64 builds. // Note that this doesn't mean the required CPU features are available. // But the compiler will support them and the runtime checks ensure that // it's only used when available. + // + // We don't enable this on x86 because it seems to generate invalid code. println!("cargo:rustc-cfg=feature=\"simd256\""); } } From d51781e7354f9c2da30e1bf0ed91101e5a20fd79 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Thu, 16 May 2024 15:59:35 +0200 Subject: [PATCH 38/59] rustfmt --- polynomials-avx2/src/arithmetic.rs | 5 ++--- polynomials-avx2/src/compress.rs | 8 ++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/polynomials-avx2/src/arithmetic.rs b/polynomials-avx2/src/arithmetic.rs index 701d76ad6..e51fd5b5b 100644 --- a/polynomials-avx2/src/arithmetic.rs +++ b/polynomials-avx2/src/arithmetic.rs @@ -28,7 +28,7 @@ pub(crate) fn shift_right(vector: __m256i) -> __m256i { #[inline(always)] pub(crate) fn shift_left(vector: __m256i) -> __m256i { - mm256_slli_epi16::<{SHIFT_BY}>(vector) + mm256_slli_epi16::<{ SHIFT_BY }>(vector) } #[inline(always)] @@ -52,8 +52,7 @@ pub(crate) fn barrett_reduce(vector: __m256i) -> __m256i { let quotient = mm256_srai_epi16::<10>(t); - let quotient_times_field_modulus = - mm256_mullo_epi16(quotient, mm256_set1_epi16(FIELD_MODULUS)); + let quotient_times_field_modulus = mm256_mullo_epi16(quotient, mm256_set1_epi16(FIELD_MODULUS)); mm256_sub_epi16(vector, quotient_times_field_modulus) } diff --git a/polynomials-avx2/src/compress.rs b/polynomials-avx2/src/compress.rs index 57f0f98d2..858a5b278 100644 --- a/polynomials-avx2/src/compress.rs +++ b/polynomials-avx2/src/compress.rs @@ -46,7 +46,7 @@ pub(crate) fn compress_ciphertext_coefficient( let coefficients_low = mm256_castsi256_si128(vector); let coefficients_low = mm256_cvtepi16_epi32(coefficients_low); - let compressed_low = mm256_slli_epi32::<{COEFFICIENT_BITS}>(coefficients_low); + let compressed_low = mm256_slli_epi32::<{ COEFFICIENT_BITS }>(coefficients_low); let compressed_low = mm256_add_epi32(compressed_low, field_modulus_halved); let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); @@ -57,7 +57,7 @@ pub(crate) fn compress_ciphertext_coefficient( let coefficients_high = mm256_extracti128_si256::<1>(vector); let coefficients_high = mm256_cvtepi16_epi32(coefficients_high); - let compressed_high = mm256_slli_epi32::<{COEFFICIENT_BITS}>(coefficients_high); + let compressed_high = mm256_slli_epi32::<{ COEFFICIENT_BITS }>(coefficients_high); let compressed_high = mm256_add_epi32(compressed_high, field_modulus_halved); let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); @@ -87,7 +87,7 @@ pub(crate) fn decompress_ciphertext_coefficient( // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack // of support for const generic expressions. - let decompressed_low = mm256_srli_epi32::<{COEFFICIENT_BITS}>(decompressed_low); + let decompressed_low = mm256_srli_epi32::<{ COEFFICIENT_BITS }>(decompressed_low); let decompressed_low = mm256_srli_epi32::<1>(decompressed_low); // Compress the next 8 coefficients @@ -100,7 +100,7 @@ pub(crate) fn decompress_ciphertext_coefficient( // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack // of support for const generic expressions. - let decompressed_high = mm256_srli_epi32::<{COEFFICIENT_BITS}>(decompressed_high); + let decompressed_high = mm256_srli_epi32::<{ COEFFICIENT_BITS }>(decompressed_high); let decompressed_high = mm256_srli_epi32::<1>(decompressed_high); // Combine them From 82ed694babe8ada20a0a7746663f7e7fa008a8d6 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Thu, 16 May 2024 16:51:25 +0200 Subject: [PATCH 39/59] fixup sha3 features --- libcrux-sha3/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 0e5e63ddc..d6a5bedd7 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -251,6 +251,7 @@ pub mod portable { /// /// Feature `simd128` enables the implementations in this module. pub mod neon { + #[cfg(feature = "simd128")] use crate::generic_keccak::keccak; #[cfg(feature = "simd128")] @@ -333,6 +334,7 @@ pub mod neon { /// Performing 2 operations in parallel pub mod x2 { + #[cfg(feature = "simd128")] use super::*; /// Run SHAKE256 on both inputs in parallel. @@ -349,6 +351,7 @@ pub mod neon { /// An incremental API to perform 2 operations in parallel pub mod incremental { + #[cfg(feature = "simd128")] use crate::generic_keccak::{ absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, }; @@ -442,6 +445,8 @@ pub mod avx2 { /// Performing 4 operations in parallel pub mod x4 { + #[cfg(feature = "simd256")] + use crate::generic_keccak::keccak; /// Perform 4 SHAKE256 operations in parallel #[allow(unused_variables)] // TODO: decide if we want to fall back here From f39d840116e1ce12c5937d5843060574071ae0f0 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Thu, 16 May 2024 20:24:00 +0200 Subject: [PATCH 40/59] check for target_arch in ml-kem --- libcrux-ml-kem/src/ind_cca.rs | 56 +++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/libcrux-ml-kem/src/ind_cca.rs b/libcrux-ml-kem/src/ind_cca.rs index ac0cb8b98..a538e313b 100644 --- a/libcrux-ml-kem/src/ind_cca.rs +++ b/libcrux-ml-kem/src/ind_cca.rs @@ -52,8 +52,11 @@ pub(crate) fn validate_public_key< >( public_key: &[u8; PUBLIC_KEY_SIZE], ) -> bool { - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return validate_public_key_generic::< K, RANKED_BYTES_PER_RING_ELEMENT, @@ -67,8 +70,11 @@ pub(crate) fn validate_public_key< PUBLIC_KEY_SIZE, PortableVector, >(public_key) - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { - #[cfg(feature = "simd128")] + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] return validate_public_key_generic::< K, RANKED_BYTES_PER_RING_ELEMENT, @@ -127,8 +133,11 @@ pub(crate) fn generate_keypair< let implicit_rejection_value = &randomness[CPA_PKE_KEY_GENERATION_SEED_SIZE..]; // Runtime feature detection. - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return generate_keypair_generic::< K, CPA_PRIVATE_KEY_SIZE, @@ -152,8 +161,11 @@ pub(crate) fn generate_keypair< PortableVector, hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { - #[cfg(feature = "simd128")] + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] return generate_keypair_generic::< K, CPA_PRIVATE_KEY_SIZE, @@ -246,8 +258,11 @@ pub(crate) fn encapsulate< public_key: &MlKemPublicKey, randomness: [u8; SHARED_SECRET_SIZE], ) -> (MlKemCiphertext, MlKemSharedSecret) { - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return encapsulate_generic::< K, CIPHERTEXT_SIZE, @@ -283,7 +298,10 @@ pub(crate) fn encapsulate< PortableVector, hash_functions::portable::PortableHash, >(public_key, randomness) - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { #[cfg(not(feature = "simd128"))] return encapsulate_generic::< K, @@ -302,7 +320,7 @@ pub(crate) fn encapsulate< PortableVector, hash_functions::portable::PortableHash, >(public_key, randomness); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] encapsulate_generic::< K, CIPHERTEXT_SIZE, @@ -409,8 +427,11 @@ pub(crate) fn decapsulate< private_key: &MlKemPrivateKey, ciphertext: &MlKemCiphertext, ) -> MlKemSharedSecret { - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return decapsulate_generic::< K, SECRET_KEY_SIZE, @@ -452,8 +473,11 @@ pub(crate) fn decapsulate< PortableVector, hash_functions::portable::PortableHash, >(private_key, ciphertext); - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { - #[cfg(feature = "simd128")] + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] return decapsulate_generic::< K, SECRET_KEY_SIZE, From 8fd44b76421e5d90835feb86044fc1ebfd71308f Mon Sep 17 00:00:00 2001 From: Karthikeyan Bhargavan Date: Thu, 16 May 2024 20:45:51 +0200 Subject: [PATCH 41/59] added pointer annotations for neon --- polynomials-aarch64/src/neon.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/polynomials-aarch64/src/neon.rs b/polynomials-aarch64/src/neon.rs index d233c6cc6..d0454cfef 100644 --- a/polynomials-aarch64/src/neon.rs +++ b/polynomials-aarch64/src/neon.rs @@ -193,7 +193,7 @@ pub(crate) fn _vmlal_high_s16(a: int32x4_t, b: int16x8_t, c: int16x8_t) -> int32 } #[inline(always)] pub(crate) fn _vld1q_u8(ptr: &[u8]) -> uint8x16_t { - unsafe { vld1q_u8(ptr.as_ptr()) } + unsafe { vld1q_u8(ptr.as_ptr() as *const u8) } } #[inline(always)] pub(crate) fn _vreinterpretq_u8_s16(a: int16x8_t) -> uint8x16_t { @@ -254,7 +254,7 @@ pub(super) fn _vreinterpretq_u8_s64(a: int64x2_t) -> uint8x16_t { #[inline(always)] pub(super) fn _vst1q_u8(out: &mut [u8], v: uint8x16_t) { - unsafe { vst1q_u8(out.as_mut_ptr(), v) } + unsafe { vst1q_u8(out.as_mut_ptr() as *mut u8, v) } } #[inline(always)] pub(crate) fn _vdupq_n_u16(value: u16) -> uint16x8_t { @@ -270,7 +270,7 @@ pub(crate) fn _vreinterpretq_u16_u8(a: uint8x16_t) -> uint16x8_t { } #[inline(always)] pub(crate) fn _vld1q_u16(ptr: &[u16]) -> uint16x8_t { - unsafe { vld1q_u16(ptr.as_ptr()) } + unsafe { vld1q_u16(ptr.as_ptr() as *const u16) } } #[inline(always)] pub(crate) fn _vcleq_s16(a: int16x8_t, b: int16x8_t) -> uint16x8_t { From 1065a337876b10e793060e425b3e3ed7f802cfed Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 07:56:34 +0200 Subject: [PATCH 42/59] bump libc --- sys/platform/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sys/platform/Cargo.toml b/sys/platform/Cargo.toml index 6c9f512af..3907c49f8 100644 --- a/sys/platform/Cargo.toml +++ b/sys/platform/Cargo.toml @@ -10,4 +10,4 @@ readme.workspace = true description = "Platform detection crate for libcrux." [dependencies] -libc = "0.2.147" +libc = "0.2.154" From 30696f888a7a35d6db37ae120ff8976635e9478b Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 08:25:39 +0200 Subject: [PATCH 43/59] check for arm on macos platform --- sys/platform/Cargo.toml | 2 +- sys/platform/src/lib.rs | 3 --- sys/platform/src/macos_arm.rs | 32 +++++++++++++++++++++++++++++++- sys/platform/src/test.rs | 3 +++ 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/sys/platform/Cargo.toml b/sys/platform/Cargo.toml index 3907c49f8..6766bf9b9 100644 --- a/sys/platform/Cargo.toml +++ b/sys/platform/Cargo.toml @@ -10,4 +10,4 @@ readme.workspace = true description = "Platform detection crate for libcrux." [dependencies] -libc = "0.2.154" +libc = "0.2.153" diff --git a/sys/platform/src/lib.rs b/sys/platform/src/lib.rs index 96619cafa..d2e8b0310 100644 --- a/sys/platform/src/lib.rs +++ b/sys/platform/src/lib.rs @@ -9,9 +9,6 @@ #[macro_use] extern crate std; -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// sflwnwc - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod x86; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] diff --git a/sys/platform/src/macos_arm.rs b/sys/platform/src/macos_arm.rs index 4d92236df..2edafdbdc 100644 --- a/sys/platform/src/macos_arm.rs +++ b/sys/platform/src/macos_arm.rs @@ -1,6 +1,30 @@ //! Obtain particular CPU features for AArch64 on macOS -use libc::{c_void, sysctlbyname}; +use libc::{c_char, c_void, sysctlbyname, uname, utsname}; + +#[allow(dead_code)] +fn cstr(src: &[i8]) -> &str { + // default to length if no `0` present + let end = src.iter().position(|&c| c == 0).unwrap_or(src.len()); + unsafe { core::str::from_utf8_unchecked(core::mem::transmute(&src[0..end])) } +} + +/// Check that we're actually on an ARM mac. +/// When this returns false, no other function in here must be called. +pub(crate) fn actually_arm() -> bool { + let mut buf = utsname { + sysname: [c_char::default(); 256], + nodename: [c_char::default(); 256], + release: [c_char::default(); 256], + version: [c_char::default(); 256], + machine: [c_char::default(); 256], + }; + let error = unsafe { uname(&mut buf) }; + // buf.machine == "arm" + // It could also be "arm64". + let arm = buf.machine[0] == 97 && buf.machine[1] == 114 && buf.machine[2] == 109; + error == 0 && arm +} #[inline(always)] fn check(feature: &[i8]) -> bool { @@ -20,6 +44,12 @@ fn check(feature: &[i8]) -> bool { #[inline(always)] fn sysctl() { + // Check that we're actually on arm and set everything to false if not. + // This may happen when running on an intel mac. + if !actually_arm() { + return; + } + // hw.optional.AdvSIMD const ADV_SIMD_STR: [i8; 20] = [ 0x68, 0x77, 0x2e, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2e, 0x41, 0x64, 0x76, diff --git a/sys/platform/src/test.rs b/sys/platform/src/test.rs index b2b66c9ff..c23d60b05 100644 --- a/sys/platform/src/test.rs +++ b/sys/platform/src/test.rs @@ -1,9 +1,12 @@ //! Test functions for CPU feature detection +use crate::macos_arm::actually_arm; + use super::*; #[test] fn dump_features() { + eprintln!("arm\t\t{:?}", actually_arm()); eprintln!("simd128\t\t{:?}", simd128_support()); eprintln!("simd256\t\t{:?}", simd256_support()); eprintln!("x25519\t\t{:?}", x25519_support()); From 4b1c8a2155e3ce1c933a6257badefccfdd7c9636 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 09:01:30 +0200 Subject: [PATCH 44/59] check platform on ci --- .github/workflows/platform.yml | 87 ++++++++++++++++++++++++++++++++++ sys/platform/src/test.rs | 3 -- 2 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/platform.yml diff --git a/.github/workflows/platform.yml b/.github/workflows/platform.yml new file mode 100644 index 000000000..899e16353 --- /dev/null +++ b/.github/workflows/platform.yml @@ -0,0 +1,87 @@ +name: Platform + +on: + push: + branches: ["main", "dev"] + pull_request: + branches: ["main", "dev", "*"] + workflow_dispatch: + merge_group: + +env: + CARGO_TERM_COLOR: always + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + bits: [32, 64] + os: + - macos-latest + - ubuntu-latest + - windows-latest + exclude: + - bits: 32 + os: "macos-latest" + + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: sys/platform + + steps: + - uses: actions/checkout@v4 + + - name: 🔨 Build + run: cargo build --verbose + + - name: 🏃🏻‍♀️ Test + run: cargo test --verbose + + - name: 🏃🏻‍♀️ Test Release + run: cargo test --verbose --release + + - name: 🛠️ Setup Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: | + rustup target add i686-unknown-linux-gnu + + - name: 🏃🏻‍♀️ Test Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: cargo test --verbose --target i686-unknown-linux-gnu + + - name: 🏃🏻‍♀️ Test Release Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: cargo test --verbose --release --target i686-unknown-linux-gnu + + - name: 🛠️ Setup Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: | + rustup target add i686-pc-windows-msvc + + - name: 🏃🏻‍♀️ Test Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: cargo test --verbose --target i686-pc-windows-msvc + + - name: 🏃🏻‍♀️ Test Release Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: cargo test --verbose --release --target i686-pc-windows-msvc + + - name: 🛠️ Setup MacOS x86_64 + if: ${{ matrix.os == 'macos-latest' }} + run: | + rustup target add x86_64-apple-darwin + + - name: 🏃🏻‍♀️ Test MacOS x86_64 + if: ${{ matrix.os == 'macos-latest' }} + run: cargo test --verbose --target x86_64-apple-darwin + + - name: 🏃🏻‍♀️ Test Release MacOS x86_64 + if: ${{ matrix.os == 'macos-latest' }} + run: cargo test --verbose --release --target x86_64-apple-darwin diff --git a/sys/platform/src/test.rs b/sys/platform/src/test.rs index c23d60b05..b2b66c9ff 100644 --- a/sys/platform/src/test.rs +++ b/sys/platform/src/test.rs @@ -1,12 +1,9 @@ //! Test functions for CPU feature detection -use crate::macos_arm::actually_arm; - use super::*; #[test] fn dump_features() { - eprintln!("arm\t\t{:?}", actually_arm()); eprintln!("simd128\t\t{:?}", simd128_support()); eprintln!("simd256\t\t{:?}", simd256_support()); eprintln!("x25519\t\t{:?}", x25519_support()); From 9841f6f62a4fa86a4d3ec41d04e08911dec29a7d Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 09:04:10 +0200 Subject: [PATCH 45/59] work around gh actions --- .github/workflows/mlkem.yml | 70 +++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index f0b7040f7..f2bf139f9 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -192,3 +192,73 @@ jobs: - name: 🏃🏻‍♀️ Benchmarks Clang if: ${{ matrix.os != 'windows-latest' }} run: CC=clang cargo bench --verbose $RUST_TARGET_FLAG + + platform: + strategy: + fail-fast: false + matrix: + bits: [32, 64] + os: + - macos-latest + - ubuntu-latest + - windows-latest + exclude: + - bits: 32 + os: "macos-latest" + + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: sys/platform + + steps: + - uses: actions/checkout@v4 + + - name: 🔨 Build + run: cargo build --verbose + + - name: 🏃🏻‍♀️ Test + run: cargo test --verbose + + - name: 🏃🏻‍♀️ Test Release + run: cargo test --verbose --release + + - name: 🛠️ Setup Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: | + rustup target add i686-unknown-linux-gnu + + - name: 🏃🏻‍♀️ Test Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: cargo test --verbose --target i686-unknown-linux-gnu + + - name: 🏃🏻‍♀️ Test Release Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: cargo test --verbose --release --target i686-unknown-linux-gnu + + - name: 🛠️ Setup Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: | + rustup target add i686-pc-windows-msvc + + - name: 🏃🏻‍♀️ Test Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: cargo test --verbose --target i686-pc-windows-msvc + + - name: 🏃🏻‍♀️ Test Release Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: cargo test --verbose --release --target i686-pc-windows-msvc + + - name: 🛠️ Setup MacOS x86_64 + if: ${{ matrix.os == 'macos-latest' }} + run: | + rustup target add x86_64-apple-darwin + + - name: 🏃🏻‍♀️ Test MacOS x86_64 + if: ${{ matrix.os == 'macos-latest' }} + run: cargo test --verbose --target x86_64-apple-darwin + + - name: 🏃🏻‍♀️ Test Release MacOS x86_64 + if: ${{ matrix.os == 'macos-latest' }} + run: cargo test --verbose --release --target x86_64-apple-darwin From 3a3071d969735fa36da8686c62e10570e7d2ac1c Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 09:07:10 +0200 Subject: [PATCH 46/59] update platforms --- .github/workflows/mlkem.yml | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index f2bf139f9..2313c1e83 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -199,7 +199,8 @@ jobs: matrix: bits: [32, 64] os: - - macos-latest + - macos-13 # Intel mac + - macos-latest # macos-14 m1 - ubuntu-latest - windows-latest exclude: @@ -219,10 +220,10 @@ jobs: run: cargo build --verbose - name: 🏃🏻‍♀️ Test - run: cargo test --verbose + run: cargo test --verbose -- --nocapture - name: 🏃🏻‍♀️ Test Release - run: cargo test --verbose --release + run: cargo test --verbose --release -- --nocapture - name: 🛠️ Setup Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} @@ -231,34 +232,21 @@ jobs: - name: 🏃🏻‍♀️ Test Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - run: cargo test --verbose --target i686-unknown-linux-gnu + run: cargo test --verbose --target i686-unknown-linux-gnu -- --nocapture - name: 🏃🏻‍♀️ Test Release Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - run: cargo test --verbose --release --target i686-unknown-linux-gnu + run: cargo test --verbose --release --target i686-unknown-linux-gnu -- --nocapture - name: 🛠️ Setup Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} run: | - rustup target add i686-pc-windows-msvc + rustup target add i686-pc-windows-msvc -- --nocapture - name: 🏃🏻‍♀️ Test Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: cargo test --verbose --target i686-pc-windows-msvc + run: cargo test --verbose --target i686-pc-windows-msvc -- --nocapture - name: 🏃🏻‍♀️ Test Release Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: cargo test --verbose --release --target i686-pc-windows-msvc - - - name: 🛠️ Setup MacOS x86_64 - if: ${{ matrix.os == 'macos-latest' }} - run: | - rustup target add x86_64-apple-darwin - - - name: 🏃🏻‍♀️ Test MacOS x86_64 - if: ${{ matrix.os == 'macos-latest' }} - run: cargo test --verbose --target x86_64-apple-darwin - - - name: 🏃🏻‍♀️ Test Release MacOS x86_64 - if: ${{ matrix.os == 'macos-latest' }} - run: cargo test --verbose --release --target x86_64-apple-darwin + run: cargo test --verbose --release --target i686-pc-windows-msvc -- --nocapture From ad16ac69afe12f23cdeeabd30a07a42c5c33d0b9 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 09:21:33 +0200 Subject: [PATCH 47/59] asan --- .github/workflows/mlkem.yml | 40 ++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 2313c1e83..9a82a79b7 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -44,23 +44,18 @@ jobs: - run: echo "RUST_TARGET_FLAG=" > $GITHUB_ENV if: ${{ matrix.bits == 64 }} - - name: ⚙️ Setup Ubuntu x86 + - name: 🛠️ Setup Ubuntu x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: | rustup target add i686-unknown-linux-gnu sudo apt-get update sudo apt-get install -y gcc-multilib g++-multilib - - name: ⚙️ Setup Ubuntu x64 + - name: 🛠️ Setup Ubuntu x64 if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} run: | rustup target add aarch64-unknown-linux-gnu - - name: ⚙️ Setup macOS - if: ${{ matrix.os == 'macos-latest' }} - run: | - rustup target add aarch64-apple-darwin - # Set up 32 bit systems - name: 🛠️ Config Windows x86 @@ -74,14 +69,14 @@ jobs: # Set up windows - - name: ⚙️ Setup Windows x86 + - name: 🛠️ Setup Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} shell: pwsh run: | echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append vcpkg install openssl:x86-windows-static-md - - name: ⚙️ Setup Windows x64 + - name: 🛠️ Setup Windows x64 if: ${{ matrix.bits == 64 && matrix.os == 'windows-latest' }} shell: pwsh run: | @@ -96,11 +91,13 @@ jobs: - name: 🔨 Build Release run: cargo build --verbose --release $RUST_TARGET_FLAG - # Cross compilation - - - name: 🔨 Build aarch64 macOS + - name: 🏃🏻 Asan MacOS if: ${{ matrix.os == 'macos-latest' }} - run: cargo build --verbose --target aarch64-apple-darwin + run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target aarch64-apple-darwin + + - name: 🏃🏻 Asan Linux + if: ${{ matrix.os == 'ubuntu-latest' }} + run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu # Test ... @@ -139,19 +136,19 @@ jobs: - run: echo "RUST_TARGET_FLAG=" > $GITHUB_ENV if: ${{ matrix.bits == 64 }} - - name: ⚙️ Setup Ubuntu x86 + - name: 🛠️ Setup Ubuntu x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: | rustup target add i686-unknown-linux-gnu sudo apt-get update sudo apt-get install -y gcc-multilib g++-multilib - - name: ⚙️ Setup Ubuntu x64 + - name: 🛠️ Setup Ubuntu x64 if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} run: | rustup target add aarch64-unknown-linux-gnu - - name: ⚙️ Setup macOS + - name: 🛠️ Setup macOS if: ${{ matrix.os == 'macos-latest' }} run: | rustup target add aarch64-apple-darwin @@ -169,14 +166,14 @@ jobs: # Set up windows - - name: ⚙️ Setup Windows x86 + - name: 🛠️ Setup Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} shell: pwsh run: | echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append vcpkg install openssl:x86-windows-static-md - - name: ⚙️ Setup Windows x64 + - name: 🛠️ Setup Windows x64 if: ${{ matrix.bits == 64 && matrix.os == 'windows-latest' }} shell: pwsh run: | @@ -229,6 +226,8 @@ jobs: if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: | rustup target add i686-unknown-linux-gnu + sudo apt-get update + sudo apt-get install -y gcc-multilib g++-multilib - name: 🏃🏻‍♀️ Test Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} @@ -238,11 +237,6 @@ jobs: if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: cargo test --verbose --release --target i686-unknown-linux-gnu -- --nocapture - - name: 🛠️ Setup Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: | - rustup target add i686-pc-windows-msvc -- --nocapture - - name: 🏃🏻‍♀️ Test Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} run: cargo test --verbose --target i686-pc-windows-msvc -- --nocapture From 0c3d8448a09d8b8a00a320aa48565005f39eb46c Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 09:25:53 +0200 Subject: [PATCH 48/59] install nightly on ci --- .github/workflows/mlkem.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 9a82a79b7..89d1686c2 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -43,6 +43,9 @@ jobs: - run: echo "RUST_TARGET_FLAG=" > $GITHUB_ENV if: ${{ matrix.bits == 64 }} + + - name: 🛠️ Setup Rust Nightly + run: rustup toolchain install nightly - name: 🛠️ Setup Ubuntu x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} From ef9de75def1958355df750e788578c2f37fcdb5b Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 10:00:24 +0200 Subject: [PATCH 49/59] clean before test --- .github/workflows/mlkem.yml | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 89d1686c2..03244c5cf 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -22,15 +22,15 @@ jobs: matrix: bits: [32, 64] os: - - macos-latest + - macos-13 # Intel mac + - macos-latest # macos-14 m1 - ubuntu-latest - windows-latest exclude: - bits: 32 os: "macos-latest" - # FIXME: Linking isn't working here yet for hacl #42 - bits: 32 - os: "windows-latest" + os: "macos-13" runs-on: ${{ matrix.os }} defaults: @@ -43,7 +43,7 @@ jobs: - run: echo "RUST_TARGET_FLAG=" > $GITHUB_ENV if: ${{ matrix.bits == 64 }} - + - name: 🛠️ Setup Rust Nightly run: rustup toolchain install nightly @@ -99,16 +99,16 @@ jobs: run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target aarch64-apple-darwin - name: 🏃🏻 Asan Linux - if: ${{ matrix.os == 'ubuntu-latest' }} + if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu # Test ... - name: 🏃🏻‍♀️ Test - run: cargo test --verbose $RUST_TARGET_FLAG + run: cargo clean && cargo test --verbose $RUST_TARGET_FLAG - name: 🏃🏻‍♀️ Test Release - run: cargo test --verbose --release $RUST_TARGET_FLAG + run: cargo clean && cargo test --verbose --release $RUST_TARGET_FLAG benchmarks: strategy: @@ -116,6 +116,7 @@ jobs: matrix: bits: [32, 64] os: + - macos-13 - macos-latest - ubuntu-latest - windows-latest @@ -123,9 +124,8 @@ jobs: # There's no such thing as 32-bit macOS - bits: 32 os: "macos-latest" - # FIXME: Linking isn't working here yet for hacl #42 - bits: 32 - os: "windows-latest" + os: "macos-13" runs-on: ${{ matrix.os }} defaults: @@ -206,6 +206,8 @@ jobs: exclude: - bits: 32 os: "macos-latest" + - bits: 32 + os: "macos-13" runs-on: ${{ matrix.os }} defaults: From 5b39c285b8abc1bd13db436e544f95cde7df3500 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 11:23:42 +0200 Subject: [PATCH 50/59] disable linux asan for now --- .github/workflows/mlkem.yml | 40 +++++-------------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 03244c5cf..004bfe164 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -70,22 +70,6 @@ jobs: echo "RUST_TARGET_FLAG=--target=i686-unknown-linux-gnu" > $GITHUB_ENV if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - # Set up windows - - - name: 🛠️ Setup Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x86-windows-static-md - - - name: 🛠️ Setup Windows x64 - if: ${{ matrix.bits == 64 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x64-windows-static-md - # Build ... - name: 🔨 Build @@ -98,9 +82,11 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target aarch64-apple-darwin - - name: 🏃🏻 Asan Linux - if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} - run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu + # We get false positives here. + # TODO: Figure out what is going on here + # - name: 🏃🏻 Asan Linux + # if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} + # run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu # Test ... @@ -167,22 +153,6 @@ jobs: echo "RUST_TARGET_FLAG=--target=i686-unknown-linux-gnu" > $GITHUB_ENV if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - # Set up windows - - - name: 🛠️ Setup Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x86-windows-static-md - - - name: 🛠️ Setup Windows x64 - if: ${{ matrix.bits == 64 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x64-windows-static-md - # Benchmarks ... - name: 🏃🏻‍♀️ Benchmarks Windows From 3d75761f855458eb4c950c45cce31700d1e0449b Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 11:40:35 +0200 Subject: [PATCH 51/59] more guards --- .github/workflows/mlkem.yml | 8 ++- libcrux-sha3/benches/sha3.rs | 15 ++--- libcrux-sha3/src/lib.rs | 107 ++++++++++++++++++----------------- 3 files changed, 69 insertions(+), 61 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 004bfe164..2783a9a17 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -91,10 +91,14 @@ jobs: # Test ... - name: 🏃🏻‍♀️ Test - run: cargo clean && cargo test --verbose $RUST_TARGET_FLAG + run: | + cargo test --verbose $RUST_TARGET_FLAG + cd ../libcrux-sha3 && cargo test --verbose $RUST_TARGET_FLAG - name: 🏃🏻‍♀️ Test Release - run: cargo clean && cargo test --verbose --release $RUST_TARGET_FLAG + run: | + cargo test --verbose --release $RUST_TARGET_FLAG + cd ../libcrux-sha3 && cargo test --verbose $RUST_TARGET_FLAG benchmarks: strategy: diff --git a/libcrux-sha3/benches/sha3.rs b/libcrux-sha3/benches/sha3.rs index 0195560aa..93e427551 100644 --- a/libcrux-sha3/benches/sha3.rs +++ b/libcrux-sha3/benches/sha3.rs @@ -19,7 +19,7 @@ pub fn fmt(x: usize) -> String { } macro_rules! impl_comp { - ($fun:ident, $libcrux:expr) => { + ($fun:ident, $libcrux:expr, $neon_fun:ident) => { // Comparing libcrux performance for different payload sizes and other implementations. fn $fun(c: &mut Criterion) { const PAYLOAD_SIZES: [usize; 3] = [128, 1024, 1024 * 1024 * 10]; @@ -43,7 +43,7 @@ macro_rules! impl_comp { }, ); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] group.bench_with_input( BenchmarkId::new("rust version (simd128)", fmt(*payload_size)), payload_size, @@ -51,7 +51,8 @@ macro_rules! impl_comp { b.iter_batched( || randombytes(*payload_size), |payload| { - let _d: [u8; digest_size($libcrux)] = neon::$fun(&payload); + let mut digest = [0u8; digest_size($libcrux)]; + neon::$neon_fun(&mut digest, &payload); }, BatchSize::SmallInput, ) @@ -62,10 +63,10 @@ macro_rules! impl_comp { }; } -impl_comp!(Sha3_224, Algorithm::Sha224); -impl_comp!(Sha3_256, Algorithm::Sha256); -impl_comp!(Sha3_384, Algorithm::Sha384); -impl_comp!(Sha3_512, Algorithm::Sha512); +impl_comp!(Sha3_224, Algorithm::Sha224, sha224); +impl_comp!(Sha3_256, Algorithm::Sha256, sha256); +impl_comp!(Sha3_384, Algorithm::Sha384, sha384); +impl_comp!(Sha3_512, Algorithm::Sha512, sha512); fn benchmarks(c: &mut Criterion) { Sha3_224(c); diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index d6a5bedd7..563dd4716 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -251,10 +251,10 @@ pub mod portable { /// /// Feature `simd128` enables the implementations in this module. pub mod neon { - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] use crate::generic_keccak::keccak; - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] #[inline(always)] fn keccakx2(data: [&[u8]; 2], out: [&mut [u8]; 2]) { keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data, out) @@ -263,9 +263,9 @@ pub mod neon { /// A portable SHA3 224 implementation. #[allow(unused_variables)] pub fn sha224(digest: &mut [u8], data: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] { let mut dummy = [0u8; 28]; keccakx2::<144, 0x06u8>([data, data], [digest, &mut dummy]); @@ -275,9 +275,9 @@ pub mod neon { /// A portable SHA3 256 implementation. #[allow(unused_variables)] pub fn sha256(digest: &mut [u8], data: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] { let mut dummy = [0u8; 32]; keccakx2::<136, 0x06u8>([data, data], [digest, &mut dummy]); @@ -287,9 +287,9 @@ pub mod neon { /// A portable SHA3 384 implementation. #[allow(unused_variables)] pub fn sha384(digest: &mut [u8], data: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] { let mut dummy = [0u8; 48]; keccakx2::<104, 0x06u8>([data, data], [digest, &mut dummy]); @@ -299,9 +299,9 @@ pub mod neon { /// A portable SHA3 512 implementation. #[allow(unused_variables)] pub fn sha512(digest: &mut [u8], data: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] { let mut dummy = [0u8; 64]; keccakx2::<72, 0x06u8>([data, data], [digest, &mut dummy]); @@ -311,9 +311,9 @@ pub mod neon { /// A portable SHAKE128 implementation. #[allow(unused_variables)] pub fn shake128(digest: &mut [u8; LEN], data: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] { let mut dummy = [0u8; LEN]; keccakx2::<168, 0x1fu8>([data, data], [digest, &mut dummy]); @@ -323,9 +323,9 @@ pub mod neon { /// A portable SHAKE256 implementation. #[allow(unused_variables)] pub fn shake256(digest: &mut [u8; LEN], data: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] { let mut dummy = [0u8; LEN]; keccakx2::<136, 0x1fu8>([data, data], [digest, &mut dummy]); @@ -334,7 +334,7 @@ pub mod neon { /// Performing 2 operations in parallel pub mod x2 { - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] use super::*; /// Run SHAKE256 on both inputs in parallel. @@ -343,26 +343,26 @@ pub mod neon { #[allow(unused_variables)] pub fn shake256(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { // TODO: make argument ordering consistent - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); } /// An incremental API to perform 2 operations in parallel pub mod incremental { - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] use crate::generic_keccak::{ absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, }; - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] pub type KeccakState2 = KeccakState<2, core::arch::aarch64::uint64x2_t>; - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] pub type KeccakState2 = [crate::portable::KeccakState1; 2]; pub fn shake128_init() -> KeccakState2 { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation @@ -371,13 +371,13 @@ pub mod neon { // let s1 = KeccakState1::new(); // [s0, s1] // } - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] KeccakState2::new() } #[allow(unused_variables)] pub fn shake128_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation @@ -386,7 +386,7 @@ pub mod neon { // shake128_absorb_final(&mut s0, data0); // shake128_absorb_final(&mut s1, data1); // } - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(s, [data0, data1]); } @@ -396,7 +396,7 @@ pub mod neon { out0: &mut [u8], out1: &mut [u8], ) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation @@ -405,7 +405,7 @@ pub mod neon { // shake128_squeeze_first_three_blocks(&mut s0, out0); // shake128_squeeze_first_three_blocks(&mut s1, out1); // } - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( s, [out0, out1], @@ -418,7 +418,7 @@ pub mod neon { out0: &mut [u8], out1: &mut [u8], ) { - #[cfg(not(feature = "simd128"))] + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation @@ -427,7 +427,7 @@ pub mod neon { // shake128_squeeze_next_block(&mut s0, out0); // shake128_squeeze_next_block(&mut s1, out1); // } - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) } } @@ -445,7 +445,7 @@ pub mod avx2 { /// Performing 4 operations in parallel pub mod x4 { - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] use crate::generic_keccak::keccak; /// Perform 4 SHAKE256 operations in parallel @@ -460,11 +460,11 @@ pub mod avx2 { out2: &mut [u8], out3: &mut [u8], ) { - #[cfg(not(feature = "simd256"))] + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation - // #[cfg(feature = "simd128")] + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] // { // keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); // keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); @@ -475,7 +475,7 @@ pub mod avx2 { // keccakx1::<136, 0x1fu8>([input2], [out2]); // keccakx1::<136, 0x1fu8>([input3], [out3]); // } - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] keccak::<4, core::arch::x86_64::__m256i, 136, 0x1fu8>( [input0, input1, input2, input3], [out0, out1, out2, out3], @@ -484,30 +484,33 @@ pub mod avx2 { /// An incremental API to perform 4 operations in parallel pub mod incremental { - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] use crate::generic_keccak::{ absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, }; - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] pub type KeccakState4 = [crate::neon::x2::incremental::KeccakState2; 2]; - #[cfg(not(any(feature = "simd256", feature = "simd128")))] + #[cfg(not(any( + all(feature = "simd256", target_arch = "x86_64"), + all(feature = "simd128", target_arch = "aarch64") + )))] pub type KeccakState4 = [crate::portable::KeccakState1; 4]; pub fn shake128_init() -> KeccakState4 { - #[cfg(not(feature = "simd256"))] + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation - // #[cfg(feature = "simd128")] + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] // { // let s0 = KeccakState2::new(); // let s1 = KeccakState2::new(); // [s0, s1] // } - // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] // { // let s0 = KeccakState1::new(); // let s1 = KeccakState1::new(); @@ -515,7 +518,7 @@ pub mod avx2 { // let s3 = KeccakState1::new(); // [s0, s1, s2, s3] // } - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] KeccakState4::new() } @@ -527,11 +530,11 @@ pub mod avx2 { data2: &[u8], data3: &[u8], ) { - #[cfg(not(feature = "simd256"))] + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation - // #[cfg(feature = "simd128")] + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] // { // let [mut s0, mut s1] = s; // absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>( @@ -543,7 +546,7 @@ pub mod avx2 { // [data2, data3], // ); // } - // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] // { // let [mut s0, mut s1, mut s2, mut s3] = s; // shake128_absorb_final(&mut s0, data0); @@ -551,7 +554,7 @@ pub mod avx2 { // shake128_absorb_final(&mut s2, data2); // shake128_absorb_final(&mut s3, data3); // } - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] absorb_final::<4, core::arch::x86_64::__m256i, 168, 0x1fu8>( s, [data0, data1, data2, data3], @@ -566,11 +569,11 @@ pub mod avx2 { out2: &mut [u8], out3: &mut [u8], ) { - #[cfg(not(feature = "simd256"))] + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation - // #[cfg(feature = "simd128")] + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] // { // let [mut s0, mut s1] = s; // squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( @@ -582,7 +585,7 @@ pub mod avx2 { // [out2, out3], // ); // } - // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] // { // let [mut s0, mut s1, mut s2, mut s3] = s; // shake128_squeeze_first_three_blocks(&mut s0, out0); @@ -590,7 +593,7 @@ pub mod avx2 { // shake128_squeeze_first_three_blocks(&mut s2, out2); // shake128_squeeze_first_three_blocks(&mut s3, out3); // } - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] squeeze_first_three_blocks::<4, core::arch::x86_64::__m256i, 168>( s, [out0, out1, out2, out3], @@ -605,11 +608,11 @@ pub mod avx2 { out2: &mut [u8], out3: &mut [u8], ) { - #[cfg(not(feature = "simd256"))] + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] unimplemented!("The target architecture does not support neon instructions."); // XXX: These functions could alternatively implement the same with // the portable implementation - // #[cfg(feature = "simd128")] + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] // { // let [mut s0, mut s1] = s; // squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>( @@ -621,7 +624,7 @@ pub mod avx2 { // [out2, out3], // ); // } - // #[cfg(not(any(feature = "simd128", feature = "simd256")))] + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] // { // let [mut s0, mut s1, mut s2, mut s3] = s; // shake128_squeeze_next_block(&mut s0, out0); @@ -629,7 +632,7 @@ pub mod avx2 { // shake128_squeeze_next_block(&mut s2, out2); // shake128_squeeze_next_block(&mut s3, out3); // } - #[cfg(feature = "simd256")] + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>( s, [out0, out1, out2, out3], From b9fa21c6d53dc2b7de2e50d1694af284cb55f769 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 11:48:23 +0200 Subject: [PATCH 52/59] more ci --- .github/workflows/mlkem.yml | 103 +++++++++++++-------------------- .github/workflows/platform.yml | 39 +++++-------- 2 files changed, 53 insertions(+), 89 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 2783a9a17..92e5ca6ed 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -88,17 +88,51 @@ jobs: # if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} # run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu + # Test SHA3 + - name: 🏃🏻‍♀️ Test SHA3 + working-directory: libcrux-sha3 + run: cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Release + working-directory: libcrux-sha3 + run: cargo test --release --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Portable + working-directory: libcrux-sha3 + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Portable Release + working-directory: libcrux-sha3 + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose --release $RUST_TARGET_FLAG + # Test ... - name: 🏃🏻‍♀️ Test run: | cargo test --verbose $RUST_TARGET_FLAG - cd ../libcrux-sha3 && cargo test --verbose $RUST_TARGET_FLAG - name: 🏃🏻‍♀️ Test Release run: | cargo test --verbose --release $RUST_TARGET_FLAG - cd ../libcrux-sha3 && cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Release MacOS +neon + if: ${{ matrix.os == 'macos-latest' }} + run: | + RUSTFLAGS="-C target_feature=+neon" cargo test --verbose --release $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Portable + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Portable Release + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose --release $RUST_TARGET_FLAG benchmarks: strategy: @@ -159,67 +193,10 @@ jobs: # Benchmarks ... - - name: 🏃🏻‍♀️ Benchmarks Windows - if: ${{ matrix.os == 'windows-latest' }} + - name: 🏃🏻‍♀️ Benchmarks run: cargo bench --verbose $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Benchmarks Clang - if: ${{ matrix.os != 'windows-latest' }} - run: CC=clang cargo bench --verbose $RUST_TARGET_FLAG - - platform: - strategy: - fail-fast: false - matrix: - bits: [32, 64] - os: - - macos-13 # Intel mac - - macos-latest # macos-14 m1 - - ubuntu-latest - - windows-latest - exclude: - - bits: 32 - os: "macos-latest" - - bits: 32 - os: "macos-13" - - runs-on: ${{ matrix.os }} - defaults: - run: - shell: bash - working-directory: sys/platform - - steps: - - uses: actions/checkout@v4 - - - name: 🔨 Build - run: cargo build --verbose - - - name: 🏃🏻‍♀️ Test - run: cargo test --verbose -- --nocapture - - - name: 🏃🏻‍♀️ Test Release - run: cargo test --verbose --release -- --nocapture - - - name: 🛠️ Setup Linux x86 - if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + - name: 🏃🏻‍♀️ Benchmarks Portable run: | - rustup target add i686-unknown-linux-gnu - sudo apt-get update - sudo apt-get install -y gcc-multilib g++-multilib - - - name: 🏃🏻‍♀️ Test Linux x86 - if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - run: cargo test --verbose --target i686-unknown-linux-gnu -- --nocapture - - - name: 🏃🏻‍♀️ Test Release Linux x86 - if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - run: cargo test --verbose --release --target i686-unknown-linux-gnu -- --nocapture - - - name: 🏃🏻‍♀️ Test Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: cargo test --verbose --target i686-pc-windows-msvc -- --nocapture - - - name: 🏃🏻‍♀️ Test Release Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: cargo test --verbose --release --target i686-pc-windows-msvc -- --nocapture + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo bench --verbose $RUST_TARGET_FLAG diff --git a/.github/workflows/platform.yml b/.github/workflows/platform.yml index 899e16353..84c4bf71f 100644 --- a/.github/workflows/platform.yml +++ b/.github/workflows/platform.yml @@ -16,18 +16,21 @@ concurrency: cancel-in-progress: true jobs: - build: + platform: strategy: fail-fast: false matrix: bits: [32, 64] os: - - macos-latest + - macos-13 # Intel mac + - macos-latest # macos-14 m1 - ubuntu-latest - windows-latest exclude: - bits: 32 os: "macos-latest" + - bits: 32 + os: "macos-13" runs-on: ${{ matrix.os }} defaults: @@ -42,46 +45,30 @@ jobs: run: cargo build --verbose - name: 🏃🏻‍♀️ Test - run: cargo test --verbose + run: cargo test --verbose -- --nocapture - name: 🏃🏻‍♀️ Test Release - run: cargo test --verbose --release + run: cargo test --verbose --release -- --nocapture - name: 🛠️ Setup Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: | rustup target add i686-unknown-linux-gnu + sudo apt-get update + sudo apt-get install -y gcc-multilib g++-multilib - name: 🏃🏻‍♀️ Test Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - run: cargo test --verbose --target i686-unknown-linux-gnu + run: cargo test --verbose --target i686-unknown-linux-gnu -- --nocapture - name: 🏃🏻‍♀️ Test Release Linux x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - run: cargo test --verbose --release --target i686-unknown-linux-gnu - - - name: 🛠️ Setup Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: | - rustup target add i686-pc-windows-msvc + run: cargo test --verbose --release --target i686-unknown-linux-gnu -- --nocapture - name: 🏃🏻‍♀️ Test Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: cargo test --verbose --target i686-pc-windows-msvc + run: cargo test --verbose --target i686-pc-windows-msvc -- --nocapture - name: 🏃🏻‍♀️ Test Release Windows x86 if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - run: cargo test --verbose --release --target i686-pc-windows-msvc - - - name: 🛠️ Setup MacOS x86_64 - if: ${{ matrix.os == 'macos-latest' }} - run: | - rustup target add x86_64-apple-darwin - - - name: 🏃🏻‍♀️ Test MacOS x86_64 - if: ${{ matrix.os == 'macos-latest' }} - run: cargo test --verbose --target x86_64-apple-darwin - - - name: 🏃🏻‍♀️ Test Release MacOS x86_64 - if: ${{ matrix.os == 'macos-latest' }} - run: cargo test --verbose --release --target x86_64-apple-darwin + run: cargo test --verbose --release --target i686-pc-windows-msvc -- --nocapture From df5b91c66a17bcddbe2fed786bf975f029b99344 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 11:59:52 +0200 Subject: [PATCH 53/59] clean before test --- .github/workflows/mlkem.yml | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 92e5ca6ed..b6425d370 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -89,21 +89,25 @@ jobs: # run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu # Test SHA3 - - name: 🏃🏻‍♀️ Test SHA3 + - name: 🏃🏻‍♀️ SHA3 Test working-directory: libcrux-sha3 - run: cargo test --verbose $RUST_TARGET_FLAG + run: | + cargo clean + cargo test --verbose $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Release + - name: 🏃🏻‍♀️ SHA3 Test Release working-directory: libcrux-sha3 - run: cargo test --release --verbose $RUST_TARGET_FLAG + run: | + cargo clean + cargo test --release --verbose $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Portable + - name: 🏃🏻‍♀️ SHA3 Test Portable working-directory: libcrux-sha3 run: | cargo clean LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Portable Release + - name: 🏃🏻‍♀️ SHA3 Test Portable Release working-directory: libcrux-sha3 run: | cargo clean @@ -113,17 +117,20 @@ jobs: - name: 🏃🏻‍♀️ Test run: | + cargo clean cargo test --verbose $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Release - run: | - cargo test --verbose --release $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Release MacOS +neon if: ${{ matrix.os == 'macos-latest' }} run: | + cargo clean RUSTFLAGS="-C target_feature=+neon" cargo test --verbose --release $RUST_TARGET_FLAG + - name: 🏃🏻‍♀️ Test Release + run: | + cargo clean + cargo test --verbose --release $RUST_TARGET_FLAG + - name: 🏃🏻‍♀️ Test Portable run: | cargo clean From 0240aacc6fdc82d9ec3daf0bf9420498d0937d1a Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 13:30:53 +0200 Subject: [PATCH 54/59] print cfg --- .github/workflows/mlkem.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index b6425d370..d6678e902 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -73,7 +73,9 @@ jobs: # Build ... - name: 🔨 Build - run: cargo build --verbose $RUST_TARGET_FLAG + run: | + rustc --print=cfg + cargo build --verbose $RUST_TARGET_FLAG - name: 🔨 Build Release run: cargo build --verbose --release $RUST_TARGET_FLAG From 9ee7d1d7057cf4c99452945f7e9b4bd9d38ef9c5 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 13:42:32 +0200 Subject: [PATCH 55/59] check sysctl --- .github/workflows/mlkem.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index d6678e902..e6dcb1600 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -75,6 +75,7 @@ jobs: - name: 🔨 Build run: | rustc --print=cfg + sysctl hw cargo build --verbose $RUST_TARGET_FLAG - name: 🔨 Build Release From 6f7f943fb4e43bae278e24a1d420540e0a293f65 Mon Sep 17 00:00:00 2001 From: Goutam Tamvada Date: Fri, 17 May 2024 08:58:52 -0400 Subject: [PATCH 56/59] More safe wrappers around avx2 intrinsics (#283). --- polynomials-avx2/src/intrinsics.rs | 218 ++++++++++ polynomials-avx2/src/lib.rs | 22 +- polynomials-avx2/src/ntt.rs | 366 +++++++--------- polynomials-avx2/src/sampling.rs | 47 +- polynomials-avx2/src/serialize.rs | 670 ++++++++++++++--------------- 5 files changed, 733 insertions(+), 590 deletions(-) diff --git a/polynomials-avx2/src/intrinsics.rs b/polynomials-avx2/src/intrinsics.rs index d28b227c7..93133cc18 100644 --- a/polynomials-avx2/src/intrinsics.rs +++ b/polynomials-avx2/src/intrinsics.rs @@ -3,16 +3,171 @@ pub(crate) use core::arch::x86::*; #[cfg(target_arch = "x86_64")] pub(crate) use core::arch::x86_64::*; +pub(crate) fn mm256_storeu_si256(output: &mut [i16], vector: __m256i) { + debug_assert_eq!(output.len(), 16); + unsafe { + _mm256_storeu_si256(output.as_mut_ptr() as *mut __m256i, vector); + } +} +pub(crate) fn mm_storeu_si128(output: &mut [i16], vector: __m128i) { + debug_assert_eq!(output.len(), 8); + unsafe { + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector); + } +} + +pub(crate) fn mm_storeu_bytes_si128(output: &mut [u8], vector: __m128i) { + debug_assert_eq!(output.len(), 16); + unsafe { + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector); + } +} + +pub(crate) fn mm_loadu_si128(input: &[u8]) -> __m128i { + debug_assert_eq!(input.len(), 16); + unsafe { _mm_loadu_si128(input.as_ptr() as *const __m128i) } +} + +pub(crate) fn mm256_loadu_si256(input: &[i16]) -> __m256i { + debug_assert_eq!(input.len(), 16); + unsafe { _mm256_loadu_si256(input.as_ptr() as *const __m256i) } +} + +pub(crate) fn mm256_setzero_si256() -> __m256i { + unsafe { _mm256_setzero_si256() } +} + +pub(crate) fn mm_set_epi8( + byte15: i8, + byte14: i8, + byte13: i8, + byte12: i8, + byte11: i8, + byte10: i8, + byte9: i8, + byte8: i8, + byte7: i8, + byte6: i8, + byte5: i8, + byte4: i8, + byte3: i8, + byte2: i8, + byte1: i8, + byte0: i8, +) -> __m128i { + unsafe { + _mm_set_epi8( + byte15, byte14, byte13, byte12, byte11, byte10, + byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0, + ) + } +} + +pub(crate) fn mm256_set_epi8( + byte31: i8, + byte30: i8, + byte29: i8, + byte28: i8, + byte27: i8, + byte26: i8, + byte25: i8, + byte24: i8, + byte23: i8, + byte22: i8, + byte21: i8, + byte20: i8, + byte19: i8, + byte18: i8, + byte17: i8, + byte16: i8, + byte15: i8, + byte14: i8, + byte13: i8, + byte12: i8, + byte11: i8, + byte10: i8, + byte9: i8, + byte8: i8, + byte7: i8, + byte6: i8, + byte5: i8, + byte4: i8, + byte3: i8, + byte2: i8, + byte1: i8, + byte0: i8, +) -> __m256i { + unsafe { + _mm256_set_epi8( + byte31, byte30, byte29, byte28, byte27, byte26, byte25, byte24, byte23, byte22, byte21, + byte20, byte19, byte18, byte17, byte16, byte15, byte14, byte13, byte12, byte11, byte10, + byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0, + ) + } +} + pub(crate) fn mm256_set1_epi16(constant: i16) -> __m256i { unsafe { _mm256_set1_epi16(constant) } } +pub(crate) fn mm256_set_epi16( + input15: i16, + input14: i16, + input13: i16, + input12: i16, + input11: i16, + input10: i16, + input9: i16, + input8: i16, + input7: i16, + input6: i16, + input5: i16, + input4: i16, + input3: i16, + input2: i16, + input1: i16, + input0: i16, +) -> __m256i { + unsafe { + _mm256_set_epi16( + input15, input14, input13, input12, input11, input10, input9, input8, input7, input6, + input5, input4, input3, input2, input1, input0, + ) + } +} + +pub(crate) fn mm_set1_epi16(constant: i16) -> __m128i { + unsafe { _mm_set1_epi16(constant) } +} + pub(crate) fn mm256_set1_epi32(constant: i32) -> __m256i { unsafe { _mm256_set1_epi32(constant) } } +pub(crate) fn mm256_set_epi32( + input7: i32, + input6: i32, + input5: i32, + input4: i32, + input3: i32, + input2: i32, + input1: i32, + input0: i32, +) -> __m256i { + unsafe { + _mm256_set_epi32( + input7, input6, input5, input4, input3, input2, input1, input0, + ) + } +} +pub(crate) fn mm_add_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_add_epi16(lhs, rhs) } +} pub(crate) fn mm256_add_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_add_epi16(lhs, rhs) } } +pub(crate) fn mm256_madd_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_madd_epi16(lhs, rhs) } +} pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_add_epi32(lhs, rhs) } } @@ -20,10 +175,26 @@ pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { pub(crate) fn mm256_sub_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_sub_epi16(lhs, rhs) } } +pub(crate) fn mm_sub_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_sub_epi16(lhs, rhs) } +} pub(crate) fn mm256_mullo_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_mullo_epi16(lhs, rhs) } } + +pub(crate) fn mm_mullo_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_mullo_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_cmpgt_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_cmpgt_epi16(lhs, rhs) } +} + +pub(crate) fn mm_mulhi_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_mulhi_epi16(lhs, rhs) } +} + pub(crate) fn mm256_mullo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_mullo_epi32(lhs, rhs) } } @@ -48,6 +219,11 @@ pub(crate) fn mm256_srai_epi16(vector: __m256i) -> __m256i debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_srai_epi16(vector, SHIFT_BY) } } +pub(crate) fn mm256_srai_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_srai_epi32(vector, SHIFT_BY) } +} + pub(crate) fn mm256_srli_epi16(vector: __m256i) -> __m256i { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_srli_epi16(vector, SHIFT_BY) } @@ -57,6 +233,11 @@ pub(crate) fn mm256_srli_epi32(vector: __m256i) -> __m256i unsafe { _mm256_srli_epi32(vector, SHIFT_BY) } } +pub(crate) fn mm256_srli_epi64(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); + unsafe { _mm256_srli_epi64(vector, SHIFT_BY) } +} + pub(crate) fn mm256_slli_epi16(vector: __m256i) -> __m256i { debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); unsafe { _mm256_slli_epi16(vector, SHIFT_BY) } @@ -67,6 +248,12 @@ pub(crate) fn mm256_slli_epi32(vector: __m256i) -> __m256i unsafe { _mm256_slli_epi32(vector, SHIFT_BY) } } +pub(crate) fn mm_shuffle_epi8(vector: __m128i, control: __m128i) -> __m128i { + unsafe { _mm_shuffle_epi8(vector, control) } +} +pub(crate) fn mm256_shuffle_epi8(vector: __m256i, control: __m256i) -> __m256i { + unsafe { _mm256_shuffle_epi8(vector, control) } +} pub(crate) fn mm256_shuffle_epi32(vector: __m256i) -> __m256i { debug_assert!(CONTROL >= 0 && CONTROL < 256); unsafe { _mm256_shuffle_epi32(vector, CONTROL) } @@ -92,11 +279,17 @@ pub(crate) fn mm256_unpackhi_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { pub(crate) fn mm256_castsi256_si128(vector: __m256i) -> __m128i { unsafe { _mm256_castsi256_si128(vector) } } +pub(crate) fn mm256_castsi128_si256(vector: __m128i) -> __m256i { + unsafe { _mm256_castsi128_si256(vector) } +} pub(crate) fn mm256_cvtepi16_epi32(vector: __m128i) -> __m256i { unsafe { _mm256_cvtepi16_epi32(vector) } } +pub(crate) fn mm_packs_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_packs_epi16(lhs, rhs) } +} pub(crate) fn mm256_packs_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { unsafe { _mm256_packs_epi32(lhs, rhs) } } @@ -105,3 +298,28 @@ pub(crate) fn mm256_extracti128_si256(vector: __m256i) -> __ debug_assert!(CONTROL == 0 || CONTROL == 1); unsafe { _mm256_extracti128_si256(vector, CONTROL) } } + +pub(crate) fn mm256_inserti128_si256( + vector: __m256i, + vector_i128: __m128i, +) -> __m256i { + debug_assert!(CONTROL == 0 || CONTROL == 1); + unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) } +} + +pub(crate) fn mm256_blend_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_blend_epi16(lhs, rhs, CONTROL) } +} + +pub(crate) fn mm_movemask_epi8(vector: __m128i) -> i32 { + unsafe { _mm_movemask_epi8(vector) } +} + +pub(crate) fn mm256_permutevar8x32_epi32(vector: __m256i, control: __m256i) -> __m256i { + unsafe { _mm256_permutevar8x32_epi32(vector, control) } +} + +pub(crate) fn mm256_sllv_epi32(vector: __m256i, counts: __m256i) -> __m256i { + unsafe { _mm256_sllv_epi32(vector, counts) } +} diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 52de02fd0..e24519afd 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -1,7 +1,4 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; use libcrux_traits::Operations; #[cfg(test)] @@ -24,24 +21,21 @@ pub struct SIMD256Vector { #[inline(always)] fn zero() -> SIMD256Vector { SIMD256Vector { - elements: unsafe { _mm256_setzero_si256() }, + elements: mm256_setzero_si256(), } } #[inline(always)] fn to_i16_array(v: SIMD256Vector) -> [i16; 16] { - let mut out = [0i16; 16]; + let mut output = [0i16; 16]; + mm256_storeu_si256(&mut output[..], v.elements); - unsafe { - _mm256_storeu_si256(out.as_mut_ptr() as *mut __m256i, v.elements); - } - - out + output } #[inline(always)] fn from_i16_array(array: &[i16]) -> SIMD256Vector { SIMD256Vector { - elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) }, + elements: mm256_loadu_si256(array), } } @@ -187,9 +181,9 @@ impl Operations for SIMD256Vector { serialize::serialize_1(vector.elements) } - fn deserialize_1(input: &[u8]) -> Self { + fn deserialize_1(bytes: &[u8]) -> Self { Self { - elements: serialize::deserialize_1(input), + elements: serialize::deserialize_1(bytes), } } diff --git a/polynomials-avx2/src/ntt.rs b/polynomials-avx2/src/ntt.rs index 28377dbfc..2ebb1561d 100644 --- a/polynomials-avx2/src/ntt.rs +++ b/polynomials-avx2/src/ntt.rs @@ -1,210 +1,172 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; use crate::arithmetic; use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; #[inline(always)] -fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { - v = unsafe { - let value_low = _mm256_mullo_epi16(v, c); +fn montgomery_multiply_by_constants(v: __m256i, c: __m256i) -> __m256i { + let value_low = mm256_mullo_epi16(v, c); - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); + let k = mm256_mullo_epi16( + value_low, + mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS)); - let value_high = _mm256_mulhi_epi16(v, c); + let value_high = mm256_mulhi_epi16(v, c); - _mm256_sub_epi16(value_high, k_times_modulus) - }; - - v + mm256_sub_epi16(value_high, k_times_modulus) } #[inline(always)] -fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { - v = unsafe { - let k = _mm256_mullo_epi16( - v, - _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi32(FIELD_MODULUS as i32)); +fn montgomery_reduce_i32s(v: __m256i) -> __m256i { + let k = mm256_mullo_epi16( + v, + mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi32(FIELD_MODULUS as i32)); - let value_high = _mm256_srli_epi32(v, 16); + let value_high = mm256_srli_epi32::<16>(v); - let result = _mm256_sub_epi16(value_high, k_times_modulus); + let result = mm256_sub_epi16(value_high, k_times_modulus); - let result = _mm256_slli_epi32(result, 16); - _mm256_srai_epi32(result, 16) - }; + let result = mm256_slli_epi32::<16>(result); - v + mm256_srai_epi32::<16>(result) } #[inline(always)] -fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { - v = unsafe { - let value_low = _mm_mullo_epi16(v, c); - - let k = _mm_mullo_epi16( - value_low, - _mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm_mulhi_epi16(k, _mm_set1_epi16(FIELD_MODULUS)); +fn montgomery_multiply_m128i_by_constants(v: __m128i, c: __m128i) -> __m128i { + let value_low = mm_mullo_epi16(v, c); - let value_high = _mm_mulhi_epi16(v, c); + let k = mm_mullo_epi16( + value_low, + mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm_mulhi_epi16(k, mm_set1_epi16(FIELD_MODULUS)); - _mm_sub_epi16(value_high, k_times_modulus) - }; + let value_high = mm_mulhi_epi16(v, c); - v + mm_sub_epi16(value_high, k_times_modulus) } #[inline(always)] pub(crate) fn ntt_layer_1_step( - mut vector: __m256i, + vector: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, ) -> __m256i { - vector = unsafe { - let zetas = _mm256_set_epi16( - -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, - zeta1, -zeta0, -zeta0, zeta0, zeta0, - ); + let zetas = mm256_set_epi16( + -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, zeta1, + -zeta0, -zeta0, zeta0, zeta0, + ); - let rhs = _mm256_shuffle_epi32(vector, 0b11_11_01_01); - let rhs = montgomery_multiply_by_constants(rhs, zetas); + let rhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector); + let rhs = montgomery_multiply_by_constants(rhs, zetas); - let lhs = _mm256_shuffle_epi32(vector, 0b10_10_00_00); + let lhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector); - _mm256_add_epi16(lhs, rhs) - }; - - vector + mm256_add_epi16(lhs, rhs) } #[inline(always)] -pub(crate) fn ntt_layer_2_step(mut vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { - vector = unsafe { - let zetas = _mm256_set_epi16( - -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, - -zeta0, zeta0, zeta0, zeta0, zeta0, - ); - - let rhs = _mm256_shuffle_epi32(vector, 0b11_10_11_10); - let rhs = montgomery_multiply_by_constants(rhs, zetas); +pub(crate) fn ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + let zetas = mm256_set_epi16( + -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, -zeta0, + zeta0, zeta0, zeta0, zeta0, + ); - let lhs = _mm256_shuffle_epi32(vector, 0b01_00_01_00); + let rhs = mm256_shuffle_epi32::<0b11_10_11_10>(vector); + let rhs = montgomery_multiply_by_constants(rhs, zetas); - _mm256_add_epi16(lhs, rhs) - }; + let lhs = mm256_shuffle_epi32::<0b01_00_01_00>(vector); - vector + mm256_add_epi16(lhs, rhs) } #[inline(always)] -pub(crate) fn ntt_layer_3_step(mut vector: __m256i, zeta: i16) -> __m256i { - vector = unsafe { - let rhs = _mm256_extracti128_si256(vector, 1); - let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); +pub(crate) fn ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i { + let rhs = mm256_extracti128_si256::<1>(vector); + let rhs = montgomery_multiply_m128i_by_constants(rhs, mm_set1_epi16(zeta)); - let lhs = _mm256_castsi256_si128(vector); + let lhs = mm256_castsi256_si128(vector); - let lower_coefficients = _mm_add_epi16(lhs, rhs); - let upper_coefficients = _mm_sub_epi16(lhs, rhs); + let lower_coefficients = mm_add_epi16(lhs, rhs); + let upper_coefficients = mm_sub_epi16(lhs, rhs); - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + let combined = mm256_castsi128_si256(lower_coefficients); + let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients); - combined - }; - - vector + combined } #[inline(always)] pub(crate) fn inv_ntt_layer_1_step( - mut vector: __m256i, + vector: __m256i, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16, ) -> __m256i { - vector = unsafe { - let lhs = _mm256_shuffle_epi32(vector, 0b11_11_01_01); - - let rhs = _mm256_shuffle_epi32(vector, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), - ); + let lhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector); - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, - ), - ); + let rhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector); + let rhs = mm256_mullo_epi16( + rhs, + mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), + ); - let sum = arithmetic::barrett_reduce(sum); + let sum = mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + mm256_set_epi16( + zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, + ), + ); - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) - }; + let sum = arithmetic::barrett_reduce(sum); - vector + mm256_blend_epi16::<0b1_1_0_0_1_1_0_0>(sum, sum_times_zetas) } #[inline(always)] -pub(crate) fn inv_ntt_layer_2_step(mut vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { - vector = unsafe { - let lhs = _mm256_permute4x64_epi64(vector, 0b11_11_01_01); - - let rhs = _mm256_permute4x64_epi64(vector, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), - ); - - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, - ), - ); - - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) - }; - - vector +pub(crate) fn inv_ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + let lhs = mm256_permute4x64_epi64::<0b11_11_01_01>(vector); + + let rhs = mm256_permute4x64_epi64::<0b10_10_00_00>(vector); + let rhs = mm256_mullo_epi16( + rhs, + mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), + ); + + let sum = mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + mm256_set_epi16( + zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, + ), + ); + + mm256_blend_epi16::<0b1_1_1_1_0_0_0_0>(sum, sum_times_zetas) } #[inline(always)] -pub(crate) fn inv_ntt_layer_3_step(mut vector: __m256i, zeta: i16) -> __m256i { - vector = unsafe { - let lhs = _mm256_extracti128_si256(vector, 1); - let rhs = _mm256_castsi256_si128(vector); - - let lower_coefficients = _mm_add_epi16(lhs, rhs); +pub(crate) fn inv_ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i { + let lhs = mm256_extracti128_si256::<1>(vector); + let rhs = mm256_castsi256_si128(vector); - let upper_coefficients = _mm_sub_epi16(lhs, rhs); - let upper_coefficients = - montgomery_multiply_m128i_by_constants(upper_coefficients, _mm_set1_epi16(zeta)); + let lower_coefficients = mm_add_epi16(lhs, rhs); - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); + let upper_coefficients = mm_sub_epi16(lhs, rhs); + let upper_coefficients = + montgomery_multiply_m128i_by_constants(upper_coefficients, mm_set1_epi16(zeta)); - combined - }; + let combined = mm256_castsi128_si256(lower_coefficients); + let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients); - vector + combined } #[inline(always)] @@ -216,69 +178,67 @@ pub(crate) fn ntt_multiply( zeta2: i16, zeta3: i16, ) -> __m256i { - return unsafe { - // Compute the first term of the product - let shuffle_with = _mm256_set_epi8( - 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, - 12, 9, 8, 5, 4, 1, 0, - ); - const PERMUTE_WITH: i32 = 0b11_01_10_00; - - // Prepare the left hand side - let lhs_shuffled = _mm256_shuffle_epi8(lhs, shuffle_with); - let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); - - let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); - let lhs_evens = _mm256_cvtepi16_epi32(lhs_evens); - - let lhs_odds = _mm256_extracti128_si256(lhs_shuffled, 1); - let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); - - // Prepare the right hand side - let rhs_shuffled = _mm256_shuffle_epi8(rhs, shuffle_with); - let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); - - let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); - let rhs_evens = _mm256_cvtepi16_epi32(rhs_evens); - - let rhs_odds = _mm256_extracti128_si256(rhs_shuffled, 1); - let rhs_odds = _mm256_cvtepi16_epi32(rhs_odds); - - // Start operating with them - let left = _mm256_mullo_epi32(lhs_evens, rhs_evens); - - let right = _mm256_mullo_epi32(lhs_odds, rhs_odds); - let right = montgomery_reduce_i32s(right); - let right = _mm256_mullo_epi32( - right, - _mm256_set_epi32( - -(zeta3 as i32), - zeta3 as i32, - -(zeta2 as i32), - zeta2 as i32, - -(zeta1 as i32), - zeta1 as i32, - -(zeta0 as i32), - zeta0 as i32, - ), - ); - - let products_left = _mm256_add_epi32(left, right); - let products_left = montgomery_reduce_i32s(products_left); - - // Compute the second term of the product - let rhs_adjacent_swapped = _mm256_shuffle_epi8( - rhs, - _mm256_set_epi8( - 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, - 5, 4, 7, 6, 1, 0, 3, 2, - ), - ); - let products_right = _mm256_madd_epi16(lhs, rhs_adjacent_swapped); - let products_right = montgomery_reduce_i32s(products_right); - let products_right = _mm256_slli_epi32(products_right, 16); - - // Combine them into one vector - _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) - }; + // Compute the first term of the product + let shuffle_with = mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, + 9, 8, 5, 4, 1, 0, + ); + const PERMUTE_WITH: i32 = 0b11_01_10_00; + + // Prepare the left hand side + let lhs_shuffled = mm256_shuffle_epi8(lhs, shuffle_with); + let lhs_shuffled = mm256_permute4x64_epi64::<{ PERMUTE_WITH }>(lhs_shuffled); + + let lhs_evens = mm256_castsi256_si128(lhs_shuffled); + let lhs_evens = mm256_cvtepi16_epi32(lhs_evens); + + let lhs_odds = mm256_extracti128_si256::<1>(lhs_shuffled); + let lhs_odds = mm256_cvtepi16_epi32(lhs_odds); + + // Prepare the right hand side + let rhs_shuffled = mm256_shuffle_epi8(rhs, shuffle_with); + let rhs_shuffled = mm256_permute4x64_epi64::<{ PERMUTE_WITH }>(rhs_shuffled); + + let rhs_evens = mm256_castsi256_si128(rhs_shuffled); + let rhs_evens = mm256_cvtepi16_epi32(rhs_evens); + + let rhs_odds = mm256_extracti128_si256::<1>(rhs_shuffled); + let rhs_odds = mm256_cvtepi16_epi32(rhs_odds); + + // Start operating with them + let left = mm256_mullo_epi32(lhs_evens, rhs_evens); + + let right = mm256_mullo_epi32(lhs_odds, rhs_odds); + let right = montgomery_reduce_i32s(right); + let right = mm256_mullo_epi32( + right, + mm256_set_epi32( + -(zeta3 as i32), + zeta3 as i32, + -(zeta2 as i32), + zeta2 as i32, + -(zeta1 as i32), + zeta1 as i32, + -(zeta0 as i32), + zeta0 as i32, + ), + ); + + let products_left = mm256_add_epi32(left, right); + let products_left = montgomery_reduce_i32s(products_left); + + // Compute the second term of the product + let rhs_adjacent_swapped = mm256_shuffle_epi8( + rhs, + mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5, + 4, 7, 6, 1, 0, 3, 2, + ), + ); + let products_right = mm256_madd_epi16(lhs, rhs_adjacent_swapped); + let products_right = montgomery_reduce_i32s(products_right); + let products_right = mm256_slli_epi32::<16>(products_right); + + // Combine them into one vector + mm256_blend_epi16::<0b1_0_1_0_1_0_1_0>(products_left, products_right) } diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs index aa6cc6ca6..40542efec 100644 --- a/polynomials-avx2/src/sampling.rs +++ b/polynomials-avx2/src/sampling.rs @@ -1,7 +1,4 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; use crate::serialize::{deserialize_12, serialize_1}; use libcrux_traits::FIELD_MODULUS; @@ -756,34 +753,30 @@ const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ #[inline(always)] pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { - let count = unsafe { - let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); + let field_modulus = mm256_set1_epi16(FIELD_MODULUS); - let potential_coefficients = deserialize_12(input); + let potential_coefficients = deserialize_12(input); - let compare_with_field_modulus = _mm256_cmpgt_epi16(field_modulus, potential_coefficients); - let good = serialize_1(compare_with_field_modulus); + let compare_with_field_modulus = mm256_cmpgt_epi16(field_modulus, potential_coefficients); + let good = serialize_1(compare_with_field_modulus); - let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; - let lower_shuffles = _mm_loadu_si128(lower_shuffles.as_ptr() as *const __m128i); - let lower_coefficients = _mm256_castsi256_si128(potential_coefficients); - let lower_coefficients = _mm_shuffle_epi8(lower_coefficients, lower_shuffles); + let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; + let lower_shuffles = mm_loadu_si128(&lower_shuffles); + let lower_coefficients = mm256_castsi256_si128(potential_coefficients); + let lower_coefficients = mm_shuffle_epi8(lower_coefficients, lower_shuffles); - _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, lower_coefficients); - let sampled_count = good[0].count_ones(); + mm_storeu_si128(&mut output[0..8], lower_coefficients); + let sampled_count = good[0].count_ones() as usize; - let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; - let upper_shuffles = _mm_loadu_si128(upper_shuffles.as_ptr() as *const __m128i); - let upper_coefficients = _mm256_extractf128_si256(potential_coefficients, 1); - let upper_coefficients = _mm_shuffle_epi8(upper_coefficients, upper_shuffles); + let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; + let upper_shuffles = mm_loadu_si128(&upper_shuffles); + let upper_coefficients = mm256_extracti128_si256::<1>(potential_coefficients); + let upper_coefficients = mm_shuffle_epi8(upper_coefficients, upper_shuffles); - _mm_storeu_si128( - output.as_mut_ptr().offset(sampled_count as isize) as *mut __m128i, - upper_coefficients, - ); + mm_storeu_si128( + &mut output[sampled_count..sampled_count + 8], + upper_coefficients, + ); - sampled_count + good[1].count_ones() - }; - - count as usize + sampled_count + (good[1].count_ones() as usize) } diff --git a/polynomials-avx2/src/serialize.rs b/polynomials-avx2/src/serialize.rs index 39b75ea2c..3483c7afb 100644 --- a/polynomials-avx2/src/serialize.rs +++ b/polynomials-avx2/src/serialize.rs @@ -1,23 +1,17 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; +use crate::intrinsics::*; -use crate::portable; -use crate::SIMD256Vector; +use crate::{portable, SIMD256Vector}; #[inline(always)] pub(crate) fn serialize_1(vector: __m256i) -> [u8; 2] { - let bits_packed = unsafe { - let lsb_shifted_up = _mm256_slli_epi16(vector, 15); + let lsb_shifted_up = mm256_slli_epi16::<15>(vector); - let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); - let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); + let low_lanes = mm256_castsi256_si128(lsb_shifted_up); + let high_lanes = mm256_extracti128_si256::<1>(lsb_shifted_up); - let msbs = _mm_packs_epi16(low_lanes, high_lanes); + let msbs = mm_packs_epi16(low_lanes, high_lanes); - _mm_movemask_epi8(msbs) - }; + let bits_packed = mm_movemask_epi8(msbs); let mut serialized = [0u8; 2]; serialized[0] = bits_packed as u8; @@ -28,193 +22,185 @@ pub(crate) fn serialize_1(vector: __m256i) -> [u8; 2] { #[inline(always)] pub(crate) fn deserialize_1(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsb_to_msb = _mm256_set_epi16( - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - ); - - let coefficients = _mm256_set_epi16( - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsb_to_msb); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 7); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) - }; + let shift_lsb_to_msb = mm256_set_epi16( + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + ); + + let coefficients = mm256_set_epi16( + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsb_to_msb); + let coefficients_in_lsb = mm256_srli_epi16::<7>(coefficients_in_msb); + + mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 1) - 1)) } #[inline(always)] pub(crate) fn serialize_4(vector: __m256i) -> [u8; 8] { let mut serialized = [0u8; 16]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - ), - ); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_2_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, - ), - ); - - let combined = _mm256_permutevar8x32_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), - ); - let combined = _mm256_castsi256_si128(combined); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, combined); - } - - serialized[0..8].try_into().unwrap() -} - -#[inline(always)] -pub(crate) fn deserialize_4(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - 1 << 0, + 1, 1 << 4, - ); - - let coefficients = _mm256_set_epi16( - bytes[7] as i16, - bytes[7] as i16, - bytes[6] as i16, - bytes[6] as i16, - bytes[5] as i16, - bytes[5] as i16, - bytes[4] as i16, - bytes[4] as i16, - bytes[3] as i16, - bytes[3] as i16, - bytes[2] as i16, - bytes[2] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 4); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) - }; + 1, + ), + ); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_2_combined, + mm256_set_epi8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, + ), + ); + + let combined = mm256_permutevar8x32_epi32( + adjacent_8_combined, + mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), + ); + let combined = mm256_castsi256_si128(combined); + + mm_storeu_bytes_si128(&mut serialized[..], combined); + + serialized[0..8].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_4(bytes: &[u8]) -> __m256i { + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let coefficients = mm256_set_epi16( + bytes[7] as i16, + bytes[7] as i16, + bytes[6] as i16, + bytes[6] as i16, + bytes[5] as i16, + bytes[5] as i16, + bytes[4] as i16, + bytes[4] as i16, + bytes[3] as i16, + bytes[3] as i16, + bytes[2] as i16, + bytes[2] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients_in_lsb = mm256_srli_epi16::<4>(coefficients_in_msb); + + mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 4) - 1)) } #[inline(always)] pub(crate) fn serialize_5(vector: __m256i) -> [u8; 10] { let mut serialized = [0u8; 32]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 22); - - let adjacent_8_combined = _mm256_shuffle_epi32(adjacent_4_combined, 0b00_00_10_00); - let adjacent_8_combined = _mm256_sllv_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_8_combined = _mm256_srli_epi64(adjacent_8_combined, 12); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(5) as *mut __m128i, upper_8); - } + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), + ); + let adjacent_4_combined = mm256_srli_epi64::<22>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi32::<0b00_00_10_00>(adjacent_4_combined); + let adjacent_8_combined = mm256_sllv_epi32( + adjacent_8_combined, + mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_8_combined = mm256_srli_epi64::<12>(adjacent_8_combined); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[5..21], upper_8); serialized[0..10].try_into().unwrap() } @@ -230,95 +216,91 @@ pub(crate) fn deserialize_5(bytes: &[u8]) -> __m256i { pub(crate) fn serialize_10(vector: __m256i) -> [u8; 20] { let mut serialized = [0u8; 32]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 12); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, - 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(10) as *mut __m128i, upper_8); - } + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_4_combined = mm256_srli_epi64::<12>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_4_combined, + mm256_set_epi8( + -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, + 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[10..26], upper_8); serialized[0..20].try_into().unwrap() } #[inline(always)] pub(crate) fn deserialize_10(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - ); - - let lower_coefficients = _mm_loadu_si128(bytes.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(bytes.as_ptr().offset(4) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 6); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 10) - 1)); - - coefficients - }; + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + ); + + let lower_coefficients = mm_loadu_si128(bytes[0..16].try_into().unwrap()); + let lower_coefficients = mm_shuffle_epi8( + lower_coefficients, + mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), + ); + let upper_coefficients = mm_loadu_si128(bytes[4..20].try_into().unwrap()); + let upper_coefficients = mm_shuffle_epi8( + upper_coefficients, + mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), + ); + + let coefficients = mm256_castsi128_si256(lower_coefficients); + let coefficients = mm256_inserti128_si256::<1>(coefficients, upper_coefficients); + + let coefficients = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = mm256_srli_epi16::<6>(coefficients); + let coefficients = mm256_and_si256(coefficients, mm256_set1_epi16((1 << 10) - 1)); + + coefficients } #[inline(always)] @@ -339,93 +321,89 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> __m256i { pub(crate) fn serialize_12(vector: __m256i) -> [u8; 24] { let mut serialized = [0u8; 32]; - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - vector, - _mm256_set_epi16( - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 8); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, - 10, 9, 8, 5, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(12) as *mut __m128i, upper_8); - } + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), + ); + let adjacent_4_combined = mm256_srli_epi64::<8>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_4_combined, + mm256_set_epi8( + -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, + 10, 9, 8, 5, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[12..28], upper_8); serialized[0..24].try_into().unwrap() } #[inline(always)] pub(crate) fn deserialize_12(bytes: &[u8]) -> __m256i { - return unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - - let lower_coefficients = _mm_loadu_si128(bytes.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(bytes.as_ptr().offset(8) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 4); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 12) - 1)); - - coefficients - }; + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let lower_coefficients = mm_loadu_si128(bytes[0..16].try_into().unwrap()); + let lower_coefficients = mm_shuffle_epi8( + lower_coefficients, + mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), + ); + let upper_coefficients = mm_loadu_si128(bytes[8..24].try_into().unwrap()); + let upper_coefficients = mm_shuffle_epi8( + upper_coefficients, + mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), + ); + + let coefficients = mm256_castsi128_si256(lower_coefficients); + let coefficients = mm256_inserti128_si256::<1>(coefficients, upper_coefficients); + + let coefficients = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = mm256_srli_epi16::<4>(coefficients); + let coefficients = mm256_and_si256(coefficients, mm256_set1_epi16((1 << 12) - 1)); + + coefficients } From e2592de574c230a0fedcd458bc199fc4be68ddd1 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 15:25:33 +0200 Subject: [PATCH 57/59] disable macos release tests for now --- .github/workflows/mlkem.yml | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index e6dcb1600..13c639a43 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -75,7 +75,6 @@ jobs: - name: 🔨 Build run: | rustc --print=cfg - sysctl hw cargo build --verbose $RUST_TARGET_FLAG - name: 🔨 Build Release @@ -123,13 +122,8 @@ jobs: cargo clean cargo test --verbose $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Release MacOS +neon - if: ${{ matrix.os == 'macos-latest' }} - run: | - cargo clean - RUSTFLAGS="-C target_feature=+neon" cargo test --verbose --release $RUST_TARGET_FLAG - - name: 🏃🏻‍♀️ Test Release + if: ${{ matrix.os != 'macos-latest' }} run: | cargo clean cargo test --verbose --release $RUST_TARGET_FLAG From 13de53dd1ffe9164aba3da9fb1dafc41a293b9b5 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 15:37:57 +0200 Subject: [PATCH 58/59] disable hax extraction for now on ci --- .github/workflows/hax.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/hax.yml b/.github/workflows/hax.yml index c0c3c14f7..9f7b6fc2d 100644 --- a/.github/workflows/hax.yml +++ b/.github/workflows/hax.yml @@ -92,7 +92,7 @@ jobs: - name: 🏃 Extract & Verify ML-KEM crate (lax) run: | cd libcrux-ml-kem - ./hax.py extract + # ./hax.py extract # env FSTAR_HOME=${{ github.workspace }}/fstar \ # HACL_HOME=${{ github.workspace }}/hacl-star \ # HAX_HOME=${{ github.workspace }}/hax \ From c0ebb12f65c1e53ef0f6bfac12cf1d68b9880695 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Fri, 17 May 2024 15:39:04 +0200 Subject: [PATCH 59/59] rustfmt --- polynomials-avx2/src/intrinsics.rs | 4 ++-- polynomials-avx2/src/serialize.rs | 24 ++++++++++-------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/polynomials-avx2/src/intrinsics.rs b/polynomials-avx2/src/intrinsics.rs index 93133cc18..db7981ae4 100644 --- a/polynomials-avx2/src/intrinsics.rs +++ b/polynomials-avx2/src/intrinsics.rs @@ -57,8 +57,8 @@ pub(crate) fn mm_set_epi8( ) -> __m128i { unsafe { _mm_set_epi8( - byte15, byte14, byte13, byte12, byte11, byte10, - byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0, + byte15, byte14, byte13, byte12, byte11, byte10, byte9, byte8, byte7, byte6, byte5, + byte4, byte3, byte2, byte1, byte0, ) } } diff --git a/polynomials-avx2/src/serialize.rs b/polynomials-avx2/src/serialize.rs index 3483c7afb..7e5303b01 100644 --- a/polynomials-avx2/src/serialize.rs +++ b/polynomials-avx2/src/serialize.rs @@ -95,15 +95,13 @@ pub(crate) fn serialize_4(vector: __m256i) -> [u8; 8] { let adjacent_8_combined = mm256_shuffle_epi8( adjacent_2_combined, mm256_set_epi8( - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, ), ); - let combined = mm256_permutevar8x32_epi32( - adjacent_8_combined, - mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), - ); + let combined = + mm256_permutevar8x32_epi32(adjacent_8_combined, mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0)); let combined = mm256_castsi256_si128(combined); mm_storeu_bytes_si128(&mut serialized[..], combined); @@ -247,8 +245,8 @@ pub(crate) fn serialize_10(vector: __m256i) -> [u8; 20] { let adjacent_8_combined = mm256_shuffle_epi8( adjacent_4_combined, mm256_set_epi8( - -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, - 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, + -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, + 11, 10, 9, 8, 4, 3, 2, 1, 0, ), ); @@ -343,17 +341,15 @@ pub(crate) fn serialize_12(vector: __m256i) -> [u8; 24] { ), ); - let adjacent_4_combined = mm256_sllv_epi32( - adjacent_2_combined, - mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), - ); + let adjacent_4_combined = + mm256_sllv_epi32(adjacent_2_combined, mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8)); let adjacent_4_combined = mm256_srli_epi64::<8>(adjacent_4_combined); let adjacent_8_combined = mm256_shuffle_epi8( adjacent_4_combined, mm256_set_epi8( - -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, - 10, 9, 8, 5, 4, 3, 2, 1, 0, + -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, 10, + 9, 8, 5, 4, 3, 2, 1, 0, ), );