Skip to content

Commit

Permalink
Merge pull request #287 from cryspen/franziskus/avx2-extraction
Browse files Browse the repository at this point in the history
avx2 extraction improvements
  • Loading branch information
franziskuskiefer authored May 23, 2024
2 parents 7de9e87 + 76ab582 commit fa47313
Show file tree
Hide file tree
Showing 58 changed files with 8,099 additions and 2,095 deletions.
14 changes: 13 additions & 1 deletion .github/workflows/mlkem.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target aarch64-apple-darwin

# - name: ⬆ Upload build
# uses: ./.github/actions/upload_artifacts
# with:
# name: build_${{ matrix.os }}_${{ matrix.bits }}

# We get false positives here.
# TODO: Figure out what is going on here
# - name: 🏃🏻 Asan Linux
Expand Down Expand Up @@ -123,7 +128,6 @@ jobs:
cargo test --verbose $RUST_TARGET_FLAG
- name: 🏃🏻‍♀️ Test Release
if: ${{ matrix.os != 'macos-latest' }}
run: |
cargo clean
cargo test --verbose --release $RUST_TARGET_FLAG
Expand Down Expand Up @@ -195,6 +199,14 @@ jobs:
echo "RUST_TARGET_FLAG=--target=i686-unknown-linux-gnu" > $GITHUB_ENV
if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }}

# - name: 🔨 Build
# run: cargo build --benches

# - name: ⬆ Upload build
# uses: ./.github/actions/upload_artifacts
# with:
# name: benchmarks_${{ matrix.os }}_${{ matrix.bits }}

# Benchmarks ...

- name: 🏃🏻‍♀️ Benchmarks
Expand Down
30 changes: 23 additions & 7 deletions libcrux-ml-kem/hax.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,38 @@ class extractAction(argparse.Action):

def __call__(self, parser, args, values, option_string=None) -> None:
# Extract platform and sha3 interfaces
include_str = "+:libcrux_sha3::** -libcrux_sha3::x4::internal::**"
interface_include = "+!**"
# include_str = "+:libcrux_sha3::** -libcrux_sha3::x4::internal::**"
# interface_include = "+!**"
# cargo_hax_into = [
# "cargo",
# "hax",
# "into",
# "-i",
# include_str,
# "fstar",
# "--interfaces",
# interface_include,
# ]
# hax_env = {}
# shell(
# cargo_hax_into,
# cwd="../libcrux-sha3",
# env=hax_env,
# )

# Extract avx2
# include_str = "+:libcrux_sha3::** -libcrux_sha3::x4::internal::**"
# interface_include = "+!**"
cargo_hax_into = [
"cargo",
"hax",
"into",
"-i",
include_str,
"fstar",
"--interfaces",
interface_include,
]
hax_env = {}
shell(
cargo_hax_into,
cwd="../libcrux-sha3",
cwd="../polynomials-avx2",
env=hax_env,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ let v_BYTES_PER_RING_ELEMENT: usize = v_BITS_PER_RING_ELEMENT /! sz 8

let v_CPA_PKE_KEY_GENERATION_SEED_SIZE: usize = sz 32

/// Field modulus: 3329
let v_FIELD_MODULUS: i32 = 3329l
/// SHA3 512 digest size
let v_G_DIGEST_SIZE: usize = sz 64

/// SHA3 256 digest size
let v_H_DIGEST_SIZE: usize = sz 32

/// PKE message size
/// The size of an ML-KEM shared secret.
let v_SHARED_SECRET_SIZE: usize = sz 32

/// Field modulus: 3329
let v__FIELD_MODULUS: i16 = 3329s
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
module Libcrux_ml_kem.Hash_functions.Avx2
#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
open Core
open FStar.Mul

/// The state.
/// It's only used for SHAKE128.
/// All other functions don't actually use any members.
type t_Simd256Hash = {
f_shake128_state:Libcrux_sha3.Generic_keccak.t_KeccakState (sz 4) Core.Core_arch.X86.t____m256i
}

[@@ FStar.Tactics.Typeclasses.tcinstance]
let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash t_Simd256Hash v_K =
{
f_G_pre = (fun (input: t_Slice u8) -> true);
f_G_post = (fun (input: t_Slice u8) (out: t_Array u8 (sz 64)) -> true);
f_G
=
(fun (input: t_Slice u8) ->
let digest:t_Array u8 (sz 64) = Rust_primitives.Hax.repeat 0uy (sz 64) in
let digest:t_Array u8 (sz 64) = Libcrux_sha3.Portable.sha512 digest input in
digest);
f_H_pre = (fun (input: t_Slice u8) -> true);
f_H_post = (fun (input: t_Slice u8) (out: t_Array u8 (sz 32)) -> true);
f_H
=
(fun (input: t_Slice u8) ->
let digest:t_Array u8 (sz 32) = Rust_primitives.Hax.repeat 0uy (sz 32) in
let digest:t_Array u8 (sz 32) = Libcrux_sha3.Portable.sha256 digest input in
digest);
f_PRF_pre = (fun (v_LEN: usize) (input: t_Slice u8) -> true);
f_PRF_post = (fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) -> true);
f_PRF
=
(fun (v_LEN: usize) (input: t_Slice u8) ->
let digest:t_Array u8 v_LEN = Rust_primitives.Hax.repeat 0uy v_LEN in
let digest:t_Array u8 v_LEN = Libcrux_sha3.Portable.shake256 v_LEN digest input in
digest);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
Libcrux_sha3.Avx2.X4.shake256xN v_LEN v_K input);
f_shake128_init_absorb_pre = (fun (input: t_Array (t_Array u8 (sz 34)) v_K) -> true);
f_shake128_init_absorb_post
=
(fun (input: t_Array (t_Array u8 (sz 34)) v_K) (out: t_Simd256Hash) -> true);
f_shake128_init_absorb
=
(fun (input: t_Array (t_Array u8 (sz 34)) v_K) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let state:Libcrux_sha3.Generic_keccak.t_KeccakState (sz 4) Core.Core_arch.X86.t____m256i =
Libcrux_sha3.Avx2.X4.Incremental.shake128_absorb_finalxN v_K input
in
{ f_shake128_state = state } <: t_Simd256Hash);
f_shake128_squeeze_three_blocks_pre = (fun (self: t_Simd256Hash) -> true);
f_shake128_squeeze_three_blocks_post
=
(fun (self: t_Simd256Hash) (out1: (t_Simd256Hash & t_Array (t_Array u8 (sz 504)) v_K)) -> true);
f_shake128_squeeze_three_blocks
=
(fun (self: t_Simd256Hash) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let tmp0, out:(Libcrux_sha3.Generic_keccak.t_KeccakState (sz 4)
Core.Core_arch.X86.t____m256i &
t_Array (t_Array u8 (sz 504)) v_K) =
Libcrux_sha3.Avx2.X4.Incremental.shake128_squeeze3xN (sz 504) v_K self.f_shake128_state
in
let self:t_Simd256Hash = { self with f_shake128_state = tmp0 } <: t_Simd256Hash in
let hax_temp_output:t_Array (t_Array u8 (sz 504)) v_K = out in
self, hax_temp_output <: (t_Simd256Hash & t_Array (t_Array u8 (sz 504)) v_K));
f_shake128_squeeze_block_pre = (fun (self: t_Simd256Hash) -> true);
f_shake128_squeeze_block_post
=
(fun (self: t_Simd256Hash) (out1: (t_Simd256Hash & t_Array (t_Array u8 (sz 168)) v_K)) -> true);
f_shake128_squeeze_block
=
fun (self: t_Simd256Hash) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let tmp0, out:(Libcrux_sha3.Generic_keccak.t_KeccakState (sz 4) Core.Core_arch.X86.t____m256i &
t_Array (t_Array u8 (sz 168)) v_K) =
Libcrux_sha3.Avx2.X4.Incremental.shake128_squeezexN (sz 168) v_K self.f_shake128_state
in
let self:t_Simd256Hash = { self with f_shake128_state = tmp0 } <: t_Simd256Hash in
let hax_temp_output:t_Array (t_Array u8 (sz 168)) v_K = out in
self, hax_temp_output <: (t_Simd256Hash & t_Array (t_Array u8 (sz 168)) v_K)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
module Libcrux_ml_kem.Hash_functions.Neon
#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
open Core
open FStar.Mul

/// The state.
/// It's only used for SHAKE128.
/// All other functions don't actually use any members.
type t_Simd128Hash = {
f_shake128_state:t_Array (t_Array (Libcrux_sha3.Generic_keccak.t_KeccakState (sz 1) u64) (sz 2))
(sz 2)
}

[@@ FStar.Tactics.Typeclasses.tcinstance]
let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash t_Simd128Hash v_K =
{
f_G_pre = (fun (input: t_Slice u8) -> true);
f_G_post = (fun (input: t_Slice u8) (out: t_Array u8 (sz 64)) -> true);
f_G
=
(fun (input: t_Slice u8) ->
let digest:t_Array u8 (sz 64) = Rust_primitives.Hax.repeat 0uy (sz 64) in
let digest:t_Array u8 (sz 64) = Libcrux_sha3.Neon.sha512 digest input in
digest);
f_H_pre = (fun (input: t_Slice u8) -> true);
f_H_post = (fun (input: t_Slice u8) (out: t_Array u8 (sz 32)) -> true);
f_H
=
(fun (input: t_Slice u8) ->
let digest:t_Array u8 (sz 32) = Rust_primitives.Hax.repeat 0uy (sz 32) in
let digest:t_Array u8 (sz 32) = Libcrux_sha3.Neon.sha256 digest input in
digest);
f_PRF_pre = (fun (v_LEN: usize) (input: t_Slice u8) -> true);
f_PRF_post = (fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) -> true);
f_PRF
=
(fun (v_LEN: usize) (input: t_Slice u8) ->
let digest:t_Array u8 v_LEN = Rust_primitives.Hax.repeat 0uy v_LEN in
let digest:t_Array u8 v_LEN = Libcrux_sha3.Neon.shake256 v_LEN digest input in
digest);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
Libcrux_sha3.Neon.X2.shake256xN v_LEN v_K input);
f_shake128_init_absorb_pre = (fun (input: t_Array (t_Array u8 (sz 34)) v_K) -> true);
f_shake128_init_absorb_post
=
(fun (input: t_Array (t_Array u8 (sz 34)) v_K) (out: t_Simd128Hash) -> true);
f_shake128_init_absorb
=
(fun (input: t_Array (t_Array u8 (sz 34)) v_K) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let state:t_Array (t_Array (Libcrux_sha3.Generic_keccak.t_KeccakState (sz 1) u64) (sz 2))
(sz 2) =
Libcrux_sha3.Neon.X2.Incremental.shake128_absorb_finalxN v_K input
in
{ f_shake128_state = state } <: t_Simd128Hash);
f_shake128_squeeze_three_blocks_pre = (fun (self: t_Simd128Hash) -> true);
f_shake128_squeeze_three_blocks_post
=
(fun (self: t_Simd128Hash) (out1: (t_Simd128Hash & t_Array (t_Array u8 (sz 504)) v_K)) -> true);
f_shake128_squeeze_three_blocks
=
(fun (self: t_Simd128Hash) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let tmp0, out:(t_Array
(t_Array (Libcrux_sha3.Generic_keccak.t_KeccakState (sz 1) u64) (sz 2)) (sz 2) &
t_Array (t_Array u8 (sz 504)) v_K) =
Libcrux_sha3.Neon.X2.Incremental.shake128_squeeze3xN (sz 504) v_K self.f_shake128_state
in
let self:t_Simd128Hash = { self with f_shake128_state = tmp0 } <: t_Simd128Hash in
let hax_temp_output:t_Array (t_Array u8 (sz 504)) v_K = out in
self, hax_temp_output <: (t_Simd128Hash & t_Array (t_Array u8 (sz 504)) v_K));
f_shake128_squeeze_block_pre = (fun (self: t_Simd128Hash) -> true);
f_shake128_squeeze_block_post
=
(fun (self: t_Simd128Hash) (out1: (t_Simd128Hash & t_Array (t_Array u8 (sz 168)) v_K)) -> true);
f_shake128_squeeze_block
=
fun (self: t_Simd128Hash) ->
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Rust_primitives.Hax.t_Never)
in
()
in
let tmp0, out:(t_Array (t_Array (Libcrux_sha3.Generic_keccak.t_KeccakState (sz 1) u64) (sz 2))
(sz 2) &
t_Array (t_Array u8 (sz 168)) v_K) =
Libcrux_sha3.Neon.X2.Incremental.shake128_squeezexN (sz 168) v_K self.f_shake128_state
in
let self:t_Simd128Hash = { self with f_shake128_state = tmp0 } <: t_Simd128Hash in
let hax_temp_output:t_Array (t_Array u8 (sz 168)) v_K = out in
self, hax_temp_output <: (t_Simd128Hash & t_Array (t_Array u8 (sz 168)) v_K)
}
Loading

0 comments on commit fa47313

Please sign in to comment.