Skip to content

Commit

Permalink
Break test vectors CPP file into multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
bifurcation committed Jul 13, 2023
1 parent dd919f2 commit 21d554e
Show file tree
Hide file tree
Showing 16 changed files with 1,939 additions and 2,012 deletions.
43 changes: 43 additions & 0 deletions lib/mls_vectors/src/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "common.h"

namespace mls_vectors {

using namespace mls;

///
/// Assertions for verifying test vectors
///

std::ostream&
operator<<(std::ostream& str, const NodeIndex& obj)
{
return str << obj.val;
}

std::ostream&
operator<<(std::ostream& str, const NodeCount& obj)
{
return str << obj.val;
}

std::ostream&
operator<<(std::ostream& str, const std::vector<uint8_t>& obj)
{
return str << to_hex(obj);
}

std::ostream&
operator<<(std::ostream& str, const GroupContent::RawContent& obj)
{
return var::visit(
overloaded{
[&](const Proposal&) -> std::ostream& { return str << "[Proposal]"; },
[&](const Commit&) -> std::ostream& { return str << "[Commit]"; },
[&](const ApplicationData&) -> std::ostream& {
return str << "[ApplicationData]";
},
},
obj);
}

} // namespace mls_vectors
177 changes: 177 additions & 0 deletions lib/mls_vectors/src/crypto_basics.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include "common.h"
#include <mls_vectors/mls_vectors.h>

namespace mls_vectors {

using namespace mls;

CryptoBasicsTestVector::RefHash::RefHash(CipherSuite suite,
PseudoRandom::Generator&& prg)
: label("RefHash")
, value(prg.secret("value"))
, out(suite.raw_ref(from_ascii(label), value))
{
}

std::optional<std::string>
CryptoBasicsTestVector::RefHash::verify(CipherSuite suite) const
{
VERIFY_EQUAL("ref hash", out, suite.raw_ref(from_ascii(label), value));
return std::nullopt;
}

CryptoBasicsTestVector::ExpandWithLabel::ExpandWithLabel(
CipherSuite suite,
PseudoRandom::Generator&& prg)
: secret(prg.secret("secret"))
, label("ExpandWithLabel")
, context(prg.secret("context"))
, length(static_cast<uint16_t>(prg.output_length()))
, out(suite.expand_with_label(secret, label, context, length))
{
}

std::optional<std::string>
CryptoBasicsTestVector::ExpandWithLabel::verify(CipherSuite suite) const
{
VERIFY_EQUAL("expand with label",
out,
suite.expand_with_label(secret, label, context, length));
return std::nullopt;
}

CryptoBasicsTestVector::DeriveSecret::DeriveSecret(
CipherSuite suite,
PseudoRandom::Generator&& prg)
: secret(prg.secret("secret"))
, label("DeriveSecret")
, out(suite.derive_secret(secret, label))
{
}

std::optional<std::string>
CryptoBasicsTestVector::DeriveSecret::verify(CipherSuite suite) const
{
VERIFY_EQUAL("derive secret", out, suite.derive_secret(secret, label));
return std::nullopt;
}

CryptoBasicsTestVector::DeriveTreeSecret::DeriveTreeSecret(
CipherSuite suite,
PseudoRandom::Generator&& prg)
: secret(prg.secret("secret"))
, label("DeriveTreeSecret")
, generation(prg.uint32("generation"))
, length(static_cast<uint16_t>(prg.output_length()))
, out(suite.derive_tree_secret(secret, label, generation, length))
{
}

std::optional<std::string>
CryptoBasicsTestVector::DeriveTreeSecret::verify(CipherSuite suite) const
{
VERIFY_EQUAL("derive tree secret",
out,
suite.derive_tree_secret(secret, label, generation, length));
return std::nullopt;
}

CryptoBasicsTestVector::SignWithLabel::SignWithLabel(
CipherSuite suite,
PseudoRandom::Generator&& prg)
: priv(prg.signature_key("priv"))
, pub(priv.public_key)
, content(prg.secret("content"))
, label("SignWithLabel")
, signature(priv.sign(suite, label, content))
{
}

std::optional<std::string>
CryptoBasicsTestVector::SignWithLabel::verify(CipherSuite suite) const
{
VERIFY("verify with label", pub.verify(suite, label, content, signature));

auto new_signature = priv.sign(suite, label, content);
VERIFY("sign with label", pub.verify(suite, label, content, new_signature));

return std::nullopt;
}

CryptoBasicsTestVector::EncryptWithLabel::EncryptWithLabel(
CipherSuite suite,
PseudoRandom::Generator&& prg)
: priv(prg.hpke_key("priv"))
, pub(priv.public_key)
, label("EncryptWithLabel")
, context(prg.secret("context"))
, plaintext(prg.secret("plaintext"))
{
auto ct = pub.encrypt(suite, label, context, plaintext);
kem_output = ct.kem_output;
ciphertext = ct.ciphertext;
}

std::optional<std::string>
CryptoBasicsTestVector::EncryptWithLabel::verify(CipherSuite suite) const
{
auto ct = HPKECiphertext{ kem_output, ciphertext };
auto pt = priv.decrypt(suite, label, context, ct);
VERIFY_EQUAL("decrypt with label", pt, plaintext);

auto new_ct = pub.encrypt(suite, label, context, plaintext);
auto new_pt = priv.decrypt(suite, label, context, new_ct);
VERIFY_EQUAL("encrypt with label", new_pt, plaintext);

return std::nullopt;
}

CryptoBasicsTestVector::CryptoBasicsTestVector(CipherSuite suite)
: PseudoRandom(suite, "crypto-basics")
, cipher_suite(suite)
, ref_hash(suite, prg.sub("ref_hash"))
, expand_with_label(suite, prg.sub("expand_with_label"))
, derive_secret(suite, prg.sub("derive_secret"))
, derive_tree_secret(suite, prg.sub("derive_tree_secret"))
, sign_with_label(suite, prg.sub("sign_with_label"))
, encrypt_with_label(suite, prg.sub("encrypt_with_label"))
{
}

std::optional<std::string>
CryptoBasicsTestVector::verify() const
{
auto result = ref_hash.verify(cipher_suite);
if (result) {
return result;
}

result = expand_with_label.verify(cipher_suite);
if (result) {
return result;
}

result = derive_secret.verify(cipher_suite);
if (result) {
return result;
}

result = derive_tree_secret.verify(cipher_suite);
if (result) {
return result;
}

result = sign_with_label.verify(cipher_suite);
if (result) {
return result;
}

result = encrypt_with_label.verify(cipher_suite);
if (result) {
return result;
}

return std::nullopt;
}

} // namespace mls_vectors
126 changes: 126 additions & 0 deletions lib/mls_vectors/src/key_schedule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "common.h"
#include <mls_vectors/mls_vectors.h>

namespace mls_vectors {

using namespace mls;

KeyScheduleTestVector::KeyScheduleTestVector(CipherSuite suite,
uint32_t n_epochs)
: PseudoRandom(suite, "key-schedule")
, cipher_suite(suite)
, group_id(prg.secret("group_id"))
, initial_init_secret(prg.secret("group_id"))
{
auto group_context = GroupContext{ suite, group_id, 0, {}, {}, {} };
auto epoch = KeyScheduleEpoch(cipher_suite);
epoch.init_secret = initial_init_secret;

for (uint64_t i = 0; i < n_epochs; i++) {
auto epoch_prg = prg.sub(to_hex(tls::marshal(i)));

group_context.tree_hash = epoch_prg.secret("tree_hash");
group_context.confirmed_transcript_hash =
epoch_prg.secret("confirmed_transcript_hash");
auto ctx = tls::marshal(group_context);

// TODO(RLB) Add Test case for externally-driven epoch change
auto commit_secret = epoch_prg.secret("commit_secret");
auto psk_secret = epoch_prg.secret("psk_secret");
epoch = epoch.next_raw(commit_secret, psk_secret, std::nullopt, ctx);

auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw(
cipher_suite, epoch.joiner_secret, psk_secret);

auto exporter_prg = epoch_prg.sub("exporter");
auto exporter_label = to_hex(exporter_prg.secret("label"));
auto exporter_context = exporter_prg.secret("context");
auto exporter_length = cipher_suite.secret_size();
auto exported =
epoch.do_export(exporter_label, exporter_context, exporter_length);

epochs.push_back({ group_context.tree_hash,
commit_secret,
psk_secret,
group_context.confirmed_transcript_hash,

ctx,

epoch.joiner_secret,
welcome_secret,
epoch.init_secret,

epoch.sender_data_secret,
epoch.encryption_secret,
epoch.exporter_secret,
epoch.epoch_authenticator,
epoch.external_secret,
epoch.confirmation_key,
epoch.membership_key,
epoch.resumption_psk,

epoch.external_priv.public_key,

{
exporter_label,
exporter_context,
exporter_length,
exported,
} });

group_context.epoch += 1;
}
}

std::optional<std::string>
KeyScheduleTestVector::verify() const
{
auto group_context = GroupContext{ cipher_suite, group_id, 0, {}, {}, {} };
auto epoch = KeyScheduleEpoch(cipher_suite);
epoch.init_secret = initial_init_secret;

for (const auto& tve : epochs) {
group_context.tree_hash = tve.tree_hash;
group_context.confirmed_transcript_hash = tve.confirmed_transcript_hash;
auto ctx = tls::marshal(group_context);
VERIFY_EQUAL("group context", ctx, tve.group_context);

epoch =
epoch.next_raw(tve.commit_secret, tve.psk_secret, std::nullopt, ctx);

// Verify the rest of the epoch
VERIFY_EQUAL("joiner secret", epoch.joiner_secret, tve.joiner_secret);

auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw(
cipher_suite, tve.joiner_secret, tve.psk_secret);
VERIFY_EQUAL("welcome secret", welcome_secret, tve.welcome_secret);

VERIFY_EQUAL(
"sender data secret", epoch.sender_data_secret, tve.sender_data_secret);
VERIFY_EQUAL(
"encryption secret", epoch.encryption_secret, tve.encryption_secret);
VERIFY_EQUAL("exporter secret", epoch.exporter_secret, tve.exporter_secret);
VERIFY_EQUAL("epoch authenticator",
epoch.epoch_authenticator,
tve.epoch_authenticator);
VERIFY_EQUAL("external secret", epoch.external_secret, tve.external_secret);
VERIFY_EQUAL(
"confirmation key", epoch.confirmation_key, tve.confirmation_key);
VERIFY_EQUAL("membership key", epoch.membership_key, tve.membership_key);
VERIFY_EQUAL("resumption psk", epoch.resumption_psk, tve.resumption_psk);
VERIFY_EQUAL("init secret", epoch.init_secret, tve.init_secret);

VERIFY_EQUAL(
"external pub", epoch.external_priv.public_key, tve.external_pub);

auto exported = epoch.do_export(
tve.exporter.label, tve.exporter.context, tve.exporter.length);
VERIFY_EQUAL("exported", exported, tve.exporter.secret);

group_context.epoch += 1;
}

return std::nullopt;
}

} // namespace mls_vectors
Loading

0 comments on commit 21d554e

Please sign in to comment.