Skip to content

Commit

Permalink
Merge pull request #188 from cryspen/lucas/kyber-serialize-specs
Browse files Browse the repository at this point in the history
feat(kyber/fstar): add specs for serialization functions
  • Loading branch information
franziskuskiefer authored Jan 26, 2024
2 parents d581937 + 654d0c4 commit 05fb489
Showing 1 changed file with 63 additions and 20 deletions.
83 changes: 63 additions & 20 deletions proofs/fstar/extraction-edited/Spec.Kyber.fst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@ module Spec.Kyber
open Core
open FStar.Mul

(** Utils *)
let map' #a #b
(f:(x:a -> b))
(s: t_Slice a): t_Slice b
= createi (length s) (fun i -> f (Seq.index s (v i)))

let flatten #t #n
(#m: usize {range (v n * v m) usize_inttype})
(x: t_Array (t_Array t m) n)
: t_Array t (m *! n)
= createi (m *! n) (fun i -> Seq.index (Seq.index x (v i / v m)) (v i % v m))

(** Constants *)
let v_BITS_PER_COEFFICIENT: usize = sz 12

Expand Down Expand Up @@ -131,9 +143,6 @@ assume val poly_inv_ntt: #p:params -> polynomial -> polynomial
assume val vector_ntt: #p:params -> vector p -> vector p
assume val vector_inv_ntt: #p:params -> vector p -> vector p

assume val vector_encode_12: #p:params -> vector p -> t_Array u8 (v_T_AS_NTT_ENCODED_SIZE p)
assume val vector_decode_12: #p:params -> t_Array u8 (v_T_AS_NTT_ENCODED_SIZE p) -> vector p

// note we take seed of size 32 not 34 as in hacspec
assume val sample_matrix_A: #p:params -> seed:t_Array u8 (sz 32) -> matrix p
// note we take seed of size 32 not 33 as in hacspec
Expand All @@ -153,36 +162,70 @@ let sample_poly_cbd #p seed domain_sep =
let sample_vector_cbd_then_ntt (#p:params) (seed:t_Array u8 (sz 32)) (domain_sep:usize) =
vector_ntt (sample_vector_cbd #p seed domain_sep)

type dT = d: nat {d = 1 \/ d = 4 \/ d = 5 \/ d = 11 \/ d = 12}
type dT = d: nat {d = 1 \/ d = 4 \/ d = 5 \/ d = 10 \/ d = 11 \/ d = 12}

assume val compress_ciphertext_coefficient
: dT -> field_element -> r: field_element
let compress_d (d: dT {d <> 12}) (x: field_element): field_element
= (pow2 d * x + 1664) / v v_FIELD_MODULUS

assume val bits_to_bytes (#bytes: usize) (f: (i:nat {i < v bytes * 8} -> bit))
: Pure (t_Array u8 bytes)
(requires True)
(ensures fun r -> (forall i. bit_vec_of_int_arr r 8 i == f i))

let encode_bytes
(d: dT)
(coefficients: polynomial)
: t_Array u8 (sz (32 * d))
assume val bytes_to_bits (#bytes: usize) (r: t_Array u8 bytes)
: Pure (i:nat {i < v bytes * 8} -> bit)
(requires True)
(ensures fun f -> (forall i. bit_vec_of_int_arr r 8 i == f i))

let byte_encode (d: dT) (coefficients: polynomial): t_Array u8 (sz (32 * d))
= bits_to_bytes #(sz (32 * d)) (bit_vec_of_nat_arr coefficients d)

let byte_decode (d: dT) (coefficients: t_Array u8 (sz (32 * d))): polynomial
= admit ()

let vector_encode_12 (#p:params) (v: vector p): t_Array u8 (v_T_AS_NTT_ENCODED_SIZE p)
= let s: t_Array (t_Array _ (sz 384)) p.v_RANK = map' (byte_encode 12) v in
flatten s

let vector_decode_12 (#p:params) (arr: t_Array u8 (v_T_AS_NTT_ENCODED_SIZE p)): vector p
= createi p.v_RANK (fun block ->
let block_size = (sz (32 * 12)) in
let slice = Seq.slice arr (v block * v block_size)
(v block * v block_size + v block_size) in
byte_decode 12 slice
)

let compress_then_byte_encode (d: dT {d <> 12}) (coefficients: polynomial): t_Array u8 (sz (32 * d))
= let coefs: t_Array nat (sz 256) = map (fun (f: nat {f < v v_FIELD_MODULUS}) ->
compress_ciphertext_coefficient d f <: nat
compress_d d f <: nat
) coefficients
in
bits_to_bytes #(sz (32 * d)) (bit_vec_of_nat_arr coefs d)
byte_encode d coefficients

let compress_then_encode_message: polynomial -> t_Array u8 v_SHARED_SECRET_SIZE
= encode_bytes 1
= byte_encode 1

let decode_then_decompress_message: t_Array u8 v_SHARED_SECRET_SIZE -> polynomial
= byte_decode 1

let compress_then_encode_u (p:params) (vec: vector p): t_Array u8 (v_C1_SIZE p)
= let d = p.v_VECTOR_U_COMPRESSION_FACTOR in
flatten (map #_ #_ #(fun _ -> True) (byte_encode (v d)) vec)

let decode_then_decompress_u (p:params) (arr: t_Array u8 (v_C1_SIZE p)): vector p
= let d = p.v_VECTOR_U_COMPRESSION_FACTOR in
createi p.v_RANK (fun block ->
let block_size = v_C1_BLOCK_SIZE p in
let slice = Seq.slice arr (v block * v block_size)
(v block * v block_size + v block_size) in
byte_decode (v d) slice
)

assume val decode_then_decompress_message: t_Array u8 v_SHARED_SECRET_SIZE -> polynomial
assume val compress_then_encode_u (p:params): vector p -> t_Array u8 (v_C1_SIZE p)

assume val decode_then_decompress_u: p:params -> t_Array u8 (v_C1_SIZE p) -> vector p
let compress_then_encode_v (p:params): polynomial -> t_Array u8 (v_C2_SIZE p)
= encode_bytes (v p.v_VECTOR_V_COMPRESSION_FACTOR)
assume val decode_then_decompress_v: p:params -> t_Array u8 (v_C2_SIZE p) -> polynomial
= byte_encode (v p.v_VECTOR_V_COMPRESSION_FACTOR)

let decode_then_decompress_v (p:params): t_Array u8 (v_C2_SIZE p) -> polynomial
= byte_decode (v p.v_VECTOR_V_COMPRESSION_FACTOR)

(** IND-CPA Functions *)

Expand Down

0 comments on commit 05fb489

Please sign in to comment.