diff --git a/libcrux-ml-kem/Cargo.toml b/libcrux-ml-kem/Cargo.toml index f928c0626..f83163e70 100644 --- a/libcrux-ml-kem/Cargo.toml +++ b/libcrux-ml-kem/Cargo.toml @@ -26,6 +26,7 @@ rand = { version = "0.8", optional = true } libcrux-platform = { version = "0.0.2-beta.2", path = "../sys/platform" } libcrux-sha3 = { version = "0.0.2-beta.2", path = "../libcrux-sha3" } libcrux-intrinsics = { version = "0.0.2-beta.2", path = "../libcrux-intrinsics" } +libcrux-secrets = { version = "0.0.2-beta.2", path = "../secrets" } hax-lib = { version = "0.1.0", git = "https://github.com/hacspec/hax/" } [features] diff --git a/libcrux-ml-kem/src/hash_functions.rs b/libcrux-ml-kem/src/hash_functions.rs index 572664cff..6f9c3f98f 100644 --- a/libcrux-ml-kem/src/hash_functions.rs +++ b/libcrux-ml-kem/src/hash_functions.rs @@ -8,6 +8,8 @@ // them to be properly abstracted in F*. We would like hax to do this automatically. // Related Issue: https://github.com/hacspec/hax/issues/616 +use libcrux_secrets::{AsSecret, AsSecretRef}; + use crate::constants::{G_DIGEST_SIZE, H_DIGEST_SIZE}; /// The SHA3 block size. @@ -89,7 +91,7 @@ pub(crate) mod portable { #[inline(always)] fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { let mut digest = [0u8; G_DIGEST_SIZE]; - portable::sha512(&mut digest, input); + portable::sha512(&mut digest, input.as_secret()); digest } @@ -99,7 +101,7 @@ pub(crate) mod portable { #[inline(always)] fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { let mut digest = [0u8; H_DIGEST_SIZE]; - portable::sha256(&mut digest, input); + portable::sha256(&mut digest, input.as_secret()); digest } @@ -110,7 +112,7 @@ pub(crate) mod portable { #[inline(always)] fn PRF(input: &[u8]) -> [u8; LEN] { let mut digest = [0u8; LEN]; - portable::shake256(&mut digest, input); + portable::shake256(&mut digest, input.as_secret()); digest } @@ -124,7 +126,7 @@ pub(crate) mod portable { let mut out = [[0u8; LEN]; K]; for i in 0..K { - portable::shake256(&mut out[i], &input[i]); + portable::shake256(&mut out[i], (&input[i]).as_secret()); } out } @@ -135,7 +137,7 @@ pub(crate) mod portable { let mut shake128_state = [incremental::shake128_init(); K]; for i in 0..K { - incremental::shake128_absorb_final(&mut shake128_state[i], &input[i]); + incremental::shake128_absorb_final(&mut shake128_state[i], (&input[i]).as_secret()); } PortableHash { shake128_state } } @@ -498,7 +500,7 @@ pub(crate) mod neon { #[inline(always)] fn G(input: &[u8]) -> [u8; G_DIGEST_SIZE] { let mut digest = [0u8; G_DIGEST_SIZE]; - libcrux_sha3::neon::sha512(&mut digest, input); + libcrux_sha3::neon::sha512(&mut digest, input.as_secret()); digest } @@ -508,7 +510,7 @@ pub(crate) mod neon { #[inline(always)] fn H(input: &[u8]) -> [u8; H_DIGEST_SIZE] { let mut digest = [0u8; H_DIGEST_SIZE]; - libcrux_sha3::neon::sha256(&mut digest, input); + libcrux_sha3::neon::sha256(&mut digest, input.as_secret()); digest } @@ -520,7 +522,12 @@ pub(crate) mod neon { fn PRF(input: &[u8]) -> [u8; LEN] { let mut digest = [0u8; LEN]; let mut dummy = [0u8; LEN]; - x2::shake256(input, input, &mut digest, &mut dummy); + x2::shake256( + input.as_secret(), + input.as_secret(), + &mut digest, + &mut dummy, + ); digest } @@ -538,20 +545,45 @@ pub(crate) mod neon { let mut out3 = [0u8; LEN]; match K as u8 { 2 => { - x2::shake256(&input[0], &input[1], &mut out0, &mut out1); + x2::shake256( + (&input[0]).as_secret(), + (&input[1]).as_secret(), + &mut out0, + &mut out1, + ); out[0] = out0; out[1] = out1; } 3 => { - x2::shake256(&input[0], &input[1], &mut out0, &mut out1); - x2::shake256(&input[2], &input[2], &mut out2, &mut out3); + x2::shake256( + (&input[0]).as_secret(), + (&input[1]).as_secret(), + &mut out0, + &mut out1, + ); + x2::shake256( + (&input[2]).as_secret(), + (&input[2]).as_secret(), + &mut out2, + &mut out3, + ); out[0] = out0; out[1] = out1; out[2] = out2; } 4 => { - x2::shake256(&input[0], &input[1], &mut out0, &mut out1); - x2::shake256(&input[2], &input[3], &mut out2, &mut out3); + x2::shake256( + (&input[0]).as_secret(), + (&input[1]).as_secret(), + &mut out0, + &mut out1, + ); + x2::shake256( + (&input[2]).as_secret(), + (&input[3]).as_secret(), + &mut out2, + &mut out3, + ); out[0] = out0; out[1] = out1; out[2] = out2; @@ -568,15 +600,35 @@ pub(crate) mod neon { let mut state = [x2::incremental::init(), x2::incremental::init()]; match K as u8 { 2 => { - x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); + x2::incremental::shake128_absorb_final( + &mut state[0], + (&input[0]).as_secret(), + (&input[1]).as_secret(), + ); } 3 => { - x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); - x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[2]); + x2::incremental::shake128_absorb_final( + &mut state[0], + (&input[0]).as_secret(), + (&input[1]).as_secret(), + ); + x2::incremental::shake128_absorb_final( + &mut state[1], + (&input[2]).as_secret(), + (&input[2]).as_secret(), + ); } 4 => { - x2::incremental::shake128_absorb_final(&mut state[0], &input[0], &input[1]); - x2::incremental::shake128_absorb_final(&mut state[1], &input[2], &input[3]); + x2::incremental::shake128_absorb_final( + &mut state[0], + (&input[0]).as_secret(), + (&input[1]).as_secret(), + ); + x2::incremental::shake128_absorb_final( + &mut state[1], + (&input[2]).as_secret(), + (&input[3]).as_secret(), + ); } _ => unreachable!("This function can only called be called with N = 2, 3, 4"), } diff --git a/libcrux-ml-kem/tests/nistkats.rs b/libcrux-ml-kem/tests/nistkats.rs index 99acc27a4..298696dd1 100644 --- a/libcrux-ml-kem/tests/nistkats.rs +++ b/libcrux-ml-kem/tests/nistkats.rs @@ -1,3 +1,4 @@ +use libcrux_secrets::{AsSecret, AsSecretRef}; use serde::Deserialize; use serde_json; use std::{fs::File, io::BufReader, path::Path}; @@ -52,23 +53,23 @@ macro_rules! impl_nist_known_answer_tests { let pk = unpacked::key_pair_serialized_public_key(&unpacked_keys); let sk = unpacked::key_pair_serialized_private_key(&unpacked_keys); - let public_key_hash = sha256(pk.as_slice()); - let secret_key_hash = sha256(sk.as_slice()); + let public_key_hash = sha256(pk.as_slice().as_secret()); + let secret_key_hash = sha256(sk.as_slice().as_secret()); assert_eq!(public_key_hash, kat.sha3_256_hash_of_public_key, "lhs: computed public key hash, rhs: hash from kat"); assert_eq!(secret_key_hash, kat.sha3_256_hash_of_secret_key, "lhs: computed secret key hash, rhs: hash from kat"); } - let public_key_hash = sha256(key_pair.pk()); + let public_key_hash = sha256(key_pair.pk().as_secret()); eprintln!("pk hash: {}", hex::encode(public_key_hash)); - let secret_key_hash = sha256(key_pair.sk()); + let secret_key_hash = sha256(key_pair.sk().as_secret()); assert_eq!(public_key_hash, kat.sha3_256_hash_of_public_key, "lhs: computed public key hash, rhs: hash from kat"); assert_eq!(secret_key_hash, kat.sha3_256_hash_of_secret_key, "lhs: computed secret key hash, rhs: hash from kat"); let (ciphertext, shared_secret) = encapsulate(key_pair.public_key(), kat.encapsulation_seed); - let ciphertext_hash = sha256(ciphertext.as_ref()); + let ciphertext_hash = sha256(ciphertext.as_ref().as_secret()); assert_eq!(ciphertext_hash, kat.sha3_256_hash_of_ciphertext, "lhs: computed ciphertext hash, rhs: hash from akt"); assert_eq!(shared_secret.as_ref(), kat.shared_secret, "lhs: computed shared secret from encapsulate, rhs: shared secret from kat"); diff --git a/libcrux-ml-kem/tests/self.rs b/libcrux-ml-kem/tests/self.rs index d54a72184..fedd07d09 100644 --- a/libcrux-ml-kem/tests/self.rs +++ b/libcrux-ml-kem/tests/self.rs @@ -1,5 +1,6 @@ use libcrux_ml_kem::{MlKemCiphertext, MlKemPrivateKey}; +use libcrux_secrets::AsSecret; use libcrux_sha3::shake256; use rand::{rngs::OsRng, thread_rng, RngCore}; @@ -212,7 +213,7 @@ fn compute_implicit_rejection_shared_secret let mut to_hash = secret_key[MlKemPrivateKey::::len() - SHARED_SECRET_SIZE..].to_vec(); to_hash.extend_from_slice(ciphertext.as_ref()); - shake256(&to_hash) + shake256((&to_hash).as_secret()) } macro_rules! impl_modified_secret_key { diff --git a/secrets/src/ct.rs b/secrets/src/ct.rs new file mode 100644 index 000000000..b16292820 --- /dev/null +++ b/secrets/src/ct.rs @@ -0,0 +1,248 @@ +//! # Constant time operations +//! +//! These are crude attempts to prevent LLVM from optimizing away the code in this +//! module. This is not guaranteed to work but at the time of writing, achieved +//! its goals. +//! `read_volatile` could be used as well but seems unnecessary at this point in +//! time. +//! Examine the output that LLVM produces for this code from time to time to ensure +//! operations are not being optimized away/constant-timedness is not being broken. + +// XXX: We have to disable some of this for C extraction for now. See eurydice/issues#37 + +use hax_lib::{ensures, fstar, loop_invariant, requires}; + +/// Return 1 if `value` is not zero and 0 otherwise. +#[ensures(|result| fstar!(r#"($value == (mk_u8 0) ==> $result == (mk_u8 0)) /\ + ($value =!= (mk_u8 0) ==> $result == (mk_u8 1))"#))] +fn inz(value: u8) -> u8 { + // We need the original value for the F* proof + let _orig_value = value; + + let value = value as u16; + let result = ((!value).wrapping_add(1) >> 8) as u8; + let res = result & 1; + + // F* proof + fstar!( + r#"if v $_orig_value = 0 then ( + assert($value == zero); + lognot_lemma $value; + assert((~.$value +. (mk_u16 1)) == zero); + assert((Core.Num.impl__u16__wrapping_add (~.$value <: u16) (mk_u16 1) <: u16) == zero); + logor_lemma $value zero; + assert(($value |. (Core.Num.impl__u16__wrapping_add (~.$value <: u16) (mk_u16 1) <: u16) <: u16) == $value); + assert (v $result == v (($value >>! (mk_i32 8)))); + assert ((v $value / pow2 8) == 0); + assert ($result == (mk_u8 0)); + logand_lemma (mk_u8 1) $result; + assert ($res == (mk_u8 0))) + else ( + assert (v $value <> 0); + lognot_lemma $value; + assert (v (~.$value) = pow2 16 - 1 - v $value); + assert (v (~.$value) + 1 = pow2 16 - v $value); + assert (v ($value) <= pow2 8 - 1); + assert ((v (~.$value) + 1) = (pow2 16 - pow2 8) + (pow2 8 - v $value)); + assert ((v (~.$value) + 1) = (pow2 8 - 1) * pow2 8 + (pow2 8 - v $value)); + assert ((v (~.$value) + 1)/pow2 8 = (pow2 8 - 1)); + assert (v ((Core.Num.impl__u16__wrapping_add (~.$value <: u16) (mk_u16 1) <: u16) >>! (mk_i32 8)) = pow2 8 - 1); + assert ($result = ones); + logand_lemma (mk_u8 1) $result; + assert ($res = (mk_u8 1)))"# + ); + + res +} + +#[inline(never)] // Don't inline this to avoid that the compiler optimizes this out. +#[ensures(|result| fstar!(r#"($value == (mk_u8 0) ==> $result == (mk_u8 0)) /\ + ($value =!= (mk_u8 0) ==> $result == (mk_u8 1))"#))] +fn is_non_zero(value: u8) -> u8 { + #[cfg(eurydice)] + return inz(value); + + // Eurydice can't handle this + // XXX: May be fixed by now + #[cfg(not(eurydice))] + core::hint::black_box(inz(value)) +} + +/// Return 1 if the bytes of `lhs` and `rhs` do not exactly +/// match and 0 otherwise. +#[requires(lhs.len() == rhs.len())] +#[ensures(|result| fstar!(r#"($lhs == $rhs ==> $result == (mk_u8 0)) /\ + ($lhs =!= $rhs ==> $result == (mk_u8 1))"#))] +fn _compare(lhs: &[u8], rhs: &[u8]) -> u8 { + let mut r: u8 = 0; + + for i in 0..lhs.len() { + loop_invariant!(|i: usize| { + fstar!( + r#"v $i <= Seq.length $lhs /\ + (if (Seq.slice $lhs 0 (v $i) = Seq.slice $rhs 0 (v $i)) then + $r == (mk_u8 0) + else ~ ($r == (mk_u8 0)))"# + ) + }); + + let nr = r | (lhs[i] ^ rhs[i]); + + // F* proof + fstar!( + r#"if $r =. (mk_u8 0) then ( + if (Seq.index $lhs (v $i) = Seq.index $rhs (v $i)) then ( + logxor_lemma (Seq.index $lhs (v $i)) (Seq.index $rhs (v $i)); + assert (((${lhs}.[ $i ] <: u8) ^. (${rhs}.[ $i ] <: u8) <: u8) = zero); + logor_lemma $r ((${lhs}.[ $i ] <: u8) ^. (${rhs}.[ $i ] <: u8) <: u8); + assert ($nr = $r); + assert (forall j. Seq.index (Seq.slice $lhs 0 (v $i)) j == Seq.index $lhs j); + assert (forall j. Seq.index (Seq.slice $rhs 0 (v $i)) j == Seq.index $rhs j); + eq_intro (Seq.slice $lhs 0 ((v $i) + 1)) (Seq.slice $rhs 0 ((v $i) + 1)) + ) + else ( + logxor_lemma (Seq.index $lhs (v $i)) (Seq.index $rhs (v $i)); + assert (((${lhs}.[ $i ] <: u8) ^. (${rhs}.[ $i ] <: u8) <: u8) <> zero); + logor_lemma r ((${lhs}.[ $i ] <: u8) ^. (${rhs}.[ $i ] <: u8) <: u8); + assert (v $nr > 0); + assert (Seq.index (Seq.slice $lhs 0 ((v $i)+1)) (v $i) <> + Seq.index (Seq.slice $rhs 0 ((v $i)+1)) (v $i)); + assert (Seq.slice $lhs 0 ((v $i)+1) <> Seq.slice $rhs 0 ((v $i) + 1)) + ) + ) else ( + logor_lemma $r ((${lhs}.[ $i ] <: u8) ^. (${rhs}.[ $i ] <: u8) <: u8); + assert (v $nr >= v $r); + assert (Seq.slice $lhs 0 (v $i) <> Seq.slice $rhs 0 (v $i)); + if (Seq.slice $lhs 0 ((v $i)+1) = Seq.slice $rhs 0 ((v $i) + 1)) then + (assert (forall j. j < (v $i) + 1 ==> Seq.index (Seq.slice $lhs 0 ((v $i)+1)) j == Seq.index (Seq.slice $rhs 0 ((v $i)+1)) j); + eq_intro (Seq.slice $lhs 0 (v $i)) (Seq.slice $rhs 0 (v $i)); + assert(False)) + )"# + ); + + r = nr; + } + + is_non_zero(r) +} + +/// If `selector` is not zero, return the bytes in `rhs`; return the bytes in +/// `lhs` otherwise. +#[requires( + lhs.len() == rhs.len() && + lhs.len() == N +)] +#[ensures(|result| fstar!(r#"($selector == (mk_u8 0) ==> $result == $lhs) /\ + ($selector =!= (mk_u8 0) ==> $result == $rhs)"#))] +#[fstar::options("--ifuel 0 --z3rlimit 50")] +fn select_ct(lhs: &[u8], rhs: &[u8], selector: u8) -> [u8; N] { + let mask = is_non_zero(selector).wrapping_sub(1); + let mut out = [0u8; N]; + + fstar!( + "assert (if $selector = (mk_u8 0) then $mask = ones else $mask = zero); + lognot_lemma $mask; + assert (if $selector = (mk_u8 0) then ~.$mask = zero else ~.$mask = ones)" + ); + + for i in 0..N { + loop_invariant!(|i: usize| { + fstar!( + r#"v $i <= v $SHARED_SECRET_SIZE /\ + (forall j. j < v $i ==> (if ($selector =. (mk_u8 0)) then Seq.index $out j == Seq.index $lhs j else Seq.index $out j == Seq.index $rhs j)) /\ + (forall j. j >= v $i ==> Seq.index $out j == (mk_u8 0))"# + ) + }); + fstar!(r#"assert ((${out}.[ $i ] <: u8) = (mk_u8 0))"#); + + let outi = (lhs[i] & mask) | (rhs[i] & !mask); + + fstar!( + r#"if ($selector = (mk_u8 0)) then ( + logand_lemma (${lhs}.[ $i ] <: u8) $mask; + assert (((${lhs}.[ $i ] <: u8) &. $mask <: u8) == (${lhs}.[ $i ] <: u8)); + logand_lemma (${rhs}.[ $i ] <: u8) (~.$mask); + assert (((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8) == zero); + logor_lemma ((${lhs}.[ $i ] <: u8) &. $mask <: u8) ((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8); + assert ((((${lhs}.[ $i ] <: u8) &. $mask <: u8) |. ((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8) <: u8) == (${lhs}.[ $i ] <: u8)); + logor_lemma (${out}.[ $i ] <: u8) (${lhs}.[ $i ] <: u8); + assert (((${out}.[ $i ] <: u8) |. (((${lhs}.[ $i ] <: u8) &. $mask <: u8) |. ((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8) <: u8) <: u8) == (${lhs}.[ $i ] <: u8)); + assert ($outi = (${lhs}.[ $i ] <: u8)) + ) + else ( + logand_lemma (${lhs}.[ $i ] <: u8) $mask; + assert (((${lhs}.[ $i ] <: u8) &. $mask <: u8) == zero); + logand_lemma (${rhs}.[ $i ] <: u8) (~.$mask); + assert (((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8) == (${rhs}.[ $i ] <: u8)); + logor_lemma (${rhs}.[ $i ] <: u8) zero; + assert ((logor zero (${rhs}.[ $i ] <: u8)) == (${rhs}.[ $i ] <: u8)); + assert ((((${lhs}.[ $i ] <: u8) &. $mask <: u8) |. ((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8)) == (${rhs}.[ $i ] <: u8)); + logor_lemma (${out}.[ $i ] <: u8) (${rhs}.[ $i ] <: u8); + assert (((${out}.[ $i ] <: u8) |. (((${lhs}.[ $i ] <: u8) &. $mask <: u8) |. ((${rhs}.[ $i ] <: u8) &. (~.$mask <: u8) <: u8) <: u8) <: u8) == (${rhs}.[ $i ] <: u8)); + assert ($outi = (${rhs}.[ $i ] <: u8)) + )"# + ); + + out[i] = outi; + } + + fstar!( + "if ($selector =. (mk_u8 0)) then ( + eq_intro $out $lhs + ) + else ( + eq_intro $out $rhs + )" + ); + + out +} + +// Don't inline this to avoid that the compiler optimizes this out. +#[inline(never)] +#[requires(lhs.len() == rhs.len())] +#[ensures(|result| fstar!(r#"($lhs == $rhs ==> $result == (mk_u8 0)) /\ + ($lhs =!= $rhs ==> $result == (mk_u8 1))"#))] +pub fn compare(lhs: &[u8], rhs: &[u8]) -> u8 { + #[cfg(eurydice)] + return _compare(lhs, rhs); + + #[cfg(not(eurydice))] + core::hint::black_box(_compare(lhs, rhs)) +} + +// Don't inline this to avoid that the compiler optimizes this out. +#[inline(never)] +#[requires( + lhs.len() == rhs.len() && + lhs.len() == N +)] +#[ensures(|result| fstar!(r#"($selector == (mk_u8 0) ==> $result == $lhs) /\ + ($selector =!= (mk_u8 0) ==> $result == $rhs)"#))] +pub fn select(lhs: &[u8], rhs: &[u8], selector: u8) -> [u8; N] { + #[cfg(eurydice)] + return select_ct(lhs, rhs, selector); + + #[cfg(not(eurydice))] + core::hint::black_box(select_ct(lhs, rhs, selector)) +} + +// Don't inline this to avoid that the compiler optimizes this out. +#[inline(never)] +#[requires( + lhs_c.len() == rhs_c.len() && + lhs_s.len() == rhs_s.len() && + lhs_s.len() == N +)] +#[ensures(|result| fstar!(r#"let selector = if $lhs_c =. $rhs_c then (mk_u8 0) else (mk_u8 1) in + ((selector == (mk_u8 0) ==> $result == $lhs_s) /\ + (selector =!= (mk_u8 0) ==> $result == $rhs_s))"#))] +pub fn compare_select( + lhs_c: &[u8], + rhs_c: &[u8], + lhs_s: &[u8], + rhs_s: &[u8], +) -> [u8; N] { + let selector = compare(lhs_c, rhs_c); + select(lhs_s, rhs_s, selector) +} diff --git a/secrets/src/integers.rs b/secrets/src/integers.rs index 1e248f65d..12f7df675 100644 --- a/secrets/src/integers.rs +++ b/secrets/src/integers.rs @@ -285,11 +285,13 @@ impl EncodeOps for U32 { } fn try_from_le_bytes(x: &[Secret]) -> Self { - todo!() + let x: &[u8] = unsafe { core::mem::transmute(x) }; + u32::from_le_bytes(x.try_into().unwrap()).classify() } fn try_from_be_bytes(x: &[Secret]) -> Self { - todo!() + let x: &[u8] = unsafe { core::mem::transmute(x) }; + u32::from_be_bytes(x.try_into().unwrap()).classify() } } diff --git a/secrets/src/lib.rs b/secrets/src/lib.rs index b0c7191f6..2ae001f7c 100644 --- a/secrets/src/lib.rs +++ b/secrets/src/lib.rs @@ -8,6 +8,7 @@ pub use traits::*; pub mod array; pub mod util; pub mod zeroize; +pub mod ct; mod sequences; pub use sequences::*;