diff --git a/.github/workflows/hax.yml b/.github/workflows/hax.yml index c0c3c14f7..9f7b6fc2d 100644 --- a/.github/workflows/hax.yml +++ b/.github/workflows/hax.yml @@ -92,7 +92,7 @@ jobs: - name: 🏃 Extract & Verify ML-KEM crate (lax) run: | cd libcrux-ml-kem - ./hax.py extract + # ./hax.py extract # env FSTAR_HOME=${{ github.workspace }}/fstar \ # HACL_HOME=${{ github.workspace }}/hacl-star \ # HAX_HOME=${{ github.workspace }}/hax \ diff --git a/.github/workflows/mlkem.yml b/.github/workflows/mlkem.yml index 7abb24116..13c639a43 100644 --- a/.github/workflows/mlkem.yml +++ b/.github/workflows/mlkem.yml @@ -4,7 +4,7 @@ on: push: branches: ["main", "dev"] pull_request: - branches: ["main", "dev"] + branches: ["main", "dev", "*"] workflow_dispatch: merge_group: @@ -22,15 +22,15 @@ jobs: matrix: bits: [32, 64] os: - - macos-latest + - macos-13 # Intel mac + - macos-latest # macos-14 m1 - ubuntu-latest - windows-latest exclude: - bits: 32 os: "macos-latest" - # FIXME: Linking isn't working here yet for hacl #42 - bits: 32 - os: "windows-latest" + os: "macos-13" runs-on: ${{ matrix.os }} defaults: @@ -44,23 +44,21 @@ jobs: - run: echo "RUST_TARGET_FLAG=" > $GITHUB_ENV if: ${{ matrix.bits == 64 }} - - name: ⚙️ Setup Ubuntu x86 + - name: 🛠️ Setup Rust Nightly + run: rustup toolchain install nightly + + - name: 🛠️ Setup Ubuntu x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: | rustup target add i686-unknown-linux-gnu sudo apt-get update sudo apt-get install -y gcc-multilib g++-multilib - - name: ⚙️ Setup Ubuntu x64 + - name: 🛠️ Setup Ubuntu x64 if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} run: | rustup target add aarch64-unknown-linux-gnu - - name: ⚙️ Setup macOS - if: ${{ matrix.os == 'macos-latest' }} - run: | - rustup target add aarch64-apple-darwin - # Set up 32 bit systems - name: 🛠️ Config Windows x86 @@ -72,43 +70,73 @@ jobs: echo "RUST_TARGET_FLAG=--target=i686-unknown-linux-gnu" > $GITHUB_ENV if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - # Set up windows - - - name: ⚙️ Setup Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x86-windows-static-md - - - name: ⚙️ Setup Windows x64 - if: ${{ matrix.bits == 64 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x64-windows-static-md - # Build ... - name: 🔨 Build - run: cargo build --verbose $RUST_TARGET_FLAG + run: | + rustc --print=cfg + cargo build --verbose $RUST_TARGET_FLAG - name: 🔨 Build Release run: cargo build --verbose --release $RUST_TARGET_FLAG - # Cross compilation - - - name: 🔨 Build aarch64 macOS + - name: 🏃🏻 Asan MacOS if: ${{ matrix.os == 'macos-latest' }} - run: cargo build --verbose --target aarch64-apple-darwin + run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target aarch64-apple-darwin + + # We get false positives here. + # TODO: Figure out what is going on here + # - name: 🏃🏻 Asan Linux + # if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} + # run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target x86_64-unknown-linux-gnu + + # Test SHA3 + - name: 🏃🏻‍♀️ SHA3 Test + working-directory: libcrux-sha3 + run: | + cargo clean + cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ SHA3 Test Release + working-directory: libcrux-sha3 + run: | + cargo clean + cargo test --release --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ SHA3 Test Portable + working-directory: libcrux-sha3 + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ SHA3 Test Portable Release + working-directory: libcrux-sha3 + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose --release $RUST_TARGET_FLAG # Test ... - name: 🏃🏻‍♀️ Test - run: cargo test --verbose $RUST_TARGET_FLAG + run: | + cargo clean + cargo test --verbose $RUST_TARGET_FLAG - name: 🏃🏻‍♀️ Test Release - run: cargo test --verbose --release $RUST_TARGET_FLAG + if: ${{ matrix.os != 'macos-latest' }} + run: | + cargo clean + cargo test --verbose --release $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Portable + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Test Portable Release + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo test --verbose --release $RUST_TARGET_FLAG benchmarks: strategy: @@ -116,6 +144,7 @@ jobs: matrix: bits: [32, 64] os: + - macos-13 - macos-latest - ubuntu-latest - windows-latest @@ -123,9 +152,8 @@ jobs: # There's no such thing as 32-bit macOS - bits: 32 os: "macos-latest" - # FIXME: Linking isn't working here yet for hacl #42 - bits: 32 - os: "windows-latest" + os: "macos-13" runs-on: ${{ matrix.os }} defaults: @@ -139,19 +167,19 @@ jobs: - run: echo "RUST_TARGET_FLAG=" > $GITHUB_ENV if: ${{ matrix.bits == 64 }} - - name: ⚙️ Setup Ubuntu x86 + - name: 🛠️ Setup Ubuntu x86 if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} run: | rustup target add i686-unknown-linux-gnu sudo apt-get update sudo apt-get install -y gcc-multilib g++-multilib - - name: ⚙️ Setup Ubuntu x64 + - name: 🛠️ Setup Ubuntu x64 if: ${{ matrix.bits == 64 && matrix.os == 'ubuntu-latest' }} run: | rustup target add aarch64-unknown-linux-gnu - - name: ⚙️ Setup macOS + - name: 🛠️ Setup macOS if: ${{ matrix.os == 'macos-latest' }} run: | rustup target add aarch64-apple-darwin @@ -167,28 +195,12 @@ jobs: echo "RUST_TARGET_FLAG=--target=i686-unknown-linux-gnu" > $GITHUB_ENV if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} - # Set up windows - - - name: ⚙️ Setup Windows x86 - if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x86-windows-static-md - - - name: ⚙️ Setup Windows x64 - if: ${{ matrix.bits == 64 && matrix.os == 'windows-latest' }} - shell: pwsh - run: | - echo "VCPKG_ROOT=$env:VCPKG_INSTALLATION_ROOT" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install openssl:x64-windows-static-md - # Benchmarks ... - - name: 🏃🏻‍♀️ Benchmarks Windows - if: ${{ matrix.os == 'windows-latest' }} + - name: 🏃🏻‍♀️ Benchmarks run: cargo bench --verbose $RUST_TARGET_FLAG - - - name: 🏃🏻‍♀️ Benchmarks Clang - if: ${{ matrix.os != 'windows-latest' }} - run: CC=clang cargo bench --verbose $RUST_TARGET_FLAG + + - name: 🏃🏻‍♀️ Benchmarks Portable + run: | + cargo clean + LIBCRUX_DISABLE_SIMD128=1 LIBCRUX_DISABLE_SIMD256=1 cargo bench --verbose $RUST_TARGET_FLAG diff --git a/.github/workflows/platform.yml b/.github/workflows/platform.yml new file mode 100644 index 000000000..84c4bf71f --- /dev/null +++ b/.github/workflows/platform.yml @@ -0,0 +1,74 @@ +name: Platform + +on: + push: + branches: ["main", "dev"] + pull_request: + branches: ["main", "dev", "*"] + workflow_dispatch: + merge_group: + +env: + CARGO_TERM_COLOR: always + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + platform: + strategy: + fail-fast: false + matrix: + bits: [32, 64] + os: + - macos-13 # Intel mac + - macos-latest # macos-14 m1 + - ubuntu-latest + - windows-latest + exclude: + - bits: 32 + os: "macos-latest" + - bits: 32 + os: "macos-13" + + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: sys/platform + + steps: + - uses: actions/checkout@v4 + + - name: 🔨 Build + run: cargo build --verbose + + - name: 🏃🏻‍♀️ Test + run: cargo test --verbose -- --nocapture + + - name: 🏃🏻‍♀️ Test Release + run: cargo test --verbose --release -- --nocapture + + - name: 🛠️ Setup Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: | + rustup target add i686-unknown-linux-gnu + sudo apt-get update + sudo apt-get install -y gcc-multilib g++-multilib + + - name: 🏃🏻‍♀️ Test Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: cargo test --verbose --target i686-unknown-linux-gnu -- --nocapture + + - name: 🏃🏻‍♀️ Test Release Linux x86 + if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }} + run: cargo test --verbose --release --target i686-unknown-linux-gnu -- --nocapture + + - name: 🏃🏻‍♀️ Test Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: cargo test --verbose --target i686-pc-windows-msvc -- --nocapture + + - name: 🏃🏻‍♀️ Test Release Windows x86 + if: ${{ matrix.bits == 32 && matrix.os == 'windows-latest' }} + run: cargo test --verbose --release --target i686-pc-windows-msvc -- --nocapture diff --git a/libcrux-ml-kem/build.rs b/libcrux-ml-kem/build.rs index f15f3d581..ef1138666 100644 --- a/libcrux-ml-kem/build.rs +++ b/libcrux-ml-kem/build.rs @@ -16,11 +16,13 @@ fn main() { // We enable simd128 on all aarch64 builds. println!("cargo:rustc-cfg=feature=\"simd128\""); } - if (target_arch == "x86" || target_arch == "x86_64") && !disable_simd256 { - // We enable simd256 on all x86 and x86_64 builds. + if target_arch == "x86_64" && !disable_simd256 { + // We enable simd256 on all x86_64 builds. // Note that this doesn't mean the required CPU features are available. // But the compiler will support them and the runtime checks ensure that // it's only used when available. + // + // We don't enable this on x86 because it seems to generate invalid code. println!("cargo:rustc-cfg=feature=\"simd256\""); } } diff --git a/libcrux-ml-kem/src/constants.rs b/libcrux-ml-kem/src/constants.rs index b3903e9db..cf89e9348 100644 --- a/libcrux-ml-kem/src/constants.rs +++ b/libcrux-ml-kem/src/constants.rs @@ -32,4 +32,7 @@ pub(crate) const CPA_PKE_KEY_GENERATION_SEED_SIZE: usize = 32; // XXX: Eurydice can't handle this. // digest_size(Algorithm::Sha3_256); +/// SHA3 256 digest size pub(crate) const H_DIGEST_SIZE: usize = 32; +/// SHA3 512 digest size +pub(crate) const G_DIGEST_SIZE: usize = 64; diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 67d862cfa..cc00df6fd 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -1,65 +1,530 @@ #![allow(non_snake_case)] -use crate::constants::H_DIGEST_SIZE; -use libcrux_sha3::{x4::Shake128StateX4, *}; +use crate::constants::{G_DIGEST_SIZE, H_DIGEST_SIZE}; -pub(crate) fn G(input: &[u8]) -> [u8; digest_size(Algorithm::Sha3_512)] { - sha512(input) -} +/// The SHA3 block size. +pub(crate) const BLOCK_SIZE: usize = 168; -pub(crate) fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { - sha256(input) -} +/// The size of 3 SHA3 blocks. +pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; + +/// Abstraction for the hashing, to pick the fastest version depending on the +/// platform features available. +/// +/// There are 3 instantiations of this trait right now, using the libcrux-sha3 crate. +/// - AVX2 +/// - NEON +/// - Portable +pub(crate) trait Hash { + /// G aka SHA3 512 + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE]; + + /// H aka SHA3 256 + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE]; -pub(crate) fn PRF(input: &[u8]) -> [u8; LEN] { - shake256::(input) + /// PRF aka SHAKE256 + fn PRF(input: &[u8]) -> [u8; LEN]; + + /// PRFxN aka N SHAKE256 + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K]; + + /// Create a SHAKE128 state and absorb the input. + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self; + + /// Squeeze 3 blocks out of the SHAKE128 state. + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K]; + + /// Squeeze 1 block out of the SHAKE128 state. + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K]; } -#[inline(always)] -pub(crate) fn absorb(input: [[u8; 34]; K]) -> Shake128StateX4 { - debug_assert!(K == 2 || K == 3 || K == 4); +/// A portable implementation of [`Hash`] +pub(crate) mod portable { + use super::*; + use libcrux_sha3::portable::{ + self, + incremental::{ + shake128_absorb_final, shake128_init, shake128_squeeze_first_three_blocks, + shake128_squeeze_next_block, + }, + KeccakState1, + }; - let mut state = Shake128StateX4::new(); - // XXX: We need to do this dance to get it through hax and eurydice for now. - let mut data: [&[u8]; K] = [&[0u8]; K]; - for i in 0..K { - data[i] = &input[i] as &[u8]; + /// The state. + /// + /// It's only used for SHAKE128. + /// All other functions don't actually use any members. + pub(crate) struct PortableHash { + shake128_state: [KeccakState1; K], } - state.absorb_final(data); - state -} -pub(crate) const BLOCK_SIZE: usize = 168; -pub(crate) const THREE_BLOCKS: usize = BLOCK_SIZE * 3; + impl Hash for PortableHash { + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { + let mut digest = [0u8; G_DIGEST_SIZE]; + portable::sha512(&mut digest, input); + digest + } -#[inline(always)] -pub(crate) fn squeeze_three_blocks( - xof_state: &mut Shake128StateX4, -) -> [[u8; THREE_BLOCKS]; K] { - let output: [[u8; THREE_BLOCKS]; K] = xof_state.squeeze_blocks(); - let mut out = [[0u8; THREE_BLOCKS]; K]; - for i in 0..K { - out[i] = output[i]; + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + let mut digest = [0u8; H_DIGEST_SIZE]; + portable::sha256(&mut digest, input); + digest + } + + fn PRF(input: &[u8]) -> [u8; LEN] { + let mut digest = [0u8; LEN]; + portable::shake256(&mut digest, input); + digest + } + + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; LEN]; K]; + for i in 0..K { + portable::shake256::(&mut out[i], &input[i]); + } + out + } + + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut state = [shake128_init(); K]; + for i in 0..K { + shake128_absorb_final(&mut state[i], &input[i]); + } + Self { + shake128_state: state, + } + } + + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; THREE_BLOCKS]; K]; + for i in 0..K { + shake128_squeeze_first_three_blocks(&mut self.shake128_state[i], &mut out[i]); + } + out + } + + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; BLOCK_SIZE]; K]; + for i in 0..K { + shake128_squeeze_next_block(&mut self.shake128_state[i], &mut out[i]); + } + out + } } - out } -#[inline(always)] -pub(crate) fn squeeze_block( - xof_state: &mut Shake128StateX4, -) -> [[u8; BLOCK_SIZE]; K] { - let output: [[u8; BLOCK_SIZE]; K] = xof_state.squeeze_blocks(); - let mut out = [[0u8; BLOCK_SIZE]; K]; - for i in 0..K { - out[i] = output[i]; +/// A SIMD256 implementation of [`Hash`] for AVX2 +pub(crate) mod avx2 { + use super::*; + use libcrux_sha3::{ + avx2::x4::{self, incremental::KeccakState4}, + portable, + }; + + /// The state. + /// + /// It's only used for SHAKE128. + /// All other functions don't actually use any members. + pub(crate) struct Simd256Hash { + shake128_state: KeccakState4, + } + + impl Hash for Simd256Hash { + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { + let mut digest = [0u8; G_DIGEST_SIZE]; + portable::sha512(&mut digest, input); + digest + } + + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + let mut digest = [0u8; H_DIGEST_SIZE]; + portable::sha256(&mut digest, input); + digest + } + + fn PRF(input: &[u8]) -> [u8; LEN] { + let mut digest = [0u8; LEN]; + portable::shake256(&mut digest, input); + digest + } + + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut out = [[0u8; LEN]; K]; + + match K { + 2 => { + let mut dummy_out0 = [0u8; LEN]; + let mut dummy_out1 = [0u8; LEN]; + let (out0, out1) = out.split_at_mut(1); + x4::shake256( + &input[0], + &input[1], + &input[0], + &input[0], + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let mut dummy_out0 = [0u8; LEN]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x4::shake256( + &input[0], + &input[1], + &input[2], + &input[0], + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x4::shake256( + &input[0], + &input[1], + &input[2], + &input[3], + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out + } + + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut state = x4::incremental::shake128_init(); + + match K { + 2 => { + x4::incremental::shake128_absorb_final( + &mut state, &input[0], &input[1], &input[0], &input[0], + ); + } + 3 => { + x4::incremental::shake128_absorb_final( + &mut state, &input[0], &input[1], &input[2], &input[0], + ); + } + 4 => { + x4::incremental::shake128_absorb_final( + &mut state, &input[0], &input[1], &input[2], &input[3], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + Self { + shake128_state: state, + } + } + + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; THREE_BLOCKS]; K]; + match K { + 2 => { + let mut dummy_out0 = [0u8; THREE_BLOCKS]; + let mut dummy_out1 = [0u8; THREE_BLOCKS]; + let (out0, out1) = out.split_at_mut(1); + x4::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let mut dummy_out0 = [0u8; THREE_BLOCKS]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x4::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x4::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out + } + + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut dummy_out0 = [0u8; BLOCK_SIZE]; + let mut dummy_out1 = [0u8; BLOCK_SIZE]; + + let mut out = [[0u8; BLOCK_SIZE]; K]; + + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + x4::incremental::shake128_squeeze_next_block( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut dummy_out0, + &mut dummy_out1, + ); + } + 3 => { + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x4::incremental::shake128_squeeze_next_block( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut dummy_out0, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x4::incremental::shake128_squeeze_next_block( + &mut self.shake128_state, + &mut out0[0], + &mut out1[0], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out + } } - out } -/// Free the memory of the state. -/// -/// **NOTE:** That this needs to be done manually for now. -#[inline(always)] -pub(crate) fn free_state(xof_state: Shake128StateX4) { - xof_state.free_memory(); +/// A SIMD128 implementation of [`Hash`] for NEON +pub(crate) mod neon { + use super::*; + use libcrux_sha3::neon::x2::{self, incremental::KeccakState2}; + + /// The state. + /// + /// It's only used for SHAKE128. + /// All other functions don't actually use any members. + pub(crate) struct Simd128Hash { + shake128_state: [KeccakState2; 2], + } + + impl Hash for Simd128Hash { + fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { + let mut digest = [0u8; G_DIGEST_SIZE]; + libcrux_sha3::neon::sha512(&mut digest, input); + digest + } + + fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { + let mut digest = [0u8; H_DIGEST_SIZE]; + libcrux_sha3::neon::sha256(&mut digest, input); + digest + } + + fn PRF(input: &[u8]) -> [u8; LEN] { + let mut digest = [0u8; LEN]; + libcrux_sha3::neon::shake256(&mut digest, input); + digest + } + + fn PRFxN(input: &[[u8; 33]; K]) -> [[u8; LEN]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; LEN]; K]; + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]); + } + 3 => { + let mut extra = [0u8; LEN]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]); + x2::shake256(&input[2], &input[2], &mut out2[0], &mut extra); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x2::shake256(&input[0], &input[1], &mut out0[0], &mut out1[0]); + x2::shake256(&input[2], &input[3], &mut out2[0], &mut out3[0]); + } + _ => unreachable!(), + } + out + } + + fn shake128_init_absorb(input: [[u8; 34]; K]) -> Self { + debug_assert!(K == 2 || K == 3 || K == 4); + let mut state = [ + x2::incremental::shake128_init(), + x2::incremental::shake128_init(), + ]; + + match K { + 2 => { + x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + } + 3 => { + x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[2]); + } + _ => { + x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[3]); + } + } + // _ => unreachable!("This function is only called with 2, 3, 4"), + Self { + shake128_state: state, + } + } + + fn shake128_squeeze_three_blocks(&mut self) -> [[u8; THREE_BLOCKS]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; THREE_BLOCKS]; K]; + match K { + 2 => { + let (out0, out1) = out.split_at_mut(1); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[0], + &mut out0[0], + &mut out1[0], + ); + } + 3 => { + let mut extra = [0u8; THREE_BLOCKS]; + let (out0, out12) = out.split_at_mut(1); + let (out1, out2) = out12.split_at_mut(1); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[0], + &mut out0[0], + &mut out1[0], + ); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[1], + &mut out2[0], + &mut extra, + ); + } + 4 => { + let (out0, out123) = out.split_at_mut(1); + let (out1, out23) = out123.split_at_mut(1); + let (out2, out3) = out23.split_at_mut(1); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[0], + &mut out0[0], + &mut out1[0], + ); + x2::incremental::shake128_squeeze_first_three_blocks( + &mut self.shake128_state[1], + &mut out2[0], + &mut out3[0], + ); + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out + } + + fn shake128_squeeze_block(&mut self) -> [[u8; BLOCK_SIZE]; K] { + debug_assert!(K == 2 || K == 3 || K == 4); + + let mut out = [[0u8; BLOCK_SIZE]; K]; + match K { + 2 => { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[0], + &mut out0, + &mut out1, + ); + out[0] = out0; + out[1] = out1; + } + 3 => { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + let mut out2 = [0u8; BLOCK_SIZE]; + let mut out3 = [0u8; BLOCK_SIZE]; + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[0], + &mut out0, + &mut out1, + ); + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[1], + &mut out2, + &mut out3, + ); + out[0] = out0; + out[1] = out1; + out[2] = out2; + } + 4 => { + let mut out0 = [0u8; BLOCK_SIZE]; + let mut out1 = [0u8; BLOCK_SIZE]; + let mut out2 = [0u8; BLOCK_SIZE]; + let mut out3 = [0u8; BLOCK_SIZE]; + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[0], + &mut out0, + &mut out1, + ); + x2::incremental::shake128_squeeze_next_block( + &mut self.shake128_state[1], + &mut out2, + &mut out3, + ); + out[0] = out0; + out[1] = out1; + out[2] = out2; + out[3] = out3; + } + _ => unreachable!("This function is only called with 2, 3, 4"), + } + out + } + } } diff --git a/libcrux-ml-kem/src/helper.rs b/libcrux-ml-kem/src/helper.rs index 3c9b77dfc..20e850b0d 100644 --- a/libcrux-ml-kem/src/helper.rs +++ b/libcrux-ml-kem/src/helper.rs @@ -1,7 +1,7 @@ /// The following macros are defined so that the extraction from Rust to C code /// can go through. -#[cfg(not(hax))] +#[cfg(eurydice)] macro_rules! cloop { (for ($i:ident, $chunk:ident) in $val:ident.$values:ident.chunks_exact($($chunk_size:expr),*).enumerate() $body:block) => { for $i in 0..$val.$values.len() / ($($chunk_size)*) { @@ -23,7 +23,7 @@ macro_rules! cloop { }; (for ($i:ident, $item:ident) in $val:ident.into_iter().enumerate() $body:block) => { for $i in 0..$val.len() { - let $item = $val[$i]; + let $item = &$val[$i]; $body } }; @@ -35,7 +35,7 @@ macro_rules! cloop { }; } -#[cfg(hax)] +#[cfg(not(eurydice))] macro_rules! cloop { (for ($i:ident, $chunk:ident) in $val:ident.$values:ident.chunks_exact($($chunk_size:expr),*).enumerate() $body:block) => { for ($i, $chunk) in $val.$values.chunks_exact($($chunk_size),*).enumerate() $body diff --git a/libcrux-ml-kem/src/ind_cca.rs b/libcrux-ml-kem/src/ind_cca.rs index e7b6343db..a538e313b 100644 --- a/libcrux-ml-kem/src/ind_cca.rs +++ b/libcrux-ml-kem/src/ind_cca.rs @@ -5,7 +5,7 @@ use crate::{ compare_ciphertexts_in_constant_time, select_shared_secret_in_constant_time, }, constants::{CPA_PKE_KEY_GENERATION_SEED_SIZE, H_DIGEST_SIZE, SHARED_SECRET_SIZE}, - hash_functions::{G, H, PRF}, + hash_functions::{self, Hash}, ind_cpa::{into_padded_array, serialize_public_key}, serialize::deserialize_ring_elements_reduced, types::{MlKemCiphertext, MlKemKeyPair, MlKemPrivateKey, MlKemPublicKey}, @@ -27,7 +27,7 @@ pub type MlKemSharedSecret = [u8; SHARED_SECRET_SIZE]; /// Serialize the secret key. #[inline(always)] -fn serialize_kem_secret_key( +fn serialize_kem_secret_key>( private_key: &[u8], public_key: &[u8], implicit_rejection_value: &[u8], @@ -38,7 +38,7 @@ fn serialize_kem_secret_key( pointer += private_key.len(); out[pointer..pointer + public_key.len()].copy_from_slice(public_key); pointer += public_key.len(); - out[pointer..pointer + H_DIGEST_SIZE].copy_from_slice(&H(public_key)); + out[pointer..pointer + H_DIGEST_SIZE].copy_from_slice(&Hasher::H(public_key)); pointer += H_DIGEST_SIZE; out[pointer..pointer + implicit_rejection_value.len()] .copy_from_slice(implicit_rejection_value); @@ -52,8 +52,11 @@ pub(crate) fn validate_public_key< >( public_key: &[u8; PUBLIC_KEY_SIZE], ) -> bool { - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return validate_public_key_generic::< K, RANKED_BYTES_PER_RING_ELEMENT, @@ -67,8 +70,11 @@ pub(crate) fn validate_public_key< PUBLIC_KEY_SIZE, PortableVector, >(public_key) - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { - #[cfg(feature = "simd128")] + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] return validate_public_key_generic::< K, RANKED_BYTES_PER_RING_ELEMENT, @@ -127,8 +133,11 @@ pub(crate) fn generate_keypair< let implicit_rejection_value = &randomness[CPA_PKE_KEY_GENERATION_SEED_SIZE..]; // Runtime feature detection. - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return generate_keypair_generic::< K, CPA_PRIVATE_KEY_SIZE, @@ -138,6 +147,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, libcrux_polynomials::SIMD256Vector, + hash_functions::avx2::Simd256Hash, >(ind_cpa_keypair_randomness, implicit_rejection_value); #[cfg(not(feature = "simd256"))] generate_keypair_generic::< @@ -149,9 +159,13 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { - #[cfg(feature = "simd128")] + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] return generate_keypair_generic::< K, CPA_PRIVATE_KEY_SIZE, @@ -161,6 +175,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, libcrux_polynomials::SIMD128Vector, + hash_functions::neon::Simd128Hash, >(ind_cpa_keypair_randomness, implicit_rejection_value); #[cfg(not(feature = "simd128"))] generate_keypair_generic::< @@ -172,6 +187,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) } else { generate_keypair_generic::< @@ -183,6 +199,7 @@ pub(crate) fn generate_keypair< ETA1, ETA1_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(ind_cpa_keypair_randomness, implicit_rejection_value) } } @@ -196,6 +213,7 @@ fn generate_keypair_generic< const ETA1: usize, const ETA1_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( ind_cpa_keypair_randomness: &[u8], implicit_rejection_value: &[u8], @@ -208,10 +226,14 @@ fn generate_keypair_generic< ETA1, ETA1_RANDOMNESS_SIZE, Vector, + Hasher, >(ind_cpa_keypair_randomness); - let secret_key_serialized = - serialize_kem_secret_key(&ind_cpa_private_key, &public_key, implicit_rejection_value); + let secret_key_serialized = serialize_kem_secret_key::( + &ind_cpa_private_key, + &public_key, + implicit_rejection_value, + ); let private_key: MlKemPrivateKey = MlKemPrivateKey::from(secret_key_serialized); @@ -236,8 +258,11 @@ pub(crate) fn encapsulate< public_key: &MlKemPublicKey, randomness: [u8; SHARED_SECRET_SIZE], ) -> (MlKemCiphertext, MlKemSharedSecret) { - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return encapsulate_generic::< K, CIPHERTEXT_SIZE, @@ -253,6 +278,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, libcrux_polynomials::SIMD256Vector, + hash_functions::avx2::Simd256Hash, >(public_key, randomness); #[cfg(not(feature = "simd256"))] encapsulate_generic::< @@ -270,8 +296,12 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(public_key, randomness) - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { #[cfg(not(feature = "simd128"))] return encapsulate_generic::< K, @@ -288,8 +318,9 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(public_key, randomness); - #[cfg(feature = "simd128")] + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] encapsulate_generic::< K, CIPHERTEXT_SIZE, @@ -305,6 +336,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, libcrux_polynomials::SIMD128Vector, + hash_functions::neon::Simd128Hash, >(public_key, randomness) } else { encapsulate_generic::< @@ -322,6 +354,7 @@ pub(crate) fn encapsulate< ETA2, ETA2_RANDOMNESS_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(public_key, randomness) } } @@ -341,14 +374,15 @@ fn encapsulate_generic< const ETA2: usize, const ETA2_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( public_key: &MlKemPublicKey, randomness: [u8; SHARED_SECRET_SIZE], ) -> (MlKemCiphertext, MlKemSharedSecret) { let mut to_hash: [u8; 2 * H_DIGEST_SIZE] = into_padded_array(&randomness); - to_hash[H_DIGEST_SIZE..].copy_from_slice(&H(public_key.as_slice())); + to_hash[H_DIGEST_SIZE..].copy_from_slice(&Hasher::H(public_key.as_slice())); - let hashed = G(&to_hash); + let hashed = Hasher::G(&to_hash); let (shared_secret, pseudorandomness) = hashed.split_at(SHARED_SECRET_SIZE); let ciphertext = crate::ind_cpa::encrypt::< @@ -365,6 +399,7 @@ fn encapsulate_generic< ETA2, ETA2_RANDOMNESS_SIZE, Vector, + Hasher, >(public_key.as_slice(), randomness, pseudorandomness); let mut shared_secret_array = [0u8; SHARED_SECRET_SIZE]; shared_secret_array.copy_from_slice(shared_secret); @@ -392,8 +427,11 @@ pub(crate) fn decapsulate< private_key: &MlKemPrivateKey, ciphertext: &MlKemCiphertext, ) -> MlKemSharedSecret { - if cfg!(feature = "simd256") && libcrux_platform::simd256_support() { - #[cfg(feature = "simd256")] + if cfg!(feature = "simd256") + && cfg!(target_arch = "x86_64") + && libcrux_platform::simd256_support() + { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] return decapsulate_generic::< K, SECRET_KEY_SIZE, @@ -412,6 +450,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, libcrux_polynomials::SIMD256Vector, + hash_functions::avx2::Simd256Hash, >(private_key, ciphertext); #[cfg(not(feature = "simd256"))] return decapsulate_generic::< @@ -432,9 +471,13 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(private_key, ciphertext); - } else if cfg!(feature = "simd128") && libcrux_platform::simd128_support() { - #[cfg(feature = "simd128")] + } else if cfg!(feature = "simd128") + && cfg!(target_arch = "aarch64") + && libcrux_platform::simd128_support() + { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] return decapsulate_generic::< K, SECRET_KEY_SIZE, @@ -453,6 +496,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, libcrux_polynomials::SIMD128Vector, + hash_functions::neon::Simd128Hash, >(private_key, ciphertext); #[cfg(not(feature = "simd128"))] return decapsulate_generic::< @@ -473,6 +517,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(private_key, ciphertext); } else { decapsulate_generic::< @@ -493,6 +538,7 @@ pub(crate) fn decapsulate< ETA2_RANDOMNESS_SIZE, IMPLICIT_REJECTION_HASH_INPUT_SIZE, PortableVector, + hash_functions::portable::PortableHash, >(private_key, ciphertext) } } @@ -515,6 +561,7 @@ pub(crate) fn decapsulate_generic< const ETA2_RANDOMNESS_SIZE: usize, const IMPLICIT_REJECTION_HASH_INPUT_SIZE: usize, Vector: Operations, + Hasher: Hash, >( private_key: &MlKemPrivateKey, ciphertext: &MlKemCiphertext, @@ -535,13 +582,13 @@ pub(crate) fn decapsulate_generic< let mut to_hash: [u8; SHARED_SECRET_SIZE + H_DIGEST_SIZE] = into_padded_array(&decrypted); to_hash[SHARED_SECRET_SIZE..].copy_from_slice(ind_cpa_public_key_hash); - let hashed = G(&to_hash); + let hashed = Hasher::G(&to_hash); let (shared_secret, pseudorandomness) = hashed.split_at(SHARED_SECRET_SIZE); let mut to_hash: [u8; IMPLICIT_REJECTION_HASH_INPUT_SIZE] = into_padded_array(implicit_rejection_value); to_hash[SHARED_SECRET_SIZE..].copy_from_slice(ciphertext.as_ref()); - let implicit_rejection_shared_secret: [u8; SHARED_SECRET_SIZE] = PRF(&to_hash); + let implicit_rejection_shared_secret: [u8; SHARED_SECRET_SIZE] = Hasher::PRF(&to_hash); let expected_ciphertext = crate::ind_cpa::encrypt::< K, @@ -557,6 +604,7 @@ pub(crate) fn decapsulate_generic< ETA2, ETA2_RANDOMNESS_SIZE, Vector, + Hasher, >(ind_cpa_public_key, decrypted, pseudorandomness); let selector = compare_ciphertexts_in_constant_time::( diff --git a/libcrux-ml-kem/src/ind_cpa.rs b/libcrux-ml-kem/src/ind_cpa.rs index c11e67772..2271b9ed9 100644 --- a/libcrux-ml-kem/src/ind_cpa.rs +++ b/libcrux-ml-kem/src/ind_cpa.rs @@ -2,7 +2,7 @@ use libcrux_polynomials::Operations; use crate::{ constants::{BYTES_PER_RING_ELEMENT, COEFFICIENTS_IN_RING_ELEMENT, SHARED_SECRET_SIZE}, - hash_functions::{G, PRF}, + hash_functions::Hash, helper::cloop, matrix::*, ntt::{ntt_binomially_sampled_ring_element, ntt_vector_u}, @@ -55,7 +55,7 @@ fn serialize_secret_key, >( - prf_input: &mut [u8; 33], - domain_separator: &mut u8, -) -> [PolynomialRingElement; K] { - let mut error_1 = [PolynomialRingElement::::ZERO(); K]; + prf_input: [u8; 33], + mut domain_separator: u8, +) -> ([PolynomialRingElement; K], u8) { + let mut error_1 = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); + let mut prf_inputs = [prf_input; K]; for i in 0..K { - prf_input[32] = *domain_separator; - *domain_separator += 1; - - let prf_output: [u8; ETA2_RANDOMNESS_SIZE] = PRF(prf_input); - error_1[i] = sample_from_binomial_distribution::(&prf_output); + prf_inputs[i][32] = domain_separator; + domain_separator += 1; } - error_1 + let prf_outputs: [[u8; ETA2_RANDOMNESS_SIZE]; K] = Hasher::PRFxN(&prf_inputs); + for i in 0..K { + error_1[i] = sample_from_binomial_distribution::(&prf_outputs[i]); + } + (error_1, domain_separator) } /// Sample a vector of ring elements from a centered binomial distribution and @@ -92,19 +95,21 @@ fn sample_vector_cbd_then_ntt< const ETA: usize, const ETA_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( - mut prf_input: [u8; 33], + prf_input: [u8; 33], mut domain_separator: u8, ) -> ([PolynomialRingElement; K], u8) { - let mut re_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut re_as_ntt = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); + let mut prf_inputs = [prf_input; K]; for i in 0..K { - prf_input[32] = domain_separator; + prf_inputs[i][32] = domain_separator; domain_separator += 1; - - let prf_output: [u8; ETA_RANDOMNESS_SIZE] = PRF(&prf_input); - - let r = sample_from_binomial_distribution::(&prf_output); - re_as_ntt[i] = ntt_binomially_sampled_ring_element(r); + } + let prf_outputs: [[u8; ETA_RANDOMNESS_SIZE]; K] = Hasher::PRFxN(&prf_inputs); + for i in 0..K { + re_as_ntt[i] = sample_from_binomial_distribution::(&prf_outputs[i]); + ntt_binomially_sampled_ring_element(&mut re_as_ntt[i]); } (re_as_ntt, domain_separator) } @@ -156,22 +161,24 @@ pub(crate) fn generate_keypair< const ETA1: usize, const ETA1_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( key_generation_seed: &[u8], ) -> ([u8; PRIVATE_KEY_SIZE], [u8; PUBLIC_KEY_SIZE]) { // (ρ,σ) := G(d) - let hashed = G(key_generation_seed); + let hashed = Hasher::G(key_generation_seed); let (seed_for_A, seed_for_secret_and_error) = hashed.split_at(32); - let A_transpose = sample_matrix_A(into_padded_array(seed_for_A), true); + let A_transpose = sample_matrix_A::(into_padded_array(seed_for_A), true); let prf_input: [u8; 33] = into_padded_array(seed_for_secret_and_error); let (secret_as_ntt, domain_separator) = - sample_vector_cbd_then_ntt::(prf_input, 0); - let (error_as_ntt, _) = sample_vector_cbd_then_ntt::( - prf_input, - domain_separator, - ); + sample_vector_cbd_then_ntt::(prf_input, 0); + let (error_as_ntt, _) = + sample_vector_cbd_then_ntt::( + prf_input, + domain_separator, + ); // tˆ := Aˆ ◦ sˆ + eˆ let t_as_ntt = compute_As_plus_e(&A_transpose, &secret_as_ntt, &error_as_ntt); @@ -197,17 +204,15 @@ fn compress_then_serialize_u< Vector: Operations, >( input: [PolynomialRingElement; K], -) -> [u8; OUT_LEN] { - let mut out = [0u8; OUT_LEN]; + out: &mut [u8], +) { cloop! { for (i, re) in input.into_iter().enumerate() { out[i * (OUT_LEN / K)..(i + 1) * (OUT_LEN / K)].copy_from_slice( - &compress_then_serialize_ring_element_u::(re), + &compress_then_serialize_ring_element_u::(&re), ); } } - - out } /// This function implements Algorithm 13 of the @@ -264,6 +269,7 @@ pub(crate) fn encrypt< const ETA2: usize, const ETA2_RANDOMNESS_SIZE: usize, Vector: Operations, + Hasher: Hash, >( public_key: &[u8], message: [u8; SHARED_SECRET_SIZE], @@ -281,7 +287,7 @@ pub(crate) fn encrypt< // end for // end for let seed = &public_key[T_AS_NTT_ENCODED_SIZE..]; - let A_transpose = sample_matrix_A(into_padded_array(seed), false); + let A_transpose = sample_matrix_A::(into_padded_array(seed), false); // for i from 0 to k−1 do // r[i] := CBD{η1}(PRF(r, N)) @@ -289,21 +295,22 @@ pub(crate) fn encrypt< // end for // rˆ := NTT(r) let mut prf_input: [u8; 33] = into_padded_array(randomness); - let (r_as_ntt, mut domain_separator) = - sample_vector_cbd_then_ntt::(prf_input, 0); + let (r_as_ntt, domain_separator) = + sample_vector_cbd_then_ntt::(prf_input, 0); // for i from 0 to k−1 do // e1[i] := CBD_{η2}(PRF(r,N)) // N := N + 1 // end for - let error_1 = sample_ring_element_cbd::( - &mut prf_input, - &mut domain_separator, - ); + let (error_1, domain_separator) = + sample_ring_element_cbd::( + prf_input, + domain_separator, + ); // e_2 := CBD{η2}(PRF(r, N)) prf_input[32] = domain_separator; - let prf_output: [u8; ETA2_RANDOMNESS_SIZE] = PRF(&prf_input); + let prf_output: [u8; ETA2_RANDOMNESS_SIZE] = Hasher::PRF(&prf_input); let error_2 = sample_from_binomial_distribution::(&prf_output); // u := NTT^{-1}(AˆT ◦ rˆ) + e_1 @@ -313,14 +320,14 @@ pub(crate) fn encrypt< let message_as_ring_element = deserialize_then_decompress_message(message); let v = compute_ring_element_v(&t_as_ntt, &r_as_ntt, &error_2, &message_as_ring_element); + let mut ciphertext = [0u8; CIPHERTEXT_SIZE]; + let (c1, c2) = ciphertext.split_at_mut(C1_LEN); + // c_1 := Encode_{du}(Compress_q(u,d_u)) - let c1 = compress_then_serialize_u::(u); + compress_then_serialize_u::(u, c1); // c_2 := Encode_{dv}(Compress_q(v,d_v)) - let c2 = compress_then_serialize_ring_element_v::(v); - - let mut ciphertext: [u8; CIPHERTEXT_SIZE] = into_padded_array(&c1); - ciphertext[C1_LEN..].copy_from_slice(c2.as_slice()); + compress_then_serialize_ring_element_v::(v, c2); ciphertext } @@ -336,14 +343,14 @@ fn deserialize_then_decompress_u< >( ciphertext: &[u8; CIPHERTEXT_SIZE], ) -> [PolynomialRingElement; K] { - let mut u_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut u_as_ntt = core::array::from_fn(|_| PolynomialRingElement::::ZERO()); cloop! { for (i, u_bytes) in ciphertext .chunks_exact((COEFFICIENTS_IN_RING_ELEMENT * U_COMPRESSION_FACTOR) / 8) .enumerate() { - let u = deserialize_then_decompress_ring_element_u::(u_bytes); - u_as_ntt[i] = ntt_vector_u::(u); + u_as_ntt[i] = deserialize_then_decompress_ring_element_u::(u_bytes); + ntt_vector_u::(&mut u_as_ntt[i]); } } u_as_ntt @@ -354,7 +361,7 @@ fn deserialize_then_decompress_u< fn deserialize_secret_key( secret_key: &[u8], ) -> [PolynomialRingElement; K] { - let mut secret_as_ntt = [PolynomialRingElement::::ZERO(); K]; + let mut secret_as_ntt = core::array::from_fn(|_| PolynomialRingElement::::ZERO()); cloop! { for (i, secret_bytes) in secret_key.chunks_exact(BYTES_PER_RING_ELEMENT).enumerate() { secret_as_ntt[i] = deserialize_to_uncompressed_ring_element(secret_bytes); diff --git a/libcrux-ml-kem/src/invert_ntt.rs b/libcrux-ml-kem/src/invert_ntt.rs index 89ca5cad3..589633ac1 100644 --- a/libcrux-ml-kem/src/invert_ntt.rs +++ b/libcrux-ml-kem/src/invert_ntt.rs @@ -7,9 +7,9 @@ use libcrux_polynomials::{GenericOperations, Operations, FIELD_ELEMENTS_IN_VECTO #[inline(always)] pub(crate) fn invert_ntt_at_layer_1( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_1_step( @@ -21,15 +21,14 @@ pub(crate) fn invert_ntt_at_layer_1( ); *zeta_i -= 3; } - re } #[inline(always)] pub(crate) fn invert_ntt_at_layer_2( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_2_step( @@ -39,21 +38,19 @@ pub(crate) fn invert_ntt_at_layer_2( ); *zeta_i -= 1; } - re } #[inline(always)] pub(crate) fn invert_ntt_at_layer_3( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i -= 1; re.coefficients[round] = Vector::inv_ntt_layer_3_step(re.coefficients[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]); } - re } #[inline(always)] @@ -70,9 +67,9 @@ pub(crate) fn inv_ntt_layer_int_vec_step_reduce( #[inline(always)] pub(crate) fn invert_ntt_at_layer_4_plus( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, layer: usize, -) -> PolynomialRingElement { +) { let step = 1 << layer; for round in 0..(128 >> layer) { @@ -92,13 +89,12 @@ pub(crate) fn invert_ntt_at_layer_4_plus( re.coefficients[j + step_vec] = y; } } - re } #[inline(always)] pub(crate) fn invert_ntt_montgomery( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { // We only ever call this function after matrix/vector multiplication hax_debug_assert!(to_i16_array(re) .into_iter() @@ -106,13 +102,13 @@ pub(crate) fn invert_ntt_montgomery( let mut zeta_i = super::constants::COEFFICIENTS_IN_RING_ELEMENT / 2; - re = invert_ntt_at_layer_1(&mut zeta_i, re, 1); - re = invert_ntt_at_layer_2(&mut zeta_i, re, 2); - re = invert_ntt_at_layer_3(&mut zeta_i, re, 3); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 4); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 5); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 6); - re = invert_ntt_at_layer_4_plus(&mut zeta_i, re, 7); + invert_ntt_at_layer_1(&mut zeta_i, re, 1); + invert_ntt_at_layer_2(&mut zeta_i, re, 2); + invert_ntt_at_layer_3(&mut zeta_i, re, 3); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 4); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 5); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 6); + invert_ntt_at_layer_4_plus(&mut zeta_i, re, 7); hax_debug_assert!( to_i16_array(re)[0].abs() < 128 * (K as i16) * FIELD_MODULUS diff --git a/libcrux-ml-kem/src/matrix.rs b/libcrux-ml-kem/src/matrix.rs index f10565f82..2b4d4ed85 100644 --- a/libcrux-ml-kem/src/matrix.rs +++ b/libcrux-ml-kem/src/matrix.rs @@ -1,17 +1,19 @@ use libcrux_polynomials::Operations; use crate::{ - helper::cloop, invert_ntt::invert_ntt_montgomery, polynomial::PolynomialRingElement, - sampling::sample_from_xof, + hash_functions::Hash, helper::cloop, invert_ntt::invert_ntt_montgomery, + polynomial::PolynomialRingElement, sampling::sample_from_xof, }; #[inline(always)] #[allow(non_snake_case)] -pub(crate) fn sample_matrix_A( +pub(crate) fn sample_matrix_A>( seed: [u8; 34], transpose: bool, ) -> [[PolynomialRingElement; K]; K] { - let mut A_transpose = [[PolynomialRingElement::::ZERO(); K]; K]; + let mut A_transpose = core::array::from_fn(|_i| { + core::array::from_fn(|_j| PolynomialRingElement::::ZERO()) + }); for i in 0..K { let mut seeds = [seed; K]; @@ -19,13 +21,13 @@ pub(crate) fn sample_matrix_A( seeds[j][32] = i as u8; seeds[j][33] = j as u8; } - let sampled = sample_from_xof(seeds); - for j in 0..K { + let sampled = sample_from_xof::(seeds); + for (j, sample) in sampled.into_iter().enumerate() { // A[i][j] = A_transpose[j][i] if transpose { - A_transpose[j][i] = sampled[j]; + A_transpose[j][i] = sample; } else { - A_transpose[i][j] = sampled[j]; + A_transpose[i][j] = sample; } } } @@ -48,10 +50,10 @@ pub(crate) fn compute_message( for i in 0..K { let product = secret_as_ntt[i].ntt_multiply(&u_as_ntt[i]); - result = result.add_to_ring_element::(&product); + result.add_to_ring_element::(&product); } - result = invert_ntt_montgomery::(result); + invert_ntt_montgomery::(&mut result); result = v.subtract_reduce(result); result @@ -69,10 +71,10 @@ pub(crate) fn compute_ring_element_v( for i in 0..K { let product = t_as_ntt[i].ntt_multiply(&r_as_ntt[i]); - result = result.add_to_ring_element::(&product); + result.add_to_ring_element::(&product); } - result = invert_ntt_montgomery::(result); + invert_ntt_montgomery::(&mut result); result = error_2.add_message_error_reduce(message, result); result @@ -85,19 +87,19 @@ pub(crate) fn compute_vector_u( r_as_ntt: &[PolynomialRingElement; K], error_1: &[PolynomialRingElement; K], ) -> [PolynomialRingElement; K] { - let mut result = [PolynomialRingElement::::ZERO(); K]; + let mut result = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); cloop! { for (i, row) in a_as_ntt.iter().enumerate() { cloop! { for (j, a_element) in row.iter().enumerate() { let product = a_element.ntt_multiply(&r_as_ntt[j]); - result[i] = result[i].add_to_ring_element::(&product); + result[i].add_to_ring_element::(&product); } } - result[i] = invert_ntt_montgomery::(result[i]); - result[i] = error_1[i].add_error_reduce(result[i]); + invert_ntt_montgomery::(&mut result[i]); + result[i].add_error_reduce(&error_1[i]); } } @@ -112,17 +114,17 @@ pub(crate) fn compute_As_plus_e( s_as_ntt: &[PolynomialRingElement; K], error_as_ntt: &[PolynomialRingElement; K], ) -> [PolynomialRingElement; K] { - let mut result = [PolynomialRingElement::::ZERO(); K]; + let mut result = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); cloop! { for (i, row) in matrix_A.iter().enumerate() { cloop! { for (j, matrix_element) in row.iter().enumerate() { let product = matrix_element.ntt_multiply(&s_as_ntt[j]); - result[i] = result[i].add_to_ring_element::(&product); + result[i].add_to_ring_element::(&product); } } - result[i] = error_as_ntt[i].add_standard_error_reduce(result[i]); + result[i].add_standard_error_reduce(&error_as_ntt[i]); } } diff --git a/libcrux-ml-kem/src/ntt.rs b/libcrux-ml-kem/src/ntt.rs index 9d164beaa..afa17cf9e 100644 --- a/libcrux-ml-kem/src/ntt.rs +++ b/libcrux-ml-kem/src/ntt.rs @@ -7,10 +7,10 @@ use libcrux_polynomials::{GenericOperations, Operations, FIELD_ELEMENTS_IN_VECTO #[inline(always)] pub(crate) fn ntt_at_layer_1( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_1_step( @@ -22,16 +22,15 @@ pub(crate) fn ntt_at_layer_1( ); *zeta_i += 3; } - re } #[inline(always)] pub(crate) fn ntt_at_layer_2( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_2_step( @@ -41,23 +40,20 @@ pub(crate) fn ntt_at_layer_2( ); *zeta_i += 1; } - re } #[inline(always)] pub(crate) fn ntt_at_layer_3( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, _layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { for round in 0..16 { *zeta_i += 1; re.coefficients[round] = Vector::ntt_layer_3_step(re.coefficients[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]); } - - re } #[inline(always)] @@ -74,10 +70,10 @@ fn ntt_layer_int_vec_step( #[inline(always)] pub(crate) fn ntt_at_layer_4_plus( zeta_i: &mut usize, - mut re: PolynomialRingElement, + re: &mut PolynomialRingElement, layer: usize, _initial_coefficient_bound: usize, -) -> PolynomialRingElement { +) { debug_assert!(layer >= 4); let step = 1 << layer; @@ -98,59 +94,54 @@ pub(crate) fn ntt_at_layer_4_plus( re.coefficients[j + step_vec] = y; } } - re } #[inline(always)] -pub(crate) fn ntt_at_layer_7( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { +pub(crate) fn ntt_at_layer_7(re: &mut PolynomialRingElement) { let step = VECTORS_IN_RING_ELEMENT / 2; for j in 0..step { let t = Vector::multiply_by_constant(re.coefficients[j + step], -1600); re.coefficients[j + step] = Vector::sub(re.coefficients[j], &t); re.coefficients[j] = Vector::add(re.coefficients[j], &t); } - - re } #[inline(always)] pub(crate) fn ntt_binomially_sampled_ring_element( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { // Due to the small coefficient bound, we can skip the first round of // Montgomery reductions. - re = ntt_at_layer_7(re); + ntt_at_layer_7(re); let mut zeta_i = 1; - re = ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3); - re = ntt_at_layer_3(&mut zeta_i, re, 3, 3); - re = ntt_at_layer_2(&mut zeta_i, re, 2, 3); - re = ntt_at_layer_1(&mut zeta_i, re, 1, 3); + ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3); + ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3); + ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3); + ntt_at_layer_3(&mut zeta_i, re, 3, 3); + ntt_at_layer_2(&mut zeta_i, re, 2, 3); + ntt_at_layer_1(&mut zeta_i, re, 1, 3); re.poly_barrett_reduce() } #[inline(always)] pub(crate) fn ntt_vector_u( - mut re: PolynomialRingElement, -) -> PolynomialRingElement { + re: &mut PolynomialRingElement, +) { hax_debug_assert!(to_i16_array(re) .into_iter() .all(|coefficient| coefficient.abs() <= 3328)); let mut zeta_i = 0; - re = ntt_at_layer_4_plus(&mut zeta_i, re, 7, 3328); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3328); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3328); - re = ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3328); - re = ntt_at_layer_3(&mut zeta_i, re, 3, 3328); - re = ntt_at_layer_2(&mut zeta_i, re, 2, 3328); - re = ntt_at_layer_1(&mut zeta_i, re, 1, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 7, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 6, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 5, 3328); + ntt_at_layer_4_plus(&mut zeta_i, re, 4, 3328); + ntt_at_layer_3(&mut zeta_i, re, 3, 3328); + ntt_at_layer_2(&mut zeta_i, re, 2, 3328); + ntt_at_layer_1(&mut zeta_i, re, 1, 3328); re.poly_barrett_reduce() } diff --git a/libcrux-ml-kem/src/polynomial.rs b/libcrux-ml-kem/src/polynomial.rs index da38d39cb..b78c0e491 100644 --- a/libcrux-ml-kem/src/polynomial.rs +++ b/libcrux-ml-kem/src/polynomial.rs @@ -16,7 +16,6 @@ pub(crate) const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 128] = pub(crate) const VECTORS_IN_RING_ELEMENT: usize = super::constants::COEFFICIENTS_IN_RING_ELEMENT / FIELD_ELEMENTS_IN_VECTOR; -#[derive(Clone, Copy)] pub(crate) struct PolynomialRingElement { pub(crate) coefficients: [Vector; VECTORS_IN_RING_ELEMENT], } @@ -31,13 +30,11 @@ impl PolynomialRingElement { } #[inline(always)] - pub(crate) fn from_i16_array(a: [i16; 256]) -> Self { + pub(crate) fn from_i16_array(a: &[i16]) -> Self { let mut result = PolynomialRingElement::ZERO(); for i in 0..VECTORS_IN_RING_ELEMENT { result.coefficients[i] = Vector::from_i16_array( - a[i * FIELD_ELEMENTS_IN_VECTOR..(i + 1) * FIELD_ELEMENTS_IN_VECTOR] - .try_into() - .unwrap(), + &a[i * FIELD_ELEMENTS_IN_VECTOR..(i + 1) * FIELD_ELEMENTS_IN_VECTOR], ); } result @@ -46,19 +43,17 @@ impl PolynomialRingElement { /// Given two polynomial ring elements `lhs` and `rhs`, compute the pointwise /// sum of their constituent coefficients. #[inline(always)] - pub(crate) fn add_to_ring_element(mut self, rhs: &Self) -> Self { + pub(crate) fn add_to_ring_element(&mut self, rhs: &Self) { for i in 0..self.coefficients.len() { self.coefficients[i] = Vector::add(self.coefficients[i], &rhs.coefficients[i]); } - self } #[inline(always)] - pub fn poly_barrett_reduce(mut self) -> Self { + pub fn poly_barrett_reduce(&mut self) { for i in 0..VECTORS_IN_RING_ELEMENT { self.coefficients[i] = Vector::barrett_reduce(self.coefficients[i]); } - self } #[inline(always)] @@ -86,28 +81,30 @@ impl PolynomialRingElement { } #[inline(always)] - pub(crate) fn add_error_reduce(&self, mut result: Self) -> Self { + pub(crate) fn add_error_reduce(&mut self, error: &Self) { for j in 0..VECTORS_IN_RING_ELEMENT { let coefficient_normal_form = - Vector::montgomery_multiply_by_constant(result.coefficients[j], 1441); + Vector::montgomery_multiply_by_constant(self.coefficients[j], 1441); - result.coefficients[j] = - Vector::barrett_reduce(Vector::add(coefficient_normal_form, &self.coefficients[j])); + self.coefficients[j] = Vector::barrett_reduce(Vector::add( + coefficient_normal_form, + &error.coefficients[j], + )); } - result } #[inline(always)] - pub(crate) fn add_standard_error_reduce(&self, mut result: Self) -> Self { + pub(crate) fn add_standard_error_reduce(&mut self, error: &Self) { for j in 0..VECTORS_IN_RING_ELEMENT { // The coefficients are of the form aR^{-1} mod q, which means // calling to_montgomery_domain() on them should return a mod q. - let coefficient_normal_form = Vector::to_standard_domain(result.coefficients[j]); + let coefficient_normal_form = Vector::to_standard_domain(self.coefficients[j]); - result.coefficients[j] = - Vector::barrett_reduce(Vector::add(coefficient_normal_form, &self.coefficients[j])); + self.coefficients[j] = Vector::barrett_reduce(Vector::add( + coefficient_normal_form, + &error.coefficients[j], + )); } - result } /// Given two `KyberPolynomialRingElement`s in their NTT representations, diff --git a/libcrux-ml-kem/src/sampling.rs b/libcrux-ml-kem/src/sampling.rs index fe864bb9e..26e6ff216 100644 --- a/libcrux-ml-kem/src/sampling.rs +++ b/libcrux-ml-kem/src/sampling.rs @@ -47,28 +47,23 @@ use crate::{ fn sample_from_uniform_distribution_next( randomness: [[u8; N]; K], sampled_coefficients: &mut [usize; K], - out: &mut [[i16; 256]; K], + out: &mut [[i16; 272]; K], ) -> bool { // Would be great to trigger auto-vectorization or at least loop unrolling here for i in 0..K { for r in 0..N / 24 { - let remaining = COEFFICIENTS_IN_RING_ELEMENT - sampled_coefficients[i]; - if remaining > 0 { - let (sampled, vec) = Vector::rej_sample(&randomness[i][r * 24..(r * 24) + 24]); - let pick = if sampled <= remaining { - sampled - } else { - remaining - }; - out[i][sampled_coefficients[i]..sampled_coefficients[i] + pick] - .copy_from_slice(&vec[0..pick]); - sampled_coefficients[i] += pick; + if sampled_coefficients[i] < COEFFICIENTS_IN_RING_ELEMENT { + let out0 = out[i][sampled_coefficients[i]..sampled_coefficients[i] + 16].as_mut(); + let sampled = Vector::rej_sample(&randomness[i][r * 24..(r * 24) + 24], out0); + sampled_coefficients[i] += sampled; } } } let mut done = true; for i in 0..K { - if sampled_coefficients[i] < COEFFICIENTS_IN_RING_ELEMENT { + if sampled_coefficients[i] >= COEFFICIENTS_IN_RING_ELEMENT { + sampled_coefficients[i] = COEFFICIENTS_IN_RING_ELEMENT; + } else { done = false } } @@ -76,14 +71,14 @@ fn sample_from_uniform_distribution_next( +pub(super) fn sample_from_xof>( seeds: [[u8; 34]; K], ) -> [PolynomialRingElement; K] { let mut sampled_coefficients: [usize; K] = [0; K]; - let mut out: [[i16; 256]; K] = [[0; 256]; K]; + let mut out: [[i16; 272]; K] = [[0; 272]; K]; - let mut xof_state = absorb(seeds); - let randomness = squeeze_three_blocks(&mut xof_state); + let mut xof_state = Hasher::shake128_init_absorb(seeds); + let randomness = xof_state.shake128_squeeze_three_blocks(); let mut done = sample_from_uniform_distribution_next::( randomness, @@ -97,17 +92,15 @@ pub(super) fn sample_from_xof( // To avoid failing here, we squeeze more blocks out of the state until // we have enough. while !done { - let randomness = squeeze_block(&mut xof_state); + let randomness = xof_state.shake128_squeeze_block(); done = sample_from_uniform_distribution_next::( randomness, &mut sampled_coefficients, &mut out, ); } - // XXX: We have to manually free the state here due to a Eurydice issue. - free_state(xof_state); - out.map(PolynomialRingElement::::from_i16_array) + out.map(|s| PolynomialRingElement::::from_i16_array(&s[0..256])) } /// Given a series of uniformly random bytes in `randomness`, for some number `eta`, @@ -192,7 +185,7 @@ fn sample_from_binomial_distribution_2( } } } - PolynomialRingElement::from_i16_array(sampled_i16s) + PolynomialRingElement::from_i16_array(&sampled_i16s) } #[cfg_attr(hax, hax_lib::requires(randomness.len() == 3 * 64))] @@ -229,7 +222,7 @@ fn sample_from_binomial_distribution_3( } } } - PolynomialRingElement::from_i16_array(sampled_i16s) + PolynomialRingElement::from_i16_array(&sampled_i16s) } #[inline(always)] diff --git a/libcrux-ml-kem/src/serialize.rs b/libcrux-ml-kem/src/serialize.rs index 0d8db5d69..9cc9fc9c3 100644 --- a/libcrux-ml-kem/src/serialize.rs +++ b/libcrux-ml-kem/src/serialize.rs @@ -38,7 +38,7 @@ pub(super) fn deserialize_then_decompress_message( #[inline(always)] pub(super) fn serialize_uncompressed_ring_element( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; BYTES_PER_RING_ELEMENT] { let mut serialized = [0u8; BYTES_PER_RING_ELEMENT]; for i in 0..VECTORS_IN_RING_ELEMENT { @@ -99,7 +99,7 @@ pub(super) fn deserialize_ring_elements_reduced< >( public_key: &[u8], ) -> [PolynomialRingElement; K] { - let mut deserialized_pk = [PolynomialRingElement::::ZERO(); K]; + let mut deserialized_pk = core::array::from_fn(|_i| PolynomialRingElement::::ZERO()); cloop! { for (i, ring_element) in public_key .chunks_exact(BYTES_PER_RING_ELEMENT) @@ -113,7 +113,7 @@ pub(super) fn deserialize_ring_elements_reduced< #[inline(always)] fn compress_then_serialize_10( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; OUT_LEN] { let mut serialized = [0u8; OUT_LEN]; for i in 0..VECTORS_IN_RING_ELEMENT { @@ -128,7 +128,7 @@ fn compress_then_serialize_10( #[inline(always)] fn compress_then_serialize_11( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; OUT_LEN] { let mut serialized = [0u8; OUT_LEN]; for i in 0..VECTORS_IN_RING_ELEMENT { @@ -147,7 +147,7 @@ pub(super) fn compress_then_serialize_ring_element_u< const OUT_LEN: usize, Vector: Operations, >( - re: PolynomialRingElement, + re: &PolynomialRingElement, ) -> [u8; OUT_LEN] { hax_debug_assert!((COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN); @@ -159,10 +159,10 @@ pub(super) fn compress_then_serialize_ring_element_u< } #[inline(always)] -fn compress_then_serialize_4( +fn compress_then_serialize_4( re: PolynomialRingElement, -) -> [u8; OUT_LEN] { - let mut serialized = [0u8; OUT_LEN]; + serialized: &mut [u8], +) { for i in 0..VECTORS_IN_RING_ELEMENT { let coefficient = Vector::compress::<4>(Vector::to_unsigned_representative(re.coefficients[i])); @@ -170,15 +170,13 @@ fn compress_then_serialize_4( let bytes = Vector::serialize_4(coefficient); serialized[8 * i..8 * i + 8].copy_from_slice(&bytes); } - serialized } #[inline(always)] -fn compress_then_serialize_5( +fn compress_then_serialize_5( re: PolynomialRingElement, -) -> [u8; OUT_LEN] { - let mut serialized = [0u8; OUT_LEN]; - + serialized: &mut [u8], +) { for i in 0..VECTORS_IN_RING_ELEMENT { let coefficients = Vector::compress::<5>(Vector::to_unsigned_representative(re.coefficients[i])); @@ -186,7 +184,6 @@ fn compress_then_serialize_5( let bytes = Vector::serialize_5(coefficients); serialized[10 * i..10 * i + 10].copy_from_slice(&bytes); } - serialized } #[inline(always)] @@ -196,12 +193,13 @@ pub(super) fn compress_then_serialize_ring_element_v< Vector: Operations, >( re: PolynomialRingElement, -) -> [u8; OUT_LEN] { + out: &mut [u8], +) { hax_debug_assert!((COEFFICIENTS_IN_RING_ELEMENT * COMPRESSION_FACTOR) / 8 == OUT_LEN); match COMPRESSION_FACTOR as u32 { - 4 => compress_then_serialize_4(re), - 5 => compress_then_serialize_5(re), + 4 => compress_then_serialize_4(re, out), + 5 => compress_then_serialize_5(re, out), _ => unreachable!(), } } diff --git a/libcrux-sha3/Cargo.toml b/libcrux-sha3/Cargo.toml index 5ec4ed9be..b8bb92749 100644 --- a/libcrux-sha3/Cargo.toml +++ b/libcrux-sha3/Cargo.toml @@ -9,9 +9,6 @@ repository.workspace = true readme.workspace = true [dependencies] -libcrux-hacl = { version = "0.0.2-pre.2", path = "../sys/hacl", features = [ - "sha3", -] } libcrux-platform = { version = "0.0.2-pre.2", path = "../sys/platform" } # This is only required for verification. @@ -29,4 +26,5 @@ harness = false [dev-dependencies] criterion = "0.5.1" +hex = "0.4.3" rand = "0.8.5" diff --git a/libcrux-sha3/benches/sha3.rs b/libcrux-sha3/benches/sha3.rs index 6ff6628e7..93e427551 100644 --- a/libcrux-sha3/benches/sha3.rs +++ b/libcrux-sha3/benches/sha3.rs @@ -19,10 +19,10 @@ pub fn fmt(x: usize) -> String { } macro_rules! impl_comp { - ($fun:ident, $libcrux:expr, $rust_crypto:ty, $openssl:expr) => { + ($fun:ident, $libcrux:expr, $neon_fun:ident) => { // Comparing libcrux performance for different payload sizes and other implementations. fn $fun(c: &mut Criterion) { - const PAYLOAD_SIZES: [usize; 1] = [1024 * 1024 * 10]; + const PAYLOAD_SIZES: [usize; 3] = [128, 1024, 1024 * 1024 * 10]; let mut group = c.benchmark_group(stringify!($fun).replace("_", " ")); @@ -43,93 +43,30 @@ macro_rules! impl_comp { }, ); - // group.bench_with_input( - // BenchmarkId::new("RustCrypto", fmt(*payload_size)), - // payload_size, - // |b, payload_size| { - // use sha3::Digest; - - // b.iter_batched( - // || randombytes(*payload_size), - // |payload| { - // let mut hasher = <$rust_crypto>::new(); - // hasher.update(&payload); - // let _result = hasher.finalize(); - // }, - // BatchSize::SmallInput, - // ) - // }, - // ); - - // #[cfg(all(not(windows), not(target_arch = "wasm32"), not(target_arch = "x86")))] - // group.bench_with_input( - // BenchmarkId::new("OpenSSL", fmt(*payload_size)), - // payload_size, - // |b, payload_size| { - // use openssl::hash::*; - - // b.iter_batched( - // || randombytes(*payload_size), - // |payload| { - // let _result = hash($openssl, &payload); - // }, - // BatchSize::SmallInput, - // ) - // }, - // ); - - // #[cfg(not(target_arch = "wasm32"))] - // if stringify!($fun) != "Sha3_224" { - // group.bench_with_input( - // BenchmarkId::new("PQClean", fmt(*payload_size)), - // payload_size, - // |b, payload_size| { - // b.iter_batched( - // || randombytes(*payload_size), - // |payload| { - // let mut digest = [0; libcrux::digest::digest_size($libcrux)]; - // unsafe { - // $pqclean( - // digest.as_mut_ptr(), - // payload.as_ptr() as _, - // payload.len(), - // ) - // }; - // }, - // BatchSize::SmallInput, - // ) - // }, - // ); - // } + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + group.bench_with_input( + BenchmarkId::new("rust version (simd128)", fmt(*payload_size)), + payload_size, + |b, payload_size| { + b.iter_batched( + || randombytes(*payload_size), + |payload| { + let mut digest = [0u8; digest_size($libcrux)]; + neon::$neon_fun(&mut digest, &payload); + }, + BatchSize::SmallInput, + ) + }, + ); } } }; } -impl_comp!( - Sha3_224, - Algorithm::Sha3_224, - sha3::Sha3_224, - MessageDigest::sha3_224() // libcrux_pqclean::sha3_256 // This is wrong, but it's not actually used. -); -impl_comp!( - Sha3_256, - Algorithm::Sha3_256, - sha3::Sha3_256, - MessageDigest::sha3_256() // libcrux_pqclean::sha3_256 -); -impl_comp!( - Sha3_384, - Algorithm::Sha3_384, - sha3::Sha3_384, - MessageDigest::sha3_384() // libcrux_pqclean::sha3_384 -); -impl_comp!( - Sha3_512, - Algorithm::Sha3_512, - sha3::Sha3_512, - MessageDigest::sha3_512() // libcrux_pqclean::sha3_512 -); +impl_comp!(Sha3_224, Algorithm::Sha224, sha224); +impl_comp!(Sha3_256, Algorithm::Sha256, sha256); +impl_comp!(Sha3_384, Algorithm::Sha384, sha384); +impl_comp!(Sha3_512, Algorithm::Sha512, sha512); fn benchmarks(c: &mut Criterion) { Sha3_224(c); diff --git a/libcrux-sha3/build.rs b/libcrux-sha3/build.rs index a5dce81b2..ef1138666 100644 --- a/libcrux-sha3/build.rs +++ b/libcrux-sha3/build.rs @@ -17,10 +17,12 @@ fn main() { println!("cargo:rustc-cfg=feature=\"simd128\""); } if target_arch == "x86_64" && !disable_simd256 { - // We enable simd256 on all x86 and x86_64 builds. + // We enable simd256 on all x86_64 builds. // Note that this doesn't mean the required CPU features are available. // But the compiler will support them and the runtime checks ensure that // it's only used when available. + // + // We don't enable this on x86 because it seems to generate invalid code. println!("cargo:rustc-cfg=feature=\"simd256\""); } } diff --git a/libcrux-sha3/src/generic_keccak.rs b/libcrux-sha3/src/generic_keccak.rs new file mode 100644 index 000000000..e5f6ea9f1 --- /dev/null +++ b/libcrux-sha3/src/generic_keccak.rs @@ -0,0 +1,272 @@ +//! The generic SHA3 implementation that uses portable or platform specific +//! sub-routines. + +use core::ops::Index; + +use crate::traits::*; + +#[cfg_attr(hax, hax_lib::opaque_type)] +#[allow(private_bounds)] // TODO: figure out visibility +#[derive(Clone, Copy)] +pub struct KeccakState> { + pub st: [[T; 5]; 5], +} + +impl> Index for KeccakState { + type Output = [T; 5]; + + fn index(&self, index: usize) -> &Self::Output { + &self.st[index] + } +} + +#[allow(private_bounds)] // TODO: figure out visibility +impl> KeccakState { + /// Create a new Shake128 x4 state. + #[inline(always)] + pub(crate) fn new() -> Self { + Self { + st: [[T::zero(); 5]; 5], + } + } +} + +/// From here, everything is generic +/// +const _ROTC: [usize; 24] = [ + 1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43, 25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14, +]; + +#[inline(always)] +pub(crate) fn theta_rho>(s: &mut KeccakState) { + let c: [T; 5] = core::array::from_fn(|j| { + T::xor5(s.st[0][j], s.st[1][j], s.st[2][j], s.st[3][j], s.st[4][j]) + }); + let t: [T; 5] = + core::array::from_fn(|j| T::rotate_left1_and_xor(c[(j + 4) % 5], c[(j + 1) % 5])); + + s.st[0][0] = T::xor(s.st[0][0], t[0]); + s.st[1][0] = T::xor_and_rotate::<36, 28>(s.st[1][0], t[0]); + s.st[2][0] = T::xor_and_rotate::<3, 61>(s.st[2][0], t[0]); + s.st[3][0] = T::xor_and_rotate::<41, 23>(s.st[3][0], t[0]); + s.st[4][0] = T::xor_and_rotate::<18, 46>(s.st[4][0], t[0]); + + s.st[0][1] = T::xor_and_rotate::<1, 63>(s.st[0][1], t[1]); + s.st[1][1] = T::xor_and_rotate::<44, 20>(s.st[1][1], t[1]); + s.st[2][1] = T::xor_and_rotate::<10, 54>(s.st[2][1], t[1]); + s.st[3][1] = T::xor_and_rotate::<45, 19>(s.st[3][1], t[1]); + s.st[4][1] = T::xor_and_rotate::<2, 62>(s.st[4][1], t[1]); + + s.st[0][2] = T::xor_and_rotate::<62, 2>(s.st[0][2], t[2]); + s.st[1][2] = T::xor_and_rotate::<6, 58>(s.st[1][2], t[2]); + s.st[2][2] = T::xor_and_rotate::<43, 21>(s.st[2][2], t[2]); + s.st[3][2] = T::xor_and_rotate::<15, 49>(s.st[3][2], t[2]); + s.st[4][2] = T::xor_and_rotate::<61, 3>(s.st[4][2], t[2]); + + s.st[0][3] = T::xor_and_rotate::<28, 36>(s.st[0][3], t[3]); + s.st[1][3] = T::xor_and_rotate::<55, 9>(s.st[1][3], t[3]); + s.st[2][3] = T::xor_and_rotate::<25, 39>(s.st[2][3], t[3]); + s.st[3][3] = T::xor_and_rotate::<21, 43>(s.st[3][3], t[3]); + s.st[4][3] = T::xor_and_rotate::<56, 8>(s.st[4][3], t[3]); + + s.st[0][4] = T::xor_and_rotate::<27, 37>(s.st[0][4], t[4]); + s.st[1][4] = T::xor_and_rotate::<20, 44>(s.st[1][4], t[4]); + s.st[2][4] = T::xor_and_rotate::<39, 25>(s.st[2][4], t[4]); + s.st[3][4] = T::xor_and_rotate::<8, 56>(s.st[3][4], t[4]); + s.st[4][4] = T::xor_and_rotate::<14, 50>(s.st[4][4], t[4]); +} + +const _PI: [usize; 24] = [ + 6, 12, 18, 24, 3, 9, 10, 16, 22, 1, 7, 13, 19, 20, 4, 5, 11, 17, 23, 2, 8, 14, 15, 21, +]; + +#[inline(always)] +pub(crate) fn pi>(s: &mut KeccakState) { + let old = s.st.clone(); + s.st[0][1] = old[1][1]; + s.st[0][2] = old[2][2]; + s.st[0][3] = old[3][3]; + s.st[0][4] = old[4][4]; + s.st[1][0] = old[0][3]; + s.st[1][1] = old[1][4]; + s.st[1][2] = old[2][0]; + s.st[1][3] = old[3][1]; + s.st[1][4] = old[4][2]; + s.st[2][0] = old[0][1]; + s.st[2][1] = old[1][2]; + s.st[2][2] = old[2][3]; + s.st[2][3] = old[3][4]; + s.st[2][4] = old[4][0]; + s.st[3][0] = old[0][4]; + s.st[3][1] = old[1][0]; + s.st[3][2] = old[2][1]; + s.st[3][3] = old[3][2]; + s.st[3][4] = old[4][3]; + s.st[4][0] = old[0][2]; + s.st[4][1] = old[1][3]; + s.st[4][2] = old[2][4]; + s.st[4][3] = old[3][0]; + s.st[4][4] = old[4][1]; +} + +#[inline(always)] +pub(crate) fn chi>(s: &mut KeccakState) { + let old = s.st; + for i in 0..5 { + for j in 0..5 { + s.st[i][j] = T::and_not_xor(s.st[i][j], old[i][(j + 2) % 5], old[i][(j + 1) % 5]); + } + } +} + +const ROUNDCONSTANTS: [u64; 24] = [ + 0x0000_0000_0000_0001u64, + 0x0000_0000_0000_8082u64, + 0x8000_0000_0000_808au64, + 0x8000_0000_8000_8000u64, + 0x0000_0000_0000_808bu64, + 0x0000_0000_8000_0001u64, + 0x8000_0000_8000_8081u64, + 0x8000_0000_0000_8009u64, + 0x0000_0000_0000_008au64, + 0x0000_0000_0000_0088u64, + 0x0000_0000_8000_8009u64, + 0x0000_0000_8000_000au64, + 0x0000_0000_8000_808bu64, + 0x8000_0000_0000_008bu64, + 0x8000_0000_0000_8089u64, + 0x8000_0000_0000_8003u64, + 0x8000_0000_0000_8002u64, + 0x8000_0000_0000_0080u64, + 0x0000_0000_0000_800au64, + 0x8000_0000_8000_000au64, + 0x8000_0000_8000_8081u64, + 0x8000_0000_0000_8080u64, + 0x0000_0000_8000_0001u64, + 0x8000_0000_8000_8008u64, +]; + +#[inline(always)] +pub(crate) fn iota>(s: &mut KeccakState, i: usize) { + s.st[0][0] = T::xor_constant(s.st[0][0], ROUNDCONSTANTS[i]); +} + +#[inline(always)] +pub(crate) fn keccakf1600>(s: &mut KeccakState) { + for i in 0..24 { + theta_rho(s); + pi(s); + chi(s); + iota(s, i); + } +} + +#[inline(always)] +pub(crate) fn absorb_block, const RATE: usize>( + s: &mut KeccakState, + blocks: [&[u8]; N], +) { + T::load_block::(&mut s.st, blocks); + keccakf1600(s) +} + +#[inline(always)] +pub(crate) fn absorb_final, const RATE: usize, const DELIM: u8>( + s: &mut KeccakState, + last: [&[u8]; N], +) { + debug_assert!(N > 0 && last[0].len() < RATE); + let last_len = last[0].len(); + let mut blocks = [[0u8; 200]; N]; + for i in 0..N { + blocks[i][0..last_len].copy_from_slice(&last[i]); + blocks[i][last_len] = DELIM; + blocks[i][RATE - 1] = blocks[i][RATE - 1] | 128u8; + } + T::load_block_full::(&mut s.st, blocks); + keccakf1600(s) +} + +#[inline(always)] +pub(crate) fn squeeze_first_block, const RATE: usize>( + s: &KeccakState, + out: [&mut [u8]; N], +) { + T::store_block::(&s.st, out) +} + +#[inline(always)] +pub(crate) fn squeeze_next_block, const RATE: usize>( + s: &mut KeccakState, + out: [&mut [u8]; N], +) { + keccakf1600(s); + T::store_block::(&s.st, out) +} + +#[inline(always)] +pub(crate) fn squeeze_first_three_blocks, const RATE: usize>( + s: &mut KeccakState, + out: [&mut [u8]; N], +) { + let (o0, o1) = T::split_at_mut_n(out, RATE); + squeeze_first_block::(s, o0); + let (o1, o2) = T::split_at_mut_n(o1, RATE); + squeeze_next_block::(s, o1); + squeeze_next_block::(s, o2); +} + +#[inline(always)] +pub(crate) fn squeeze_last, const RATE: usize>( + mut s: KeccakState, + out: [&mut [u8]; N], +) { + keccakf1600(&mut s); + let b = T::store_block_full::(&s.st); + for i in 0..N { + out[i].copy_from_slice(&b[i][0..out[i].len()]); + } +} + +#[inline(always)] +pub(crate) fn squeeze_first_and_last, const RATE: usize>( + s: &KeccakState, + out: [&mut [u8]; N], +) { + let b = T::store_block_full::(&s.st); + for i in 0..N { + out[i].copy_from_slice(&b[i][0..out[i].len()]); + } +} + +#[inline(always)] +pub(crate) fn keccak, const RATE: usize, const DELIM: u8>( + data: [&[u8]; N], + out: [&mut [u8]; N], +) { + let mut s = KeccakState::::new(); + for i in 0..data[0].len() / RATE { + absorb_block::(&mut s, T::slice_n(data, i * RATE, RATE)); + } + let rem = data[0].len() % RATE; + absorb_final::(&mut s, T::slice_n(data, data[0].len() - rem, rem)); + + let outlen = out[0].len(); + let blocks = outlen / RATE; + let last = outlen - (outlen % RATE); + + if blocks == 0 { + squeeze_first_and_last::(&s, out) + } else { + let (o0, mut o1) = T::split_at_mut_n(out, RATE); + squeeze_first_block::(&s, o0); + for _i in 1..blocks { + let (o, orest) = T::split_at_mut_n(o1, RATE); + squeeze_next_block::(&mut s, o); + o1 = orest; + } + if last < outlen { + squeeze_last::(s, o1) + } + } +} diff --git a/libcrux-sha3/src/hacl.rs b/libcrux-sha3/src/hacl.rs deleted file mode 100644 index 7fa89d0d1..000000000 --- a/libcrux-sha3/src/hacl.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod streaming_types; - -pub(crate) mod hash_sha3; diff --git a/libcrux-sha3/src/hacl/hash_sha3.rs b/libcrux-sha3/src/hacl/hash_sha3.rs deleted file mode 100644 index c54d0458b..000000000 --- a/libcrux-sha3/src/hacl/hash_sha3.rs +++ /dev/null @@ -1,644 +0,0 @@ -#![allow(non_snake_case)] -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(unused_assignments)] -#![allow(unused_mut)] -#![allow(unreachable_patterns)] -#![allow(const_item_mutation)] - -#[inline(always)] -fn block_len(a: crate::hacl::streaming_types::hash_alg) -> u32 { - match a { - crate::hacl::streaming_types::hash_alg::SHA3_224 => 144u32, - crate::hacl::streaming_types::hash_alg::SHA3_256 => 136u32, - crate::hacl::streaming_types::hash_alg::SHA3_384 => 104u32, - crate::hacl::streaming_types::hash_alg::SHA3_512 => 72u32, - crate::hacl::streaming_types::hash_alg::Shake128 => 168u32, - crate::hacl::streaming_types::hash_alg::Shake256 => 136u32, - _ => panic!("Precondition of the function most likely violated"), - } -} - -#[inline(always)] -fn hash_len(a: crate::hacl::streaming_types::hash_alg) -> u32 { - match a { - crate::hacl::streaming_types::hash_alg::SHA3_224 => 28u32, - crate::hacl::streaming_types::hash_alg::SHA3_256 => 32u32, - crate::hacl::streaming_types::hash_alg::SHA3_384 => 48u32, - crate::hacl::streaming_types::hash_alg::SHA3_512 => 64u32, - _ => panic!("Precondition of the function most likely violated"), - } -} - -#[inline(always)] -pub fn update_multi_sha3( - a: crate::hacl::streaming_types::hash_alg, - s: &mut [u64], - blocks: &mut [u8], - n_blocks: u32, -) -> () { - for i in 0u32..n_blocks { - let block: (&mut [u8], &mut [u8]) = - blocks.split_at_mut(i.wrapping_mul(block_len(a)) as usize); - absorb_inner(block_len(a), block.1, s) - } -} - -#[inline(always)] -pub fn update_last_sha3( - a: crate::hacl::streaming_types::hash_alg, - s: &mut [u64], - input: &mut [u8], - input_len: u32, -) -> () { - let suffix: u8 = if a == crate::hacl::streaming_types::hash_alg::Shake128 - || a == crate::hacl::streaming_types::hash_alg::Shake256 - { - 0x1fu8 - } else { - 0x06u8 - }; - let len: u32 = block_len(a); - if input_len == len { - absorb_inner(len, input, s); - let mut lastBlock_: [u8; 200] = [0u8; 200usize]; - let lastBlock: (&mut [u8], &mut [u8]) = (&mut lastBlock_).split_at_mut(0usize); - (lastBlock.1[0usize..0usize]) - .copy_from_slice(&(&mut input[input_len as usize..])[0usize..0usize]); - lastBlock.1[0usize] = suffix; - loadState(len, lastBlock.1, s); - if !(suffix & 0x80u8 == 0u8) && 0u32 == len.wrapping_sub(1u32) { - state_permute(s) - }; - let mut nextBlock_: [u8; 200] = [0u8; 200usize]; - let nextBlock: (&mut [u8], &mut [u8]) = (&mut nextBlock_).split_at_mut(0usize); - nextBlock.1[len.wrapping_sub(1u32) as usize] = 0x80u8; - loadState(len, nextBlock.1, s); - state_permute(s) - } else { - let mut lastBlock_: [u8; 200] = [0u8; 200usize]; - let lastBlock: (&mut [u8], &mut [u8]) = (&mut lastBlock_).split_at_mut(0usize); - (lastBlock.1[0usize..input_len as usize]) - .copy_from_slice(&input[0usize..input_len as usize]); - lastBlock.1[input_len as usize] = suffix; - loadState(len, lastBlock.1, s); - if !(suffix & 0x80u8 == 0u8) && input_len == len.wrapping_sub(1u32) { - state_permute(s) - }; - let mut nextBlock_: [u8; 200] = [0u8; 200usize]; - let nextBlock: (&mut [u8], &mut [u8]) = (&mut nextBlock_).split_at_mut(0usize); - nextBlock.1[len.wrapping_sub(1u32) as usize] = 0x80u8; - loadState(len, nextBlock.1, s); - state_permute(s) - } -} - -pub struct hash_buf { - pub fst: crate::hacl::streaming_types::hash_alg, - pub snd: Vec, -} - -pub struct state_t { - pub block_state: hash_buf, - pub buf: Vec, - pub total_len: u64, -} - -#[inline(always)] -pub fn get_alg(s: &mut [state_t]) -> crate::hacl::streaming_types::hash_alg { - let mut block_state: &mut hash_buf = &mut (s[0usize]).block_state; - (*block_state).fst -} - -#[inline(always)] -pub fn malloc(a: crate::hacl::streaming_types::hash_alg) -> Vec { - let mut buf: Vec = vec![0u8; block_len(a) as usize]; - let mut buf0: Vec = vec![0u64; 25usize]; - let mut block_state: hash_buf = hash_buf { fst: a, snd: buf0 }; - let s: &mut [u64] = &mut block_state.snd; - (s[0usize..25usize]).copy_from_slice(&[0u64; 25usize]); - let mut s0: state_t = state_t { - block_state: block_state, - buf: buf, - total_len: 0u32 as u64, - }; - let mut p: Vec = { - let mut tmp: Vec = Vec::new(); - tmp.push(s0); - tmp - }; - p -} - -#[inline(always)] -pub fn copy(state: &mut [state_t]) -> Vec { - let mut block_state0: &mut hash_buf = &mut (state[0usize]).block_state; - let buf0: &mut [u8] = &mut (state[0usize]).buf; - let total_len0: u64 = (state[0usize]).total_len; - let i: crate::hacl::streaming_types::hash_alg = (*block_state0).fst; - let mut buf: Vec = vec![0u8; block_len(i) as usize]; - ((&mut buf)[0usize..block_len(i) as usize]) - .copy_from_slice(&buf0[0usize..block_len(i) as usize]); - let mut buf1: Vec = vec![0u64; 25usize]; - let mut block_state: hash_buf = hash_buf { fst: i, snd: buf1 }; - let s_src: &mut [u64] = &mut (*block_state0).snd; - let s_dst: &mut [u64] = &mut block_state.snd; - (s_dst[0usize..25usize]).copy_from_slice(&s_src[0usize..25usize]); - let mut s: state_t = state_t { - block_state: block_state, - buf: buf, - total_len: total_len0, - }; - let mut p: Vec = { - let mut tmp: Vec = Vec::new(); - tmp.push(s); - tmp - }; - p -} - -#[inline(always)] -pub fn reset(state: &mut [state_t]) -> () { - let mut block_state: &mut hash_buf = &mut (state[0usize]).block_state; - let i: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - crate::lowstar::ignore::ignore::(i); - let s: &mut [u64] = &mut (*block_state).snd; - (s[0usize..25usize]).copy_from_slice(&[0u64; 25usize]); - (state[0usize]).total_len = 0u32 as u64 -} - -#[inline(always)] -pub fn update( - state: &mut [state_t], - chunk: &mut [u8], - chunk_len: u32, -) -> crate::hacl::streaming_types::error_code { - let mut block_state: &mut hash_buf = &mut (state[0usize]).block_state; - let total_len: u64 = (state[0usize]).total_len; - let i: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - if chunk_len as u64 > 0xFFFFFFFFFFFFFFFFu64.wrapping_sub(total_len) { - crate::hacl::streaming_types::error_code::MaximumLengthExceeded - } else { - let sz: u32 = if total_len.wrapping_rem(block_len(i) as u64) == 0u64 && total_len > 0u64 { - block_len(i) - } else { - total_len.wrapping_rem(block_len(i) as u64) as u32 - }; - if chunk_len <= (block_len(i)).wrapping_sub(sz) { - let buf: &mut [u8] = &mut (state[0usize]).buf; - let total_len1: u64 = (state[0usize]).total_len; - let sz1: u32 = - if total_len1.wrapping_rem(block_len(i) as u64) == 0u64 && total_len1 > 0u64 { - block_len(i) - } else { - total_len1.wrapping_rem(block_len(i) as u64) as u32 - }; - let buf2: (&mut [u8], &mut [u8]) = buf.split_at_mut(sz1 as usize); - (buf2.1[0usize..chunk_len as usize]) - .copy_from_slice(&chunk[0usize..chunk_len as usize]); - let total_len2: u64 = total_len1.wrapping_add(chunk_len as u64); - (state[0usize]).total_len = total_len2 - } else if sz == 0u32 { - let buf: &mut [u8] = &mut (state[0usize]).buf; - let total_len1: u64 = (state[0usize]).total_len; - let sz1: u32 = - if total_len1.wrapping_rem(block_len(i) as u64) == 0u64 && total_len1 > 0u64 { - block_len(i) - } else { - total_len1.wrapping_rem(block_len(i) as u64) as u32 - }; - if !(sz1 == 0u32) { - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, buf, (block_len(i)).wrapping_div(block_len(a1))) - }; - let ite: u32 = if (chunk_len as u64).wrapping_rem(block_len(i) as u64) == 0u64 - && chunk_len as u64 > 0u64 - { - block_len(i) - } else { - (chunk_len as u64).wrapping_rem(block_len(i) as u64) as u32 - }; - let n_blocks: u32 = chunk_len.wrapping_sub(ite).wrapping_div(block_len(i)); - let data1_len: u32 = n_blocks.wrapping_mul(block_len(i)); - let data2_len: u32 = chunk_len.wrapping_sub(data1_len); - let data1: (&mut [u8], &mut [u8]) = chunk.split_at_mut(0usize); - let data2: (&mut [u8], &mut [u8]) = data1.1.split_at_mut(data1_len as usize); - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, data2.0, data1_len.wrapping_div(block_len(a1))); - let dst: (&mut [u8], &mut [u8]) = buf.split_at_mut(0usize); - (dst.1[0usize..data2_len as usize]) - .copy_from_slice(&data2.1[0usize..data2_len as usize]); - (state[0usize]).total_len = total_len1.wrapping_add(chunk_len as u64) - } else { - let diff: u32 = (block_len(i)).wrapping_sub(sz); - let chunk1: (&mut [u8], &mut [u8]) = chunk.split_at_mut(0usize); - let chunk2: (&mut [u8], &mut [u8]) = chunk1.1.split_at_mut(diff as usize); - let buf: &mut [u8] = &mut (state[0usize]).buf; - let total_len1: u64 = (state[0usize]).total_len; - let sz1: u32 = - if total_len1.wrapping_rem(block_len(i) as u64) == 0u64 && total_len1 > 0u64 { - block_len(i) - } else { - total_len1.wrapping_rem(block_len(i) as u64) as u32 - }; - let buf2: (&mut [u8], &mut [u8]) = buf.split_at_mut(sz1 as usize); - (buf2.1[0usize..diff as usize]).copy_from_slice(&chunk2.0[0usize..diff as usize]); - let total_len2: u64 = total_len1.wrapping_add(diff as u64); - (state[0usize]).total_len = total_len2; - let buf0: &mut [u8] = &mut (state[0usize]).buf; - let total_len10: u64 = (state[0usize]).total_len; - let sz10: u32 = - if total_len10.wrapping_rem(block_len(i) as u64) == 0u64 && total_len10 > 0u64 { - block_len(i) - } else { - total_len10.wrapping_rem(block_len(i) as u64) as u32 - }; - if !(sz10 == 0u32) { - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, buf0, (block_len(i)).wrapping_div(block_len(a1))) - }; - let ite: u32 = - if (chunk_len.wrapping_sub(diff) as u64).wrapping_rem(block_len(i) as u64) == 0u64 - && chunk_len.wrapping_sub(diff) as u64 > 0u64 - { - block_len(i) - } else { - (chunk_len.wrapping_sub(diff) as u64).wrapping_rem(block_len(i) as u64) as u32 - }; - let n_blocks: u32 = chunk_len - .wrapping_sub(diff) - .wrapping_sub(ite) - .wrapping_div(block_len(i)); - let data1_len: u32 = n_blocks.wrapping_mul(block_len(i)); - let data2_len: u32 = chunk_len.wrapping_sub(diff).wrapping_sub(data1_len); - let data1: (&mut [u8], &mut [u8]) = chunk2.1.split_at_mut(0usize); - let data2: (&mut [u8], &mut [u8]) = data1.1.split_at_mut(data1_len as usize); - let a1: crate::hacl::streaming_types::hash_alg = (*block_state).fst; - let s1: &mut [u64] = &mut (*block_state).snd; - update_multi_sha3(a1, s1, data2.0, data1_len.wrapping_div(block_len(a1))); - let dst: (&mut [u8], &mut [u8]) = buf0.split_at_mut(0usize); - (dst.1[0usize..data2_len as usize]) - .copy_from_slice(&data2.1[0usize..data2_len as usize]); - (state[0usize]).total_len = - total_len10.wrapping_add(chunk_len.wrapping_sub(diff) as u64) - }; - crate::hacl::streaming_types::error_code::Success - } -} - -#[inline(always)] -fn digest_( - a: crate::hacl::streaming_types::hash_alg, - state: &mut [state_t], - output: &mut [u8], - l: u32, -) -> () { - let mut block_state: &mut hash_buf = &mut (state[0usize]).block_state; - let buf_: &mut [u8] = &mut (state[0usize]).buf; - let total_len: u64 = (state[0usize]).total_len; - let r: u32 = if total_len.wrapping_rem(block_len(a) as u64) == 0u64 && total_len > 0u64 { - block_len(a) - } else { - total_len.wrapping_rem(block_len(a) as u64) as u32 - }; - let buf_1: (&mut [u8], &mut [u8]) = buf_.split_at_mut(0usize); - let mut buf: [u64; 25] = [0u64; 25usize]; - let mut tmp_block_state: hash_buf = hash_buf { - fst: a, - snd: Vec::from(buf), - }; - let s_src: &mut [u64] = &mut (*block_state).snd; - let s_dst: &mut [u64] = &mut tmp_block_state.snd; - (s_dst[0usize..25usize]).copy_from_slice(&s_src[0usize..25usize]); - let buf_multi: (&mut [u8], &mut [u8]) = buf_1.1.split_at_mut(0usize); - let ite: u32 = if r.wrapping_rem(block_len(a)) == 0u32 && r > 0u32 { - block_len(a) - } else { - r.wrapping_rem(block_len(a)) - }; - let buf_last: (&mut [u8], &mut [u8]) = buf_multi.1.split_at_mut(r.wrapping_sub(ite) as usize); - let a1: crate::hacl::streaming_types::hash_alg = tmp_block_state.fst; - let s: &mut [u64] = &mut tmp_block_state.snd; - update_multi_sha3(a1, s, buf_last.0, 0u32.wrapping_div(block_len(a1))); - let a10: crate::hacl::streaming_types::hash_alg = tmp_block_state.fst; - let s0: &mut [u64] = &mut tmp_block_state.snd; - update_last_sha3(a10, s0, buf_last.1, r); - let a11: crate::hacl::streaming_types::hash_alg = tmp_block_state.fst; - let s1: &mut [u64] = &mut tmp_block_state.snd; - if a11 == crate::hacl::streaming_types::hash_alg::Shake128 - || a11 == crate::hacl::streaming_types::hash_alg::Shake256 - { - squeeze0(s1, block_len(a11), l, output) - } else { - squeeze0(s1, block_len(a11), hash_len(a11), output) - } -} - -#[inline(always)] -pub fn digest( - state: &mut [state_t], - output: &mut [u8], -) -> crate::hacl::streaming_types::error_code { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(state); - if a1 == crate::hacl::streaming_types::hash_alg::Shake128 - || a1 == crate::hacl::streaming_types::hash_alg::Shake256 - { - crate::hacl::streaming_types::error_code::InvalidAlgorithm - } else { - digest_(a1, state, output, hash_len(a1)); - crate::hacl::streaming_types::error_code::Success - } -} - -#[inline(always)] -pub fn squeeze( - s: &mut [state_t], - dst: &mut [u8], - l: u32, -) -> crate::hacl::streaming_types::error_code { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(s); - if !(a1 == crate::hacl::streaming_types::hash_alg::Shake128 - || a1 == crate::hacl::streaming_types::hash_alg::Shake256) - { - crate::hacl::streaming_types::error_code::InvalidAlgorithm - } else if l == 0u32 { - crate::hacl::streaming_types::error_code::InvalidLength - } else { - digest_(a1, s, dst, l); - crate::hacl::streaming_types::error_code::Success - } -} - -#[inline(always)] -pub fn block_len0(s: &mut [state_t]) -> u32 { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(s); - block_len(a1) -} - -#[inline(always)] -pub fn hash_len0(s: &mut [state_t]) -> u32 { - let a1: crate::hacl::streaming_types::hash_alg = get_alg(s); - hash_len(a1) -} - -#[inline(always)] -pub fn is_shake(s: &mut [state_t]) -> bool { - let uu____0: crate::hacl::streaming_types::hash_alg = get_alg(s); - uu____0 == crate::hacl::streaming_types::hash_alg::Shake128 - || uu____0 == crate::hacl::streaming_types::hash_alg::Shake256 -} - -#[inline(always)] -pub fn shake128_hacl( - inputByteLen: u32, - input: &mut [u8], - outputByteLen: u32, - output: &mut [u8], -) -> () { - keccak( - 1344u32, - 256u32, - inputByteLen, - input, - 0x1Fu8, - outputByteLen, - output, - ) -} - -#[inline(always)] -pub fn shake256_hacl( - inputByteLen: u32, - input: &mut [u8], - outputByteLen: u32, - output: &mut [u8], -) -> () { - keccak( - 1088u32, - 512u32, - inputByteLen, - input, - 0x1Fu8, - outputByteLen, - output, - ) -} - -#[inline(always)] -pub fn sha3_224(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(1152u32, 448u32, input_len, input, 0x06u8, 28u32, output) -} - -#[inline(always)] -pub fn sha3_256(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(1088u32, 512u32, input_len, input, 0x06u8, 32u32, output) -} - -#[inline(always)] -pub fn sha3_384(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(832u32, 768u32, input_len, input, 0x06u8, 48u32, output) -} - -#[inline(always)] -pub fn sha3_512(output: &mut [u8], input: &mut [u8], input_len: u32) -> () { - keccak(576u32, 1024u32, input_len, input, 0x06u8, 64u32, output) -} - -const keccak_rotc: [u32; 24] = [ - 1u32, 3u32, 6u32, 10u32, 15u32, 21u32, 28u32, 36u32, 45u32, 55u32, 2u32, 14u32, 27u32, 41u32, - 56u32, 8u32, 25u32, 43u32, 62u32, 18u32, 39u32, 61u32, 20u32, 44u32, -]; - -const keccak_piln: [u32; 24] = [ - 10u32, 7u32, 11u32, 17u32, 18u32, 3u32, 5u32, 16u32, 8u32, 21u32, 24u32, 4u32, 15u32, 23u32, - 19u32, 13u32, 12u32, 2u32, 20u32, 14u32, 22u32, 9u32, 6u32, 1u32, -]; - -const keccak_rndc: [u64; 24] = [ - 0x0000000000000001u64, - 0x0000000000008082u64, - 0x800000000000808au64, - 0x8000000080008000u64, - 0x000000000000808bu64, - 0x0000000080000001u64, - 0x8000000080008081u64, - 0x8000000000008009u64, - 0x000000000000008au64, - 0x0000000000000088u64, - 0x0000000080008009u64, - 0x000000008000000au64, - 0x000000008000808bu64, - 0x800000000000008bu64, - 0x8000000000008089u64, - 0x8000000000008003u64, - 0x8000000000008002u64, - 0x8000000000000080u64, - 0x000000000000800au64, - 0x800000008000000au64, - 0x8000000080008081u64, - 0x8000000000008080u64, - 0x0000000080000001u64, - 0x8000000080008008u64, -]; - -#[inline(always)] -pub fn state_permute(s: &mut [u64]) -> () { - for i in 0u32..24u32 { - let mut _C: [u64; 5] = [0u64; 5usize]; - for i0 in 0u32..5u32 { - (&mut _C)[i0 as usize] = s[i0.wrapping_add(0u32) as usize] - ^ (s[i0.wrapping_add(5u32) as usize] - ^ (s[i0.wrapping_add(10u32) as usize] - ^ (s[i0.wrapping_add(15u32) as usize] - ^ s[i0.wrapping_add(20u32) as usize]))) - } - for i0 in 0u32..5u32 { - let uu____0: u64 = (&mut _C)[i0.wrapping_add(1u32).wrapping_rem(5u32) as usize]; - let _D: u64 = (&mut _C)[i0.wrapping_add(4u32).wrapping_rem(5u32) as usize] - ^ (uu____0.wrapping_shl(1u32) | uu____0.wrapping_shr(63u32)); - for i1 in 0u32..5u32 { - s[i0.wrapping_add(5u32.wrapping_mul(i1)) as usize] = - s[i0.wrapping_add(5u32.wrapping_mul(i1)) as usize] ^ _D - } - } - let x: u64 = s[1usize]; - let mut current: [u64; 1] = [x; 1usize]; - for i0 in 0u32..24u32 { - let _Y: u32 = (&keccak_piln)[i0 as usize]; - let r: u32 = (&keccak_rotc)[i0 as usize]; - let temp: u64 = s[_Y as usize]; - let uu____1: u64 = (&mut current)[0usize]; - s[_Y as usize] = uu____1.wrapping_shl(r) | uu____1.wrapping_shr(64u32.wrapping_sub(r)); - (&mut current)[0usize] = temp - } - for i0 in 0u32..5u32 { - let v0: u64 = s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v1: u64 = s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v2: u64 = s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v3: u64 = s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - let v4: u64 = s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - ^ !s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] - & s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize]; - s[0u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v0; - s[1u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v1; - s[2u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v2; - s[3u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v3; - s[4u32.wrapping_add(5u32.wrapping_mul(i0)) as usize] = v4 - } - let c: u64 = (&keccak_rndc)[i as usize]; - s[0usize] = s[0usize] ^ c - } -} - -#[inline(always)] -pub fn loadState(rateInBytes: u32, input: &mut [u8], s: &mut [u64]) -> () { - let mut block: [u8; 200] = [0u8; 200usize]; - ((&mut block)[0usize..rateInBytes as usize]) - .copy_from_slice(&input[0usize..rateInBytes as usize]); - for i in 0u32..25u32 { - let u: u64 = crate::lowstar::endianness::load64_le( - &mut (&mut block)[i.wrapping_mul(8u32) as usize..], - ); - let x: u64 = u; - s[i as usize] = s[i as usize] ^ x - } -} - -#[inline(always)] -fn storeState(rateInBytes: u32, s: &mut [u64], res: &mut [u8]) -> () { - let mut block: [u8; 200] = [0u8; 200usize]; - for i in 0u32..25u32 { - let sj: u64 = s[i as usize]; - crate::lowstar::endianness::store64_le( - &mut (&mut block)[i.wrapping_mul(8u32) as usize..], - sj, - ) - } - (res[0usize..rateInBytes as usize]) - .copy_from_slice(&(&mut (&mut block)[0usize..])[0usize..rateInBytes as usize]) -} - -#[inline(always)] -pub fn absorb_inner(rateInBytes: u32, block: &mut [u8], s: &mut [u64]) -> () { - loadState(rateInBytes, block, s); - state_permute(s) -} - -#[inline(always)] -fn absorb( - s: &mut [u64], - rateInBytes: u32, - inputByteLen: u32, - input: &mut [u8], - delimitedSuffix: u8, -) -> () { - let n_blocks: u32 = inputByteLen.wrapping_div(rateInBytes); - let rem: u32 = inputByteLen.wrapping_rem(rateInBytes); - for i in 0u32..n_blocks { - let block: (&mut [u8], &mut [u8]) = - input.split_at_mut(i.wrapping_mul(rateInBytes) as usize); - absorb_inner(rateInBytes, block.1, s) - } - let last: (&mut [u8], &mut [u8]) = - input.split_at_mut(n_blocks.wrapping_mul(rateInBytes) as usize); - let mut lastBlock_: [u8; 200] = [0u8; 200usize]; - let lastBlock: (&mut [u8], &mut [u8]) = (&mut lastBlock_).split_at_mut(0usize); - (lastBlock.1[0usize..rem as usize]).copy_from_slice(&last.1[0usize..rem as usize]); - lastBlock.1[rem as usize] = delimitedSuffix; - loadState(rateInBytes, lastBlock.1, s); - if !(delimitedSuffix & 0x80u8 == 0u8) && rem == rateInBytes.wrapping_sub(1u32) { - state_permute(s) - }; - let mut nextBlock_: [u8; 200] = [0u8; 200usize]; - let nextBlock: (&mut [u8], &mut [u8]) = (&mut nextBlock_).split_at_mut(0usize); - nextBlock.1[rateInBytes.wrapping_sub(1u32) as usize] = 0x80u8; - loadState(rateInBytes, nextBlock.1, s); - state_permute(s) -} - -#[inline(always)] -pub fn squeeze0(s: &mut [u64], rateInBytes: u32, outputByteLen: u32, output: &mut [u8]) -> () { - let outBlocks: u32 = outputByteLen.wrapping_div(rateInBytes); - let remOut: u32 = outputByteLen.wrapping_rem(rateInBytes); - let blocks: (&mut [u8], &mut [u8]) = output.split_at_mut(0usize); - let last: (&mut [u8], &mut [u8]) = blocks - .1 - .split_at_mut(outputByteLen.wrapping_sub(remOut) as usize); - for i in 0u32..outBlocks { - storeState( - rateInBytes, - s, - &mut last.0[i.wrapping_mul(rateInBytes) as usize..], - ); - state_permute(s) - } - storeState(remOut, s, last.1) -} - -#[inline(always)] -pub fn keccak( - rate: u32, - capacity: u32, - inputByteLen: u32, - input: &mut [u8], - delimitedSuffix: u8, - outputByteLen: u32, - output: &mut [u8], -) -> () { - crate::lowstar::ignore::ignore::(capacity); - let rateInBytes: u32 = rate.wrapping_div(8u32); - let mut s: [u64; 25] = [0u64; 25usize]; - absorb(&mut s, rateInBytes, inputByteLen, input, delimitedSuffix); - squeeze0(&mut s, rateInBytes, outputByteLen, output) -} diff --git a/libcrux-sha3/src/hacl/streaming_types.rs b/libcrux-sha3/src/hacl/streaming_types.rs deleted file mode 100644 index 559af9cdb..000000000 --- a/libcrux-sha3/src/hacl/streaming_types.rs +++ /dev/null @@ -1,39 +0,0 @@ -#![allow(non_snake_case)] -#![allow(non_upper_case_globals)] -#![allow(non_camel_case_types)] -#![allow(unused_assignments)] -#![allow(unused_mut)] -#![allow(unreachable_patterns)] -#![allow(const_item_mutation)] - -#[derive(PartialEq, Clone, Copy)] -pub enum hash_alg -{ - SHA2_224, - SHA2_256, - SHA2_384, - SHA2_512, - SHA1, - MD5, - Blake2S, - Blake2B, - SHA3_256, - SHA3_224, - SHA3_384, - SHA3_512, - Shake128, - Shake256 -} - -#[derive(PartialEq, Clone, Copy)] -pub enum error_code -{ - Success, - InvalidAlgorithm, - InvalidLength, - MaximumLengthExceeded -} - -pub struct state_32 { pub block_state: Vec, pub buf: Vec, pub total_len: u64 } - -pub struct state_64 { pub block_state: Vec, pub buf: Vec, pub total_len: u64 } diff --git a/libcrux-sha3/src/lib.rs b/libcrux-sha3/src/lib.rs index 63c1e61c0..563dd4716 100644 --- a/libcrux-sha3/src/lib.rs +++ b/libcrux-sha3/src/lib.rs @@ -1,38 +1,44 @@ -// XXX: Can't do no_std -// #![no_std] +//! # SHA3 +//! +//! A SHA3 implementation with optional simd optimisations. -// // Low* library code -// mod lowstar; +#![no_std] -// // SHA3 plus helpers -// mod hacl; -// use crate::hacl::hash_sha3::{self, shake128_hacl, shake256_hacl}; +pub mod simd; -/// A Sha3x4 API -pub mod x4; +mod generic_keccak; +mod portable_keccak; +mod traits; +/// A SHA3 224 Digest pub type Sha3_224Digest = [u8; 28]; + +/// A SHA3 256 Digest pub type Sha3_256Digest = [u8; 32]; + +/// A SHA3 384 Digest pub type Sha3_384Digest = [u8; 48]; + +/// A SHA3 512 Digest pub type Sha3_512Digest = [u8; 64]; /// The Digest Algorithm. #[derive(Copy, Clone, Debug, PartialEq)] #[repr(u32)] pub enum Algorithm { - Sha3_224 = 1, - Sha3_256 = 2, - Sha3_384 = 3, - Sha3_512 = 4, + Sha224 = 1, + Sha256 = 2, + Sha384 = 3, + Sha512 = 4, } impl From for Algorithm { fn from(v: u32) -> Algorithm { match v { - 1 => Algorithm::Sha3_224, - 2 => Algorithm::Sha3_256, - 3 => Algorithm::Sha3_384, - 4 => Algorithm::Sha3_512, + 1 => Algorithm::Sha224, + 2 => Algorithm::Sha256, + 3 => Algorithm::Sha384, + 4 => Algorithm::Sha512, _ => panic!("Unknown Digest mode {}", v), } } @@ -41,10 +47,10 @@ impl From for Algorithm { impl From for u32 { fn from(v: Algorithm) -> u32 { match v { - Algorithm::Sha3_224 => 1, - Algorithm::Sha3_256 => 2, - Algorithm::Sha3_384 => 3, - Algorithm::Sha3_512 => 4, + Algorithm::Sha224 => 1, + Algorithm::Sha256 => 2, + Algorithm::Sha384 => 3, + Algorithm::Sha512 => 4, } } } @@ -52,10 +58,10 @@ impl From for u32 { /// Returns the output size of a digest. pub const fn digest_size(mode: Algorithm) -> usize { match mode { - Algorithm::Sha3_224 => 28, - Algorithm::Sha3_256 => 32, - Algorithm::Sha3_384 => 48, - Algorithm::Sha3_512 => 64, + Algorithm::Sha224 => 28, + Algorithm::Sha256 => 32, + Algorithm::Sha384 => 48, + Algorithm::Sha512 => 64, } } @@ -65,48 +71,40 @@ pub fn hash(algorithm: Algorithm, payload: &[u8]) -> [u8; LEN] let mut out = [0u8; LEN]; match algorithm { - Algorithm::Sha3_224 => sha224_ema(&mut out, payload), - Algorithm::Sha3_256 => sha256_ema(&mut out, payload), - Algorithm::Sha3_384 => sha384_ema(&mut out, payload), - Algorithm::Sha3_512 => sha512_ema(&mut out, payload), + Algorithm::Sha224 => portable::sha224(&mut out, payload), + Algorithm::Sha256 => portable::sha256(&mut out, payload), + Algorithm::Sha384 => portable::sha384(&mut out, payload), + Algorithm::Sha512 => portable::sha512(&mut out, payload), } out } -use libcrux_hacl::{ - Hacl_Hash_SHA3_sha3_224, Hacl_Hash_SHA3_sha3_256, Hacl_Hash_SHA3_sha3_384, - Hacl_Hash_SHA3_sha3_512, Hacl_Hash_SHA3_shake128_hacl, Hacl_Hash_SHA3_shake256_hacl, -}; - /// SHA3 224 #[inline(always)] -pub fn sha224(payload: &[u8]) -> [u8; 28] { - let mut digest = [0u8; 28]; - sha224_ema(&mut digest, payload); - digest +pub fn sha224(data: &[u8]) -> Sha3_224Digest { + let mut out = [0u8; 28]; + sha224_ema(&mut out, data); + out } /// SHA3 224 +/// +/// Preconditions: +/// - `digest.len() == 28` #[inline(always)] pub fn sha224_ema(digest: &mut [u8], payload: &[u8]) { debug_assert!(payload.len() <= u32::MAX as usize); debug_assert!(digest.len() == 28); - unsafe { - Hacl_Hash_SHA3_sha3_224( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } + portable::sha224(digest, payload) } /// SHA3 256 #[inline(always)] -pub fn sha256(payload: &[u8]) -> [u8; 32] { - let mut digest = [0u8; 32]; - sha256_ema(&mut digest, payload); - digest +pub fn sha256(data: &[u8]) -> Sha3_256Digest { + let mut out = [0u8; 32]; + sha256_ema(&mut out, data); + out } /// SHA3 256 @@ -115,21 +113,15 @@ pub fn sha256_ema(digest: &mut [u8], payload: &[u8]) { debug_assert!(payload.len() <= u32::MAX as usize); debug_assert!(digest.len() == 32); - unsafe { - Hacl_Hash_SHA3_sha3_256( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } + portable::sha256(digest, payload) } /// SHA3 384 #[inline(always)] -pub fn sha384(payload: &[u8]) -> [u8; 48] { - let mut digest = [0u8; 48]; - sha384_ema(&mut digest, payload); - digest +pub fn sha384(data: &[u8]) -> Sha3_384Digest { + let mut out = [0u8; 48]; + sha384_ema(&mut out, data); + out } /// SHA3 384 @@ -138,21 +130,15 @@ pub fn sha384_ema(digest: &mut [u8], payload: &[u8]) { debug_assert!(payload.len() <= u32::MAX as usize); debug_assert!(digest.len() == 48); - unsafe { - Hacl_Hash_SHA3_sha3_384( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } + portable::sha384(digest, payload) } /// SHA3 512 #[inline(always)] -pub fn sha512(payload: &[u8]) -> [u8; 64] { - let mut digest = [0u8; 64]; - sha512_ema(&mut digest, payload); - digest +pub fn sha512(data: &[u8]) -> Sha3_512Digest { + let mut out = [0u8; 64]; + sha512_ema(&mut out, data); + out } /// SHA3 512 @@ -161,27 +147,14 @@ pub fn sha512_ema(digest: &mut [u8], payload: &[u8]) { debug_assert!(payload.len() <= u32::MAX as usize); debug_assert!(digest.len() == 64); - unsafe { - Hacl_Hash_SHA3_sha3_512( - digest.as_mut_ptr(), - payload.as_ptr() as _, - payload.len().try_into().unwrap(), - ); - } + portable::sha512(digest, payload) } /// SHAKE 128 #[inline(always)] pub fn shake128(data: &[u8]) -> [u8; BYTES] { let mut out = [0u8; BYTES]; - unsafe { - Hacl_Hash_SHA3_shake128_hacl( - data.len() as u32, - data.as_ptr() as _, - BYTES as u32, - out.as_mut_ptr(), - ); - } + portable::shake128(&mut out, data); out } @@ -192,96 +165,479 @@ pub fn shake128(data: &[u8]) -> [u8; BYTES] { #[inline(always)] pub fn shake256(data: &[u8]) -> [u8; BYTES] { let mut out = [0u8; BYTES]; - unsafe { - Hacl_Hash_SHA3_shake256_hacl( - data.len() as u32, - data.as_ptr() as _, - BYTES as u32, - out.as_mut_ptr(), - ); - } + portable::shake256(&mut out, data); out } -// mod pure { - -// /// SHA3 224 -// pub fn sha3_224(payload: &[u8]) -> Sha3_224Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 28]; - -// hash_sha3::sha3_224(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHA3 256 -// pub fn sha3_256(payload: &[u8]) -> Sha3_256Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 32]; - -// hash_sha3::sha3_256(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHA3 384 -// pub fn sha3_384(payload: &[u8]) -> Sha3_384Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 48]; - -// hash_sha3::sha3_384(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHA3 512 -// pub fn sha3_512(payload: &[u8]) -> Sha3_512Digest { -// debug_assert!(payload.len() <= u32::MAX as usize); -// let payload = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(payload.as_ptr() as *mut u8, payload.len())) -// }; -// let mut out = [0u8; 64]; - -// hash_sha3::sha3_512(&mut out, payload, payload.len() as u32); - -// out -// } - -// /// SHAKE 128 -// /// -// /// The caller must define the size of the output in the return type. -// pub fn shake128(data: &[u8]) -> [u8; LEN] { -// debug_assert!(LEN <= u32::MAX as usize && data.len() <= u32::MAX as usize); -// let data = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(data.as_ptr() as *mut u8, data.len())) -// }; -// let mut out = [0u8; LEN]; -// shake128_hacl(data.len() as u32, data, LEN as u32, &mut out); - -// out -// } - -// /// SHAKE 256 -// /// -// /// The caller must define the size of the output in the return type. -// pub fn shake256(data: &[u8]) -> [u8; LEN] { -// debug_assert!(LEN <= u32::MAX as usize && data.len() <= u32::MAX as usize); -// let data = unsafe { -// &mut *(core::ptr::slice_from_raw_parts_mut(data.as_ptr() as *mut u8, data.len())) -// }; -// let mut out = [0u8; LEN]; -// shake256_hacl(data.len() as u32, data, LEN as u32, &mut out); - -// out -// } -// } +mod incremental {} + +// === The portable instantiation === // + +/// A portable SHA3 implementations without platform dependent optimisations. +pub mod portable { + use super::*; + use generic_keccak::{keccak, KeccakState}; + + pub type KeccakState1 = KeccakState<1, u64>; + + #[inline(always)] + fn keccakx1(data: [&[u8]; 1], out: [&mut [u8]; 1]) { + keccak::<1, u64, RATE, DELIM>(data, out) + } + + /// A portable SHA3 224 implementation. + pub fn sha224(digest: &mut [u8], data: &[u8]) { + keccakx1::<144, 0x06u8>([data], [digest]); + } + + /// A portable SHA3 256 implementation. + pub fn sha256(digest: &mut [u8], data: &[u8]) { + keccakx1::<136, 0x06u8>([data], [digest]); + } + + /// A portable SHA3 384 implementation. + pub fn sha384(digest: &mut [u8], data: &[u8]) { + keccakx1::<104, 0x06u8>([data], [digest]); + } + + /// A portable SHA3 512 implementation. + pub fn sha512(digest: &mut [u8], data: &[u8]) { + keccakx1::<72, 0x06u8>([data], [digest]); + } + + /// A portable SHAKE128 implementation. + pub fn shake128(digest: &mut [u8; LEN], data: &[u8]) { + keccakx1::<168, 0x1fu8>([data], [digest]); + } + + /// A portable SHAKE256 implementation. + pub fn shake256(digest: &mut [u8; LEN], data: &[u8]) { + keccakx1::<136, 0x1fu8>([data], [digest]); + } + + /// An incremental API for SHAKE + pub mod incremental { + use generic_keccak::{absorb_final, squeeze_first_three_blocks, squeeze_next_block}; + + use super::*; + + /// Initialise the SHAKE state. + pub fn shake128_init() -> KeccakState1 { + KeccakState1::new() + } + + /// Absorb + pub fn shake128_absorb_final(s: &mut KeccakState1, data0: &[u8]) { + absorb_final::<1, u64, 168, 0x1fu8>(s, [data0]); + } + + /// Squeeze three blocks + pub fn shake128_squeeze_first_three_blocks(s: &mut KeccakState1, out0: &mut [u8]) { + squeeze_first_three_blocks::<1, u64, 168>(s, [out0]) + } + + /// Squeeze another block + pub fn shake128_squeeze_next_block(s: &mut KeccakState1, out0: &mut [u8]) { + squeeze_next_block::<1, u64, 168>(s, [out0]) + } + } +} + +/// A neon optimised implementation. +/// +/// When this is compiled for non-neon architectures, the functions panic. +/// The caller must make sure to check for hardware feature before calling these +/// functions and compile them in. +/// +/// Feature `simd128` enables the implementations in this module. +pub mod neon { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + use crate::generic_keccak::keccak; + + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + #[inline(always)] + fn keccakx2(data: [&[u8]; 2], out: [&mut [u8]; 2]) { + keccak::<2, core::arch::aarch64::uint64x2_t, RATE, DELIM>(data, out) + } + + /// A portable SHA3 224 implementation. + #[allow(unused_variables)] + pub fn sha224(digest: &mut [u8], data: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + { + let mut dummy = [0u8; 28]; + keccakx2::<144, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHA3 256 implementation. + #[allow(unused_variables)] + pub fn sha256(digest: &mut [u8], data: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + { + let mut dummy = [0u8; 32]; + keccakx2::<136, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHA3 384 implementation. + #[allow(unused_variables)] + pub fn sha384(digest: &mut [u8], data: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + { + let mut dummy = [0u8; 48]; + keccakx2::<104, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHA3 512 implementation. + #[allow(unused_variables)] + pub fn sha512(digest: &mut [u8], data: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + { + let mut dummy = [0u8; 64]; + keccakx2::<72, 0x06u8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHAKE128 implementation. + #[allow(unused_variables)] + pub fn shake128(digest: &mut [u8; LEN], data: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + { + let mut dummy = [0u8; LEN]; + keccakx2::<168, 0x1fu8>([data, data], [digest, &mut dummy]); + } + } + + /// A portable SHAKE256 implementation. + #[allow(unused_variables)] + pub fn shake256(digest: &mut [u8; LEN], data: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + { + let mut dummy = [0u8; LEN]; + keccakx2::<136, 0x1fu8>([data, data], [digest, &mut dummy]); + } + } + + /// Performing 2 operations in parallel + pub mod x2 { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + use super::*; + + /// Run SHAKE256 on both inputs in parallel. + /// + /// Writes the two results into `out0` and `out1` + #[allow(unused_variables)] + pub fn shake256(input0: &[u8], input1: &[u8], out0: &mut [u8], out1: &mut [u8]) { + // TODO: make argument ordering consistent + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); + } + + /// An incremental API to perform 2 operations in parallel + pub mod incremental { + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + use crate::generic_keccak::{ + absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, + }; + + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + pub type KeccakState2 = KeccakState<2, core::arch::aarch64::uint64x2_t>; + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + pub type KeccakState2 = [crate::portable::KeccakState1; 2]; + + pub fn shake128_init() -> KeccakState2 { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let s0 = KeccakState1::new(); + // let s1 = KeccakState1::new(); + // [s0, s1] + // } + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + KeccakState2::new() + } + + #[allow(unused_variables)] + pub fn shake128_absorb_final(s: &mut KeccakState2, data0: &[u8], data1: &[u8]) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let [mut s0, mut s1] = s; + // shake128_absorb_final(&mut s0, data0); + // shake128_absorb_final(&mut s1, data1); + // } + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>(s, [data0, data1]); + } + + #[allow(unused_variables)] + pub fn shake128_squeeze_first_three_blocks( + s: &mut KeccakState2, + out0: &mut [u8], + out1: &mut [u8], + ) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let [mut s0, mut s1] = s; + // shake128_squeeze_first_three_blocks(&mut s0, out0); + // shake128_squeeze_first_three_blocks(&mut s1, out1); + // } + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( + s, + [out0, out1], + ) + } + + #[allow(unused_variables)] + pub fn shake128_squeeze_next_block( + s: &mut KeccakState2, + out0: &mut [u8], + out1: &mut [u8], + ) { + #[cfg(not(all(feature = "simd128", target_arch = "aarch64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // { + // let [mut s0, mut s1] = s; + // shake128_squeeze_next_block(&mut s0, out0); + // shake128_squeeze_next_block(&mut s1, out1); + // } + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>(s, [out0, out1]) + } + } + } +} + +/// An AVX2 optimised implementation. +/// +/// When this is compiled for non-neon architectures, the functions panic. +/// The caller must make sure to check for hardware feature before calling these +/// functions and compile them in. +/// +/// Feature `simd256` enables the implementations in this module. +pub mod avx2 { + + /// Performing 4 operations in parallel + pub mod x4 { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + use crate::generic_keccak::keccak; + + /// Perform 4 SHAKE256 operations in parallel + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake256( + input0: &[u8], + input1: &[u8], + input2: &[u8], + input3: &[u8], + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + ) { + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + // { + // keccakx2::<136, 0x1fu8>([input0, input1], [out0, out1]); + // keccakx2::<136, 0x1fu8>([input2, input3], [out2, out3]); + // } + // { + // keccakx1::<136, 0x1fu8>([input0], [out0]); + // keccakx1::<136, 0x1fu8>([input1], [out1]); + // keccakx1::<136, 0x1fu8>([input2], [out2]); + // keccakx1::<136, 0x1fu8>([input3], [out3]); + // } + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + keccak::<4, core::arch::x86_64::__m256i, 136, 0x1fu8>( + [input0, input1, input2, input3], + [out0, out1, out2, out3], + ); + } + + /// An incremental API to perform 4 operations in parallel + pub mod incremental { + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + use crate::generic_keccak::{ + absorb_final, squeeze_first_three_blocks, squeeze_next_block, KeccakState, + }; + + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + pub type KeccakState4 = KeccakState<4, core::arch::x86_64::__m256i>; + #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + pub type KeccakState4 = [crate::neon::x2::incremental::KeccakState2; 2]; + #[cfg(not(any( + all(feature = "simd256", target_arch = "x86_64"), + all(feature = "simd128", target_arch = "aarch64") + )))] + pub type KeccakState4 = [crate::portable::KeccakState1; 4]; + + pub fn shake128_init() -> KeccakState4 { + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + // { + // let s0 = KeccakState2::new(); + // let s1 = KeccakState2::new(); + // [s0, s1] + // } + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] + // { + // let s0 = KeccakState1::new(); + // let s1 = KeccakState1::new(); + // let s2 = KeccakState1::new(); + // let s3 = KeccakState1::new(); + // [s0, s1, s2, s3] + // } + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + KeccakState4::new() + } + + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake128_absorb_final( + s: &mut KeccakState4, + data0: &[u8], + data1: &[u8], + data2: &[u8], + data3: &[u8], + ) { + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + // { + // let [mut s0, mut s1] = s; + // absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>( + // &mut s0, + // [data0, data1], + // ); + // absorb_final::<2, core::arch::aarch64::uint64x2_t, 168, 0x1fu8>( + // &mut s1, + // [data2, data3], + // ); + // } + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] + // { + // let [mut s0, mut s1, mut s2, mut s3] = s; + // shake128_absorb_final(&mut s0, data0); + // shake128_absorb_final(&mut s1, data1); + // shake128_absorb_final(&mut s2, data2); + // shake128_absorb_final(&mut s3, data3); + // } + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + absorb_final::<4, core::arch::x86_64::__m256i, 168, 0x1fu8>( + s, + [data0, data1, data2, data3], + ); + } + + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake128_squeeze_first_three_blocks( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + ) { + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + // { + // let [mut s0, mut s1] = s; + // squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s0, + // [out0, out1], + // ); + // squeeze_first_three_blocks::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s1, + // [out2, out3], + // ); + // } + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] + // { + // let [mut s0, mut s1, mut s2, mut s3] = s; + // shake128_squeeze_first_three_blocks(&mut s0, out0); + // shake128_squeeze_first_three_blocks(&mut s1, out1); + // shake128_squeeze_first_three_blocks(&mut s2, out2); + // shake128_squeeze_first_three_blocks(&mut s3, out3); + // } + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + squeeze_first_three_blocks::<4, core::arch::x86_64::__m256i, 168>( + s, + [out0, out1, out2, out3], + ); + } + + #[allow(unused_variables)] // TODO: decide if we want to fall back here + pub fn shake128_squeeze_next_block( + s: &mut KeccakState4, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + ) { + #[cfg(not(all(feature = "simd256", target_arch = "x86_64")))] + unimplemented!("The target architecture does not support neon instructions."); + // XXX: These functions could alternatively implement the same with + // the portable implementation + // #[cfg(all(feature = "simd128", target_arch = "aarch64"))] + // { + // let [mut s0, mut s1] = s; + // squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s0, + // [out0, out1], + // ); + // squeeze_next_block::<2, core::arch::aarch64::uint64x2_t, 168>( + // &mut s1, + // [out2, out3], + // ); + // } + // #[cfg(not(any(all(feature = "simd128", target_arch = "aarch64"), all(feature = "simd256", target_arch = "x86_64"))))] + // { + // let [mut s0, mut s1, mut s2, mut s3] = s; + // shake128_squeeze_next_block(&mut s0, out0); + // shake128_squeeze_next_block(&mut s1, out1); + // shake128_squeeze_next_block(&mut s2, out2); + // shake128_squeeze_next_block(&mut s3, out3); + // } + #[cfg(all(feature = "simd256", target_arch = "x86_64"))] + squeeze_next_block::<4, core::arch::x86_64::__m256i, 168>( + s, + [out0, out1, out2, out3], + ); + } + } + } +} diff --git a/libcrux-sha3/src/lowstar.rs b/libcrux-sha3/src/lowstar.rs deleted file mode 100644 index f63af5cbe..000000000 --- a/libcrux-sha3/src/lowstar.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod endianness; -pub mod ignore; diff --git a/libcrux-sha3/src/lowstar/endianness.rs b/libcrux-sha3/src/lowstar/endianness.rs deleted file mode 100644 index 00d3ea9c5..000000000 --- a/libcrux-sha3/src/lowstar/endianness.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::convert::TryInto; - -// Little Endian - -pub fn load16_le(bytes: &[u8]) -> u16 { - u16::from_le_bytes(bytes[0..2].try_into().unwrap()) -} - -pub fn store16_le(bytes: &mut[u8], x: u16) { - bytes[0..2].copy_from_slice(&u16::to_le_bytes(x)) -} - -pub fn load32_le(bytes: &[u8]) -> u32 { - u32::from_le_bytes(bytes[0..4].try_into().unwrap()) -} - -pub fn store32_le(bytes: &mut[u8], x: u32) { - bytes[0..4].copy_from_slice(&u32::to_le_bytes(x)) -} - -pub fn load64_le(bytes: &[u8]) -> u64 { - u64::from_le_bytes(bytes[0..8].try_into().unwrap()) -} - -pub fn store64_le(bytes: &mut[u8], x: u64) { - bytes[0..8].copy_from_slice(&u64::to_le_bytes(x)) -} - -// Big Endian - -pub fn load32_be(bytes: &[u8]) -> u32 { - u32::from_be_bytes(bytes[0..4].try_into().unwrap()) -} - -pub fn store32_be(bytes: &mut[u8], x: u32) { - bytes[0..4].copy_from_slice(&u32::to_be_bytes(x)) -} - -pub fn load64_be(bytes: &[u8]) -> u64 { - u64::from_be_bytes(bytes[0..8].try_into().unwrap()) -} - -pub fn store64_be(bytes: &mut[u8], x: u64) { - bytes[0..8].copy_from_slice(&u64::to_be_bytes(x)) -} - -pub fn load128_be(bytes: &[u8]) -> u128 { - u128::from_be_bytes(bytes[0..16].try_into().unwrap()) -} - -pub fn store128_be(bytes: &mut[u8], x: u128) { - bytes[0..16].copy_from_slice(&u128::to_be_bytes(x)) -} diff --git a/libcrux-sha3/src/lowstar/ignore.rs b/libcrux-sha3/src/lowstar/ignore.rs deleted file mode 100644 index 919eb52f9..000000000 --- a/libcrux-sha3/src/lowstar/ignore.rs +++ /dev/null @@ -1 +0,0 @@ -pub fn ignore(_: T) {} diff --git a/libcrux-sha3/src/portable_keccak.rs b/libcrux-sha3/src/portable_keccak.rs new file mode 100644 index 000000000..341399985 --- /dev/null +++ b/libcrux-sha3/src/portable_keccak.rs @@ -0,0 +1,132 @@ +//! A portable SHA3 implementation using the generic implementation. + +use crate::traits::*; + +#[inline(always)] +fn rotate_left(x: u64) -> u64 { + debug_assert!(LEFT + RIGHT == 64); + (x << LEFT) | (x >> RIGHT) +} + +#[inline(always)] +fn _veor5q_u64(a: u64, b: u64, c: u64, d: u64, e: u64) -> u64 { + let ab = a ^ b; + let cd = c ^ d; + let abcd = ab ^ cd; + abcd ^ e +} + +#[inline(always)] +fn _vrax1q_u64(a: u64, b: u64) -> u64 { + a ^ rotate_left::<1, 63>(b) +} + +#[inline(always)] +fn _vxarq_u64(a: u64, b: u64) -> u64 { + let ab = a ^ b; + rotate_left::(ab) +} + +#[inline(always)] +fn _vbcaxq_u64(a: u64, b: u64, c: u64) -> u64 { + a ^ (b & !c) +} + +#[inline(always)] +fn _veorq_n_u64(a: u64, c: u64) -> u64 { + a ^ c +} + +#[inline(always)] +pub(crate) fn load_block(s: &mut [[u64; 5]; 5], blocks: [&[u8]; 1]) { + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); + for i in 0..RATE / 8 { + s[i / 5][i % 5] ^= u64::from_le_bytes(blocks[0][8 * i..8 * i + 8].try_into().unwrap()); + } +} + +#[inline(always)] +pub(crate) fn load_block_full(s: &mut [[u64; 5]; 5], blocks: [[u8; 200]; 1]) { + load_block::(s, [&blocks[0] as &[u8]]); +} + +#[inline(always)] +pub(crate) fn store_block(s: &[[u64; 5]; 5], out: [&mut [u8]; 1]) { + for i in 0..RATE / 8 { + out[0][8 * i..8 * i + 8].copy_from_slice(&s[i / 5][i % 5].to_le_bytes()); + } +} + +#[inline(always)] +pub(crate) fn store_block_full(s: &[[u64; 5]; 5]) -> [[u8; 200]; 1] { + let mut out = [0u8; 200]; + store_block::(s, [&mut out]); + [out] +} + +#[inline(always)] +fn slice_1(a: [&[u8]; 1], start: usize, len: usize) -> [&[u8]; 1] { + [&a[0][start..start + len]] +} + +#[inline(always)] +fn split_at_mut_1(out: [&mut [u8]; 1], mid: usize) -> ([&mut [u8]; 1], [&mut [u8]; 1]) { + let [out0] = out; + let (out00, out01) = out0.split_at_mut(mid); + ([out00], [out01]) +} + +impl KeccakItem<1> for u64 { + #[inline(always)] + fn zero() -> Self { + 0 + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + a ^ b + } + #[inline(always)] + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; 1]) { + load_block::(a, b) + } + #[inline(always)] + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; 1]) { + store_block::(a, b) + } + #[inline(always)] + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; 1]) { + load_block_full::(a, b) + } + #[inline(always)] + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; 1] { + store_block_full::(a) + } + #[inline(always)] + fn slice_n(a: [&[u8]; 1], start: usize, len: usize) -> [&[u8]; 1] { + slice_1(a, start, len) + } + #[inline(always)] + fn split_at_mut_n(a: [&mut [u8]; 1], mid: usize) -> ([&mut [u8]; 1], [&mut [u8]; 1]) { + split_at_mut_1(a, mid) + } +} diff --git a/libcrux-sha3/src/simd.rs b/libcrux-sha3/src/simd.rs new file mode 100644 index 000000000..7dc25011a --- /dev/null +++ b/libcrux-sha3/src/simd.rs @@ -0,0 +1,11 @@ +//! SIMD implementations of SHA3 +//! +//! Runtime feature detection must be performed before calling these functions. +//! +//! If the caller does not perform feature detection, the top level functions +//! must be used. + +#[cfg(feature = "simd128")] +mod arm64; +#[cfg(feature = "simd256")] +mod avx2; diff --git a/libcrux-sha3/src/simd/arm64.rs b/libcrux-sha3/src/simd/arm64.rs new file mode 100644 index 000000000..5d847ca7c --- /dev/null +++ b/libcrux-sha3/src/simd/arm64.rs @@ -0,0 +1,196 @@ +use core::arch::aarch64::*; + +use crate::traits::KeccakItem; + +// This file optimizes for the stable Rust Neon Intrinsics +// If we want to use the unstable neon-sha3 instructions, we could use: +// veor3q_u64, vrax1q_u64, vxarq_u64, and vbcaxq_u64 +// These instructions might speed up our code even more. + +#[inline(always)] +fn rotate_left(x: uint64x2_t) -> uint64x2_t { + debug_assert!(LEFT + RIGHT == 64); + // The following looks faster but is actually significantly slower + //unsafe { vsriq_n_u64::(vshlq_n_u64::(x), x) } + unsafe { veorq_u64(vshlq_n_u64::(x), vshrq_n_u64::(x)) } +} + +#[inline(always)] +fn _veor5q_u64( + a: uint64x2_t, + b: uint64x2_t, + c: uint64x2_t, + d: uint64x2_t, + e: uint64x2_t, +) -> uint64x2_t { + let ab = unsafe { veorq_u64(a, b) }; + let cd = unsafe { veorq_u64(c, d) }; + let abcd = unsafe { veorq_u64(ab, cd) }; + unsafe { veorq_u64(abcd, e) } + // Needs nightly+neon-sha3 + //unsafe {veor3q_u64(veor3q_u64(a,b,c),d,e)} +} + +#[inline(always)] +fn _vrax1q_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + unsafe { veorq_u64(a, rotate_left::<1, 63>(b)) } + // Needs nightly+neon-sha3 + //unsafe { vrax1q_u64(a, b) } +} + +#[inline(always)] +fn _vxarq_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + let ab = unsafe { veorq_u64(a, b) }; + rotate_left::(ab) + // Needs nightly+neon-sha3 + // unsafe { vxarq_u64::(a,b) } +} + +#[inline(always)] +fn _vbcaxq_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t) -> uint64x2_t { + unsafe { veorq_u64(a, vbicq_u64(b, c)) } + // Needs nightly+neon-sha3 + // unsafe{ vbcaxq_u64(a, b, c) } +} + +#[inline(always)] +fn _veorq_n_u64(a: uint64x2_t, c: u64) -> uint64x2_t { + let c = unsafe { vdupq_n_u64(c) }; + unsafe { veorq_u64(a, c) } +} + +#[inline(always)] +pub(crate) fn load_block(s: &mut [[uint64x2_t; 5]; 5], blocks: [&[u8]; 2]) { + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0); + for i in 0..RATE / 16 { + let v0 = unsafe { vld1q_u64(blocks[0][16 * i..16 * (i + 1)].as_ptr() as *const u64) }; + let v1 = unsafe { vld1q_u64(blocks[1][16 * i..16 * (i + 1)].as_ptr() as *const u64) }; + s[(2 * i) / 5][(2 * i) % 5] = + unsafe { veorq_u64(s[(2 * i) / 5][(2 * i) % 5], vtrn1q_u64(v0, v1)) }; + s[(2 * i + 1) / 5][(2 * i + 1) % 5] = + unsafe { veorq_u64(s[(2 * i + 1) / 5][(2 * i + 1) % 5], vtrn2q_u64(v0, v1)) }; + } + if RATE % 16 != 0 { + let i = (RATE / 8 - 1) / 5; + let j = (RATE / 8 - 1) % 5; + let mut u = [0u64; 2]; + u[0] = u64::from_le_bytes(blocks[0][RATE - 8..RATE].try_into().unwrap()); + u[1] = u64::from_le_bytes(blocks[1][RATE - 8..RATE].try_into().unwrap()); + let uvec = unsafe { vld1q_u64(u.as_ptr() as *const u64) }; + s[i][j] = unsafe { veorq_u64(s[i][j], uvec) }; + } +} + +#[inline(always)] +pub(crate) fn load_block_full( + s: &mut [[uint64x2_t; 5]; 5], + blocks: [[u8; 200]; 2], +) { + let [b0, b1] = blocks; + load_block::(s, [&b0 as &[u8], &b1 as &[u8]]); +} + +#[inline(always)] +pub(crate) fn store_block(s: &[[uint64x2_t; 5]; 5], out: [&mut [u8]; 2]) { + for i in 0..RATE / 16 { + let v0 = unsafe { + vtrn1q_u64( + s[(2 * i) / 5][(2 * i) % 5], + s[(2 * i + 1) / 5][(2 * i + 1) % 5], + ) + }; + let v1 = unsafe { + vtrn2q_u64( + s[(2 * i) / 5][(2 * i) % 5], + s[(2 * i + 1) / 5][(2 * i + 1) % 5], + ) + }; + unsafe { vst1q_u64(out[0][16 * i..16 * (i + 1)].as_mut_ptr() as *mut u64, v0) }; + unsafe { vst1q_u64(out[1][16 * i..16 * (i + 1)].as_mut_ptr() as *mut u64, v1) }; + } + if RATE % 16 != 0 { + debug_assert!(RATE % 8 == 0); + let i = (RATE / 8 - 1) / 5; + let j = (RATE / 8 - 1) % 5; + let mut u = [0u8; 16]; + unsafe { vst1q_u64(u.as_mut_ptr() as *mut u64, s[i][j]) }; + out[0][RATE - 8..RATE].copy_from_slice(&u[0..8]); + out[1][RATE - 8..RATE].copy_from_slice(&u[8..16]); + } +} + +#[inline(always)] +pub(crate) fn store_block_full(s: &[[uint64x2_t; 5]; 5]) -> [[u8; 200]; 2] { + let mut out0 = [0u8; 200]; + let mut out1 = [0u8; 200]; + store_block::(s, [&mut out0, &mut out1]); + [out0, out1] +} + +#[inline(always)] +fn slice_2(a: [&[u8]; 2], start: usize, len: usize) -> [&[u8]; 2] { + [&a[0][start..start + len], &a[1][start..start + len]] +} + +#[inline(always)] +fn split_at_mut_2(out: [&mut [u8]; 2], mid: usize) -> ([&mut [u8]; 2], [&mut [u8]; 2]) { + let [out0, out1] = out; + let (out00, out01) = out0.split_at_mut(mid); + let (out10, out11) = out1.split_at_mut(mid); + ([out00, out10], [out01, out11]) +} + +impl KeccakItem<2> for uint64x2_t { + #[inline(always)] + fn zero() -> Self { + unsafe { vdupq_n_u64(0) } + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + unsafe { veorq_u64(a, b) } + } + #[inline(always)] + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; 2]) { + load_block::(a, b) + } + #[inline(always)] + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; 2]) { + store_block::(a, b) + } + #[inline(always)] + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; 2]) { + load_block_full::(a, b) + } + #[inline(always)] + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; 2] { + store_block_full::(a) + } + #[inline(always)] + fn slice_n(a: [&[u8]; 2], start: usize, len: usize) -> [&[u8]; 2] { + slice_2(a, start, len) + } + #[inline(always)] + fn split_at_mut_n(a: [&mut [u8]; 2], mid: usize) -> ([&mut [u8]; 2], [&mut [u8]; 2]) { + split_at_mut_2(a, mid) + } +} diff --git a/libcrux-sha3/src/simd/avx2.rs b/libcrux-sha3/src/simd/avx2.rs new file mode 100644 index 000000000..153270906 --- /dev/null +++ b/libcrux-sha3/src/simd/avx2.rs @@ -0,0 +1,280 @@ +use core::arch::x86_64::*; + +use crate::traits::*; + +#[inline(always)] +fn rotate_left(x: __m256i) -> __m256i { + debug_assert!(LEFT + RIGHT == 64); + // XXX: This could be done more efficiently, if the shift values are multiples of 8. + unsafe { _mm256_xor_si256(_mm256_slli_epi64::(x), _mm256_srli_epi64::(x)) } +} + +#[inline(always)] +fn _veor5q_u64(a: __m256i, b: __m256i, c: __m256i, d: __m256i, e: __m256i) -> __m256i { + let ab = unsafe { _mm256_xor_si256(a, b) }; + let cd = unsafe { _mm256_xor_si256(c, d) }; + let abcd = unsafe { _mm256_xor_si256(ab, cd) }; + unsafe { _mm256_xor_si256(abcd, e) } +} + +#[inline(always)] +fn _vrax1q_u64(a: __m256i, b: __m256i) -> __m256i { + unsafe { _mm256_xor_si256(a, rotate_left::<1, 63>(b)) } +} + +#[inline(always)] +fn _vxarq_u64(a: __m256i, b: __m256i) -> __m256i { + let ab = unsafe { _mm256_xor_si256(a, b) }; + rotate_left::(ab) +} + +#[inline(always)] +fn _vbcaxq_u64(a: __m256i, b: __m256i, c: __m256i) -> __m256i { + unsafe { _mm256_xor_si256(a, _mm256_andnot_si256(c, b)) } +} + +#[inline(always)] +fn _veorq_n_u64(a: __m256i, c: u64) -> __m256i { + // Casting here is required, doesn't change the value. + let c = unsafe { _mm256_set1_epi64x(c as i64) }; + unsafe { _mm256_xor_si256(a, c) } +} + +#[inline(always)] +pub(crate) fn load_block(s: &mut [[__m256i; 5]; 5], blocks: [&[u8]; 4]) { + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && (RATE % 32 == 8 || RATE % 32 == 16)); + for i in 0..RATE / 32 { + let v0 = unsafe { + _mm256_loadu_si256(blocks[0][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + let v1 = unsafe { + _mm256_loadu_si256(blocks[1][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + let v2 = unsafe { + _mm256_loadu_si256(blocks[2][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + let v3 = unsafe { + _mm256_loadu_si256(blocks[3][32 * i..32 * (i + 1)].as_ptr() as *const __m256i) + }; + + let v0l = unsafe { _mm256_unpacklo_epi64(v0, v1) }; // 0 0 2 2 + let v1h = unsafe { _mm256_unpackhi_epi64(v0, v1) }; // 1 1 3 3 + let v2l = unsafe { _mm256_unpacklo_epi64(v2, v3) }; // 0 0 2 2 + let v3h = unsafe { _mm256_unpackhi_epi64(v2, v3) }; // 1 1 3 3 + + let v0 = unsafe { _mm256_permute2x128_si256(v0l, v2l, 0x20) }; // 0 0 0 0 + let v1 = unsafe { _mm256_permute2x128_si256(v1h, v3h, 0x20) }; // 1 1 1 1 + let v2 = unsafe { _mm256_permute2x128_si256(v0l, v2l, 0x31) }; // 2 2 2 2 + let v3 = unsafe { _mm256_permute2x128_si256(v1h, v3h, 0x31) }; // 3 3 3 3 + + s[(4 * i) / 5][(4 * i) % 5] = unsafe { _mm256_xor_si256(s[(4 * i) / 5][(4 * i) % 5], v0) }; + s[(4 * i + 1) / 5][(4 * i + 1) % 5] = + unsafe { _mm256_xor_si256(s[(4 * i + 1) / 5][(4 * i + 1) % 5], v1) }; + s[(4 * i + 2) / 5][(4 * i + 2) % 5] = + unsafe { _mm256_xor_si256(s[(4 * i + 2) / 5][(4 * i + 2) % 5], v2) }; + s[(4 * i + 3) / 5][(4 * i + 3) % 5] = + unsafe { _mm256_xor_si256(s[(4 * i + 3) / 5][(4 * i + 3) % 5], v3) }; + } + + let rem = RATE % 32; // has to be 8 or 16 + let start = 32 * (RATE / 32); + let mut u8s = [0u8; 32]; + u8s[0..8].copy_from_slice(&blocks[0][start..start + 8]); + u8s[8..16].copy_from_slice(&blocks[1][start..start + 8]); + u8s[16..24].copy_from_slice(&blocks[2][start..start + 8]); + u8s[24..32].copy_from_slice(&blocks[3][start..start + 8]); + let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i) }; + let i = (4 * (RATE / 32)) / 5; + let j = (4 * (RATE / 32)) % 5; + s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u) }; + if rem == 16 { + let mut u8s = [0u8; 32]; + u8s[0..8].copy_from_slice(&blocks[0][start + 8..start + 16]); + u8s[8..16].copy_from_slice(&blocks[1][start + 8..start + 16]); + u8s[16..24].copy_from_slice(&blocks[2][start + 8..start + 16]); + u8s[24..32].copy_from_slice(&blocks[3][start + 8..start + 16]); + let u = unsafe { _mm256_loadu_si256(u8s.as_ptr() as *const __m256i) }; + let i = (4 * (RATE / 32) + 1) / 5; + let j = (4 * (RATE / 32) + 1) % 5; + s[i][j] = unsafe { _mm256_xor_si256(s[i][j], u) }; + } +} + +#[inline(always)] +pub(crate) fn load_block_full( + s: &mut [[__m256i; 5]; 5], + blocks: [[u8; 200]; 4], +) { + let [b0, b1, b2, b3] = blocks; + load_block::(s, [&b0 as &[u8], &b1 as &[u8], &b2 as &[u8], &b3 as &[u8]]); +} + +#[inline(always)] +pub(crate) fn store_block(s: &[[__m256i; 5]; 5], out: [&mut [u8]; 4]) { + for i in 0..RATE / 32 { + let v0l = unsafe { + _mm256_permute2x128_si256( + s[(4 * i) / 5][(4 * i) % 5], + s[(4 * i + 2) / 5][(4 * i + 2) % 5], + 0x20, + ) + }; // 0 0 2 2 + let v1h = unsafe { + _mm256_permute2x128_si256( + s[(4 * i + 1) / 5][(4 * i + 1) % 5], + s[(4 * i + 3) / 5][(4 * i + 3) % 5], + 0x20, + ) + }; // 1 1 3 3 + let v2l = unsafe { + _mm256_permute2x128_si256( + s[(4 * i) / 5][(4 * i) % 5], + s[(4 * i + 2) / 5][(4 * i + 2) % 5], + 0x31, + ) + }; // 0 0 2 2 + let v3h = unsafe { + _mm256_permute2x128_si256( + s[(4 * i + 1) / 5][(4 * i + 1) % 5], + s[(4 * i + 3) / 5][(4 * i + 3) % 5], + 0x31, + ) + }; // 1 1 3 3 + + let v0 = unsafe { _mm256_unpacklo_epi64(v0l, v1h) }; // 0 1 2 3 + let v1 = unsafe { _mm256_unpackhi_epi64(v0l, v1h) }; // 0 1 2 3 + let v2 = unsafe { _mm256_unpacklo_epi64(v2l, v3h) }; // 0 1 2 3 + let v3 = unsafe { _mm256_unpackhi_epi64(v2l, v3h) }; // 0 1 2 3 + + unsafe { + _mm256_storeu_si256( + out[0][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v0, + ) + }; + unsafe { + _mm256_storeu_si256( + out[1][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v1, + ) + }; + unsafe { + _mm256_storeu_si256( + out[2][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v2, + ) + }; + unsafe { + _mm256_storeu_si256( + out[3][32 * i..32 * (i + 1)].as_mut_ptr() as *mut __m256i, + v3, + ) + }; + } + + let rem = RATE % 32; // has to be 8 or 16 + let start = 32 * (RATE / 32); + let mut u8s = [0u8; 32]; + let i = (4 * (RATE / 32)) / 5; + let j = (4 * (RATE / 32)) % 5; + unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j]) }; + out[0][start..start + 8].copy_from_slice(&u8s[0..8]); + out[1][start..start + 8].copy_from_slice(&u8s[8..16]); + out[2][start..start + 8].copy_from_slice(&u8s[16..24]); + out[3][start..start + 8].copy_from_slice(&u8s[24..32]); + if rem == 16 { + let mut u8s = [0u8; 32]; + let i = (4 * (RATE / 32) + 1) / 5; + let j = (4 * (RATE / 32) + 1) % 5; + unsafe { _mm256_storeu_si256(u8s.as_mut_ptr() as *mut __m256i, s[i][j]) }; + out[0][start + 8..start + 16].copy_from_slice(&u8s[0..8]); + out[1][start + 8..start + 16].copy_from_slice(&u8s[8..16]); + out[2][start + 8..start + 16].copy_from_slice(&u8s[16..24]); + out[3][start + 8..start + 16].copy_from_slice(&u8s[24..32]); + } +} + +#[inline(always)] +pub(crate) fn store_block_full(s: &[[__m256i; 5]; 5]) -> [[u8; 200]; 4] { + let mut out0 = [0u8; 200]; + let mut out1 = [0u8; 200]; + let mut out2 = [0u8; 200]; + let mut out3 = [0u8; 200]; + store_block::(s, [&mut out0, &mut out1, &mut out2, &mut out3]); + [out0, out1, out2, out3] +} + +#[inline(always)] +fn slice_4(a: [&[u8]; 4], start: usize, len: usize) -> [&[u8]; 4] { + [ + &a[0][start..start + len], + &a[1][start..start + len], + &a[2][start..start + len], + &a[3][start..start + len], + ] +} + +#[inline(always)] +fn split_at_mut_4(out: [&mut [u8]; 4], mid: usize) -> ([&mut [u8]; 4], [&mut [u8]; 4]) { + let [out0, out1, out2, out3] = out; + let (out00, out01) = out0.split_at_mut(mid); + let (out10, out11) = out1.split_at_mut(mid); + let (out20, out21) = out2.split_at_mut(mid); + let (out30, out31) = out3.split_at_mut(mid); + ([out00, out10, out20, out30], [out01, out11, out21, out31]) +} + +impl KeccakItem<4> for __m256i { + #[inline(always)] + fn zero() -> Self { + unsafe { _mm256_set1_epi64x(0) } + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + unsafe { _mm256_xor_si256(a, b) } + } + #[inline(always)] + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; 4]) { + load_block::(a, b) + } + #[inline(always)] + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; 4]) { + store_block::(a, b) + } + #[inline(always)] + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; 4]) { + load_block_full::(a, b) + } + #[inline(always)] + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; 4] { + store_block_full::(a) + } + #[inline(always)] + fn slice_n(a: [&[u8]; 4], start: usize, len: usize) -> [&[u8]; 4] { + slice_4(a, start, len) + } + #[inline(always)] + fn split_at_mut_n(a: [&mut [u8]; 4], mid: usize) -> ([&mut [u8]; 4], [&mut [u8]; 4]) { + split_at_mut_4(a, mid) + } +} diff --git a/libcrux-sha3/src/traits.rs b/libcrux-sha3/src/traits.rs new file mode 100644 index 000000000..a499305e3 --- /dev/null +++ b/libcrux-sha3/src/traits.rs @@ -0,0 +1,16 @@ +/// A trait for multiplexing implementations. +pub(crate) trait KeccakItem: Clone + Copy { + fn zero() -> Self; + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self; + fn rotate_left1_and_xor(a: Self, b: Self) -> Self; + fn xor_and_rotate(a: Self, b: Self) -> Self; + fn and_not_xor(a: Self, b: Self, c: Self) -> Self; + fn xor_constant(a: Self, c: u64) -> Self; + fn xor(a: Self, b: Self) -> Self; + fn load_block(a: &mut [[Self; 5]; 5], b: [&[u8]; N]); + fn store_block(a: &[[Self; 5]; 5], b: [&mut [u8]; N]); + fn load_block_full(a: &mut [[Self; 5]; 5], b: [[u8; 200]; N]); + fn store_block_full(a: &[[Self; 5]; 5]) -> [[u8; 200]; N]; + fn slice_n(a: [&[u8]; N], start: usize, len: usize) -> [&[u8]; N]; + fn split_at_mut_n(a: [&mut [u8]; N], mid: usize) -> ([&mut [u8]; N], [&mut [u8]; N]); +} diff --git a/libcrux-sha3/src/x4.rs b/libcrux-sha3/src/x4.rs deleted file mode 100644 index ca501b87b..000000000 --- a/libcrux-sha3/src/x4.rs +++ /dev/null @@ -1,61 +0,0 @@ -/// An incremental eXtendable Output Function API for SHA3 (shake). -/// -/// This x4 variant of the incremental API always processes 4 inputs at a time. -/// This uses AVX2 when available to run the 4 operations in parallel. -/// -/// More generic APIs will be added later. -mod internal; - -/// Incremental state -#[cfg_attr(hax, hax_lib::opaque_type)] -pub struct Shake128StateX4 { - state: internal::Shake128StateX4, -} - -impl Shake128StateX4 { - /// Create a new Shake128 x4 state. - #[inline(always)] - pub fn new() -> Self { - Self { - state: internal::Shake128StateX4::new(), - } - } - - /// This is only used internally to work around Eurydice bugs. - #[inline(always)] - pub fn free_memory(self) { - self.state.free(); - } - - /// Absorb 4 blocks. - /// - /// A blocks MUST all be the same length. - /// Each slice MUST be a multiple of the block length 168. - #[inline(always)] - pub fn absorb_4blocks(&mut self, input: [&[u8]; 4]) { - self.state.absorb_blocks(input) - } - - /// Absorb up to 4 blocks. - /// - /// The `input` must be of length 1 to 4. - /// A blocks MUST all be the same length. - /// Each slice MUST be a multiple of the block length 168. - #[inline(always)] - pub fn absorb_final(&mut self, input: [&[u8]; N]) { - // Pad the input to the length of 4 - let data = [ - input[0], - if N > 1 { input[1] } else { &[] }, - if N > 2 { input[2] } else { &[] }, - if N > 3 { input[3] } else { &[] }, - ]; - self.state.absorb_final(data); - } - - /// Squeeze `M` blocks of length `N` - #[inline(always)] - pub fn squeeze_blocks(&mut self) -> [[u8; N]; M] { - self.state.squeeze_blocks() - } -} diff --git a/libcrux-sha3/src/x4/internal.rs b/libcrux-sha3/src/x4/internal.rs deleted file mode 100644 index 12a75821f..000000000 --- a/libcrux-sha3/src/x4/internal.rs +++ /dev/null @@ -1,337 +0,0 @@ -use core::ptr::null_mut; - -use libcrux_hacl::{ - Hacl_Hash_SHA3_Scalar_shake128_absorb_final, Hacl_Hash_SHA3_Scalar_shake128_absorb_nblocks, - Hacl_Hash_SHA3_Scalar_shake128_squeeze_nblocks, Hacl_Hash_SHA3_Scalar_state_free, - Hacl_Hash_SHA3_Scalar_state_malloc, -}; -#[cfg(feature = "simd256")] -use libcrux_hacl::{ - Hacl_Hash_SHA3_Simd256_shake128_absorb_final, Hacl_Hash_SHA3_Simd256_shake128_absorb_nblocks, - Hacl_Hash_SHA3_Simd256_shake128_squeeze_nblocks, Hacl_Hash_SHA3_Simd256_state_free, - Hacl_Hash_SHA3_Simd256_state_malloc, Lib_IntVector_Intrinsics_vec256, -}; -#[cfg(feature = "simd256")] -use libcrux_platform::simd256_support; - -/// SHAKE 128 -/// -/// Handle to internal SHAKE 128 state -#[cfg(feature = "simd256")] -pub struct Shake128StateX4 { - statex4: *mut Lib_IntVector_Intrinsics_vec256, - state: [*mut u64; 4], -} - -#[cfg(not(feature = "simd256"))] -pub struct Shake128StateX4 { - state: [*mut u64; 4], -} - -impl Shake128StateX4 { - #[cfg(feature = "simd256")] - pub fn new() -> Self { - if simd256_support() { - Self { - statex4: unsafe { Hacl_Hash_SHA3_Simd256_state_malloc() }, - state: [null_mut(), null_mut(), null_mut(), null_mut()], - } - } else { - Self { - statex4: null_mut(), - state: unsafe { - [ - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - ] - }, - } - } - } - - #[cfg(not(feature = "simd256"))] - pub fn new() -> Self { - Self { - state: unsafe { - [ - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - Hacl_Hash_SHA3_Scalar_state_malloc(), - ] - }, - } - } - - /// Free and consume the state. - /// - /// **NOTE:** This consumes the value. It is not usable after this call! - #[cfg(feature = "simd256")] - pub fn free(mut self) { - if simd256_support() { - unsafe { - Hacl_Hash_SHA3_Simd256_state_free(self.statex4); - // null the pointer (hacl isn't doing that unfortunately) - // This way we can check whether the memory was freed already or not. - self.statex4 = null_mut(); - } - } else { - for i in 0..4 { - unsafe { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]); - // null the pointer (hacl isn't doing that unfortunately) - // This way we can check whether the memory was freed already or not. - self.state[i] = null_mut(); - } - } - } - } - - /// Free and consume the state. - /// - /// **NOTE:** This consumes the value. It is not usable after this call! - #[cfg(not(feature = "simd256"))] - pub fn free(mut self) { - for i in 0..4 { - unsafe { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]); - // null the pointer (hacl isn't doing that unfortunately) - // This way we can check whether the memory was freed already or not. - self.state[i] = null_mut(); - } - } - } - - /// Absorb up to 4 blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(feature = "simd256")] - pub fn absorb_blocks(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() % 168 == 0); - - if simd256_support() { - unsafe { - Hacl_Hash_SHA3_Simd256_shake128_absorb_nblocks( - self.statex4, - input[0].as_ptr() as _, - input[1].as_ptr() as _, - input[2].as_ptr() as _, - input[3].as_ptr() as _, - input[0].len() as u32, - ) - }; - } else { - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_nblocks( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - } - - /// Absorb up to 4 blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(not(feature = "simd256"))] - pub fn absorb_blocks(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() % 168 == 0); - - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_nblocks( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - - /// Absorb up to 4 final blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(feature = "simd256")] - pub fn absorb_final(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() < 168); - - if simd256_support() { - unsafe { - Hacl_Hash_SHA3_Simd256_shake128_absorb_final( - self.statex4, - input[0].as_ptr() as _, - input[1].as_ptr() as _, - input[2].as_ptr() as _, - input[3].as_ptr() as _, - input[0].len() as u32, - ) - }; - } else { - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_final( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - } - - /// Absorb up to 4 final blocks at a time. - /// - /// The input length must be a multiple of the SHA3 block length of 168. - /// - /// The input is truncated at `u32::MAX`. - #[cfg(not(feature = "simd256"))] - pub fn absorb_final(&mut self, input: [&[u8]; 4]) { - debug_assert!( - (input[0].len() == input[1].len() || input[1].len() == 0) - && (input[0].len() == input[2].len() || input[2].len() == 0) - && (input[0].len() == input[3].len() || input[3].len() == 0) - ); - debug_assert!(input[0].len() < 168); - - for i in 0..4 { - if !input[i].is_empty() { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_absorb_final( - self.state[i], - input[i].as_ptr() as _, - input[i].len() as u32, - ); - }; - } - } - } - - #[cfg(feature = "simd256")] - pub fn squeeze_blocks( - &mut self, - ) -> [[u8; OUTPUT_BYTES]; M] { - debug_assert!(OUTPUT_BYTES % 168 == 0); - debug_assert!(M <= self.state.len() && (M == 2 || M == 3 || M == 4)); - - if simd256_support() { - let mut output = [[0u8; OUTPUT_BYTES]; 4]; - unsafe { - Hacl_Hash_SHA3_Simd256_shake128_squeeze_nblocks( - self.statex4, - output[0].as_mut_ptr(), - output[1].as_mut_ptr(), - output[2].as_mut_ptr(), - output[3].as_mut_ptr(), - OUTPUT_BYTES as u32, - ); - }; - core::array::from_fn(|i| output[i]) - } else { - let mut output = [[0u8; OUTPUT_BYTES]; M]; - for i in 0..M { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_squeeze_nblocks( - self.state[i], - output[i].as_mut_ptr(), - OUTPUT_BYTES as u32, - ); - }; - } - output - } - } - - #[cfg(not(feature = "simd256"))] - pub fn squeeze_blocks( - &mut self, - ) -> [[u8; OUTPUT_BYTES]; M] { - debug_assert!(OUTPUT_BYTES % 168 == 0); - debug_assert!(M <= self.state.len()); - - let mut output = [[0u8; OUTPUT_BYTES]; M]; - - for i in 0..M { - unsafe { - Hacl_Hash_SHA3_Scalar_shake128_squeeze_nblocks( - self.state[i], - output[i].as_mut_ptr(), - OUTPUT_BYTES as u32, - ); - }; - } - - output - } -} - -/// **NOTE:** When generating C code with Eurydice, the state needs to be freed -/// manually for now due to a bug in Eurydice. -impl Drop for Shake128StateX4 { - #[cfg(feature = "simd256")] - fn drop(&mut self) { - if simd256_support() { - // A manual free may have occurred already. - // Avoid double free. - unsafe { - if !self.statex4.is_null() { - Hacl_Hash_SHA3_Simd256_state_free(self.statex4); - } - } - } else { - // A manual free may have occurred already. - // Avoid double free. - for i in 0..4 { - unsafe { - if !self.state[i].is_null() { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]) - } - } - } - } - } - - #[cfg(not(feature = "simd256"))] - fn drop(&mut self) { - // A manual free may have occurred already. - // Avoid double free. - for i in 0..4 { - unsafe { - if !self.state[i].is_null() { - Hacl_Hash_SHA3_Scalar_state_free(self.state[i]) - } - } - } - } -} diff --git a/libcrux-sha3/tests/sha3.rs b/libcrux-sha3/tests/sha3.rs new file mode 100644 index 000000000..a4b7e3248 --- /dev/null +++ b/libcrux-sha3/tests/sha3.rs @@ -0,0 +1,23 @@ +#[test] +fn sha3_kat_oneshot() { + let d256 = libcrux_sha3::sha256(b"Hello, World!"); + let expected256 = "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"; + assert_eq!(hex::encode(&d256), expected256); + + let dshake = libcrux_sha3::shake128::<42>(b"Hello, World!"); + let expectedshake = + "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; + assert_eq!(hex::encode(&dshake), expectedshake); +} + +#[test] +fn sha3_simd_kat_oneshot() { + let d256 = libcrux_sha3::sha256(b"Hello, World!"); + let expected256 = "1af17a664e3fa8e419b8ba05c2a173169df76162a5a286e0c405b460d478f7ef"; + assert_eq!(hex::encode(&d256), expected256); + + let dshake = libcrux_sha3::shake128::<42>(b"Hello, World!"); + let expectedshake = + "2bf5e6dee6079fad604f573194ba8426bd4d30eb13e8ba2edae70e529b570cbdd588f2c5dd4e465dfbaf"; + assert_eq!(hex::encode(&dshake), expectedshake); +} diff --git a/polynomials-aarch64/src/lib.rs b/polynomials-aarch64/src/lib.rs index 38a2dc578..a37953217 100644 --- a/polynomials-aarch64/src/lib.rs +++ b/polynomials-aarch64/src/lib.rs @@ -22,7 +22,7 @@ impl Operations for SIMD128Vector { to_i16_array(v) } - fn from_i16_array(array: [i16; 16]) -> Self { + fn from_i16_array(array: &[i16]) -> Self { from_i16_array(array) } @@ -157,7 +157,7 @@ impl Operations for SIMD128Vector { deserialize_12(a) } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rejsample::rej_sample(a) + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { + rejsample::rej_sample(a, out) } } diff --git a/polynomials-aarch64/src/neon.rs b/polynomials-aarch64/src/neon.rs index d233c6cc6..d0454cfef 100644 --- a/polynomials-aarch64/src/neon.rs +++ b/polynomials-aarch64/src/neon.rs @@ -193,7 +193,7 @@ pub(crate) fn _vmlal_high_s16(a: int32x4_t, b: int16x8_t, c: int16x8_t) -> int32 } #[inline(always)] pub(crate) fn _vld1q_u8(ptr: &[u8]) -> uint8x16_t { - unsafe { vld1q_u8(ptr.as_ptr()) } + unsafe { vld1q_u8(ptr.as_ptr() as *const u8) } } #[inline(always)] pub(crate) fn _vreinterpretq_u8_s16(a: int16x8_t) -> uint8x16_t { @@ -254,7 +254,7 @@ pub(super) fn _vreinterpretq_u8_s64(a: int64x2_t) -> uint8x16_t { #[inline(always)] pub(super) fn _vst1q_u8(out: &mut [u8], v: uint8x16_t) { - unsafe { vst1q_u8(out.as_mut_ptr(), v) } + unsafe { vst1q_u8(out.as_mut_ptr() as *mut u8, v) } } #[inline(always)] pub(crate) fn _vdupq_n_u16(value: u16) -> uint16x8_t { @@ -270,7 +270,7 @@ pub(crate) fn _vreinterpretq_u16_u8(a: uint8x16_t) -> uint16x8_t { } #[inline(always)] pub(crate) fn _vld1q_u16(ptr: &[u16]) -> uint16x8_t { - unsafe { vld1q_u16(ptr.as_ptr()) } + unsafe { vld1q_u16(ptr.as_ptr() as *const u16) } } #[inline(always)] pub(crate) fn _vcleq_s16(a: int16x8_t, b: int16x8_t) -> uint16x8_t { diff --git a/polynomials-aarch64/src/rejsample.rs b/polynomials-aarch64/src/rejsample.rs index e667bc3ab..b25fc0f42 100644 --- a/polynomials-aarch64/src/rejsample.rs +++ b/polynomials-aarch64/src/rejsample.rs @@ -768,7 +768,7 @@ const IDX_TABLE: [[u8; 16]; 256] = [ ]; #[inline(always)] -pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { +pub(crate) fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { let neon_bits: [u16; 8] = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80]; let bits = _vld1q_u16(&neon_bits); let fm = _vdupq_n_s16(3328); @@ -787,9 +787,8 @@ pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { let shifted1 = _vreinterpretq_s16_u8(_vqtbl1q_u8(_vreinterpretq_u8_s16(input.high), index_vec1)); - let mut out: [i16; 16] = [0i16; 16]; let idx0 = pick0 as usize; _vst1q_s16(&mut out[0..8], shifted0); _vst1q_s16(&mut out[idx0..idx0 + 8], shifted1); - ((pick0 + pick1) as usize, out) + (pick0 + pick1) as usize } diff --git a/polynomials-aarch64/src/simd128ops.rs b/polynomials-aarch64/src/simd128ops.rs index 3f1b03a16..79be6f4af 100644 --- a/polynomials-aarch64/src/simd128ops.rs +++ b/polynomials-aarch64/src/simd128ops.rs @@ -27,7 +27,7 @@ pub(crate) fn to_i16_array(v: SIMD128Vector) -> [i16; 16] { } #[inline(always)] -pub(crate) fn from_i16_array(array: [i16; 16]) -> SIMD128Vector { +pub(crate) fn from_i16_array(array: &[i16]) -> SIMD128Vector { SIMD128Vector { low: _vld1q_s16(&array[0..8]), high: _vld1q_s16(&array[8..16]), diff --git a/polynomials-avx2/src/arithmetic.rs b/polynomials-avx2/src/arithmetic.rs new file mode 100644 index 000000000..e51fd5b5b --- /dev/null +++ b/polynomials-avx2/src/arithmetic.rs @@ -0,0 +1,74 @@ +use crate::intrinsics::*; +use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; + +#[inline(always)] +pub(crate) fn add(lhs: __m256i, rhs: __m256i) -> __m256i { + mm256_add_epi16(lhs, rhs) +} + +#[inline(always)] +pub(crate) fn sub(lhs: __m256i, rhs: __m256i) -> __m256i { + mm256_sub_epi16(lhs, rhs) +} + +#[inline(always)] +pub(crate) fn multiply_by_constant(vector: __m256i, constant: i16) -> __m256i { + mm256_mullo_epi16(vector, mm256_set1_epi16(constant)) +} + +#[inline(always)] +pub(crate) fn bitwise_and_with_constant(vector: __m256i, constant: i16) -> __m256i { + mm256_and_si256(vector, mm256_set1_epi16(constant)) +} + +#[inline(always)] +pub(crate) fn shift_right(vector: __m256i) -> __m256i { + mm256_srai_epi16::<{ SHIFT_BY }>(vector) +} + +#[inline(always)] +pub(crate) fn shift_left(vector: __m256i) -> __m256i { + mm256_slli_epi16::<{ SHIFT_BY }>(vector) +} + +#[inline(always)] +pub(crate) fn cond_subtract_3329(vector: __m256i) -> __m256i { + let field_modulus = mm256_set1_epi16(FIELD_MODULUS); + + let v_minus_field_modulus = mm256_sub_epi16(vector, field_modulus); + + let sign_mask = mm256_srai_epi16::<15>(v_minus_field_modulus); + let conditional_add_field_modulus = mm256_and_si256(sign_mask, field_modulus); + + mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) +} + +const BARRETT_MULTIPLIER: i16 = 20159; + +#[inline(always)] +pub(crate) fn barrett_reduce(vector: __m256i) -> __m256i { + let t = mm256_mulhi_epi16(vector, mm256_set1_epi16(BARRETT_MULTIPLIER)); + let t = mm256_add_epi16(t, mm256_set1_epi16(512)); + + let quotient = mm256_srai_epi16::<10>(t); + + let quotient_times_field_modulus = mm256_mullo_epi16(quotient, mm256_set1_epi16(FIELD_MODULUS)); + + mm256_sub_epi16(vector, quotient_times_field_modulus) +} + +#[inline(always)] +pub(crate) fn montgomery_multiply_by_constant(vector: __m256i, constant: i16) -> __m256i { + let constant = mm256_set1_epi16(constant); + let value_low = mm256_mullo_epi16(vector, constant); + + let k = mm256_mullo_epi16( + value_low, + mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS)); + + let value_high = mm256_mulhi_epi16(vector, constant); + + mm256_sub_epi16(value_high, k_times_modulus) +} diff --git a/polynomials-avx2/src/compress.rs b/polynomials-avx2/src/compress.rs new file mode 100644 index 000000000..858a5b278 --- /dev/null +++ b/polynomials-avx2/src/compress.rs @@ -0,0 +1,110 @@ +use crate::intrinsics::*; +use libcrux_traits::FIELD_MODULUS; + +// This implementation was taken from: +// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html +// +// TODO: Optimize this implementation if performance numbers suggest doing so. +#[inline(always)] +fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + let prod02 = mm256_mul_epu32(lhs, rhs); + let prod13 = mm256_mul_epu32( + mm256_shuffle_epi32::<0b11_11_01_01>(lhs), + mm256_shuffle_epi32::<0b11_11_01_01>(rhs), + ); + + mm256_unpackhi_epi64( + mm256_unpacklo_epi32(prod02, prod13), + mm256_unpackhi_epi32(prod02, prod13), + ) +} + +#[inline(always)] +pub(crate) fn compress_message_coefficient(vector: __m256i) -> __m256i { + let field_modulus_halved = mm256_set1_epi16((FIELD_MODULUS - 1) / 2); + let field_modulus_quartered = mm256_set1_epi16((FIELD_MODULUS - 1) / 4); + + let shifted = mm256_sub_epi16(field_modulus_halved, vector); + let mask = mm256_srai_epi16::<15>(shifted); + + let shifted_to_positive = mm256_xor_si256(mask, shifted); + let shifted_to_positive_in_range = + mm256_sub_epi16(shifted_to_positive, field_modulus_quartered); + + mm256_srli_epi16::<15>(shifted_to_positive_in_range) +} + +#[inline(always)] +pub(crate) fn compress_ciphertext_coefficient( + vector: __m256i, +) -> __m256i { + let field_modulus_halved = mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); + let compression_factor = mm256_set1_epi32(10_321_340); + let coefficient_bits_mask = mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); + + // Compress the first 8 coefficients + let coefficients_low = mm256_castsi256_si128(vector); + let coefficients_low = mm256_cvtepi16_epi32(coefficients_low); + + let compressed_low = mm256_slli_epi32::<{ COEFFICIENT_BITS }>(coefficients_low); + let compressed_low = mm256_add_epi32(compressed_low, field_modulus_halved); + + let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); + let compressed_low = mm256_srli_epi32::<3>(compressed_low); + let compressed_low = mm256_and_si256(compressed_low, coefficient_bits_mask); + + // Compress the next 8 coefficients + let coefficients_high = mm256_extracti128_si256::<1>(vector); + let coefficients_high = mm256_cvtepi16_epi32(coefficients_high); + + let compressed_high = mm256_slli_epi32::<{ COEFFICIENT_BITS }>(coefficients_high); + let compressed_high = mm256_add_epi32(compressed_high, field_modulus_halved); + + let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); + let compressed_high = mm256_srli_epi32::<3>(compressed_high); + let compressed_high = mm256_and_si256(compressed_high, coefficient_bits_mask); + + // Combine them + let compressed = mm256_packs_epi32(compressed_low, compressed_high); + + mm256_permute4x64_epi64::<0b11_01_10_00>(compressed) +} + +#[inline(always)] +pub(crate) fn decompress_ciphertext_coefficient( + vector: __m256i, +) -> __m256i { + let field_modulus = mm256_set1_epi32(FIELD_MODULUS as i32); + let two_pow_coefficient_bits = mm256_set1_epi32(1 << COEFFICIENT_BITS); + + // Compress the first 8 coefficients + let coefficients_low = mm256_castsi256_si128(vector); + let coefficients_low = mm256_cvtepi16_epi32(coefficients_low); + + let decompressed_low = mm256_mullo_epi32(coefficients_low, field_modulus); + let decompressed_low = mm256_slli_epi32::<1>(decompressed_low); + let decompressed_low = mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); + + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_low = mm256_srli_epi32::<{ COEFFICIENT_BITS }>(decompressed_low); + let decompressed_low = mm256_srli_epi32::<1>(decompressed_low); + + // Compress the next 8 coefficients + let coefficients_high = mm256_extracti128_si256::<1>(vector); + let coefficients_high = mm256_cvtepi16_epi32(coefficients_high); + + let decompressed_high = mm256_mullo_epi32(coefficients_high, field_modulus); + let decompressed_high = mm256_slli_epi32::<1>(decompressed_high); + let decompressed_high = mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); + + // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack + // of support for const generic expressions. + let decompressed_high = mm256_srli_epi32::<{ COEFFICIENT_BITS }>(decompressed_high); + let decompressed_high = mm256_srli_epi32::<1>(decompressed_high); + + // Combine them + let compressed = mm256_packs_epi32(decompressed_low, decompressed_high); + + mm256_permute4x64_epi64::<0b11_01_10_00>(compressed) +} diff --git a/polynomials-avx2/src/debug.rs b/polynomials-avx2/src/debug.rs index c49a167b8..7277a76f6 100644 --- a/polynomials-avx2/src/debug.rs +++ b/polynomials-avx2/src/debug.rs @@ -9,6 +9,14 @@ pub(crate) fn print_m256i_as_i16s(a: __m256i, prefix: &'static str) { unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; println!("{}: {:?}", prefix, a_bytes); } + +#[allow(dead_code)] +pub(crate) fn print_m256i_as_i8s(a: __m256i, prefix: &'static str) { + let mut a_bytes = [0i8; 32]; + unsafe { _mm256_store_si256(a_bytes.as_mut_ptr() as *mut __m256i, a) }; + println!("{}: {:?}", prefix, a_bytes); +} + #[allow(dead_code)] pub(crate) fn print_m256i_as_i32s(a: __m256i, prefix: &'static str) { let mut a_bytes = [0i32; 8]; diff --git a/polynomials-avx2/src/intrinsics.rs b/polynomials-avx2/src/intrinsics.rs new file mode 100644 index 000000000..db7981ae4 --- /dev/null +++ b/polynomials-avx2/src/intrinsics.rs @@ -0,0 +1,325 @@ +#[cfg(target_arch = "x86")] +pub(crate) use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +pub(crate) use core::arch::x86_64::*; + +pub(crate) fn mm256_storeu_si256(output: &mut [i16], vector: __m256i) { + debug_assert_eq!(output.len(), 16); + unsafe { + _mm256_storeu_si256(output.as_mut_ptr() as *mut __m256i, vector); + } +} +pub(crate) fn mm_storeu_si128(output: &mut [i16], vector: __m128i) { + debug_assert_eq!(output.len(), 8); + unsafe { + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector); + } +} + +pub(crate) fn mm_storeu_bytes_si128(output: &mut [u8], vector: __m128i) { + debug_assert_eq!(output.len(), 16); + unsafe { + _mm_storeu_si128(output.as_mut_ptr() as *mut __m128i, vector); + } +} + +pub(crate) fn mm_loadu_si128(input: &[u8]) -> __m128i { + debug_assert_eq!(input.len(), 16); + unsafe { _mm_loadu_si128(input.as_ptr() as *const __m128i) } +} + +pub(crate) fn mm256_loadu_si256(input: &[i16]) -> __m256i { + debug_assert_eq!(input.len(), 16); + unsafe { _mm256_loadu_si256(input.as_ptr() as *const __m256i) } +} + +pub(crate) fn mm256_setzero_si256() -> __m256i { + unsafe { _mm256_setzero_si256() } +} + +pub(crate) fn mm_set_epi8( + byte15: i8, + byte14: i8, + byte13: i8, + byte12: i8, + byte11: i8, + byte10: i8, + byte9: i8, + byte8: i8, + byte7: i8, + byte6: i8, + byte5: i8, + byte4: i8, + byte3: i8, + byte2: i8, + byte1: i8, + byte0: i8, +) -> __m128i { + unsafe { + _mm_set_epi8( + byte15, byte14, byte13, byte12, byte11, byte10, byte9, byte8, byte7, byte6, byte5, + byte4, byte3, byte2, byte1, byte0, + ) + } +} + +pub(crate) fn mm256_set_epi8( + byte31: i8, + byte30: i8, + byte29: i8, + byte28: i8, + byte27: i8, + byte26: i8, + byte25: i8, + byte24: i8, + byte23: i8, + byte22: i8, + byte21: i8, + byte20: i8, + byte19: i8, + byte18: i8, + byte17: i8, + byte16: i8, + byte15: i8, + byte14: i8, + byte13: i8, + byte12: i8, + byte11: i8, + byte10: i8, + byte9: i8, + byte8: i8, + byte7: i8, + byte6: i8, + byte5: i8, + byte4: i8, + byte3: i8, + byte2: i8, + byte1: i8, + byte0: i8, +) -> __m256i { + unsafe { + _mm256_set_epi8( + byte31, byte30, byte29, byte28, byte27, byte26, byte25, byte24, byte23, byte22, byte21, + byte20, byte19, byte18, byte17, byte16, byte15, byte14, byte13, byte12, byte11, byte10, + byte9, byte8, byte7, byte6, byte5, byte4, byte3, byte2, byte1, byte0, + ) + } +} + +pub(crate) fn mm256_set1_epi16(constant: i16) -> __m256i { + unsafe { _mm256_set1_epi16(constant) } +} +pub(crate) fn mm256_set_epi16( + input15: i16, + input14: i16, + input13: i16, + input12: i16, + input11: i16, + input10: i16, + input9: i16, + input8: i16, + input7: i16, + input6: i16, + input5: i16, + input4: i16, + input3: i16, + input2: i16, + input1: i16, + input0: i16, +) -> __m256i { + unsafe { + _mm256_set_epi16( + input15, input14, input13, input12, input11, input10, input9, input8, input7, input6, + input5, input4, input3, input2, input1, input0, + ) + } +} + +pub(crate) fn mm_set1_epi16(constant: i16) -> __m128i { + unsafe { _mm_set1_epi16(constant) } +} + +pub(crate) fn mm256_set1_epi32(constant: i32) -> __m256i { + unsafe { _mm256_set1_epi32(constant) } +} +pub(crate) fn mm256_set_epi32( + input7: i32, + input6: i32, + input5: i32, + input4: i32, + input3: i32, + input2: i32, + input1: i32, + input0: i32, +) -> __m256i { + unsafe { + _mm256_set_epi32( + input7, input6, input5, input4, input3, input2, input1, input0, + ) + } +} + +pub(crate) fn mm_add_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_add_epi16(lhs, rhs) } +} +pub(crate) fn mm256_add_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_add_epi16(lhs, rhs) } +} +pub(crate) fn mm256_madd_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_madd_epi16(lhs, rhs) } +} +pub(crate) fn mm256_add_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_add_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_sub_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_sub_epi16(lhs, rhs) } +} +pub(crate) fn mm_sub_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_sub_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_mullo_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mullo_epi16(lhs, rhs) } +} + +pub(crate) fn mm_mullo_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_mullo_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_cmpgt_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_cmpgt_epi16(lhs, rhs) } +} + +pub(crate) fn mm_mulhi_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_mulhi_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_mullo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mullo_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_mulhi_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mulhi_epi16(lhs, rhs) } +} + +pub(crate) fn mm256_mul_epu32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_mul_epu32(lhs, rhs) } +} + +pub(crate) fn mm256_and_si256(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_and_si256(lhs, rhs) } +} + +pub(crate) fn mm256_xor_si256(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_xor_si256(lhs, rhs) } +} + +pub(crate) fn mm256_srai_epi16(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); + unsafe { _mm256_srai_epi16(vector, SHIFT_BY) } +} +pub(crate) fn mm256_srai_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_srai_epi32(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_srli_epi16(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); + unsafe { _mm256_srli_epi16(vector, SHIFT_BY) } +} +pub(crate) fn mm256_srli_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_srli_epi32(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_srli_epi64(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64); + unsafe { _mm256_srli_epi64(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_slli_epi16(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 16); + unsafe { _mm256_slli_epi16(vector, SHIFT_BY) } +} + +pub(crate) fn mm256_slli_epi32(vector: __m256i) -> __m256i { + debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 32); + unsafe { _mm256_slli_epi32(vector, SHIFT_BY) } +} + +pub(crate) fn mm_shuffle_epi8(vector: __m128i, control: __m128i) -> __m128i { + unsafe { _mm_shuffle_epi8(vector, control) } +} +pub(crate) fn mm256_shuffle_epi8(vector: __m256i, control: __m256i) -> __m256i { + unsafe { _mm256_shuffle_epi8(vector, control) } +} +pub(crate) fn mm256_shuffle_epi32(vector: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_shuffle_epi32(vector, CONTROL) } +} + +pub(crate) fn mm256_permute4x64_epi64(vector: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_permute4x64_epi64(vector, CONTROL) } +} + +pub(crate) fn mm256_unpackhi_epi64(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_unpackhi_epi64(lhs, rhs) } +} + +pub(crate) fn mm256_unpacklo_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_unpacklo_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_unpackhi_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_unpackhi_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_castsi256_si128(vector: __m256i) -> __m128i { + unsafe { _mm256_castsi256_si128(vector) } +} +pub(crate) fn mm256_castsi128_si256(vector: __m128i) -> __m256i { + unsafe { _mm256_castsi128_si256(vector) } +} + +pub(crate) fn mm256_cvtepi16_epi32(vector: __m128i) -> __m256i { + unsafe { _mm256_cvtepi16_epi32(vector) } +} + +pub(crate) fn mm_packs_epi16(lhs: __m128i, rhs: __m128i) -> __m128i { + unsafe { _mm_packs_epi16(lhs, rhs) } +} +pub(crate) fn mm256_packs_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { + unsafe { _mm256_packs_epi32(lhs, rhs) } +} + +pub(crate) fn mm256_extracti128_si256(vector: __m256i) -> __m128i { + debug_assert!(CONTROL == 0 || CONTROL == 1); + unsafe { _mm256_extracti128_si256(vector, CONTROL) } +} + +pub(crate) fn mm256_inserti128_si256( + vector: __m256i, + vector_i128: __m128i, +) -> __m256i { + debug_assert!(CONTROL == 0 || CONTROL == 1); + unsafe { _mm256_inserti128_si256(vector, vector_i128, CONTROL) } +} + +pub(crate) fn mm256_blend_epi16(lhs: __m256i, rhs: __m256i) -> __m256i { + debug_assert!(CONTROL >= 0 && CONTROL < 256); + unsafe { _mm256_blend_epi16(lhs, rhs, CONTROL) } +} + +pub(crate) fn mm_movemask_epi8(vector: __m128i) -> i32 { + unsafe { _mm_movemask_epi8(vector) } +} + +pub(crate) fn mm256_permutevar8x32_epi32(vector: __m256i, control: __m256i) -> __m256i { + unsafe { _mm256_permutevar8x32_epi32(vector, control) } +} + +pub(crate) fn mm256_sllv_epi32(vector: __m256i, counts: __m256i) -> __m256i { + unsafe { _mm256_sllv_epi32(vector, counts) } +} diff --git a/polynomials-avx2/src/lib.rs b/polynomials-avx2/src/lib.rs index 1c4cde0da..e24519afd 100644 --- a/polynomials-avx2/src/lib.rs +++ b/polynomials-avx2/src/lib.rs @@ -1,13 +1,17 @@ -#[cfg(target_arch = "x86")] -use core::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use core::arch::x86_64::*; -use libcrux_traits::{Operations, FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; +use crate::intrinsics::*; +use libcrux_traits::Operations; +#[cfg(test)] mod debug; -mod portable; -const BARRETT_MULTIPLIER: i16 = 20159; +mod intrinsics; + +mod arithmetic; +mod compress; +mod ntt; +mod portable; +mod sampling; +mod serialize; #[derive(Clone, Copy)] pub struct SIMD256Vector { @@ -17,977 +21,24 @@ pub struct SIMD256Vector { #[inline(always)] fn zero() -> SIMD256Vector { SIMD256Vector { - elements: unsafe { _mm256_setzero_si256() }, + elements: mm256_setzero_si256(), } } #[inline(always)] fn to_i16_array(v: SIMD256Vector) -> [i16; 16] { - let mut out = [0i16; 16]; - - unsafe { - _mm256_storeu_si256(out.as_mut_ptr() as *mut __m256i, v.elements); - } + let mut output = [0i16; 16]; + mm256_storeu_si256(&mut output[..], v.elements); - out + output } #[inline(always)] -fn from_i16_array(array: [i16; 16]) -> SIMD256Vector { +fn from_i16_array(array: &[i16]) -> SIMD256Vector { SIMD256Vector { - elements: unsafe { _mm256_loadu_si256(array.as_ptr() as *const __m256i) }, + elements: mm256_loadu_si256(array), } } -#[inline(always)] -fn add(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { - lhs.elements = unsafe { _mm256_add_epi16(lhs.elements, rhs.elements) }; - - lhs -} - -#[inline(always)] -fn sub(mut lhs: SIMD256Vector, rhs: &SIMD256Vector) -> SIMD256Vector { - lhs.elements = unsafe { _mm256_sub_epi16(lhs.elements, rhs.elements) }; - - lhs -} - -#[inline(always)] -fn multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - - _mm256_mullo_epi16(v.elements, c) - }; - - v -} - -#[inline(always)] -fn bitwise_and_with_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - - _mm256_and_si256(v.elements, c) - }; - - v -} - -#[inline(always)] -fn shift_right(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { _mm256_srai_epi16(v.elements, SHIFT_BY) }; - - v -} - -#[inline(always)] -fn shift_left(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { _mm256_slli_epi16(v.elements, SHIFT_BY) }; - - v -} - -#[inline(always)] -fn cond_subtract_3329(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus = _mm256_set1_epi16(FIELD_MODULUS); - - let v_minus_field_modulus = _mm256_sub_epi16(v.elements, field_modulus); - - let sign_mask = _mm256_srai_epi16(v_minus_field_modulus, 15); - let conditional_add_field_modulus = _mm256_and_si256(sign_mask, field_modulus); - - _mm256_add_epi16(v_minus_field_modulus, conditional_add_field_modulus) - }; - - v -} - -#[inline(always)] -fn barrett_reduce(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let t = _mm256_mulhi_epi16(v.elements, _mm256_set1_epi16(BARRETT_MULTIPLIER)); - let t = _mm256_add_epi16(t, _mm256_set1_epi16(512)); - - let quotient = _mm256_srai_epi16(t, 10); - - let quotient_times_field_modulus = - _mm256_mullo_epi16(quotient, _mm256_set1_epi16(FIELD_MODULUS)); - - _mm256_sub_epi16(v.elements, quotient_times_field_modulus) - }; - - v -} - -#[inline(always)] -fn montgomery_multiply_by_constant(mut v: SIMD256Vector, c: i16) -> SIMD256Vector { - v.elements = unsafe { - let c = _mm256_set1_epi16(c); - let value_low = _mm256_mullo_epi16(v.elements, c); - - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); - - let value_high = _mm256_mulhi_epi16(v.elements, c); - - _mm256_sub_epi16(value_high, k_times_modulus) - }; - - v -} - -#[inline(always)] -fn montgomery_multiply_by_constants(mut v: __m256i, c: __m256i) -> __m256i { - v = unsafe { - let value_low = _mm256_mullo_epi16(v, c); - - let k = _mm256_mullo_epi16( - value_low, - _mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi16(FIELD_MODULUS)); - - let value_high = _mm256_mulhi_epi16(v, c); - - _mm256_sub_epi16(value_high, k_times_modulus) - }; - - v -} - -#[inline(always)] -fn montgomery_reduce_i32s(mut v: __m256i) -> __m256i { - v = unsafe { - let k = _mm256_mullo_epi16( - v, - _mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), - ); - let k_times_modulus = _mm256_mulhi_epi16(k, _mm256_set1_epi32(FIELD_MODULUS as i32)); - - let value_high = _mm256_srli_epi32(v, 16); - - let result = _mm256_sub_epi16(value_high, k_times_modulus); - - let result = _mm256_slli_epi32(result, 16); - _mm256_srai_epi32(result, 16) - }; - - v -} - -#[inline(always)] -fn montgomery_multiply_m128i_by_constants(mut v: __m128i, c: __m128i) -> __m128i { - v = unsafe { - let value_low = _mm_mullo_epi16(v, c); - - let k = _mm_mullo_epi16( - value_low, - _mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), - ); - let k_times_modulus = _mm_mulhi_epi16(k, _mm_set1_epi16(FIELD_MODULUS)); - - let value_high = _mm_mulhi_epi16(v, c); - - _mm_sub_epi16(value_high, k_times_modulus) - }; - - v -} - -#[inline(always)] -fn compress_1(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus_halved = _mm256_set1_epi16((FIELD_MODULUS - 1) / 2); - let field_modulus_quartered = _mm256_set1_epi16((FIELD_MODULUS - 1) / 4); - - let shifted = _mm256_sub_epi16(field_modulus_halved, v.elements); - let mask = _mm256_srai_epi16(shifted, 15); - - let shifted_to_positive = _mm256_xor_si256(mask, shifted); - let shifted_to_positive_in_range = - _mm256_sub_epi16(shifted_to_positive, field_modulus_quartered); - - _mm256_srli_epi16(shifted_to_positive_in_range, 15) - }; - - v -} - -// This implementation was taken from: -// https://ei1333.github.io/library/math/combinatorics/vectorize-mod-int.hpp.html -// -// TODO: Optimize this implementation if performance numbers suggest doing so. -#[inline(always)] -fn mulhi_mm256_epi32(lhs: __m256i, rhs: __m256i) -> __m256i { - let result = unsafe { - let prod02 = _mm256_mul_epu32(lhs, rhs); - let prod13 = _mm256_mul_epu32( - _mm256_shuffle_epi32(lhs, 0b11_11_01_01), - _mm256_shuffle_epi32(rhs, 0b11_11_01_01), - ); - - _mm256_unpackhi_epi64( - _mm256_unpacklo_epi32(prod02, prod13), - _mm256_unpackhi_epi32(prod02, prod13), - ) - }; - - result -} - -#[inline(always)] -fn compress(mut v: SIMD256Vector) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus_halved = _mm256_set1_epi32(((FIELD_MODULUS as i32) - 1) / 2); - let compression_factor = _mm256_set1_epi32(10_321_340); - let coefficient_bits_mask = _mm256_set1_epi32((1 << COEFFICIENT_BITS) - 1); - - // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(v.elements); - let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); - - let compressed_low = _mm256_slli_epi32(coefficients_low, COEFFICIENT_BITS); - let compressed_low = _mm256_add_epi32(compressed_low, field_modulus_halved); - - let compressed_low = mulhi_mm256_epi32(compressed_low, compression_factor); - let compressed_low = _mm256_srli_epi32(compressed_low, 35 - 32); - let compressed_low = _mm256_and_si256(compressed_low, coefficient_bits_mask); - - // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(v.elements, 1); - let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); - - let compressed_high = _mm256_slli_epi32(coefficients_high, COEFFICIENT_BITS); - let compressed_high = _mm256_add_epi32(compressed_high, field_modulus_halved); - - let compressed_high = mulhi_mm256_epi32(compressed_high, compression_factor); - let compressed_high = _mm256_srli_epi32(compressed_high, 35 - 32); - let compressed_high = _mm256_and_si256(compressed_high, coefficient_bits_mask); - - // Combine them - let compressed = _mm256_packs_epi32(compressed_low, compressed_high); - - _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) - }; - - v -} - -#[inline(always)] -fn decompress_ciphertext_coefficient( - mut v: SIMD256Vector, -) -> SIMD256Vector { - v.elements = unsafe { - let field_modulus = _mm256_set1_epi32(FIELD_MODULUS as i32); - let two_pow_coefficient_bits = _mm256_set1_epi32(1 << COEFFICIENT_BITS); - - // Compress the first 8 coefficients - let coefficients_low = _mm256_castsi256_si128(v.elements); - let coefficients_low = _mm256_cvtepi16_epi32(coefficients_low); - - let decompressed_low = _mm256_mullo_epi32(coefficients_low, field_modulus); - let decompressed_low = _mm256_slli_epi32(decompressed_low, 1); - let decompressed_low = _mm256_add_epi32(decompressed_low, two_pow_coefficient_bits); - - // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack - // of support for const generic expressions. - let decompressed_low = _mm256_srli_epi32(decompressed_low, COEFFICIENT_BITS); - let decompressed_low = _mm256_srli_epi32(decompressed_low, 1); - - // Compress the next 8 coefficients - let coefficients_high = _mm256_extracti128_si256(v.elements, 1); - let coefficients_high = _mm256_cvtepi16_epi32(coefficients_high); - - let decompressed_high = _mm256_mullo_epi32(coefficients_high, field_modulus); - let decompressed_high = _mm256_slli_epi32(decompressed_high, 1); - let decompressed_high = _mm256_add_epi32(decompressed_high, two_pow_coefficient_bits); - - // We can't shift in one go by (COEFFICIENT_BITS + 1) due to the lack - // of support for const generic expressions. - let decompressed_high = _mm256_srli_epi32(decompressed_high, COEFFICIENT_BITS); - let decompressed_high = _mm256_srli_epi32(decompressed_high, 1); - - // Combine them - let compressed = _mm256_packs_epi32(decompressed_low, decompressed_high); - - _mm256_permute4x64_epi64(compressed, 0b11_01_10_00) - }; - - v -} - -#[inline(always)] -fn ntt_layer_1_step( - mut v: SIMD256Vector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> SIMD256Vector { - v.elements = unsafe { - let zetas = _mm256_set_epi16( - -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, - zeta1, -zeta0, -zeta0, zeta0, zeta0, - ); - - let rhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); - let rhs = montgomery_multiply_by_constants(rhs, zetas); - - let lhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); - - _mm256_add_epi16(lhs, rhs) - }; - - v -} - -#[inline(always)] -fn ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - v.elements = unsafe { - let zetas = _mm256_set_epi16( - -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, - -zeta0, zeta0, zeta0, zeta0, zeta0, - ); - - let rhs = _mm256_shuffle_epi32(v.elements, 0b11_10_11_10); - let rhs = montgomery_multiply_by_constants(rhs, zetas); - - let lhs = _mm256_shuffle_epi32(v.elements, 0b01_00_01_00); - - _mm256_add_epi16(lhs, rhs) - }; - - v -} - -#[inline(always)] -fn ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - v.elements = unsafe { - let rhs = _mm256_extracti128_si256(v.elements, 1); - let rhs = montgomery_multiply_m128i_by_constants(rhs, _mm_set1_epi16(zeta)); - - let lhs = _mm256_castsi256_si128(v.elements); - - let lower_coefficients = _mm_add_epi16(lhs, rhs); - let upper_coefficients = _mm_sub_epi16(lhs, rhs); - - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); - - combined - }; - - v -} - -#[inline(always)] -fn inv_ntt_layer_1_step( - mut v: SIMD256Vector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_shuffle_epi32(v.elements, 0b11_11_01_01); - - let rhs = _mm256_shuffle_epi32(v.elements, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), - ); - - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, - ), - ); - - let sum = barrett_reduce(SIMD256Vector { elements: sum }).elements; - - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_0_0_1_1_0_0) - }; - - v -} - -#[inline(always)] -fn inv_ntt_layer_2_step(mut v: SIMD256Vector, zeta0: i16, zeta1: i16) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_permute4x64_epi64(v.elements, 0b11_11_01_01); - - let rhs = _mm256_permute4x64_epi64(v.elements, 0b10_10_00_00); - let rhs = _mm256_mullo_epi16( - rhs, - _mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), - ); - - let sum = _mm256_add_epi16(lhs, rhs); - let sum_times_zetas = montgomery_multiply_by_constants( - sum, - _mm256_set_epi16( - zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, - ), - ); - - _mm256_blend_epi16(sum, sum_times_zetas, 0b1_1_1_1_0_0_0_0) - }; - - v -} - -#[inline(always)] -fn inv_ntt_layer_3_step(mut v: SIMD256Vector, zeta: i16) -> SIMD256Vector { - v.elements = unsafe { - let lhs = _mm256_extracti128_si256(v.elements, 1); - let rhs = _mm256_castsi256_si128(v.elements); - - let lower_coefficients = _mm_add_epi16(lhs, rhs); - - let upper_coefficients = _mm_sub_epi16(lhs, rhs); - let upper_coefficients = - montgomery_multiply_m128i_by_constants(upper_coefficients, _mm_set1_epi16(zeta)); - - let combined = _mm256_castsi128_si256(lower_coefficients); - let combined = _mm256_inserti128_si256(combined, upper_coefficients, 1); - - combined - }; - - v -} - -#[inline(always)] -fn ntt_multiply( - lhs: &SIMD256Vector, - rhs: &SIMD256Vector, - zeta0: i16, - zeta1: i16, - zeta2: i16, - zeta3: i16, -) -> SIMD256Vector { - let products = unsafe { - // Compute the first term of the product - let shuffle_with = _mm256_set_epi8( - 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, - 12, 9, 8, 5, 4, 1, 0, - ); - const PERMUTE_WITH: i32 = 0b11_01_10_00; - - // Prepare the left hand side - let lhs_shuffled = _mm256_shuffle_epi8(lhs.elements, shuffle_with); - let lhs_shuffled = _mm256_permute4x64_epi64(lhs_shuffled, PERMUTE_WITH); - - let lhs_evens = _mm256_castsi256_si128(lhs_shuffled); - let lhs_evens = _mm256_cvtepi16_epi32(lhs_evens); - - let lhs_odds = _mm256_extracti128_si256(lhs_shuffled, 1); - let lhs_odds = _mm256_cvtepi16_epi32(lhs_odds); - - // Prepare the right hand side - let rhs_shuffled = _mm256_shuffle_epi8(rhs.elements, shuffle_with); - let rhs_shuffled = _mm256_permute4x64_epi64(rhs_shuffled, PERMUTE_WITH); - - let rhs_evens = _mm256_castsi256_si128(rhs_shuffled); - let rhs_evens = _mm256_cvtepi16_epi32(rhs_evens); - - let rhs_odds = _mm256_extracti128_si256(rhs_shuffled, 1); - let rhs_odds = _mm256_cvtepi16_epi32(rhs_odds); - - // Start operating with them - let left = _mm256_mullo_epi32(lhs_evens, rhs_evens); - - let right = _mm256_mullo_epi32(lhs_odds, rhs_odds); - let right = montgomery_reduce_i32s(right); - let right = _mm256_mullo_epi32( - right, - _mm256_set_epi32( - -(zeta3 as i32), - zeta3 as i32, - -(zeta2 as i32), - zeta2 as i32, - -(zeta1 as i32), - zeta1 as i32, - -(zeta0 as i32), - zeta0 as i32, - ), - ); - - let products_left = _mm256_add_epi32(left, right); - let products_left = montgomery_reduce_i32s(products_left); - - // Compute the second term of the product - let rhs_adjacent_swapped = _mm256_shuffle_epi8( - rhs.elements, - _mm256_set_epi8( - 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, - 5, 4, 7, 6, 1, 0, 3, 2, - ), - ); - let products_right = _mm256_madd_epi16(lhs.elements, rhs_adjacent_swapped); - let products_right = montgomery_reduce_i32s(products_right); - let products_right = _mm256_slli_epi32(products_right, 16); - - // Combine them into one vector - _mm256_blend_epi16(products_left, products_right, 0b1_0_1_0_1_0_1_0) - }; - - SIMD256Vector { elements: products } -} - -#[inline(always)] -fn serialize_1(v: SIMD256Vector) -> [u8; 2] { - let mut serialized = [0u8; 2]; - - let bits_packed = unsafe { - let lsb_shifted_up = _mm256_slli_epi16(v.elements, 7); - - let low_lanes = _mm256_castsi256_si128(lsb_shifted_up); - let high_lanes = _mm256_extracti128_si256(lsb_shifted_up, 1); - - let msbs = _mm_packus_epi16(low_lanes, high_lanes); - - _mm_movemask_epi8(msbs) - }; - - serialized[0] = bits_packed as u8; - serialized[1] = (bits_packed >> 8) as u8; - - serialized -} - -#[inline(always)] -fn deserialize_1(bytes: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsb_to_msb = _mm256_set_epi16( - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - 1 << 0, - 1 << 1, - 1 << 2, - 1 << 3, - 1 << 4, - 1 << 5, - 1 << 6, - 1 << 7, - ); - - let coefficients = _mm256_set_epi16( - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsb_to_msb); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 7); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 1) - 1)) - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn serialize_4(v: SIMD256Vector) -> [u8; 8] { - let mut serialized = [0u8; 16]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - 1 << 4, - 1, - ), - ); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_2_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, - ), - ); - - let combined = _mm256_permutevar8x32_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0), - ); - let combined = _mm256_castsi256_si128(combined); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, combined); - } - - serialized[0..8].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_4(bytes: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - - let coefficients = _mm256_set_epi16( - bytes[7] as i16, - bytes[7] as i16, - bytes[6] as i16, - bytes[6] as i16, - bytes[5] as i16, - bytes[5] as i16, - bytes[4] as i16, - bytes[4] as i16, - bytes[3] as i16, - bytes[3] as i16, - bytes[2] as i16, - bytes[2] as i16, - bytes[1] as i16, - bytes[1] as i16, - bytes[0] as i16, - bytes[0] as i16, - ); - - let coefficients_in_msb = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients_in_lsb = _mm256_srli_epi16(coefficients_in_msb, 4); - - _mm256_and_si256(coefficients_in_lsb, _mm256_set1_epi16((1 << 4) - 1)) - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn serialize_5(v: SIMD256Vector) -> [u8; 10] { - let mut serialized = [0u8; 32]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - 1 << 5, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 22); - - let adjacent_8_combined = _mm256_shuffle_epi32(adjacent_4_combined, 0b00_00_10_00); - let adjacent_8_combined = _mm256_sllv_epi32( - adjacent_8_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_8_combined = _mm256_srli_epi64(adjacent_8_combined, 12); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(5) as *mut __m128i, upper_8); - } - - serialized[0..10].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_5(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_5(v); - - from_i16_array(portable::to_i16_array(output)) -} - -#[inline(always)] -fn serialize_10(v: SIMD256Vector) -> [u8; 20] { - let mut serialized = [0u8; 32]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - 1 << 10, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 12); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, - 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(10) as *mut __m128i, upper_8); - } - - serialized[0..20].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_10(v: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - 1 << 0, - 1 << 2, - 1 << 4, - 1 << 6, - ); - - let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(4) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 6); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 10) - 1)); - - coefficients - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn serialize_11(v: SIMD256Vector) -> [u8; 22] { - let input = portable::from_i16_array(to_i16_array(v)); - - portable::serialize_11(input) -} - -#[inline(always)] -fn deserialize_11(v: &[u8]) -> SIMD256Vector { - let output = portable::deserialize_11(v); - - from_i16_array(portable::to_i16_array(output)) -} - -#[inline(always)] -fn serialize_12(v: SIMD256Vector) -> [u8; 24] { - let mut serialized = [0u8; 32]; - - unsafe { - let adjacent_2_combined = _mm256_madd_epi16( - v.elements, - _mm256_set_epi16( - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - 1 << 12, - 1, - ), - ); - - let adjacent_4_combined = _mm256_sllv_epi32( - adjacent_2_combined, - _mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8), - ); - let adjacent_4_combined = _mm256_srli_epi64(adjacent_4_combined, 8); - - let adjacent_8_combined = _mm256_shuffle_epi8( - adjacent_4_combined, - _mm256_set_epi8( - -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, - 10, 9, 8, 5, 4, 3, 2, 1, 0, - ), - ); - - let lower_8 = _mm256_castsi256_si128(adjacent_8_combined); - let upper_8 = _mm256_extracti128_si256(adjacent_8_combined, 1); - - _mm_storeu_si128(serialized.as_mut_ptr() as *mut __m128i, lower_8); - _mm_storeu_si128(serialized.as_mut_ptr().offset(12) as *mut __m128i, upper_8); - } - - serialized[0..24].try_into().unwrap() -} - -#[inline(always)] -fn deserialize_12(v: &[u8]) -> SIMD256Vector { - let deserialized = unsafe { - let shift_lsbs_to_msbs = _mm256_set_epi16( - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - 1 << 0, - 1 << 4, - ); - - let lower_coefficients = _mm_loadu_si128(v.as_ptr() as *const __m128i); - let lower_coefficients = _mm_shuffle_epi8( - lower_coefficients, - _mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), - ); - let upper_coefficients = _mm_loadu_si128(v.as_ptr().offset(8) as *const __m128i); - let upper_coefficients = _mm_shuffle_epi8( - upper_coefficients, - _mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), - ); - - let coefficients = _mm256_castsi128_si256(lower_coefficients); - let coefficients = _mm256_inserti128_si256(coefficients, upper_coefficients, 1); - - let coefficients = _mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); - let coefficients = _mm256_srli_epi16(coefficients, 4); - let coefficients = _mm256_and_si256(coefficients, _mm256_set1_epi16((1 << 12) - 1)); - - coefficients - }; - - SIMD256Vector { - elements: deserialized, - } -} - -#[inline(always)] -fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - portable::rej_sample(a) -} - impl Operations for SIMD256Vector { fn ZERO() -> Self { zero() @@ -997,80 +48,120 @@ impl Operations for SIMD256Vector { to_i16_array(v) } - fn from_i16_array(array: [i16; 16]) -> Self { + fn from_i16_array(array: &[i16]) -> Self { from_i16_array(array) } fn add(lhs: Self, rhs: &Self) -> Self { - add(lhs, rhs) + Self { + elements: arithmetic::add(lhs.elements, rhs.elements), + } } fn sub(lhs: Self, rhs: &Self) -> Self { - sub(lhs, rhs) + Self { + elements: arithmetic::sub(lhs.elements, rhs.elements), + } } fn multiply_by_constant(v: Self, c: i16) -> Self { - multiply_by_constant(v, c) + Self { + elements: arithmetic::multiply_by_constant(v.elements, c), + } } - fn bitwise_and_with_constant(v: Self, c: i16) -> Self { - bitwise_and_with_constant(v, c) + fn bitwise_and_with_constant(vector: Self, constant: i16) -> Self { + Self { + elements: arithmetic::bitwise_and_with_constant(vector.elements, constant), + } } - fn shift_right(v: Self) -> Self { - shift_right::<{ SHIFT_BY }>(v) + fn shift_right(vector: Self) -> Self { + Self { + elements: arithmetic::shift_right::<{ SHIFT_BY }>(vector.elements), + } } - fn shift_left(v: Self) -> Self { - shift_left::<{ SHIFT_BY }>(v) + fn shift_left(vector: Self) -> Self { + Self { + elements: arithmetic::shift_left::<{ SHIFT_BY }>(vector.elements), + } } - fn cond_subtract_3329(v: Self) -> Self { - cond_subtract_3329(v) + fn cond_subtract_3329(vector: Self) -> Self { + Self { + elements: arithmetic::cond_subtract_3329(vector.elements), + } } - fn barrett_reduce(v: Self) -> Self { - barrett_reduce(v) + fn barrett_reduce(vector: Self) -> Self { + Self { + elements: arithmetic::barrett_reduce(vector.elements), + } } - fn montgomery_multiply_by_constant(v: Self, r: i16) -> Self { - montgomery_multiply_by_constant(v, r) + fn montgomery_multiply_by_constant(vector: Self, constant: i16) -> Self { + Self { + elements: arithmetic::montgomery_multiply_by_constant(vector.elements, constant), + } } - fn compress_1(v: Self) -> Self { - compress_1(v) + fn compress_1(vector: Self) -> Self { + Self { + elements: compress::compress_message_coefficient(vector.elements), + } } - fn compress(v: Self) -> Self { - compress::(v) + fn compress(vector: Self) -> Self { + Self { + elements: compress::compress_ciphertext_coefficient::( + vector.elements, + ), + } } - fn decompress_ciphertext_coefficient(v: Self) -> Self { - decompress_ciphertext_coefficient::(v) + fn decompress_ciphertext_coefficient(vector: Self) -> Self { + Self { + elements: compress::decompress_ciphertext_coefficient::( + vector.elements, + ), + } } - fn ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { - ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) + fn ntt_layer_1_step(vector: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { + Self { + elements: ntt::ntt_layer_1_step(vector.elements, zeta0, zeta1, zeta2, zeta3), + } } - fn ntt_layer_2_step(a: Self, zeta0: i16, zeta1: i16) -> Self { - ntt_layer_2_step(a, zeta0, zeta1) + fn ntt_layer_2_step(vector: Self, zeta0: i16, zeta1: i16) -> Self { + Self { + elements: ntt::ntt_layer_2_step(vector.elements, zeta0, zeta1), + } } - fn ntt_layer_3_step(a: Self, zeta: i16) -> Self { - ntt_layer_3_step(a, zeta) + fn ntt_layer_3_step(vector: Self, zeta: i16) -> Self { + Self { + elements: ntt::ntt_layer_3_step(vector.elements, zeta), + } } - fn inv_ntt_layer_1_step(a: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { - inv_ntt_layer_1_step(a, zeta0, zeta1, zeta2, zeta3) + fn inv_ntt_layer_1_step(vector: Self, zeta0: i16, zeta1: i16, zeta2: i16, zeta3: i16) -> Self { + Self { + elements: ntt::inv_ntt_layer_1_step(vector.elements, zeta0, zeta1, zeta2, zeta3), + } } - fn inv_ntt_layer_2_step(a: Self, zeta0: i16, zeta1: i16) -> Self { - inv_ntt_layer_2_step(a, zeta0, zeta1) + fn inv_ntt_layer_2_step(vector: Self, zeta0: i16, zeta1: i16) -> Self { + Self { + elements: ntt::inv_ntt_layer_2_step(vector.elements, zeta0, zeta1), + } } - fn inv_ntt_layer_3_step(a: Self, zeta: i16) -> Self { - inv_ntt_layer_3_step(a, zeta) + fn inv_ntt_layer_3_step(vector: Self, zeta: i16) -> Self { + Self { + elements: ntt::inv_ntt_layer_3_step(vector.elements, zeta), + } } fn ntt_multiply( @@ -1081,58 +172,72 @@ impl Operations for SIMD256Vector { zeta2: i16, zeta3: i16, ) -> Self { - ntt_multiply(lhs, rhs, zeta0, zeta1, zeta2, zeta3) + Self { + elements: ntt::ntt_multiply(lhs.elements, rhs.elements, zeta0, zeta1, zeta2, zeta3), + } } - fn serialize_1(a: Self) -> [u8; 2] { - serialize_1(a) + fn serialize_1(vector: Self) -> [u8; 2] { + serialize::serialize_1(vector.elements) } - fn deserialize_1(a: &[u8]) -> Self { - deserialize_1(a) + fn deserialize_1(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_1(bytes), + } } - fn serialize_4(a: Self) -> [u8; 8] { - serialize_4(a) + fn serialize_4(vector: Self) -> [u8; 8] { + serialize::serialize_4(vector.elements) } - fn deserialize_4(a: &[u8]) -> Self { - deserialize_4(a) + fn deserialize_4(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_4(bytes), + } } - fn serialize_5(a: Self) -> [u8; 10] { - serialize_5(a) + fn serialize_5(vector: Self) -> [u8; 10] { + serialize::serialize_5(vector.elements) } - fn deserialize_5(a: &[u8]) -> Self { - deserialize_5(a) + fn deserialize_5(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_5(bytes), + } } - fn serialize_10(a: Self) -> [u8; 20] { - serialize_10(a) + fn serialize_10(vector: Self) -> [u8; 20] { + serialize::serialize_10(vector.elements) } - fn deserialize_10(a: &[u8]) -> Self { - deserialize_10(a) + fn deserialize_10(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_10(bytes), + } } - fn serialize_11(a: Self) -> [u8; 22] { - serialize_11(a) + fn serialize_11(vector: Self) -> [u8; 22] { + serialize::serialize_11(vector.elements) } - fn deserialize_11(a: &[u8]) -> Self { - deserialize_11(a) + fn deserialize_11(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_11(bytes), + } } - fn serialize_12(a: Self) -> [u8; 24] { - serialize_12(a) + fn serialize_12(vector: Self) -> [u8; 24] { + serialize::serialize_12(vector.elements) } - fn deserialize_12(a: &[u8]) -> Self { - deserialize_12(a) + fn deserialize_12(bytes: &[u8]) -> Self { + Self { + elements: serialize::deserialize_12(bytes), + } } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rej_sample(a) + fn rej_sample(input: &[u8], output: &mut [i16]) -> usize { + sampling::rejection_sample(input, output) } } diff --git a/polynomials-avx2/src/ntt.rs b/polynomials-avx2/src/ntt.rs new file mode 100644 index 000000000..2ebb1561d --- /dev/null +++ b/polynomials-avx2/src/ntt.rs @@ -0,0 +1,244 @@ +use crate::intrinsics::*; + +use crate::arithmetic; +use libcrux_traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R}; + +#[inline(always)] +fn montgomery_multiply_by_constants(v: __m256i, c: __m256i) -> __m256i { + let value_low = mm256_mullo_epi16(v, c); + + let k = mm256_mullo_epi16( + value_low, + mm256_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi16(FIELD_MODULUS)); + + let value_high = mm256_mulhi_epi16(v, c); + + mm256_sub_epi16(value_high, k_times_modulus) +} + +#[inline(always)] +fn montgomery_reduce_i32s(v: __m256i) -> __m256i { + let k = mm256_mullo_epi16( + v, + mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32), + ); + let k_times_modulus = mm256_mulhi_epi16(k, mm256_set1_epi32(FIELD_MODULUS as i32)); + + let value_high = mm256_srli_epi32::<16>(v); + + let result = mm256_sub_epi16(value_high, k_times_modulus); + + let result = mm256_slli_epi32::<16>(result); + + mm256_srai_epi32::<16>(result) +} + +#[inline(always)] +fn montgomery_multiply_m128i_by_constants(v: __m128i, c: __m128i) -> __m128i { + let value_low = mm_mullo_epi16(v, c); + + let k = mm_mullo_epi16( + value_low, + mm_set1_epi16(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i16), + ); + let k_times_modulus = mm_mulhi_epi16(k, mm_set1_epi16(FIELD_MODULUS)); + + let value_high = mm_mulhi_epi16(v, c); + + mm_sub_epi16(value_high, k_times_modulus) +} + +#[inline(always)] +pub(crate) fn ntt_layer_1_step( + vector: __m256i, + zeta0: i16, + zeta1: i16, + zeta2: i16, + zeta3: i16, +) -> __m256i { + let zetas = mm256_set_epi16( + -zeta3, -zeta3, zeta3, zeta3, -zeta2, -zeta2, zeta2, zeta2, -zeta1, -zeta1, zeta1, zeta1, + -zeta0, -zeta0, zeta0, zeta0, + ); + + let rhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector); + let rhs = montgomery_multiply_by_constants(rhs, zetas); + + let lhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector); + + mm256_add_epi16(lhs, rhs) +} + +#[inline(always)] +pub(crate) fn ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + let zetas = mm256_set_epi16( + -zeta1, -zeta1, -zeta1, -zeta1, zeta1, zeta1, zeta1, zeta1, -zeta0, -zeta0, -zeta0, -zeta0, + zeta0, zeta0, zeta0, zeta0, + ); + + let rhs = mm256_shuffle_epi32::<0b11_10_11_10>(vector); + let rhs = montgomery_multiply_by_constants(rhs, zetas); + + let lhs = mm256_shuffle_epi32::<0b01_00_01_00>(vector); + + mm256_add_epi16(lhs, rhs) +} + +#[inline(always)] +pub(crate) fn ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i { + let rhs = mm256_extracti128_si256::<1>(vector); + let rhs = montgomery_multiply_m128i_by_constants(rhs, mm_set1_epi16(zeta)); + + let lhs = mm256_castsi256_si128(vector); + + let lower_coefficients = mm_add_epi16(lhs, rhs); + let upper_coefficients = mm_sub_epi16(lhs, rhs); + + let combined = mm256_castsi128_si256(lower_coefficients); + let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients); + + combined +} + +#[inline(always)] +pub(crate) fn inv_ntt_layer_1_step( + vector: __m256i, + zeta0: i16, + zeta1: i16, + zeta2: i16, + zeta3: i16, +) -> __m256i { + let lhs = mm256_shuffle_epi32::<0b11_11_01_01>(vector); + + let rhs = mm256_shuffle_epi32::<0b10_10_00_00>(vector); + let rhs = mm256_mullo_epi16( + rhs, + mm256_set_epi16(-1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1), + ); + + let sum = mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + mm256_set_epi16( + zeta3, zeta3, 0, 0, zeta2, zeta2, 0, 0, zeta1, zeta1, 0, 0, zeta0, zeta0, 0, 0, + ), + ); + + let sum = arithmetic::barrett_reduce(sum); + + mm256_blend_epi16::<0b1_1_0_0_1_1_0_0>(sum, sum_times_zetas) +} + +#[inline(always)] +pub(crate) fn inv_ntt_layer_2_step(vector: __m256i, zeta0: i16, zeta1: i16) -> __m256i { + let lhs = mm256_permute4x64_epi64::<0b11_11_01_01>(vector); + + let rhs = mm256_permute4x64_epi64::<0b10_10_00_00>(vector); + let rhs = mm256_mullo_epi16( + rhs, + mm256_set_epi16(-1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1), + ); + + let sum = mm256_add_epi16(lhs, rhs); + let sum_times_zetas = montgomery_multiply_by_constants( + sum, + mm256_set_epi16( + zeta1, zeta1, zeta1, zeta1, 0, 0, 0, 0, zeta0, zeta0, zeta0, zeta0, 0, 0, 0, 0, + ), + ); + + mm256_blend_epi16::<0b1_1_1_1_0_0_0_0>(sum, sum_times_zetas) +} + +#[inline(always)] +pub(crate) fn inv_ntt_layer_3_step(vector: __m256i, zeta: i16) -> __m256i { + let lhs = mm256_extracti128_si256::<1>(vector); + let rhs = mm256_castsi256_si128(vector); + + let lower_coefficients = mm_add_epi16(lhs, rhs); + + let upper_coefficients = mm_sub_epi16(lhs, rhs); + let upper_coefficients = + montgomery_multiply_m128i_by_constants(upper_coefficients, mm_set1_epi16(zeta)); + + let combined = mm256_castsi128_si256(lower_coefficients); + let combined = mm256_inserti128_si256::<1>(combined, upper_coefficients); + + combined +} + +#[inline(always)] +pub(crate) fn ntt_multiply( + lhs: __m256i, + rhs: __m256i, + zeta0: i16, + zeta1: i16, + zeta2: i16, + zeta3: i16, +) -> __m256i { + // Compute the first term of the product + let shuffle_with = mm256_set_epi8( + 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0, 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, + 9, 8, 5, 4, 1, 0, + ); + const PERMUTE_WITH: i32 = 0b11_01_10_00; + + // Prepare the left hand side + let lhs_shuffled = mm256_shuffle_epi8(lhs, shuffle_with); + let lhs_shuffled = mm256_permute4x64_epi64::<{ PERMUTE_WITH }>(lhs_shuffled); + + let lhs_evens = mm256_castsi256_si128(lhs_shuffled); + let lhs_evens = mm256_cvtepi16_epi32(lhs_evens); + + let lhs_odds = mm256_extracti128_si256::<1>(lhs_shuffled); + let lhs_odds = mm256_cvtepi16_epi32(lhs_odds); + + // Prepare the right hand side + let rhs_shuffled = mm256_shuffle_epi8(rhs, shuffle_with); + let rhs_shuffled = mm256_permute4x64_epi64::<{ PERMUTE_WITH }>(rhs_shuffled); + + let rhs_evens = mm256_castsi256_si128(rhs_shuffled); + let rhs_evens = mm256_cvtepi16_epi32(rhs_evens); + + let rhs_odds = mm256_extracti128_si256::<1>(rhs_shuffled); + let rhs_odds = mm256_cvtepi16_epi32(rhs_odds); + + // Start operating with them + let left = mm256_mullo_epi32(lhs_evens, rhs_evens); + + let right = mm256_mullo_epi32(lhs_odds, rhs_odds); + let right = montgomery_reduce_i32s(right); + let right = mm256_mullo_epi32( + right, + mm256_set_epi32( + -(zeta3 as i32), + zeta3 as i32, + -(zeta2 as i32), + zeta2 as i32, + -(zeta1 as i32), + zeta1 as i32, + -(zeta0 as i32), + zeta0 as i32, + ), + ); + + let products_left = mm256_add_epi32(left, right); + let products_left = montgomery_reduce_i32s(products_left); + + // Compute the second term of the product + let rhs_adjacent_swapped = mm256_shuffle_epi8( + rhs, + mm256_set_epi8( + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5, + 4, 7, 6, 1, 0, 3, 2, + ), + ); + let products_right = mm256_madd_epi16(lhs, rhs_adjacent_swapped); + let products_right = montgomery_reduce_i32s(products_right); + let products_right = mm256_slli_epi32::<16>(products_right); + + // Combine them into one vector + mm256_blend_epi16::<0b1_0_1_0_1_0_1_0>(products_left, products_right) +} diff --git a/polynomials-avx2/src/portable.rs b/polynomials-avx2/src/portable.rs index b18c02d19..844e6a3c2 100644 --- a/polynomials-avx2/src/portable.rs +++ b/polynomials-avx2/src/portable.rs @@ -1,4 +1,4 @@ -pub use libcrux_traits::{FIELD_ELEMENTS_IN_VECTOR, FIELD_MODULUS}; +pub use libcrux_traits::FIELD_ELEMENTS_IN_VECTOR; type FieldElement = i16; @@ -112,27 +112,3 @@ pub(crate) fn deserialize_11(bytes: &[u8]) -> PortableVector { result } - -#[inline(always)] -pub(crate) fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - let mut result = [0i16; 16]; - let mut sampled = 0; - for bytes in a.chunks(3) { - let b1 = bytes[0] as i16; - let b2 = bytes[1] as i16; - let b3 = bytes[2] as i16; - - let d1 = ((b2 & 0xF) << 8) | b1; - let d2 = (b3 << 4) | (b2 >> 4); - - if d1 < FIELD_MODULUS && sampled < 16 { - result[sampled] = d1; - sampled += 1 - } - if d2 < FIELD_MODULUS && sampled < 16 { - result[sampled] = d2; - sampled += 1 - } - } - (sampled, result) -} diff --git a/polynomials-avx2/src/sampling.rs b/polynomials-avx2/src/sampling.rs new file mode 100644 index 000000000..40542efec --- /dev/null +++ b/polynomials-avx2/src/sampling.rs @@ -0,0 +1,782 @@ +use crate::intrinsics::*; + +use crate::serialize::{deserialize_12, serialize_1}; +use libcrux_traits::FIELD_MODULUS; + +const REJECTION_SAMPLE_SHUFFLE_TABLE: [[u8; 16]; 256] = [ + [ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, + ], // 0 + [ + 0, 1, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 1 + [ + 2, 3, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 2 + [ + 0, 1, 2, 3, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 3 + [ + 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 4 + [ + 0, 1, 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 5 + [ + 2, 3, 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 6 + [ + 0, 1, 2, 3, 4, 5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 7 + [ + 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 8 + [ + 0, 1, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 9 + [ + 2, 3, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 10 + [ + 0, 1, 2, 3, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 11 + [ + 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 12 + [ + 0, 1, 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 13 + [ + 2, 3, 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 14 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 15 + [ + 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 16 + [ + 0, 1, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 17 + [ + 2, 3, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 18 + [ + 0, 1, 2, 3, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 19 + [ + 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 20 + [ + 0, 1, 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 21 + [ + 2, 3, 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 22 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 23 + [ + 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 24 + [ + 0, 1, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 25 + [ + 2, 3, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 26 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 27 + [ + 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 28 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 29 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 30 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 31 + [ + 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 32 + [ + 0, 1, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 33 + [ + 2, 3, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 34 + [ + 0, 1, 2, 3, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 35 + [ + 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 36 + [ + 0, 1, 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 37 + [ + 2, 3, 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 38 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 39 + [ + 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 40 + [ + 0, 1, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 41 + [ + 2, 3, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 42 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 43 + [ + 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 44 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 45 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 46 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 47 + [ + 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 48 + [ + 0, 1, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 49 + [ + 2, 3, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 50 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 51 + [ + 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 52 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 53 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 54 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 55 + [ + 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 56 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 57 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 58 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 59 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 60 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 61 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 62 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0xff, 0xff, 0xff, 0xff], // 63 + [ + 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 64 + [ + 0, 1, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 65 + [ + 2, 3, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 66 + [ + 0, 1, 2, 3, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 67 + [ + 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 68 + [ + 0, 1, 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 69 + [ + 2, 3, 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 70 + [ + 0, 1, 2, 3, 4, 5, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 71 + [ + 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 72 + [ + 0, 1, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 73 + [ + 2, 3, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 74 + [ + 0, 1, 2, 3, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 75 + [ + 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 76 + [ + 0, 1, 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 77 + [ + 2, 3, 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 78 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 79 + [ + 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 80 + [ + 0, 1, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 81 + [ + 2, 3, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 82 + [ + 0, 1, 2, 3, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 83 + [ + 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 84 + [ + 0, 1, 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 85 + [ + 2, 3, 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 86 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 87 + [ + 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 88 + [ + 0, 1, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 89 + [ + 2, 3, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 90 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 91 + [ + 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 92 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 93 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 94 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 0xff, 0xff, 0xff, 0xff], // 95 + [ + 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 96 + [ + 0, 1, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 97 + [ + 2, 3, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 98 + [ + 0, 1, 2, 3, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 99 + [ + 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 100 + [ + 0, 1, 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 101 + [ + 2, 3, 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 102 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 103 + [ + 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 104 + [ + 0, 1, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 105 + [ + 2, 3, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 106 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 107 + [ + 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 108 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 109 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 110 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 111 + [ + 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 112 + [ + 0, 1, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 113 + [ + 2, 3, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 114 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 115 + [ + 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 116 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 117 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 118 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 119 + [ + 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 120 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 121 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 122 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 123 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 124 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 125 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff, 0xff, 0xff, + ], // 126 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0xff, 0xff], // 127 + [ + 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 128 + [ + 0, 1, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 129 + [ + 2, 3, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 130 + [ + 0, 1, 2, 3, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 131 + [ + 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 132 + [ + 0, 1, 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 133 + [ + 2, 3, 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 134 + [ + 0, 1, 2, 3, 4, 5, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 135 + [ + 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 136 + [ + 0, 1, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 137 + [ + 2, 3, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 138 + [ + 0, 1, 2, 3, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 139 + [ + 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 140 + [ + 0, 1, 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 141 + [ + 2, 3, 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 142 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 143 + [ + 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 144 + [ + 0, 1, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 145 + [ + 2, 3, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 146 + [ + 0, 1, 2, 3, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 147 + [ + 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 148 + [ + 0, 1, 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 149 + [ + 2, 3, 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 150 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 151 + [ + 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 152 + [ + 0, 1, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 153 + [ + 2, 3, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 154 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 155 + [ + 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 156 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 157 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 158 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 0xff, 0xff, 0xff, 0xff], // 159 + [ + 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 160 + [ + 0, 1, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 161 + [ + 2, 3, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 162 + [ + 0, 1, 2, 3, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 163 + [ + 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 164 + [ + 0, 1, 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 165 + [ + 2, 3, 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 166 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 167 + [ + 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 168 + [ + 0, 1, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 169 + [ + 2, 3, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 170 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 171 + [ + 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 172 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 173 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 174 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 175 + [ + 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 176 + [ + 0, 1, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 177 + [ + 2, 3, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 178 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 179 + [ + 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 180 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 181 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 182 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 183 + [ + 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 184 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 185 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 186 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 187 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 188 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 189 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 190 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 0xff, 0xff], // 191 + [ + 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 192 + [ + 0, 1, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 193 + [ + 2, 3, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 194 + [ + 0, 1, 2, 3, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 195 + [ + 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 196 + [ + 0, 1, 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 197 + [ + 2, 3, 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 198 + [ + 0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 199 + [ + 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 200 + [ + 0, 1, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 201 + [ + 2, 3, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 202 + [ + 0, 1, 2, 3, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 203 + [ + 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 204 + [ + 0, 1, 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 205 + [ + 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 206 + [ + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 207 + [ + 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 208 + [ + 0, 1, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 209 + [ + 2, 3, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 210 + [ + 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 211 + [ + 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 212 + [ + 0, 1, 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 213 + [ + 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 214 + [ + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 215 + [ + 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 216 + [ + 0, 1, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 217 + [ + 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 218 + [ + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 219 + [ + 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 220 + [ + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 221 + [ + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 222 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, 0xff, 0xff], // 223 + [ + 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 224 + [ + 0, 1, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 225 + [ + 2, 3, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 226 + [ + 0, 1, 2, 3, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 227 + [ + 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 228 + [ + 0, 1, 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 229 + [ + 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 230 + [ + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 231 + [ + 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 232 + [ + 0, 1, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 233 + [ + 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 234 + [ + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 235 + [ + 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 236 + [ + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 237 + [ + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 238 + [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 239 + [ + 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 240 + [ + 0, 1, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 241 + [ + 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 242 + [ + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 243 + [ + 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 244 + [ + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 245 + [ + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 246 + [0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 247 + [ + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ], // 248 + [ + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 249 + [ + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 250 + [0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 251 + [ + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + ], // 252 + [0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 253 + [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0xff, 0xff], // 254 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], // 255 +]; + +#[inline(always)] +pub(crate) fn rejection_sample(input: &[u8], output: &mut [i16]) -> usize { + let field_modulus = mm256_set1_epi16(FIELD_MODULUS); + + let potential_coefficients = deserialize_12(input); + + let compare_with_field_modulus = mm256_cmpgt_epi16(field_modulus, potential_coefficients); + let good = serialize_1(compare_with_field_modulus); + + let lower_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[0] as usize]; + let lower_shuffles = mm_loadu_si128(&lower_shuffles); + let lower_coefficients = mm256_castsi256_si128(potential_coefficients); + let lower_coefficients = mm_shuffle_epi8(lower_coefficients, lower_shuffles); + + mm_storeu_si128(&mut output[0..8], lower_coefficients); + let sampled_count = good[0].count_ones() as usize; + + let upper_shuffles = REJECTION_SAMPLE_SHUFFLE_TABLE[good[1] as usize]; + let upper_shuffles = mm_loadu_si128(&upper_shuffles); + let upper_coefficients = mm256_extracti128_si256::<1>(potential_coefficients); + let upper_coefficients = mm_shuffle_epi8(upper_coefficients, upper_shuffles); + + mm_storeu_si128( + &mut output[sampled_count..sampled_count + 8], + upper_coefficients, + ); + + sampled_count + (good[1].count_ones() as usize) +} diff --git a/polynomials-avx2/src/serialize.rs b/polynomials-avx2/src/serialize.rs new file mode 100644 index 000000000..7e5303b01 --- /dev/null +++ b/polynomials-avx2/src/serialize.rs @@ -0,0 +1,405 @@ +use crate::intrinsics::*; + +use crate::{portable, SIMD256Vector}; + +#[inline(always)] +pub(crate) fn serialize_1(vector: __m256i) -> [u8; 2] { + let lsb_shifted_up = mm256_slli_epi16::<15>(vector); + + let low_lanes = mm256_castsi256_si128(lsb_shifted_up); + let high_lanes = mm256_extracti128_si256::<1>(lsb_shifted_up); + + let msbs = mm_packs_epi16(low_lanes, high_lanes); + + let bits_packed = mm_movemask_epi8(msbs); + + let mut serialized = [0u8; 2]; + serialized[0] = bits_packed as u8; + serialized[1] = (bits_packed >> 8) as u8; + + serialized +} + +#[inline(always)] +pub(crate) fn deserialize_1(bytes: &[u8]) -> __m256i { + let shift_lsb_to_msb = mm256_set_epi16( + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + 1 << 0, + 1 << 1, + 1 << 2, + 1 << 3, + 1 << 4, + 1 << 5, + 1 << 6, + 1 << 7, + ); + + let coefficients = mm256_set_epi16( + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsb_to_msb); + let coefficients_in_lsb = mm256_srli_epi16::<7>(coefficients_in_msb); + + mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 1) - 1)) +} + +#[inline(always)] +pub(crate) fn serialize_4(vector: __m256i) -> [u8; 8] { + let mut serialized = [0u8; 16]; + + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + 1 << 4, + 1, + ), + ); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_2_combined, + mm256_set_epi8( + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, + ), + ); + + let combined = + mm256_permutevar8x32_epi32(adjacent_8_combined, mm256_set_epi32(0, 0, 0, 0, 0, 0, 4, 0)); + let combined = mm256_castsi256_si128(combined); + + mm_storeu_bytes_si128(&mut serialized[..], combined); + + serialized[0..8].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_4(bytes: &[u8]) -> __m256i { + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let coefficients = mm256_set_epi16( + bytes[7] as i16, + bytes[7] as i16, + bytes[6] as i16, + bytes[6] as i16, + bytes[5] as i16, + bytes[5] as i16, + bytes[4] as i16, + bytes[4] as i16, + bytes[3] as i16, + bytes[3] as i16, + bytes[2] as i16, + bytes[2] as i16, + bytes[1] as i16, + bytes[1] as i16, + bytes[0] as i16, + bytes[0] as i16, + ); + + let coefficients_in_msb = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients_in_lsb = mm256_srli_epi16::<4>(coefficients_in_msb); + + mm256_and_si256(coefficients_in_lsb, mm256_set1_epi16((1 << 4) - 1)) +} + +#[inline(always)] +pub(crate) fn serialize_5(vector: __m256i) -> [u8; 10] { + let mut serialized = [0u8; 32]; + + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + 1 << 5, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 22, 0, 22, 0, 22, 0, 22), + ); + let adjacent_4_combined = mm256_srli_epi64::<22>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi32::<0b00_00_10_00>(adjacent_4_combined); + let adjacent_8_combined = mm256_sllv_epi32( + adjacent_8_combined, + mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_8_combined = mm256_srli_epi64::<12>(adjacent_8_combined); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[5..21], upper_8); + + serialized[0..10].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_5(bytes: &[u8]) -> __m256i { + let output = portable::deserialize_5(bytes); + + crate::from_i16_array(&portable::to_i16_array(output)).elements +} + +#[inline(always)] +pub(crate) fn serialize_10(vector: __m256i) -> [u8; 20] { + let mut serialized = [0u8; 32]; + + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + 1 << 10, + 1, + ), + ); + + let adjacent_4_combined = mm256_sllv_epi32( + adjacent_2_combined, + mm256_set_epi32(0, 12, 0, 12, 0, 12, 0, 12), + ); + let adjacent_4_combined = mm256_srli_epi64::<12>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_4_combined, + mm256_set_epi8( + -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, + 11, 10, 9, 8, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[10..26], upper_8); + + serialized[0..20].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_10(bytes: &[u8]) -> __m256i { + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + 1 << 0, + 1 << 2, + 1 << 4, + 1 << 6, + ); + + let lower_coefficients = mm_loadu_si128(bytes[0..16].try_into().unwrap()); + let lower_coefficients = mm_shuffle_epi8( + lower_coefficients, + mm_set_epi8(9, 8, 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0), + ); + let upper_coefficients = mm_loadu_si128(bytes[4..20].try_into().unwrap()); + let upper_coefficients = mm_shuffle_epi8( + upper_coefficients, + mm_set_epi8(15, 14, 14, 13, 13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 7, 6), + ); + + let coefficients = mm256_castsi128_si256(lower_coefficients); + let coefficients = mm256_inserti128_si256::<1>(coefficients, upper_coefficients); + + let coefficients = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = mm256_srli_epi16::<6>(coefficients); + let coefficients = mm256_and_si256(coefficients, mm256_set1_epi16((1 << 10) - 1)); + + coefficients +} + +#[inline(always)] +pub(crate) fn serialize_11(vector: __m256i) -> [u8; 22] { + let input = portable::from_i16_array(crate::to_i16_array(SIMD256Vector { elements: vector })); + + portable::serialize_11(input) +} + +#[inline(always)] +pub(crate) fn deserialize_11(bytes: &[u8]) -> __m256i { + let output = portable::deserialize_11(bytes); + + crate::from_i16_array(&portable::to_i16_array(output)).elements +} + +#[inline(always)] +pub(crate) fn serialize_12(vector: __m256i) -> [u8; 24] { + let mut serialized = [0u8; 32]; + + let adjacent_2_combined = mm256_madd_epi16( + vector, + mm256_set_epi16( + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + 1 << 12, + 1, + ), + ); + + let adjacent_4_combined = + mm256_sllv_epi32(adjacent_2_combined, mm256_set_epi32(0, 8, 0, 8, 0, 8, 0, 8)); + let adjacent_4_combined = mm256_srli_epi64::<8>(adjacent_4_combined); + + let adjacent_8_combined = mm256_shuffle_epi8( + adjacent_4_combined, + mm256_set_epi8( + -1, -1, -1, -1, 13, 12, 11, 10, 9, 8, 5, 4, 3, 2, 1, 0, -1, -1, -1, -1, 13, 12, 11, 10, + 9, 8, 5, 4, 3, 2, 1, 0, + ), + ); + + let lower_8 = mm256_castsi256_si128(adjacent_8_combined); + let upper_8 = mm256_extracti128_si256::<1>(adjacent_8_combined); + + mm_storeu_bytes_si128(&mut serialized[0..16], lower_8); + mm_storeu_bytes_si128(&mut serialized[12..28], upper_8); + + serialized[0..24].try_into().unwrap() +} + +#[inline(always)] +pub(crate) fn deserialize_12(bytes: &[u8]) -> __m256i { + let shift_lsbs_to_msbs = mm256_set_epi16( + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + 1 << 0, + 1 << 4, + ); + + let lower_coefficients = mm_loadu_si128(bytes[0..16].try_into().unwrap()); + let lower_coefficients = mm_shuffle_epi8( + lower_coefficients, + mm_set_epi8(11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0), + ); + let upper_coefficients = mm_loadu_si128(bytes[8..24].try_into().unwrap()); + let upper_coefficients = mm_shuffle_epi8( + upper_coefficients, + mm_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4), + ); + + let coefficients = mm256_castsi128_si256(lower_coefficients); + let coefficients = mm256_inserti128_si256::<1>(coefficients, upper_coefficients); + + let coefficients = mm256_mullo_epi16(coefficients, shift_lsbs_to_msbs); + let coefficients = mm256_srli_epi16::<4>(coefficients); + let coefficients = mm256_and_si256(coefficients, mm256_set1_epi16((1 << 12) - 1)); + + coefficients +} diff --git a/polynomials/build.rs b/polynomials/build.rs index f15f3d581..ef1138666 100644 --- a/polynomials/build.rs +++ b/polynomials/build.rs @@ -16,11 +16,13 @@ fn main() { // We enable simd128 on all aarch64 builds. println!("cargo:rustc-cfg=feature=\"simd128\""); } - if (target_arch == "x86" || target_arch == "x86_64") && !disable_simd256 { - // We enable simd256 on all x86 and x86_64 builds. + if target_arch == "x86_64" && !disable_simd256 { + // We enable simd256 on all x86_64 builds. // Note that this doesn't mean the required CPU features are available. // But the compiler will support them and the runtime checks ensure that // it's only used when available. + // + // We don't enable this on x86 because it seems to generate invalid code. println!("cargo:rustc-cfg=feature=\"simd256\""); } } diff --git a/polynomials/src/lib.rs b/polynomials/src/lib.rs index 34db69b5e..1dd33fd91 100644 --- a/polynomials/src/lib.rs +++ b/polynomials/src/lib.rs @@ -209,8 +209,10 @@ fn to_i16_array(v: PortableVector) -> [i16; FIELD_ELEMENTS_IN_VECTOR] { } #[inline(always)] -fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> PortableVector { - PortableVector { elements: array } +fn from_i16_array(array: &[i16]) -> PortableVector { + PortableVector { + elements: array[0..16].try_into().unwrap(), + } } #[inline(always)] @@ -1039,8 +1041,7 @@ fn deserialize_12(bytes: &[u8]) -> PortableVector { } #[inline(always)] -fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - let mut result = [0i16; 16]; +fn rej_sample(a: &[u8], result: &mut [i16]) -> usize { let mut sampled = 0; for bytes in a.chunks(3) { let b1 = bytes[0] as i16; @@ -1059,7 +1060,7 @@ fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { sampled += 1 } } - (sampled, result) + sampled } impl Operations for PortableVector { @@ -1071,7 +1072,7 @@ impl Operations for PortableVector { to_i16_array(v) } - fn from_i16_array(array: [i16; FIELD_ELEMENTS_IN_VECTOR]) -> Self { + fn from_i16_array(array: &[i16]) -> Self { from_i16_array(array) } @@ -1206,7 +1207,7 @@ impl Operations for PortableVector { deserialize_12(a) } - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]) { - rej_sample(a) + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize { + rej_sample(a, out) } } diff --git a/sys/platform/Cargo.toml b/sys/platform/Cargo.toml index 6c9f512af..6766bf9b9 100644 --- a/sys/platform/Cargo.toml +++ b/sys/platform/Cargo.toml @@ -10,4 +10,4 @@ readme.workspace = true description = "Platform detection crate for libcrux." [dependencies] -libc = "0.2.147" +libc = "0.2.153" diff --git a/sys/platform/src/lib.rs b/sys/platform/src/lib.rs index 96619cafa..d2e8b0310 100644 --- a/sys/platform/src/lib.rs +++ b/sys/platform/src/lib.rs @@ -9,9 +9,6 @@ #[macro_use] extern crate std; -// #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -// sflwnwc - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] mod x86; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] diff --git a/sys/platform/src/macos_arm.rs b/sys/platform/src/macos_arm.rs index 4d92236df..2edafdbdc 100644 --- a/sys/platform/src/macos_arm.rs +++ b/sys/platform/src/macos_arm.rs @@ -1,6 +1,30 @@ //! Obtain particular CPU features for AArch64 on macOS -use libc::{c_void, sysctlbyname}; +use libc::{c_char, c_void, sysctlbyname, uname, utsname}; + +#[allow(dead_code)] +fn cstr(src: &[i8]) -> &str { + // default to length if no `0` present + let end = src.iter().position(|&c| c == 0).unwrap_or(src.len()); + unsafe { core::str::from_utf8_unchecked(core::mem::transmute(&src[0..end])) } +} + +/// Check that we're actually on an ARM mac. +/// When this returns false, no other function in here must be called. +pub(crate) fn actually_arm() -> bool { + let mut buf = utsname { + sysname: [c_char::default(); 256], + nodename: [c_char::default(); 256], + release: [c_char::default(); 256], + version: [c_char::default(); 256], + machine: [c_char::default(); 256], + }; + let error = unsafe { uname(&mut buf) }; + // buf.machine == "arm" + // It could also be "arm64". + let arm = buf.machine[0] == 97 && buf.machine[1] == 114 && buf.machine[2] == 109; + error == 0 && arm +} #[inline(always)] fn check(feature: &[i8]) -> bool { @@ -20,6 +44,12 @@ fn check(feature: &[i8]) -> bool { #[inline(always)] fn sysctl() { + // Check that we're actually on arm and set everything to false if not. + // This may happen when running on an intel mac. + if !actually_arm() { + return; + } + // hw.optional.AdvSIMD const ADV_SIMD_STR: [i8; 20] = [ 0x68, 0x77, 0x2e, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2e, 0x41, 0x64, 0x76, diff --git a/traits/src/lib.rs b/traits/src/lib.rs index 4000ab7d2..ef59e381f 100644 --- a/traits/src/lib.rs +++ b/traits/src/lib.rs @@ -8,7 +8,7 @@ pub trait Operations: Copy + Clone { fn ZERO() -> Self; fn to_i16_array(v: Self) -> [i16; 16]; - fn from_i16_array(array: [i16; 16]) -> Self; + fn from_i16_array(array: &[i16]) -> Self; // Basic arithmetic fn add(lhs: Self, rhs: &Self) -> Self; @@ -61,7 +61,7 @@ pub trait Operations: Copy + Clone { fn serialize_12(a: Self) -> [u8; 24]; fn deserialize_12(a: &[u8]) -> Self; - fn rej_sample(a: &[u8]) -> (usize, [i16; 16]); + fn rej_sample(a: &[u8], out: &mut [i16]) -> usize; } // hax does not support trait with default implementations, so we use the following patter