Skip to content

Commit

Permalink
Use Init-Absorb-Squeeze* API in ML-KEM (#220)
Browse files Browse the repository at this point in the history
Enabling the incremental (nblocks) API for Shake128 (scalar and avx2) in ml-kem.

---------

Co-authored-by: Karthikeyan Bhargavan <[email protected]>
  • Loading branch information
franziskuskiefer and karthikbhargavan authored Mar 13, 2024
1 parent 9850114 commit 700700f
Show file tree
Hide file tree
Showing 33 changed files with 9,458 additions and 4,747 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ wasm-bindgen = { version = "0.2.87", optional = true }
# When using the hax toolchain, we have more dependencies.
# This is only required when doing proofs.
[target.'cfg(hax)'.dependencies]
hax-lib-macros = { version = "0.1.0-pre.1", git = "https://github.com/hacspec/hax" }
hax-lib-macros = { version = "0.1.0-pre.1", git = "https://github.com/hacspec/hax", branch = "main" }
hax-lib = { version = "0.1.0-pre.1", git = "https://github.com/hacspec/hax/", branch = "main" }

[target.'cfg(all(not(target_os = "windows"), target_arch = "x86_64", libjade))'.dependencies]
Expand All @@ -74,3 +74,4 @@ rand = []
wasm = ["wasm-bindgen"]
log = ["dep:log"]
tests = [] # Expose functions for testing.
experimental = [] # Expose experimental APIs.
3 changes: 2 additions & 1 deletion hax-driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def shell(command, expect=0, cwd=None, env={}):
f"-** +libcrux::kem::kyber::** +!libcrux_platform::platform::* {exclude_sha3_implementations} -libcrux::**::types::index_impls::**",
"fstar",
"--interfaces",
"+* -libcrux::kem::kyber::types +!libcrux_platform::**",
"+* -libcrux::kem::kyber::types +!libcrux_platform::** +!libcrux::digest::**",
],
cwd=".",
env=hax_env,
Expand All @@ -136,6 +136,7 @@ def shell(command, expect=0, cwd=None, env={}):
# remove this when https://github.com/hacspec/hax/issues/465 is
# closed)
shell(["rm", "-f", "./sys/platform/proofs/fstar/extraction/*.fst"])

elif options.kyber_specification:
shell(
cargo_hax_into
Expand Down
2 changes: 2 additions & 0 deletions kyber-c.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ files:
- name: libcrux_digest
api:
- [libcrux, digest]
include_in_h:
- '"libcrux_hacl_glue.h"'
- name: libcrux_platform
api:
- [libcrux_platform]
Expand Down
3 changes: 2 additions & 1 deletion kyber-crate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,6 @@ if [[ -n "$HACL_PACKAGES_HOME" ]]; then
cp internal/*.h $HACL_PACKAGES_HOME/libcrux/include/internal/
cp *.h $HACL_PACKAGES_HOME/libcrux/include
cp *.c $HACL_PACKAGES_HOME/libcrux/src
else
echo "Please set HACL_PACKAGES_HOME to the hacl-packages directory to copy the code over" 1>&2
fi
echo "Please set HACL_PACKAGES_HOME to the hacl-packages directory to copy the code over" 1>&2
2,492 changes: 1,435 additions & 1,057 deletions proofs/fstar/extraction-edited.patch

Large diffs are not rendered by default.

1,603 changes: 785 additions & 818 deletions proofs/fstar/extraction-secret-independent.patch

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions proofs/fstar/extraction/Libcrux.Digest.Incremental_x4.fsti
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module Libcrux.Digest.Incremental_x4
#set-options "--fuel 0 --ifuel 1 --z3rlimit 15"
open Core
open FStar.Mul

val t_Shake128StateX4:Type

val impl__Shake128StateX4__absorb_final
(v_N: usize)
(self: t_Shake128StateX4)
(input: t_Array (t_Slice u8) v_N)
: Prims.Pure t_Shake128StateX4 Prims.l_True (fun _ -> Prims.l_True)

val impl__Shake128StateX4__free_memory (self: t_Shake128StateX4)
: Prims.Pure Prims.unit Prims.l_True (fun _ -> Prims.l_True)

val impl__Shake128StateX4__new: Prims.unit
-> Prims.Pure t_Shake128StateX4 Prims.l_True (fun _ -> Prims.l_True)

val impl__Shake128StateX4__squeeze_blocks (v_N v_M: usize) (self: t_Shake128StateX4)
: Prims.Pure (t_Shake128StateX4 & t_Array (t_Array u8 v_N) v_M)
Prims.l_True
(fun _ -> Prims.l_True)
48 changes: 0 additions & 48 deletions proofs/fstar/extraction/Libcrux.Digest.fst

This file was deleted.

18 changes: 0 additions & 18 deletions proofs/fstar/extraction/Libcrux.Digest.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,5 @@ val sha3_256_ (payload: t_Slice u8)
val sha3_512_ (payload: t_Slice u8)
: Prims.Pure (t_Array u8 (sz 64)) Prims.l_True (fun _ -> Prims.l_True)

val shake128 (v_LEN: usize) (data: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN) Prims.l_True (fun _ -> Prims.l_True)

val shake256 (v_LEN: usize) (data: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN) Prims.l_True (fun _ -> Prims.l_True)

val shake128x4_portable (v_LEN: usize) (data0 data1 data2 data3: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN & t_Array u8 v_LEN & t_Array u8 v_LEN & t_Array u8 v_LEN)
Prims.l_True
(fun _ -> Prims.l_True)

val shake128x4_256_ (v_LEN: usize) (data0 data1 data2 data3: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN & t_Array u8 v_LEN & t_Array u8 v_LEN & t_Array u8 v_LEN)
Prims.l_True
(fun _ -> Prims.l_True)

val shake128x4 (v_LEN: usize) (data0 data1 data2 data3: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN & t_Array u8 v_LEN & t_Array u8 v_LEN & t_Array u8 v_LEN)
Prims.l_True
(fun _ -> Prims.l_True)
2 changes: 0 additions & 2 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Constants.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,4 @@ let v_FIELD_MODULUS: i32 = 3329l

let v_H_DIGEST_SIZE: usize = sz 32

let v_REJECTION_SAMPLING_SEED_SIZE: usize = sz 168 *! sz 5

let v_SHARED_SECRET_SIZE: usize = sz 32
200 changes: 114 additions & 86 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Hash_functions.fst
Original file line number Diff line number Diff line change
Expand Up @@ -9,95 +9,123 @@ let v_H (input: t_Slice u8) = Libcrux.Digest.sha3_256_ input

let v_PRF (v_LEN: usize) (input: t_Slice u8) = Libcrux.Digest.shake256 v_LEN input

let v_XOFx4 (v_K: usize) (input: t_Array (t_Array u8 (sz 34)) v_K) =
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.repeat (Rust_primitives.Hax.repeat 0uy (sz 840) <: t_Array u8 (sz 840)) v_K
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
if ~.(Libcrux_platform.Platform.simd256_support () <: bool)
let absorb (v_K: usize) (input: t_Array (t_Array u8 (sz 34)) v_K) =
let _:Prims.unit =
if true
then
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end = v_K
}
let _:Prims.unit =
if ~.((v_K =. sz 2 <: bool) || (v_K =. sz 3 <: bool) || (v_K =. sz 4 <: bool))
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: K == 2 || K == 3 || K == 4"

<:
Core.Ops.Range.t_Range usize)
Rust_primitives.Hax.t_Never)
in
()
in
let state:Libcrux.Digest.Incremental_x4.t_Shake128StateX4 =
Libcrux.Digest.Incremental_x4.impl__Shake128StateX4__new ()
in
let (data: t_Array (t_Slice u8) v_K):t_Array (t_Slice u8) v_K =
Rust_primitives.Hax.repeat (Rust_primitives.unsize (let list = [0uy] in
FStar.Pervasives.assert_norm (Prims.eq2 (List.Tot.length list) 1);
Rust_primitives.Hax.array_of_list 1 list)
<:
t_Slice u8)
v_K
in
let data:t_Array (t_Slice u8) v_K =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end = v_K
}
<:
Core.Ops.Range.t_Range usize)
<:
Core.Ops.Range.t_Range usize)
data
(fun data i ->
let data:t_Array (t_Slice u8) v_K = data in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize data
i
(Rust_primitives.unsize (input.[ i ] <: t_Array u8 (sz 34)) <: t_Slice u8)
<:
t_Array (t_Slice u8) v_K)
in
let state:Libcrux.Digest.Incremental_x4.t_Shake128StateX4 =
Libcrux.Digest.Incremental_x4.impl__Shake128StateX4__absorb_final v_K state data
in
state

let free_state (xof_state: Libcrux.Digest.Incremental_x4.t_Shake128StateX4) =
let _:Prims.unit = Libcrux.Digest.Incremental_x4.impl__Shake128StateX4__free_memory xof_state in
()

let squeeze_block (v_K: usize) (xof_state: Libcrux.Digest.Incremental_x4.t_Shake128StateX4) =
let tmp0, out1:(Libcrux.Digest.Incremental_x4.t_Shake128StateX4 &
t_Array (t_Array u8 (sz 168)) v_K) =
Libcrux.Digest.Incremental_x4.impl__Shake128StateX4__squeeze_blocks (sz 168) v_K xof_state
in
let xof_state:Libcrux.Digest.Incremental_x4.t_Shake128StateX4 = tmp0 in
let (output: t_Array (t_Array u8 (sz 168)) v_K):t_Array (t_Array u8 (sz 168)) v_K = out1 in
let out:t_Array (t_Array u8 (sz 168)) v_K =
Rust_primitives.Hax.repeat (Rust_primitives.Hax.repeat 0uy (sz 168) <: t_Array u8 (sz 168)) v_K
in
let out:t_Array (t_Array u8 (sz 168)) v_K =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end = v_K
}
<:
Core.Ops.Range.t_Range usize)
<:
Core.Ops.Range.t_Range usize)
out
(fun out i ->
let out:t_Array (t_Array u8 (sz 168)) v_K = out in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
i
(output.[ i ] <: t_Array u8 (sz 168))
<:
Core.Ops.Range.t_Range usize)
out
(fun out i ->
let out:t_Array (t_Array u8 (sz 840)) v_K = out in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
i
(Libcrux.Digest.shake128 (sz 840)
(Rust_primitives.unsize (input.[ i ] <: t_Array u8 (sz 34)) <: t_Slice u8)
<:
t_Array u8 (sz 840))
t_Array (t_Array u8 (sz 168)) v_K)
in
let hax_temp_output:t_Array (t_Array u8 (sz 168)) v_K = out in
xof_state, hax_temp_output
<:
(Libcrux.Digest.Incremental_x4.t_Shake128StateX4 & t_Array (t_Array u8 (sz 168)) v_K)

let squeeze_three_blocks (v_K: usize) (xof_state: Libcrux.Digest.Incremental_x4.t_Shake128StateX4) =
let tmp0, out1:(Libcrux.Digest.Incremental_x4.t_Shake128StateX4 &
t_Array (t_Array u8 (sz 504)) v_K) =
Libcrux.Digest.Incremental_x4.impl__Shake128StateX4__squeeze_blocks (sz 504) v_K xof_state
in
let xof_state:Libcrux.Digest.Incremental_x4.t_Shake128StateX4 = tmp0 in
let (output: t_Array (t_Array u8 (sz 504)) v_K):t_Array (t_Array u8 (sz 504)) v_K = out1 in
let out:t_Array (t_Array u8 (sz 504)) v_K =
Rust_primitives.Hax.repeat (Rust_primitives.Hax.repeat 0uy (sz 504) <: t_Array u8 (sz 504)) v_K
in
let out:t_Array (t_Array u8 (sz 504)) v_K =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end = v_K
}
<:
t_Array (t_Array u8 (sz 840)) v_K)
else
let out:t_Array (t_Array u8 (sz 840)) v_K =
match cast (v_K <: usize) <: u8 with
| 2uy ->
let d0, d1, _, _:(t_Array u8 (sz 840) & t_Array u8 (sz 840) & t_Array u8 (sz 840) &
t_Array u8 (sz 840)) =
Libcrux.Digest.shake128x4 (sz 840)
(Rust_primitives.unsize (input.[ sz 0 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 1 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 0 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 1 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 0) d0
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 1) d1
in
out
| 3uy ->
let d0, d1, d2, _:(t_Array u8 (sz 840) & t_Array u8 (sz 840) & t_Array u8 (sz 840) &
t_Array u8 (sz 840)) =
Libcrux.Digest.shake128x4 (sz 840)
(Rust_primitives.unsize (input.[ sz 0 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 1 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 2 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 0 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 0) d0
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 1) d1
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 2) d2
in
out
| 4uy ->
let d0, d1, d2, d3:(t_Array u8 (sz 840) & t_Array u8 (sz 840) & t_Array u8 (sz 840) &
t_Array u8 (sz 840)) =
Libcrux.Digest.shake128x4 (sz 840)
(Rust_primitives.unsize (input.[ sz 0 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 1 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 2 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
(Rust_primitives.unsize (input.[ sz 3 ] <: t_Array u8 (sz 34)) <: t_Slice u8)
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 0) d0
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 1) d1
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 2) d2
in
let out:t_Array (t_Array u8 (sz 840)) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out (sz 3) d3
in
out
| _ -> out
in
Core.Ops.Range.t_Range usize)
<:
Core.Ops.Range.t_Range usize)
out
(fun out i ->
let out:t_Array (t_Array u8 (sz 504)) v_K = out in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize out
i
(output.[ i ] <: t_Array u8 (sz 504))
<:
t_Array (t_Array u8 (sz 504)) v_K)
in
out
let hax_temp_output:t_Array (t_Array u8 (sz 504)) v_K = out in
xof_state, hax_temp_output
<:
(Libcrux.Digest.Incremental_x4.t_Shake128StateX4 & t_Array (t_Array u8 (sz 504)) v_K)
25 changes: 23 additions & 2 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Hash_functions.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,33 @@ module Libcrux.Kem.Kyber.Hash_functions
open Core
open FStar.Mul

let v_BLOCK_SIZE: usize = sz 168

val v_G (input: t_Slice u8) : Prims.Pure (t_Array u8 (sz 64)) Prims.l_True (fun _ -> Prims.l_True)

val v_H (input: t_Slice u8) : Prims.Pure (t_Array u8 (sz 32)) Prims.l_True (fun _ -> Prims.l_True)

val v_PRF (v_LEN: usize) (input: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN) Prims.l_True (fun _ -> Prims.l_True)

val v_XOFx4 (v_K: usize) (input: t_Array (t_Array u8 (sz 34)) v_K)
: Prims.Pure (t_Array (t_Array u8 (sz 840)) v_K) Prims.l_True (fun _ -> Prims.l_True)
let v_THREE_BLOCKS: usize = v_BLOCK_SIZE *! sz 3

val absorb (v_K: usize) (input: t_Array (t_Array u8 (sz 34)) v_K)
: Prims.Pure Libcrux.Digest.Incremental_x4.t_Shake128StateX4
Prims.l_True
(fun _ -> Prims.l_True)

val free_state (xof_state: Libcrux.Digest.Incremental_x4.t_Shake128StateX4)
: Prims.Pure Prims.unit Prims.l_True (fun _ -> Prims.l_True)

val squeeze_block (v_K: usize) (xof_state: Libcrux.Digest.Incremental_x4.t_Shake128StateX4)
: Prims.Pure
(Libcrux.Digest.Incremental_x4.t_Shake128StateX4 & t_Array (t_Array u8 (sz 168)) v_K)
Prims.l_True
(fun _ -> Prims.l_True)

val squeeze_three_blocks (v_K: usize) (xof_state: Libcrux.Digest.Incremental_x4.t_Shake128StateX4)
: Prims.Pure
(Libcrux.Digest.Incremental_x4.t_Shake128StateX4 & t_Array (t_Array u8 (sz 504)) v_K)
Prims.l_True
(fun _ -> Prims.l_True)
Loading

0 comments on commit 700700f

Please sign in to comment.