diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fst b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fst index ea4a5084..cb9b3948 100644 --- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fst +++ b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fst @@ -17,14 +17,43 @@ let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16) = let compress_message_coefficient (fe: u16) = let (shifted: i16):i16 = Rust_primitives.mk_i16 1664 -! (cast (fe <: u16) <: i16) in + let _:Prims.unit = assert (v shifted == 1664 - v fe) in let mask:i16 = shifted >>! Rust_primitives.mk_i32 15 in + let _:Prims.unit = + assert (v mask = v shifted / pow2 15); + assert (if v shifted < 0 then mask = ones else mask = zero) + in let shifted_to_positive:i16 = mask ^. shifted in + let _:Prims.unit = + logxor_lemma shifted mask; + assert (v shifted < 0 ==> v shifted_to_positive = v (lognot shifted)); + neg_equiv_lemma shifted; + assert (v (lognot shifted) = - (v shifted) - 1); + assert (v shifted >= 0 ==> v shifted_to_positive = v (mask `logxor` shifted)); + assert (v shifted >= 0 ==> mask = zero); + assert (v shifted >= 0 ==> mask ^. shifted = shifted); + assert (v shifted >= 0 ==> v shifted_to_positive = v shifted); + assert (shifted_to_positive >=. mk_i16 0) + in let shifted_positive_in_range:i16 = shifted_to_positive -! Rust_primitives.mk_i16 832 in - cast ((shifted_positive_in_range >>! Rust_primitives.mk_i32 15 <: i16) &. Rust_primitives.mk_i16 1 - <: - i16) - <: - u8 + let _:Prims.unit = + assert (1664 - v fe >= 0 ==> v shifted_positive_in_range == 832 - v fe); + assert (1664 - v fe < 0 ==> v shifted_positive_in_range == - 2497 + v fe) + in + let r0:i16 = shifted_positive_in_range >>! Rust_primitives.mk_i32 15 in + let (r1: i16):i16 = r0 &. Rust_primitives.mk_i16 1 in + let res:u8 = cast (r1 <: i16) <: u8 in + let _:Prims.unit = + assert (v r0 = v shifted_positive_in_range / pow2 15); + assert (if v shifted_positive_in_range < 0 then r0 = ones else r0 = zero); + logand_lemma (mk_i16 1) r0; + assert (if v shifted_positive_in_range < 0 then r1 = mk_i16 1 else r1 = mk_i16 0); + assert ((v fe >= 833 && v fe <= 2496) ==> r1 = mk_i16 1); + assert (v fe < 833 ==> r1 = mk_i16 0); + assert (v fe > 2496 ==> r1 = mk_i16 0); + assert (v res = v r1) + in + res #push-options "--fuel 0 --ifuel 0 --z3rlimit 2000" @@ -167,38 +196,79 @@ let compress_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) #pop-options +#push-options "--z3rlimit 300 --ext context_pruning" + let decompress_ciphertext_coefficient (v_COEFFICIENT_BITS: i32) - (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) + (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) = - let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = + let _:Prims.unit = + assert_norm (pow2 1 == 2); + assert_norm (pow2 4 == 16); + assert_norm (pow2 5 == 32); + assert_norm (pow2 10 == 1024); + assert_norm (pow2 11 == 2048) + in + let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = Rust_primitives.Hax.Folds.fold_range (Rust_primitives.mk_usize 0) Libcrux_ml_kem.Vector.Traits.v_FIELD_ELEMENTS_IN_VECTOR - (fun v temp_1_ -> - let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = v in - let _:usize = temp_1_ in - true) - v - (fun v i -> - let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = v in + (fun a i -> + let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = a in let i:usize = i in + (v i < 16 ==> + (forall (j: nat). + (j >= v i /\ j < 16) ==> + v (Seq.index a.f_elements j) >= 0 /\ + v (Seq.index a.f_elements j) < pow2 (v v_COEFFICIENT_BITS))) /\ + (forall (j: nat). + j < v i ==> + v (Seq.index a.f_elements j) < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS)) + a + (fun a i -> + let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = a in + let i:usize = i in + let _:Prims.unit = + assert (v (a.f_elements.[ i ] <: i16) < pow2 11); + assert (v (a.f_elements.[ i ] <: i16) == v (cast (a.f_elements.[ i ] <: i16) <: i32)); + assert (v (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) == + v (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)); + assert (v ((cast (a.f_elements.[ i ] <: i16) <: i32) *! + (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)) == + v (cast (a.f_elements.[ i ] <: i16) <: i32) * + v (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32)) + in let decompressed:i32 = - (cast (v.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements.[ i ] <: i16) <: i32) *! + (cast (a.Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements.[ i ] <: i16) <: i32) *! (cast (Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS <: i16) <: i32) in + let _:Prims.unit = + assert (v (decompressed <>! (v_COEFFICIENT_BITS +! mk_i32 1 <: i32)) == + v decompressed / pow2 (v v_COEFFICIENT_BITS + 1)) + in let decompressed:i32 = decompressed >>! (v_COEFFICIENT_BITS +! Rust_primitives.mk_i32 1 <: i32) in - let v:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = + let _:Prims.unit = + assert (v decompressed < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS); + assert (v (cast decompressed <: i16) < v Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS) + in + let a:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = { - v with + a with Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements = - Rust_primitives.Hax.Monomorphized_update_at.update_at_usize v + Rust_primitives.Hax.Monomorphized_update_at.update_at_usize a .Libcrux_ml_kem.Vector.Portable.Vector_type.f_elements i (cast (decompressed <: i32) <: i16) @@ -206,6 +276,8 @@ let decompress_ciphertext_coefficient <: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector in - v) + a) in - v + a + +#pop-options diff --git a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fsti b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fsti index 808234a2..dda306a6 100644 --- a/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fsti +++ b/libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Vector.Portable.Compress.fsti @@ -41,12 +41,12 @@ val compress_message_coefficient (fe: u16) fun result -> let result:u8 = result in Hax_lib.implies ((Rust_primitives.mk_u16 833 <=. fe <: bool) && - (fe <=. Rust_primitives.mk_u16 2596 <: bool)) + (fe <=. Rust_primitives.mk_u16 2496 <: bool)) (fun temp_0_ -> let _:Prims.unit = temp_0_ in result =. Rust_primitives.mk_u8 1 <: bool) && Hax_lib.implies (~.((Rust_primitives.mk_u16 833 <=. fe <: bool) && - (fe <=. Rust_primitives.mk_u16 2596 <: bool)) + (fe <=. Rust_primitives.mk_u16 2496 <: bool)) <: bool) (fun temp_0_ -> @@ -84,7 +84,18 @@ val compress_1_ (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) val decompress_ciphertext_coefficient (v_COEFFICIENT_BITS: i32) - (v: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) + (a: Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector) : Prims.Pure Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector - Prims.l_True - (fun _ -> Prims.l_True) + (requires + (v v_COEFFICIENT_BITS == 4 \/ v v_COEFFICIENT_BITS == 5 \/ v v_COEFFICIENT_BITS == 10 \/ + v v_COEFFICIENT_BITS == 11) /\ + (forall (i: nat). + i < 16 ==> + v (Seq.index a.f_elements i) >= 0 /\ + v (Seq.index a.f_elements i) < pow2 (v v_COEFFICIENT_BITS))) + (ensures + fun result -> + let result:Libcrux_ml_kem.Vector.Portable.Vector_type.t_PortableVector = result in + forall (i: nat). + i < 16 ==> + v (Seq.index result.f_elements i) < Libcrux_ml_kem.Vector.Traits.v_FIELD_MODULUS) diff --git a/libcrux-ml-kem/src/vector/portable/compress.rs b/libcrux-ml-kem/src/vector/portable/compress.rs index fa8e5a0e..555cc97c 100644 --- a/libcrux-ml-kem/src/vector/portable/compress.rs +++ b/libcrux-ml-kem/src/vector/portable/compress.rs @@ -25,8 +25,8 @@ use crate::vector::FIELD_MODULUS; /// . #[cfg_attr(hax, hax_lib::requires(fe < (FIELD_MODULUS as u16)))] #[cfg_attr(hax, hax_lib::ensures(|result| - hax_lib::implies(833 <= fe && fe <= 2596, || result == 1) && - hax_lib::implies(!(833 <= fe && fe <= 2596), || result == 0) + hax_lib::implies(833 <= fe && fe <= 2496, || result == 1) && + hax_lib::implies(!(833 <= fe && fe <= 2496), || result == 0) ))] pub(crate) fn compress_message_coefficient(fe: u16) -> u8 { // The approach used here is inspired by: @@ -35,6 +35,7 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 { // If 833 <= fe <= 2496, // then -832 <= shifted <= 831 let shifted: i16 = 1664 - (fe as i16); + hax_lib::fstar!("assert (v $shifted == 1664 - v $fe)"); // If shifted < 0, then // (shifted >> 15) ^ shifted = flip_bits(shifted) = -shifted - 1, and so @@ -44,13 +45,37 @@ pub(crate) fn compress_message_coefficient(fe: u16) -> u8 { // (shifted >> 15) ^ shifted = shifted, and so // if 0 <= shifted <= 831 then 0 <= shifted_positive <= 831 let mask = shifted >> 15; + hax_lib::fstar!("assert (v $mask = v $shifted / pow2 15); + assert (if v $shifted < 0 then $mask = ones else $mask = zero)"); let shifted_to_positive = mask ^ shifted; + hax_lib::fstar!("logxor_lemma $shifted $mask; + assert (v $shifted < 0 ==> v $shifted_to_positive = v (lognot $shifted)); + neg_equiv_lemma $shifted; + assert (v (lognot $shifted) = -(v $shifted) -1); + assert (v $shifted >= 0 ==> v $shifted_to_positive = v ($mask `logxor` $shifted)); + assert (v $shifted >= 0 ==> $mask = zero); + assert (v $shifted >= 0 ==> $mask ^. $shifted = $shifted); + assert (v $shifted >= 0 ==> v $shifted_to_positive = v $shifted); + assert ($shifted_to_positive >=. mk_i16 0)"); let shifted_positive_in_range = shifted_to_positive - 832; + hax_lib::fstar!("assert (1664 - v $fe >= 0 ==> v $shifted_positive_in_range == 832 - v $fe); + assert (1664 - v $fe < 0 ==> v $shifted_positive_in_range == -2497 + v $fe)"); // If x <= 831, then x - 832 <= -1, and so x - 832 < 0, which means // the most significant bit of shifted_positive_in_range will be 1. - ((shifted_positive_in_range >> 15) & 1) as u8 + let r0 = shifted_positive_in_range >> 15; + let r1: i16 = r0 & 1; + let res = r1 as u8; + hax_lib::fstar!("assert (v $r0 = v $shifted_positive_in_range / pow2 15); + assert (if v $shifted_positive_in_range < 0 then $r0 = ones else $r0 = zero); + logand_lemma (mk_i16 1) $r0; + assert (if v $shifted_positive_in_range < 0 then $r1 = mk_i16 1 else $r1 = mk_i16 0); + assert ((v $fe >= 833 && v $fe <= 2496) ==> $r1 = mk_i16 1); + assert (v $fe < 833 ==> $r1 = mk_i16 0); + assert (v $fe > 2496 ==> $r1 = mk_i16 0); + assert (v $res = v $r1)"); + res } #[cfg_attr(hax, @@ -147,23 +172,51 @@ pub(crate) fn compress(mut a: PortableVector) -> Po } #[inline(always)] +#[hax_lib::fstar::options("--z3rlimit 300 --ext context_pruning")] +#[hax_lib::requires(fstar!("(v $COEFFICIENT_BITS == 4 \\/ + v $COEFFICIENT_BITS == 5 \\/ + v $COEFFICIENT_BITS == 10 \\/ + v $COEFFICIENT_BITS == 11) /\\ + (forall (i:nat). i < 16 ==> v (Seq.index ${a}.f_elements i) >= 0 /\\ + v (Seq.index ${a}.f_elements i) < pow2 (v $COEFFICIENT_BITS))"))] +#[hax_lib::ensures(|result| fstar!("forall (i:nat). i < 16 ==> v (Seq.index ${result}.f_elements i) < $FIELD_MODULUS"))] pub(crate) fn decompress_ciphertext_coefficient( - mut v: PortableVector, + mut a: PortableVector, ) -> PortableVector { - // debug_assert!(to_i16_array(v) - // .into_iter() - // .all(|coefficient| coefficient.abs() < 1 << COEFFICIENT_BITS)); + hax_lib::fstar!("assert_norm (pow2 1 == 2); + assert_norm (pow2 4 == 16); + assert_norm (pow2 5 == 32); + assert_norm (pow2 10 == 1024); + assert_norm (pow2 11 == 2048)"); for i in 0..FIELD_ELEMENTS_IN_VECTOR { - let mut decompressed = v.elements[i] as i32 * FIELD_MODULUS as i32; + hax_lib::loop_invariant!(|i: usize| { fstar!("(v $i < 16 ==> (forall (j:nat). (j >= v $i /\\ j < 16) ==> + v (Seq.index ${a}.f_elements j) >= 0 /\\ v (Seq.index ${a}.f_elements j) < pow2 (v $COEFFICIENT_BITS))) /\\ + (forall (j:nat). j < v $i ==> + v (Seq.index ${a}.f_elements j) < v $FIELD_MODULUS)") }); + hax_lib::fstar!("assert (v (${a}.f_elements.[ $i ] <: i16) < pow2 11); + assert (v (${a}.f_elements.[ $i ] <: i16) == + v (cast (${a}.f_elements.[ $i ] <: i16) <: i32)); + assert (v ($FIELD_MODULUS <: i16) == + v (cast ($FIELD_MODULUS <: i16) <: i32)); + assert (v ((cast (${a}.f_elements.[ $i ] <: i16) <: i32) *! + (cast ($FIELD_MODULUS <: i16) <: i32)) == + v (cast (${a}.f_elements.[ $i ] <: i16) <: i32) * + v (cast ($FIELD_MODULUS <: i16) <: i32))"); + let mut decompressed = a.elements[i] as i32 * FIELD_MODULUS as i32; + hax_lib::fstar!("assert (v ($decompressed <>! ($COEFFICIENT_BITS +! mk_i32 1 <: i32)) == + v $decompressed / pow2 (v $COEFFICIENT_BITS + 1))"); decompressed = decompressed >> (COEFFICIENT_BITS + 1); - v.elements[i] = decompressed as i16; + hax_lib::fstar!("assert (v $decompressed < v $FIELD_MODULUS); + assert (v (cast $decompressed <: i16) < v $FIELD_MODULUS)"); + a.elements[i] = decompressed as i16; } - // debug_assert!(to_i16_array(v) - // .into_iter() - // .all(|coefficient| coefficient.abs() as u16 <= 1 << 12)); - - v + a }