Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Commit Creation and Handling #431

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
11 changes: 6 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
BUILD_DIR=build
TEST_DIR=build/test
CLANG_FORMAT=clang-format -i
CLANG_FORMAT_EXCLUDE="test_vectors.cpp"
CLANG_TIDY=OFF
OPENSSL11_MANIFEST=alternatives/openssl_1.1
OPENSSL3_MANIFEST=alternatives/openssl_3
Expand Down Expand Up @@ -98,8 +99,8 @@ cclean:
rm -rf ${BUILD_DIR}

format:
find include -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT}
find src -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT}
find test -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT}
find cmd -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT}
find lib -iname "*.h" -or -iname "*.cpp" | grep -v "test_vectors.cpp" | xargs ${CLANG_FORMAT}
for dir in include src test lib; \
do \
find $${dir} -iname "*.h" -or -iname "*.cpp" | grep -v ${CLANG_FORMAT_EXCLUDE} \
| xargs ${CLANG_FORMAT}; \
done
7 changes: 7 additions & 0 deletions include/mls/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,13 @@ contains(const Container& c, const Value& val)
return std::find(c.begin(), c.end(), val) != c.end();
}

template<typename Container, typename Value>
auto
find(const Container& c, const Value& val)
{
return std::find(c.begin(), c.end(), val);
}

template<typename Container, typename UnaryPredicate>
auto
find_if(Container& c, const UnaryPredicate& pred)
Expand Down
13 changes: 9 additions & 4 deletions include/mls/key_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ struct KeyScheduleEpoch
bytes epoch_authenticator;
bytes external_secret;
bytes confirmation_key;
bytes confirmation_tag;
bytes membership_key;
bytes resumption_psk;
bytes init_secret;
Expand All @@ -118,6 +119,7 @@ struct KeyScheduleEpoch
static KeyScheduleEpoch joiner(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& confirmed_transcript_hash,
const bytes& context);

// Ciphersuite-only initializer, used by external joiner
Expand All @@ -136,10 +138,10 @@ struct KeyScheduleEpoch
KeyScheduleEpoch next(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret,
const bytes& confirmed_transcript_hash,
const bytes& context) const;

GroupKeySource encryption_keys(LeafCount size) const;
bytes confirmation_tag(const bytes& confirmed_transcript_hash) const;
bytes do_export(const std::string& label,
const bytes& context,
size_t size) const;
Expand All @@ -161,10 +163,12 @@ struct KeyScheduleEpoch
const bytes& init_secret,
const bytes& commit_secret,
const bytes& psk_secret,
const bytes& confirmed_transcript_hash,
const bytes& context);
KeyScheduleEpoch next_raw(const bytes& commit_secret,
const bytes& psk_secret,
const std::optional<bytes>& force_init_secret,
const bytes& confirmed_transcript_hash,
const bytes& context) const;
static bytes welcome_secret_raw(CipherSuite suite,
const bytes& joiner_secret,
Expand All @@ -174,6 +178,7 @@ struct KeyScheduleEpoch
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& confirmed_transcript_hash,
const bytes& context);
};

Expand All @@ -194,10 +199,10 @@ struct TranscriptHash
bytes confirmed_in,
const bytes& confirmation_tag);

void update(const AuthenticatedContent& content_auth);
void update_confirmed(const AuthenticatedContent& content_auth);
// Updating hashes
bytes new_confirmed(const bytes& transcript_hash_input) const;
void set_confirmed(bytes confirmed_transcript_hash);
void update_interim(const bytes& confirmation_tag);
void update_interim(const AuthenticatedContent& content_auth);
};

bool
Expand Down
71 changes: 47 additions & 24 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,37 @@ class State
const bytes& leaf_secret,
const std::optional<CommitOpts>& opts,
const MessageOpts& msg_opts,
CommitParams params);
const CommitParams& params);

struct CommitMaterials;
CommitMaterials prepare_commit(const bytes& leaf_secret,
const std::optional<CommitOpts>& opts,
const CommitParams& params) const;
Welcome welcome(bool inline_tree,
const std::vector<PSKWithSecret>& psks,
const std::vector<KeyPackage>& joiners,
const std::vector<std::optional<bytes>>& path_secrets) const;

std::optional<State> handle(
const MLSMessage& msg,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);
std::optional<State> handle(
const ValidatedContent& val_content,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);

void handle_proposal(const AuthenticatedContent& content_auth);
State handle_commit(const AuthenticatedContent& content_auth,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params) const;

State ratchet(TreeKEMPublicKey new_tree,
LeafIndex committer,
const std::optional<NodeIndex>& path_secret_decrypt_node,
const std::optional<HPKECiphertext>& encrypted_path_secret,
ExtensionList extensions,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret,
const bytes& confirmed_transcript_hash,
const bytes& confirmation_tag) const;

// Create an MLSMessage encapsulating some content
template<typename Inner>
AuthenticatedContent sign(const Sender& sender,
Expand All @@ -345,24 +365,25 @@ class State
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);

// Apply the changes requested by various messages
LeafIndex apply(const Add& add);
void apply(LeafIndex target, const Update& update);
void apply(LeafIndex target,
const Update& update,
const HPKEPrivateKey& leaf_priv);
LeafIndex apply(const Remove& remove);
void apply(const GroupContextExtensions& gce);
std::vector<LeafIndex> apply(const std::vector<CachedProposal>& proposals,
Proposal::Type required_type);
std::tuple<std::vector<LeafIndex>, std::vector<PSKWithSecret>> apply(
const std::vector<CachedProposal>& proposals);
static LeafIndex apply(TreeKEMPublicKey& tree, const Add& add);
static void apply(TreeKEMPublicKey& tree,
LeafIndex target,
const Update& update);
static LeafIndex apply(TreeKEMPublicKey& tree, const Remove& remove);
std::vector<LeafIndex> apply(TreeKEMPublicKey& tree,
const std::vector<CachedProposal>& proposals,
Proposal::Type required_type) const;
std::tuple<TreeKEMPublicKey,
std::vector<LeafIndex>,
std::vector<PSKWithSecret>,
ExtensionList>
apply(const std::vector<CachedProposal>& proposals) const;

// Verify that a specific key package or all members support a given set of
// extensions
bool extensions_supported(const ExtensionList& exts) const;

// Extract proposals and PSKs from cache
void cache_proposal(AuthenticatedContent content_auth);
std::optional<CachedProposal> resolve(
const ProposalOrRef& id,
std::optional<LeafIndex> sender_index) const;
Expand Down Expand Up @@ -409,11 +430,6 @@ class State
friend bool operator==(const State& lhs, const State& rhs);
friend bool operator!=(const State& lhs, const State& rhs);

// Derive and set the secrets for an epoch, given some new entropy
void update_epoch_secrets(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret);

// Signature verification over a handshake message
bool verify_internal(const AuthenticatedContent& content_auth) const;
bool verify_external(const AuthenticatedContent& content_auth) const;
Expand All @@ -425,8 +441,15 @@ class State
// Convert a Roster entry into LeafIndex
LeafIndex leaf_for_roster_entry(RosterIndex index) const;

// Create a draft successor state
State successor() const;
// Create a successor state
State successor(LeafIndex index,
TreeKEMPublicKey tree,
TreeKEMPrivateKey tree_priv,
ExtensionList extensions,
const bytes& confirmed_transcript_hash,
bool has_path,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret) const;
};

} // namespace MLS_NAMESPACE
17 changes: 17 additions & 0 deletions include/mls/treekem.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ struct TreeKEMPrivateKey
const UpdatePath& path,
const std::vector<LeafIndex>& except);

void decap(LeafIndex from,
const TreeKEMPublicKey& pub,
const bytes& context,
const NodeIndex& decrypt_node,
const HPKECiphertext& encrypted_path_secret);

void truncate(LeafCount size);

bool consistent(const TreeKEMPrivateKey& other) const;
Expand Down Expand Up @@ -150,6 +156,17 @@ struct TreeKEMPublicKey
std::optional<LeafNode> leaf_node(LeafIndex index) const;
std::vector<NodeIndex> resolve(NodeIndex index) const;

struct DecapCoords
{
size_t ancestor_node_index;
size_t resolution_node_index;
NodeIndex resolution_node;
};
DecapCoords decap_coords(
LeafIndex to,
LeafIndex from,
const std::vector<LeafIndex>& joiner_locations) const;

template<typename UnaryPredicate>
bool all_leaves(const UnaryPredicate& pred) const
{
Expand Down
68 changes: 46 additions & 22 deletions lib/mls_vectors/src/mls_vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,11 @@ KeyScheduleTestVector::KeyScheduleTestVector(CipherSuite suite,
// TODO(RLB) Add Test case for externally-driven epoch change
auto commit_secret = epoch_prg.secret("commit_secret");
auto psk_secret = epoch_prg.secret("psk_secret");
epoch = epoch.next_raw(commit_secret, psk_secret, std::nullopt, ctx);
epoch = epoch.next_raw(commit_secret,
psk_secret,
std::nullopt,
group_context.confirmed_transcript_hash,
ctx);

auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw(
cipher_suite, epoch.joiner_secret, psk_secret);
Expand Down Expand Up @@ -645,8 +649,11 @@ KeyScheduleTestVector::verify() const
auto ctx = tls::marshal(group_context);
VERIFY_EQUAL("group context", ctx, tve.group_context);

epoch =
epoch.next_raw(tve.commit_secret, tve.psk_secret, std::nullopt, ctx);
epoch = epoch.next_raw(tve.commit_secret,
tve.psk_secret,
std::nullopt,
group_context.confirmed_transcript_hash,
ctx);

// Verify the rest of the epoch
VERIFY_EQUAL("joiner secret", epoch.joiner_secret, tve.joiner_secret);
Expand Down Expand Up @@ -959,17 +966,17 @@ TranscriptTestVector::TranscriptTestVector(CipherSuite suite)

auto group_id = prg.secret("group_id");
auto epoch = prg.uint64("epoch");
auto group_context_obj =
auto group_context =
GroupContext{ suite,
group_id,
epoch,
prg.secret("tree_hash_before"),
prg.secret("confirmed_transcript_hash_before"),
{} };
auto group_context = tls::marshal(group_context_obj);

auto init_secret = prg.secret("init_secret");
auto ks_epoch = KeyScheduleEpoch(suite, init_secret, group_context);
auto key_schedule_before =
KeyScheduleEpoch(suite, init_secret, tls::marshal(group_context));

auto sig_priv = prg.signature_key("sig_priv");
auto leaf_index = LeafIndex{ 0 };
Expand All @@ -980,17 +987,23 @@ TranscriptTestVector::TranscriptTestVector(CipherSuite suite)
group_id, epoch, { MemberSender{ leaf_index } }, {}, Commit{} },
suite,
sig_priv,
group_context_obj);
group_context);

transcript.update_confirmed(authenticated_content);
const auto new_confirmed = transcript.new_confirmed(authenticated_content.confirmed_transcript_hash_input());
transcript.set_confirmed(new_confirmed);

const auto confirmation_tag = ks_epoch.confirmation_tag(transcript.confirmed);
authenticated_content.set_confirmation_tag(confirmation_tag);
group_context.confirmed_transcript_hash = transcript.confirmed;
auto key_schedule_after =
key_schedule_before.next(suite.zero(),
{},
std::nullopt,
transcript.confirmed,
tls::marshal(group_context));

transcript.update_interim(authenticated_content);
transcript.update_interim(key_schedule_after.confirmation_tag);

// Store the required data
confirmation_key = ks_epoch.confirmation_key;
confirmation_key = key_schedule_after.confirmation_key;
confirmed_transcript_hash_after = transcript.confirmed;
interim_transcript_hash_after = transcript.interim;
}
Expand All @@ -1001,7 +1014,12 @@ TranscriptTestVector::verify() const
auto transcript = TranscriptHash(cipher_suite);
transcript.interim = interim_transcript_hash_before;

transcript.update(authenticated_content);
const auto new_confirmed = transcript.new_confirmed(authenticated_content.confirmed_transcript_hash_input());
transcript.set_confirmed(new_confirmed);

const auto input_confirmation_tag = opt::get(authenticated_content.auth.confirmation_tag);
transcript.update_interim(input_confirmation_tag);

VERIFY_EQUAL(
"confirmed", transcript.confirmed, confirmed_transcript_hash_after);
VERIFY_EQUAL("interim", transcript.interim, interim_transcript_hash_after);
Expand Down Expand Up @@ -1055,15 +1073,16 @@ WelcomeTestVector::WelcomeTestVector(CipherSuite suite)
cipher_suite, group_id, epoch, tree_hash, confirmed_transcript_hash, {}
};

auto key_schedule = KeyScheduleEpoch::joiner(
cipher_suite, joiner_secret, {}, tls::marshal(group_context));
auto confirmation_tag =
key_schedule.confirmation_tag(confirmed_transcript_hash);
auto key_schedule = KeyScheduleEpoch::joiner(cipher_suite,
joiner_secret,
{},
confirmed_transcript_hash,
tls::marshal(group_context));

auto group_info = GroupInfo{
group_context,
{},
confirmation_tag,
key_schedule.confirmation_tag,
};
group_info.sign(signer_index, signer_priv);

Expand Down Expand Up @@ -1098,10 +1117,15 @@ WelcomeTestVector::verify() const

// Verify confirmation tag
const auto& group_context = group_info.group_context;
auto key_schedule = KeyScheduleEpoch::joiner(
cipher_suite, group_secrets.joiner_secret, {}, tls::marshal(group_context));
auto confirmation_tag =
key_schedule.confirmation_tag(group_context.confirmed_transcript_hash);
auto key_schedule =
KeyScheduleEpoch::joiner(cipher_suite,
group_secrets.joiner_secret,
{},
group_context.confirmed_transcript_hash,
tls::marshal(group_context));
VERIFY_EQUAL("confirmation tag",
key_schedule.confirmation_tag,
group_info.confirmation_tag);

return std::nullopt;
}
Expand Down
4 changes: 2 additions & 2 deletions lib/mls_vectors/test/mls_vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ TEST_CASE("Welcome")
}
}

TEST_CASE("Tree Hashes")
TEST_CASE("Tree Hashes", "[.][all]")
{
for (auto suite : supported_suites) {
for (auto structure : all_tree_structures) {
Expand All @@ -97,7 +97,7 @@ TEST_CASE("Tree Operations")
}
}

TEST_CASE("TreeKEM")
TEST_CASE("TreeKEM", "[.][all]")
{
for (auto suite : supported_suites) {
for (auto structure : treekem_test_tree_structures) {
Expand Down
Loading
Loading