Skip to content

Commit

Permalink
update F* and drop sha3 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed Apr 28, 2024
1 parent 99dab65 commit 2befd45
Show file tree
Hide file tree
Showing 21 changed files with 384 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ module Libcrux_ml_kem.Arithmetic
open Core
open FStar.Mul

/// Values having this type hold a representative 'x' of the Kyber field.
/// We use 'fe' as a shorthand for this type.
unfold
let t_FieldElement = i32

/// If 'x' denotes a value of type `fe`, values having this type hold a
/// representative y ≡ x·MONTGOMERY_R (mod FIELD_MODULUS).
/// We use 'fer' as a shorthand for this type.
unfold
let t_FieldElementTimesMontgomeryR = i32

/// If 'x' denotes a value of type `fe`, values having this type hold a
/// representative y ≡ x·MONTGOMERY_R^(-1) (mod FIELD_MODULUS).
/// We use 'mfe' as a shorthand for this type
unfold
let t_MontgomeryFieldElement = i32

/// This is calculated as ⌊(BARRETT_R / FIELD_MODULUS) + 1/2⌋
let v_BARRETT_MULTIPLIER: i64 = 20159L

let v_BARRETT_SHIFT: i64 = 26L
Expand All @@ -20,6 +29,7 @@ let v_BARRETT_R: i64 = 1L <<! v_BARRETT_SHIFT

let v_INVERSE_OF_MODULUS_MOD_MONTGOMERY_R: u32 = 62209ul

/// This is calculated as (MONTGOMERY_R)^2 mod FIELD_MODULUS
let v_MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS: i32 = 1353l

let v_MONTGOMERY_SHIFT: u8 = 16uy
Expand All @@ -34,6 +44,10 @@ val get_n_least_significant_bits (n: u8) (value: u32)
let result:u32 = result in
result <. (Core.Num.impl__u32__pow 2ul (Core.Convert.f_into n <: u32) <: u32))

/// Given a field element `fe` such that -FIELD_MODULUS ≤ fe < FIELD_MODULUS,
/// output `o` such that:
/// - `o` is congruent to `fe`
/// - 0 ≤ `o` FIELD_MODULUS
val v__to_unsigned_representative (fe: i32)
: Prims.Pure u16
(requires
Expand All @@ -45,6 +59,13 @@ val v__to_unsigned_representative (fe: i32)
result >=. 0us &&
result <. (cast (Libcrux_ml_kem.Constants.v_FIELD_MODULUS <: i32) <: u16))

/// Signed Barrett Reduction
/// Given an input `value`, `barrett_reduce` outputs a representative `result`
/// such that:
/// - result ≡ value (mod FIELD_MODULUS)
/// - the absolute value of `result` is bound as follows:
/// `|result| ≤ FIELD_MODULUS / 2 · (|value|/BARRETT_R + 1)
/// In particular, if `|value| < BARRETT_R`, then `|result| < FIELD_MODULUS`.
val barrett_reduce (value: i32)
: Prims.Pure i32
(requires
Expand All @@ -56,6 +77,13 @@ val barrett_reduce (value: i32)
result >. (Core.Ops.Arith.Neg.neg Libcrux_ml_kem.Constants.v_FIELD_MODULUS <: i32) &&
result <. Libcrux_ml_kem.Constants.v_FIELD_MODULUS)

/// Signed Montgomery Reduction
/// Given an input `value`, `montgomery_reduce` outputs a representative `o`
/// such that:
/// - o ≡ value · MONTGOMERY_R^(-1) (mod FIELD_MODULUS)
/// - the absolute value of `o` is bound as follows:
/// `|result| ≤ (|value| / MONTGOMERY_R) + (FIELD_MODULUS / 2)
/// In particular, if `|value| ≤ FIELD_MODULUS * MONTGOMERY_R`, then `|o| < (3 · FIELD_MODULUS) / 2`.
val montgomery_reduce (value: i32)
: Prims.Pure i32
(requires
Expand All @@ -74,10 +102,39 @@ val montgomery_reduce (value: i32)
i32) &&
result <=. ((3l *! Libcrux_ml_kem.Constants.v_FIELD_MODULUS <: i32) /! 2l <: i32))

/// If x is some field element of the Kyber field and `mfe` is congruent to
/// x · MONTGOMERY_R^{-1}, this procedure outputs a value that is congruent to
/// `x`, as follows:
/// mfe · MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS ≡ x · MONTGOMERY_R^{-1} * (MONTGOMERY_R)^2 (mod FIELD_MODULUS)
/// => mfe · MONTGOMERY_R_SQUARED_MOD_FIELD_MODULUS ≡ x · MONTGOMERY_R (mod FIELD_MODULUS)
/// `montgomery_reduce` takes the value `x · MONTGOMERY_R` and outputs a representative
/// `x · MONTGOMERY_R * MONTGOMERY_R^{-1} ≡ x (mod FIELD_MODULUS)`
val v__to_standard_domain (mfe: i32) : Prims.Pure i32 Prims.l_True (fun _ -> Prims.l_True)

/// If `fe` is some field element 'x' of the Kyber field and `fer` is congruent to
/// `y · MONTGOMERY_R`, this procedure outputs a value that is congruent to
/// `x · y`, as follows:
/// `fe · fer ≡ x · y · MONTGOMERY_R (mod FIELD_MODULUS)`
/// `montgomery_reduce` takes the value `x · y · MONTGOMERY_R` and outputs a representative
/// `x · y · MONTGOMERY_R * MONTGOMERY_R^{-1} ≡ x · y (mod FIELD_MODULUS)`.
val montgomery_multiply_fe_by_fer (fe fer: i32)
: Prims.Pure i32 Prims.l_True (fun _ -> Prims.l_True)

/// Compute the product of two Kyber binomials with respect to the
/// modulus `X² - zeta`.
/// This function almost implements <strong>Algorithm 11</strong> of the
/// NIST FIPS 203 standard, which is reproduced below:
/// ```plaintext
/// Input: a₀, a₁, b₀, b₁ ∈ ℤq.
/// Input: γ ∈ ℤq.
/// Output: c₀, c₁ ∈ ℤq.
/// c₀ ← a₀·b₀ + a₁·b₁·γ
/// c₁ ← a₀·b₁ + a₁·b₀
/// return c₀, c₁
/// ```
/// We say "almost" because the coefficients output by this function are in
/// the Montgomery domain (unlike in the specification).
/// The NIST FIPS 203 standard can be found at
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
val ntt_multiply_binomials: (i32 & i32) -> (i32 & i32) -> zeta: i32
-> Prims.Pure (i32 & i32) Prims.l_True (fun _ -> Prims.l_True)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ val compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16)
result >=. 0l &&
result <. (Core.Num.impl__i32__pow 2l (cast (coefficient_bits <: u8) <: u32) <: i32))

/// The `compress_*` functions implement the `Compress` function specified in the NIST FIPS
/// 203 standard (Page 18, Expression 4.5), which is defined as:
/// ```plaintext
/// Compress_d: ℤq -> ℤ_{2ᵈ}
/// Compress_d(x) = ⌈(2ᵈ/q)·x⌋
/// ```
/// Since `⌈x⌋ = ⌊x + 1/2⌋` we have:
/// ```plaintext
/// Compress_d(x) = ⌊(2ᵈ/q)·x + 1/2⌋
/// = ⌊(2^{d+1}·x + q) / 2q⌋
/// ```
/// For further information about the function implementations, consult the
/// `implementation_notes.pdf` document in this directory.
/// The NIST FIPS 203 standard can be found at
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
val compress_message_coefficient (fe: u16)
: Prims.Pure u8
(requires fe <. (cast (Libcrux_ml_kem.Constants.v_FIELD_MODULUS <: i32) <: u16))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Libcrux_ml_kem.Constant_time_ops
open Core
open FStar.Mul

/// Return 1 if `value` is not zero and 0 otherwise.
val is_non_zero (value: u8)
: Prims.Pure u8
Prims.l_True
Expand All @@ -18,6 +19,8 @@ val is_non_zero (value: u8)
let _:Prims.unit = temp_0_ in
result =. 1uy <: bool))

/// Return 1 if the bytes of `lhs` and `rhs` do not exactly
/// match and 0 otherwise.
val compare_ciphertexts_in_constant_time (v_CIPHERTEXT_SIZE: usize) (lhs rhs: t_Slice u8)
: Prims.Pure u8
Prims.l_True
Expand All @@ -33,6 +36,8 @@ val compare_ciphertexts_in_constant_time (v_CIPHERTEXT_SIZE: usize) (lhs rhs: t_
let _:Prims.unit = temp_0_ in
result =. 1uy <: bool))

/// If `selector` is not zero, return the bytes in `rhs`; return the bytes in
/// `lhs` otherwise.
val select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8)
: Prims.Pure (t_Array u8 (sz 32))
Prims.l_True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@ module Libcrux_ml_kem.Constants
open Core
open FStar.Mul

/// Each field element needs floor(log_2(FIELD_MODULUS)) + 1 = 12 bits to represent
let v_BITS_PER_COEFFICIENT: usize = sz 12

/// Coefficients per ring element
let v_COEFFICIENTS_IN_RING_ELEMENT: usize = sz 256

/// Bits required per (uncompressed) ring element
let v_BITS_PER_RING_ELEMENT: usize = v_COEFFICIENTS_IN_RING_ELEMENT *! sz 12

/// Bytes required per (uncompressed) ring element
let v_BYTES_PER_RING_ELEMENT: usize = v_BITS_PER_RING_ELEMENT /! sz 8

let v_CPA_PKE_KEY_GENERATION_SEED_SIZE: usize = sz 32

/// Field modulus: 3329
let v_FIELD_MODULUS: i32 = 3329l

let v_H_DIGEST_SIZE: usize = sz 32

/// PKE message size
let v_SHARED_SECRET_SIZE: usize = sz 32
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ let v_BLOCK_SIZE: usize = sz 168

let v_THREE_BLOCKS: usize = v_BLOCK_SIZE *! sz 3

/// Free the memory of the state.
/// **NOTE:** That this needs to be done manually for now.
val free_state (xof_state: Libcrux_sha3.X4.t_Shake128StateX4)
: Prims.Pure Prims.unit Prims.l_True (fun _ -> Prims.l_True)

Expand Down
96 changes: 96 additions & 0 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ind_cpa.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ module Libcrux_ml_kem.Ind_cpa
open Core
open FStar.Mul

/// Pad the `slice` with `0`s at the end.
val into_padded_array (v_LEN: usize) (slice: t_Slice u8)
: Prims.Pure (t_Array u8 v_LEN) Prims.l_True (fun _ -> Prims.l_True)

/// Sample a vector of ring elements from a centered binomial distribution.
val sample_ring_element_cbd
(v_K v_ETA2_RANDOMNESS_SIZE v_ETA2: usize)
(prf_input: t_Array u8 (sz 33))
Expand All @@ -15,6 +17,8 @@ val sample_ring_element_cbd
Prims.l_True
(fun _ -> Prims.l_True)

/// Sample a vector of ring elements from a centered binomial distribution and
/// convert them into their NTT representations.
val sample_vector_cbd_then_ntt
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
(prf_input: t_Array u8 (sz 33))
Expand All @@ -23,18 +27,56 @@ val sample_vector_cbd_then_ntt
Prims.l_True
(fun _ -> Prims.l_True)

/// Call [`compress_then_serialize_ring_element_u`] on each ring element.
val compress_then_serialize_u
(v_K v_OUT_LEN v_COMPRESSION_FACTOR v_BLOCK_LEN: usize)
(input: t_Array Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_K)
: Prims.Pure (t_Array u8 v_OUT_LEN) Prims.l_True (fun _ -> Prims.l_True)

/// Call [`deserialize_then_decompress_ring_element_u`] on each ring element
/// in the `ciphertext`.
val deserialize_then_decompress_u
(v_K v_CIPHERTEXT_SIZE v_U_COMPRESSION_FACTOR: usize)
(ciphertext: t_Array u8 v_CIPHERTEXT_SIZE)
: Prims.Pure (t_Array Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_K)
Prims.l_True
(fun _ -> Prims.l_True)

/// This function implements <strong>Algorithm 13</strong> of the
/// NIST FIPS 203 specification; this is the Kyber CPA-PKE encryption algorithm.
/// Algorithm 13 is reproduced below:
/// ```plaintext
/// Input: encryption key ekₚₖₑ ∈ 𝔹^{384k+32}.
/// Input: message m ∈ 𝔹^{32}.
/// Input: encryption randomness r ∈ 𝔹^{32}.
/// Output: ciphertext c ∈ 𝔹^{32(dᵤk + dᵥ)}.
/// N ← 0
/// t̂ ← ByteDecode₁₂(ekₚₖₑ[0:384k])
/// ρ ← ekₚₖₑ[384k: 384k + 32]
/// for (i ← 0; i < k; i++)
/// for(j ← 0; j < k; j++)
/// Â[i,j] ← SampleNTT(XOF(ρ, i, j))
/// end for
/// end for
/// for(i ← 0; i < k; i++)
/// r[i] ← SamplePolyCBD_{η₁}(PRF_{η₁}(r,N))
/// N ← N + 1
/// end for
/// for(i ← 0; i < k; i++)
/// e₁[i] ← SamplePolyCBD_{η₂}(PRF_{η₂}(r,N))
/// N ← N + 1
/// end for
/// e₂ ← SamplePolyCBD_{η₂}(PRF_{η₂}(r,N))
/// r̂ ← NTT(r)
/// u ← NTT-¹(Âᵀ ◦ r̂) + e₁
/// μ ← Decompress₁(ByteDecode₁(m)))
/// v ← NTT-¹(t̂ᵀ ◦ rˆ) + e₂ + μ
/// c₁ ← ByteEncode_{dᵤ}(Compress_{dᵤ}(u))
/// c₂ ← ByteEncode_{dᵥ}(Compress_{dᵥ}(v))
/// return c ← (c₁ ‖ c₂)
/// ```
/// The NIST FIPS 203 standard can be found at
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
val encrypt
(v_K v_CIPHERTEXT_SIZE v_T_AS_NTT_ENCODED_SIZE v_C1_LEN v_C2_LEN v_U_COMPRESSION_FACTOR v_V_COMPRESSION_FACTOR v_BLOCK_LEN v_ETA1 v_ETA1_RANDOMNESS_SIZE v_ETA2 v_ETA2_RANDOMNESS_SIZE:
usize)
Expand All @@ -43,29 +85,83 @@ val encrypt
(randomness: t_Slice u8)
: Prims.Pure (t_Array u8 v_CIPHERTEXT_SIZE) Prims.l_True (fun _ -> Prims.l_True)

/// Call [`deserialize_to_uncompressed_ring_element`] for each ring element.
val deserialize_secret_key (v_K: usize) (secret_key: t_Slice u8)
: Prims.Pure (t_Array Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_K)
Prims.l_True
(fun _ -> Prims.l_True)

/// This function implements <strong>Algorithm 14</strong> of the
/// NIST FIPS 203 specification; this is the Kyber CPA-PKE decryption algorithm.
/// Algorithm 14 is reproduced below:
/// ```plaintext
/// Input: decryption key dkₚₖₑ ∈ 𝔹^{384k}.
/// Input: ciphertext c ∈ 𝔹^{32(dᵤk + dᵥ)}.
/// Output: message m ∈ 𝔹^{32}.
/// c₁ ← c[0 : 32dᵤk]
/// c₂ ← c[32dᵤk : 32(dᵤk + dᵥ)]
/// u ← Decompress_{dᵤ}(ByteDecode_{dᵤ}(c₁))
/// v ← Decompress_{dᵥ}(ByteDecode_{dᵥ}(c₂))
/// ŝ ← ByteDecode₁₂(dkₚₖₑ)
/// w ← v - NTT-¹(ŝᵀ ◦ NTT(u))
/// m ← ByteEncode₁(Compress₁(w))
/// return m
/// ```
/// The NIST FIPS 203 standard can be found at
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
val decrypt
(v_K v_CIPHERTEXT_SIZE v_VECTOR_U_ENCODED_SIZE v_U_COMPRESSION_FACTOR v_V_COMPRESSION_FACTOR:
usize)
(secret_key: t_Slice u8)
(ciphertext: t_Array u8 v_CIPHERTEXT_SIZE)
: Prims.Pure (t_Array u8 (sz 32)) Prims.l_True (fun _ -> Prims.l_True)

/// Call [`serialize_uncompressed_ring_element`] for each ring element.
val serialize_secret_key
(v_K v_OUT_LEN: usize)
(key: t_Array Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_K)
: Prims.Pure (t_Array u8 v_OUT_LEN) Prims.l_True (fun _ -> Prims.l_True)

/// Concatenate `t` and `ρ` into the public key.
val serialize_public_key
(v_K v_RANKED_BYTES_PER_RING_ELEMENT v_PUBLIC_KEY_SIZE: usize)
(tt_as_ntt: t_Array Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_K)
(seed_for_a: t_Slice u8)
: Prims.Pure (t_Array u8 v_PUBLIC_KEY_SIZE) Prims.l_True (fun _ -> Prims.l_True)

/// This function implements most of <strong>Algorithm 12</strong> of the
/// NIST FIPS 203 specification; this is the Kyber CPA-PKE key generation algorithm.
/// We say "most of" since Algorithm 12 samples the required randomness within
/// the function itself, whereas this implementation expects it to be provided
/// through the `key_generation_seed` parameter.
/// Algorithm 12 is reproduced below:
/// ```plaintext
/// Output: encryption key ekₚₖₑ ∈ 𝔹^{384k+32}.
/// Output: decryption key dkₚₖₑ ∈ 𝔹^{384k}.
/// d ←$ B
/// (ρ,σ) ← G(d)
/// N ← 0
/// for (i ← 0; i < k; i++)
/// for(j ← 0; j < k; j++)
/// Â[i,j] ← SampleNTT(XOF(ρ, i, j))
/// end for
/// end for
/// for(i ← 0; i < k; i++)
/// s[i] ← SamplePolyCBD_{η₁}(PRF_{η₁}(σ,N))
/// N ← N + 1
/// end for
/// for(i ← 0; i < k; i++)
/// e[i] ← SamplePolyCBD_{η₂}(PRF_{η₂}(σ,N))
/// N ← N + 1
/// end for
/// ŝ ← NTT(s)
/// ê ← NTT(e)
/// t̂ ← Â◦ŝ + ê
/// ekₚₖₑ ← ByteEncode₁₂(t̂) ‖ ρ
/// dkₚₖₑ ← ByteEncode₁₂(ŝ)
/// ```
/// The NIST FIPS 203 standard can be found at
/// <https://csrc.nist.gov/pubs/fips/203/ipd>.
val generate_keypair
(v_K v_PRIVATE_KEY_SIZE v_PUBLIC_KEY_SIZE v_RANKED_BYTES_PER_RING_ELEMENT v_ETA1 v_ETA1_RANDOMNESS_SIZE:
usize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,23 +71,28 @@ let t_MlKem1024PrivateKey = Libcrux_ml_kem.Types.t_MlKemPrivateKey (sz 3168)
unfold
let t_MlKem1024PublicKey = Libcrux_ml_kem.Types.t_MlKemPublicKey (sz 1568)

/// Decapsulate ML-KEM 1024
val decapsulate
(secret_key: Libcrux_ml_kem.Types.t_MlKemPrivateKey (sz 3168))
(ciphertext: Libcrux_ml_kem.Types.t_MlKemCiphertext (sz 1568))
: Prims.Pure (t_Array u8 (sz 32)) Prims.l_True (fun _ -> Prims.l_True)

/// Encapsulate ML-KEM 1024
val encapsulate
(public_key: Libcrux_ml_kem.Types.t_MlKemPublicKey (sz 1568))
(randomness: t_Array u8 (sz 32))
: Prims.Pure (Libcrux_ml_kem.Types.t_MlKemCiphertext (sz 1568) & t_Array u8 (sz 32))
Prims.l_True
(fun _ -> Prims.l_True)

/// Generate ML-KEM 1024 Key Pair
val generate_key_pair (randomness: t_Array u8 (sz 64))
: Prims.Pure (Libcrux_ml_kem.Types.t_MlKemKeyPair (sz 3168) (sz 1568))
Prims.l_True
(fun _ -> Prims.l_True)

/// Validate a public key.
/// Returns `Some(public_key)` if valid, and `None` otherwise.
val validate_public_key (public_key: Libcrux_ml_kem.Types.t_MlKemPublicKey (sz 1568))
: Prims.Pure (Core.Option.t_Option (Libcrux_ml_kem.Types.t_MlKemPublicKey (sz 1568)))
Prims.l_True
Expand Down
Loading

0 comments on commit 2befd45

Please sign in to comment.