Skip to content

Commit

Permalink
Add X-Wing combiner (#216)
Browse files Browse the repository at this point in the history
Add X-Wing combiner to kem module and hpke

---------

Co-authored-by: Franziskus Kiefer <[email protected]>
  • Loading branch information
raphaelrobert and franziskuskiefer authored Apr 4, 2024
1 parent c33e5df commit d7ec04e
Show file tree
Hide file tree
Showing 5 changed files with 574 additions and 42 deletions.
59 changes: 58 additions & 1 deletion src/hpke/hpke.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#![allow(non_camel_case_types, non_snake_case, unused_imports)]

use crate::kem::{X25519MlKem768Draft00PrivateKey, X25519MlKem768Draft00PublicKey};
use crate::ecdh::{self, x25519};
use crate::kem;
use crate::kem::{
kyber::kyber768, Ct, PublicKey, Ss, X25519MlKem768Draft00PrivateKey,
X25519MlKem768Draft00PublicKey, XWingKemDraft02PrivateKey, XWingKemDraft02PublicKey,
};

use super::aead::*;
use super::kdf::*;
Expand Down Expand Up @@ -425,6 +430,48 @@ pub fn SetupBaseS(
);
(ss.encode(), ct.encode())
}
KEM::XWingDraft02 => {
// TODO: This should re-use PublicKey::encapsulate but we need
// CryptoRng + Rng for that, not just a slice of randomness
let XWingKemDraft02PublicKey { pk_m, pk_x } =
XWingKemDraft02PublicKey::decode(pkR).map_err(|_| HpkeError::EncapError)?;

let (ct_m, ss_m) = kyber768::encapsulate(
&pk_m,
randomness[0..32]
.try_into()
.map_err(|_| HpkeError::EncapError)?,
);
let ek_x = x25519::PrivateKey(
randomness[..32]
.try_into()
.map_err(|_| HpkeError::EncapError)?,
);
let ct_x = x25519::secret_to_public(&ek_x).map_err(|_| HpkeError::EncapError)?;
let ss_x = x25519::derive(&pk_x, &ek_x).map_err(|_| HpkeError::EncapError)?;

let ct = Ct::XWingKemDraft02(
ct_m.as_slice()
.try_into()
.map_err(|_| HpkeError::EncapError)?,
ct_x.0
.as_slice()
.try_into()
.map_err(|_| HpkeError::EncapError)?,
);
let ss = Ss::XWingKemDraft02(
ss_m.as_slice()
.try_into()
.map_err(|_| HpkeError::EncapError)?,
ss_x.0
.as_slice()
.try_into()
.map_err(|_| HpkeError::EncapError)?,
ct_x,
pk_x,
);
(ss.encode(), ct.encode())
}
};

let key_schedule = KeySchedule(
Expand Down Expand Up @@ -474,6 +521,16 @@ pub fn SetupBaseR(
);
ss.encode()
}
KEM::XWingDraft02 => {
let ct = kem::Ct::decode(crate::kem::Algorithm::XWingKemDraft02, enc)
.map_err(|_| HpkeError::DecapError)?;
let sk = crate::kem::XWingKemDraft02PrivateKey::decode(skR)
.map_err(|_| HpkeError::DecapError)?;
let sk = &kem::PrivateKey::XWingKemDraft02(sk);
let ss = ct.decapsulate(sk).map_err(|_| HpkeError::DecapError)?;

ss.encode()
}
};
let key_schedule = KeySchedule(
config,
Expand Down
138 changes: 99 additions & 39 deletions src/hpke/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#![doc = include_str!("KEM_Security.md")]
#![allow(non_camel_case_types, non_snake_case)]

use crate::kem::{
kyber::{kyber768, MlKemKeyPair},
*,
use crate::{
ecdh::x25519,
kem::{
kyber::{kyber768, MlKemKeyPair},
*,
},
};

use super::errors::*;
Expand Down Expand Up @@ -62,6 +65,8 @@ pub enum KEM {
DHKEM_X448_HKDF_SHA512,
/// 0x0030
X25519Kyber768Draft00,
/// 0x004D
XWingDraft02,
}

/// [`u16`] value of the `kem_id`.
Expand All @@ -75,6 +80,7 @@ pub fn kem_value(kem_id: KEM) -> u16 {
KEM::DHKEM_X25519_HKDF_SHA256 => 0x00020,
KEM::DHKEM_X448_HKDF_SHA512 => 0x0021,
KEM::X25519Kyber768Draft00 => 0x0030,
KEM::XWingDraft02 => 0x004D,
}
}

Expand All @@ -89,6 +95,7 @@ fn kdf_for_kem(kem_id: KEM) -> KDF {
KEM::DHKEM_X25519_HKDF_SHA256 => KDF::HKDF_SHA256,
KEM::DHKEM_X448_HKDF_SHA512 => KDF::HKDF_SHA512,
KEM::X25519Kyber768Draft00 => KDF::HKDF_SHA256,
KEM::XWingDraft02 => KDF::HKDF_SHA256,
}
}

Expand All @@ -101,6 +108,7 @@ fn kem_to_named_group(alg: KEM) -> Algorithm {
KEM::DHKEM_X25519_HKDF_SHA256 => Algorithm::X25519,
KEM::DHKEM_X448_HKDF_SHA512 => Algorithm::X448,
KEM::X25519Kyber768Draft00 => Algorithm::X25519, // This is only used for DH operations
KEM::XWingDraft02 => Algorithm::X25519, // This is only used for DH operations
}
}

Expand All @@ -115,6 +123,7 @@ pub fn Nsecret(kem_id: KEM) -> usize {
KEM::DHKEM_X25519_HKDF_SHA256 => 32,
KEM::DHKEM_X448_HKDF_SHA512 => 64,
KEM::X25519Kyber768Draft00 => 64,
KEM::XWingDraft02 => 32,
}
}

Expand All @@ -129,6 +138,7 @@ pub fn Nenc(kem_id: KEM) -> usize {
KEM::DHKEM_X25519_HKDF_SHA256 => 32,
KEM::DHKEM_X448_HKDF_SHA512 => 56,
KEM::X25519Kyber768Draft00 => 1120,
KEM::XWingDraft02 => 1120,
}
}

Expand All @@ -143,6 +153,7 @@ pub fn Nsk(kem_id: KEM) -> usize {
KEM::DHKEM_X25519_HKDF_SHA256 => 32,
KEM::DHKEM_X448_HKDF_SHA512 => 56,
KEM::X25519Kyber768Draft00 => 2432,
KEM::XWingDraft02 => 2464,
}
}

Expand All @@ -157,6 +168,7 @@ pub fn Npk(kem_id: KEM) -> usize {
KEM::DHKEM_X25519_HKDF_SHA256 => 32,
KEM::DHKEM_X448_HKDF_SHA512 => 56,
KEM::X25519Kyber768Draft00 => 1216,
KEM::XWingDraft02 => 1216,
}
}

Expand All @@ -178,6 +190,7 @@ pub fn Ndh(kem_id: KEM) -> usize {
KEM::DHKEM_X25519_HKDF_SHA256 => 32,
KEM::DHKEM_X448_HKDF_SHA512 => 56,
KEM::X25519Kyber768Draft00 => 32,
KEM::XWingDraft02 => 32,
}
}

Expand Down Expand Up @@ -254,6 +267,10 @@ fn shared_secret_from_dh(alg: KEM, mut secret: Vec<u8>) -> Result<SharedSecret,
// This is only the x25519 part.
Ok(secret)
}
KEM::XWingDraft02 => {
// This is only the x25519 part.
Ok(secret)
}
}
}

Expand Down Expand Up @@ -305,6 +322,7 @@ pub fn SerializePublicKey(alg: KEM, pk: PublicKey) -> PublicKey {
KEM::DHKEM_X25519_HKDF_SHA256 => pk,
KEM::DHKEM_X448_HKDF_SHA512 => pk,
KEM::X25519Kyber768Draft00 => pk, // This must have been encoded before
KEM::XWingDraft02 => pk, // This must have been encoded before
}
}

Expand All @@ -328,6 +346,7 @@ pub fn DeserializePublicKey(alg: KEM, enc: &[u8]) -> HpkeBytesResult {
KEM::DHKEM_X25519_HKDF_SHA256 => enc.to_vec(),
KEM::DHKEM_X448_HKDF_SHA512 => enc.to_vec(),
KEM::X25519Kyber768Draft00 => enc.to_vec(), // Deserialization must be done later
KEM::XWingDraft02 => enc.to_vec(), // Deserialization must be done later
})
}

Expand Down Expand Up @@ -441,43 +460,83 @@ pub fn DeriveKeyPairX(alg: KEM, ikm: &InputKeyMaterial) -> Result<KeyPair, HpkeE
///
/// [NISTCurves]: https://doi.org/10.6028/nist.fips.186-4
pub fn DeriveKeyPair(alg: KEM, ikm: &InputKeyMaterial) -> Result<KeyPair, HpkeError> {
let kdf = kdf_for_kem(alg);
let dkp_prk = LabeledExtract(kdf, suite_id(alg), &empty(), dkp_prk_label(), ikm)?;

let named_group = kem_to_named_group(alg);
let sk = if alg == KEM::DHKEM_X25519_HKDF_SHA256 || alg == KEM::DHKEM_X448_HKDF_SHA512 {
LabeledExpand(kdf, suite_id(alg), &dkp_prk, sk_label(), &empty(), 32)?
} else {
let mut bitmask = 0xFFu8;
if alg == KEM::DHKEM_P521_HKDF_SHA512 {
bitmask = 0x01u8;
}
let mut sk = Vec::new();
for counter in 0..256 {
if sk.len() == 0 {
// Only keep looking if we didn't find one.
let mut bytes = LabeledExpand(
kdf,
suite_id(alg),
&dkp_prk,
candidate_label(),
&I2OSP(counter),
32,
)?;
bytes[0] = bytes[0] & bitmask;
// This check ensure sk != 0 or sk < order
if crate::ecdh::validate_scalar(named_group.try_into().unwrap(), &bytes).is_ok() {
sk = bytes;
match alg {
KEM::DHKEM_P256_HKDF_SHA256
| KEM::DHKEM_P384_HKDF_SHA384
| KEM::DHKEM_P521_HKDF_SHA512
| KEM::DHKEM_X25519_HKDF_SHA256
| KEM::DHKEM_X448_HKDF_SHA512 => {
let kdf = kdf_for_kem(alg);
let dkp_prk = LabeledExtract(kdf, suite_id(alg), &empty(), dkp_prk_label(), ikm)?;

let named_group = kem_to_named_group(alg);
let sk = if alg == KEM::DHKEM_X25519_HKDF_SHA256 || alg == KEM::DHKEM_X448_HKDF_SHA512 {
LabeledExpand(kdf, suite_id(alg), &dkp_prk, sk_label(), &empty(), 32)?
} else {
let mut bitmask = 0xFFu8;
if alg == KEM::DHKEM_P521_HKDF_SHA512 {
bitmask = 0x01u8;
}
let mut sk = Vec::new();
for counter in 0..256 {
if sk.len() == 0 {
// Only keep looking if we didn't find one.
let mut bytes = LabeledExpand(
kdf,
suite_id(alg),
&dkp_prk,
candidate_label(),
&I2OSP(counter),
32,
)?;
bytes[0] = bytes[0] & bitmask;
// This check ensure sk != 0 or sk < order
if crate::ecdh::validate_scalar(named_group.try_into().unwrap(), &bytes)
.is_ok()
{
sk = bytes;
}
}
}
sk
};
if sk.len() == 0 {
Result::<KeyPair, HpkeError>::Err(HpkeError::DeriveKeyPairError)
} else {
let pk = pk(alg, &sk)?;
Ok((sk, pk))
}
}
sk
};
if sk.len() == 0 {
Result::<KeyPair, HpkeError>::Err(HpkeError::DeriveKeyPairError)
} else {
let pk = pk(alg, &sk)?;
Ok((sk, pk))
KEM::X25519Kyber768Draft00 => Err(HpkeError::UnsupportedAlgorithm),
KEM::XWingDraft02 => {
// Use SHAKE128 to expand the ikm
let seed: [u8; 96] = crate::hacl::sha3::shake128(ikm);
// Use the first 64 bytes to generate the ML-KEM key pair
let MlKemKeyPair { sk, pk } = kyber768::generate_key_pair(
seed[..64]
.try_into()
.map_err(|_| HpkeError::DeriveKeyPairError)?,
);
// Use the next 32 bytes to generate the X25519 key pair
let (xsk, xpk) = DeriveKeyPair(KEM::DHKEM_X25519_HKDF_SHA256, &seed[..96])?;

let private = XWingKemDraft02PrivateKey {
sk_m: sk,
sk_x: x25519::PrivateKey(
xsk.try_into().map_err(|_| HpkeError::DeriveKeyPairError)?,
),
pk_x: x25519::PublicKey(
xpk.clone()
.try_into()
.map_err(|_| HpkeError::DeriveKeyPairError)?,
),
};
let public = XWingKemDraft02PublicKey {
pk_m: pk,
pk_x: x25519::PublicKey(xpk.try_into().map_err(|_| HpkeError::DeriveKeyPairError)?),
};
Ok((private.encode(), public.encode()))
}
}
}

Expand Down Expand Up @@ -514,7 +573,7 @@ pub fn GenerateKeyPair(alg: KEM, randomness: Randomness) -> Result<KeyPair, Hpke
&empty(),
32 + 64,
)?;
let (xsk, xpk) = DeriveKeyPair(alg, &seed[..32])?;
let (xsk, xpk) = DeriveKeyPair(KEM::DHKEM_X25519_HKDF_SHA256, &seed[..32])?;
let MlKemKeyPair { sk, pk } =
kyber768::generate_key_pair(seed[32..].try_into().unwrap());

Expand All @@ -524,10 +583,11 @@ pub fn GenerateKeyPair(alg: KEM, randomness: Randomness) -> Result<KeyPair, Hpke
};
let public = X25519MlKem768Draft00PublicKey {
mlkem: pk,
x25519: crate::ecdh::x25519::PublicKey(xpk.try_into().unwrap()),
x25519: x25519::PublicKey(xpk.try_into().unwrap()),
};
Ok((private.encode(), public.encode()))
}
KEM::XWingDraft02 => DeriveKeyPair(alg, &randomness),
}
}
}
Expand Down
Loading

0 comments on commit d7ec04e

Please sign in to comment.