Skip to content

Commit

Permalink
move neon hash functions into sha3 for ml-kem
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed May 22, 2024
1 parent c90a3d0 commit da5ed9d
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 147 deletions.
145 changes: 4 additions & 141 deletions libcrux-ml-kem/src/hash_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,163 +368,26 @@ pub(crate) mod neon {
fn PRFxN<const LEN: usize>(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] {
debug_assert!(K == 2 || K == 3 || K == 4);

let mut out = [[0u8; LEN]; K];
match K {
2 => {
let (out0, out1) = out.split_at_mut(1);
x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]);
}
3 => {
let mut extra = [0u8; LEN];
let (out0, out12) = out.split_at_mut(1);
let (out1, out2) = out12.split_at_mut(1);
x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]);
x2::shake256(&input[2], &input[2], &mut out2[0], &mut extra);
}
4 => {
let (out0, out123) = out.split_at_mut(1);
let (out1, out23) = out123.split_at_mut(1);
let (out2, out3) = out23.split_at_mut(1);
x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]);
x2::shake256(&input[2], &input[3], &mut out2[0], &mut out3[0]);
}
_ => unreachable!(),
}
out
x2::shake256xN(input)
}

fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self {
debug_assert!(K == 2 || K == 3 || K == 4);
let mut state = [
x2::incremental::shake128_init(),
x2::incremental::shake128_init(),
];
let state = x2::incremental::shake128_absorb_finalxN(input);

match K {
2 => {
x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]);
}
3 => {
x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]);
x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[2]);
}
_ => {
x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]);
x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[3]);
}
}
// _ => unreachable!("This function is only called with 2, 3, 4"),
Self {
shake128_state: state,
}
}

fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] {
debug_assert!(K == 2 || K == 3 || K == 4);

let mut out = [[0u8; THREE_BLOCKS]; K];
match K {
2 => {
let (out0, out1) = out.split_at_mut(1);
x2::incremental::shake128_squeeze_first_three_blocks(
&mut self.shake128_state[0],
&mut out0[0],
&mut out1[0],
);
}
3 => {
let mut extra = [0u8; THREE_BLOCKS];
let (out0, out12) = out.split_at_mut(1);
let (out1, out2) = out12.split_at_mut(1);
x2::incremental::shake128_squeeze_first_three_blocks(
&mut self.shake128_state[0],
&mut out0[0],
&mut out1[0],
);
x2::incremental::shake128_squeeze_first_three_blocks(
&mut self.shake128_state[1],
&mut out2[0],
&mut extra,
);
}
4 => {
let (out0, out123) = out.split_at_mut(1);
let (out1, out23) = out123.split_at_mut(1);
let (out2, out3) = out23.split_at_mut(1);
x2::incremental::shake128_squeeze_first_three_blocks(
&mut self.shake128_state[0],
&mut out0[0],
&mut out1[0],
);
x2::incremental::shake128_squeeze_first_three_blocks(
&mut self.shake128_state[1],
&mut out2[0],
&mut out3[0],
);
}
_ => unreachable!("This function is only called with 2, 3, 4"),
}
out
x2::incremental::shake128_squeeze3xN(&mut self.shake128_state)
}

fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] {
debug_assert!(K == 2 || K == 3 || K == 4);

let mut out = [[0u8; BLOCK_SIZE]; K];
match K {
2 => {
let mut out0 = [0u8; BLOCK_SIZE];
let mut out1 = [0u8; BLOCK_SIZE];
x2::incremental::shake128_squeeze_next_block(
&mut self.shake128_state[0],
&mut out0,
&mut out1,
);
out[0] = out0;
out[1] = out1;
}
3 => {
let mut out0 = [0u8; BLOCK_SIZE];
let mut out1 = [0u8; BLOCK_SIZE];
let mut out2 = [0u8; BLOCK_SIZE];
let mut out3 = [0u8; BLOCK_SIZE];
x2::incremental::shake128_squeeze_next_block(
&mut self.shake128_state[0],
&mut out0,
&mut out1,
);
x2::incremental::shake128_squeeze_next_block(
&mut self.shake128_state[1],
&mut out2,
&mut out3,
);
out[0] = out0;
out[1] = out1;
out[2] = out2;
}
4 => {
let mut out0 = [0u8; BLOCK_SIZE];
let mut out1 = [0u8; BLOCK_SIZE];
let mut out2 = [0u8; BLOCK_SIZE];
let mut out3 = [0u8; BLOCK_SIZE];
x2::incremental::shake128_squeeze_next_block(
&mut self.shake128_state[0],
&mut out0,
&mut out1,
);
x2::incremental::shake128_squeeze_next_block(
&mut self.shake128_state[1],
&mut out2,
&mut out3,
);
out[0] = out0;
out[1] = out1;
out[2] = out2;
out[3] = out3;
}
_ => unreachable!("This function is only called with 2, 3, 4"),
}
out
x2::incremental::shake128_squeezexN(&mut self.shake128_state)
}
}
}
172 changes: 166 additions & 6 deletions libcrux-sha3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,40 @@ pub mod neon {
keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]);
}

/// Run up to 4 SHAKE256 operations in parallel.
///
/// **PANICS** when `N` is not 2, 3, or 4.
#[allow(non_snake_case)]
pub fn shake256xN<const LEN: usize, const N: usize>(
input: &[[u8; 33]; N],
) -> [[u8; LEN]; N] {
debug_assert!(N == 2 || N == 3 || N == 4);

let mut out = [[0u8; LEN]; N];
match N {
2 => {
let (out0, out1) = out.split_at_mut(1);
shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]);
}
3 => {
let mut extra = [0u8; LEN];
let (out0, out12) = out.split_at_mut(1);
let (out1, out2) = out12.split_at_mut(1);
shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]);
shake256(&input[2], &input[2], &mut out2[0], &mut extra);
}
4 => {
let (out0, out123) = out.split_at_mut(1);
let (out1, out23) = out123.split_at_mut(1);
let (out2, out3) = out23.split_at_mut(1);
shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]);
shake256(&input[2], &input[3], &mut out2[0], &mut out3[0]);
}
_ => unreachable!("Only 2, 3, or 4 are supported for N"),
}
out
}

/// An incremental API to perform 2 operations in parallel
pub mod incremental {
#[cfg(all(feature = "simd128", target_arch = "aarch64"))]
Expand Down Expand Up @@ -390,8 +424,36 @@ pub mod neon {
absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(s, [data0, data1]);
}

/// Perform up to 4 absorbs at the same time, using two [`KeccakState2`].
///
/// **PANICS** when `N` is not 2, 3, or 4.
#[allow(unused_variables, non_snake_case)]
pub fn shake128_absorb_finalxN<const N: usize>(
input: [[u8; 34]; N],
) -> [KeccakState2; 2] {
debug_assert!(N == 2 || N == 3 || N == 4);
let mut state = [shake128_init(), shake128_init()];

match N {
2 => {
shake128_absorb_final(&mut state[0], &input[0], &input[1]);
}
3 => {
shake128_absorb_final(&mut state[0], &input[0], &input[1]);
shake128_absorb_final(&mut state[1], &input[2], &input[2]);
}
4 => {
shake128_absorb_final(&mut state[0], &input[0], &input[1]);
shake128_absorb_final(&mut state[1], &input[2], &input[3]);
}
_ => unreachable!("This function can only called be called with N = 2, 3, 4"),
}

state
}

#[allow(unused_variables)]
pub fn shake128_squeeze_first_three_blocks(
fn shake128_squeeze_first_three_blocks(
s: &mut KeccakState2,
out0: &mut [u8],
out1: &mut [u8],
Expand All @@ -412,12 +474,63 @@ pub mod neon {
)
}

/// Squeeze up to 3 x 4 (N) blocks in parallel, using two [`KeccakState2`].
/// Each block is of size `LEN`.
///
/// **PANICS** when `N` is not 2, 3, or 4.
#[allow(unused_variables, non_snake_case)]
pub fn shake128_squeeze3xN<const LEN: usize, const N: usize>(
state: &mut [KeccakState2; 2],
) -> [[u8; LEN]; N] {
debug_assert!(N == 2 || N == 3 || N == 4);

let mut out = [[0u8; LEN]; N];
match N {
2 => {
let (out0, out1) = out.split_at_mut(1);
shake128_squeeze_first_three_blocks(
&mut state[0],
&mut out0[0],
&mut out1[0],
);
}
3 => {
let mut extra = [0u8; LEN];
let (out0, out12) = out.split_at_mut(1);
let (out1, out2) = out12.split_at_mut(1);
shake128_squeeze_first_three_blocks(
&mut state[0],
&mut out0[0],
&mut out1[0],
);
shake128_squeeze_first_three_blocks(
&mut state[1],
&mut out2[0],
&mut extra,
);
}
4 => {
let (out0, out123) = out.split_at_mut(1);
let (out1, out23) = out123.split_at_mut(1);
let (out2, out3) = out23.split_at_mut(1);
shake128_squeeze_first_three_blocks(
&mut state[0],
&mut out0[0],
&mut out1[0],
);
shake128_squeeze_first_three_blocks(
&mut state[1],
&mut out2[0],
&mut out3[0],
);
}
_ => unreachable!("This function can only called be called with N = 2, 3, 4"),
}
out
}

#[allow(unused_variables)]
pub fn shake128_squeeze_next_block(
s: &mut KeccakState2,
out0: &mut [u8],
out1: &mut [u8],
) {
fn shake128_squeeze_next_block(s: &mut KeccakState2, out0: &mut [u8], out1: &mut [u8]) {
#[cfg(not(all(feature = "simd128", target_arch = "aarch64")))]
unimplemented!("The target architecture does not support neon instructions.");
// XXX: These functions could alternatively implement the same with
Expand All @@ -430,6 +543,53 @@ pub mod neon {
#[cfg(all(feature = "simd128", target_arch = "aarch64"))]
squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1])
}

/// Squeeze up to 4 (N) blocks in parallel, using two [`KeccakState2`].
/// Each block is of size `LEN`.
///
/// **PANICS** when `N` is not 2, 3, or 4.
#[allow(unused_variables, non_snake_case)]
pub fn shake128_squeezexN<const LEN: usize, const N: usize>(
state: &mut [KeccakState2; 2],
) -> [[u8; LEN]; N] {
debug_assert!(N == 2 || N == 3 || N == 4);

let mut out = [[0u8; LEN]; N];
match N {
2 => {
let mut out0 = [0u8; LEN];
let mut out1 = [0u8; LEN];
shake128_squeeze_next_block(&mut state[0], &mut out0, &mut out1);
out[0] = out0;
out[1] = out1;
}
3 => {
let mut out0 = [0u8; LEN];
let mut out1 = [0u8; LEN];
let mut out2 = [0u8; LEN];
let mut out3 = [0u8; LEN];
shake128_squeeze_next_block(&mut state[0], &mut out0, &mut out1);
shake128_squeeze_next_block(&mut state[1], &mut out2, &mut out3);
out[0] = out0;
out[1] = out1;
out[2] = out2;
}
4 => {
let mut out0 = [0u8; LEN];
let mut out1 = [0u8; LEN];
let mut out2 = [0u8; LEN];
let mut out3 = [0u8; LEN];
shake128_squeeze_next_block(&mut state[0], &mut out0, &mut out1);
shake128_squeeze_next_block(&mut state[1], &mut out2, &mut out3);
out[0] = out0;
out[1] = out1;
out[2] = out2;
out[3] = out3;
}
_ => unreachable!("This function is only called with N = 2, 3, 4"),
}
out
}
}
}
}
Expand Down

0 comments on commit da5ed9d

Please sign in to comment.