diff --git a/.clang-tidy b/.clang-tidy index dcb223f1..b274ccde 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -3,8 +3,10 @@ Checks: '*, -altera-*, -fuchsia-*, -abseil-string-find-startswith, + -boost-use-ranges, -bugprone-easily-swappable-parameters, -bugprone-exception-escape, + -bugprone-chained-comparison, -cert-err58-cpp, -cppcoreguidelines-avoid-const-or-ref-data-members, -cppcoreguidelines-avoid-magic-numbers, @@ -32,6 +34,9 @@ Checks: '*, -readability-function-cognitive-complexity, -readability-identifier-length, -readability-magic-numbers, + -readability-math-missing-parentheses, + -readability-redundant-casting, ' WarningsAsErrors: '*' HeaderFilterRegex: '*' +ExcludeHeaderFilterRegex: 'catch_test_macros.hpp' diff --git a/Makefile b/Makefile index 20f8e27b..08fe4822 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ BUILD_DIR=build TEST_DIR=build/test CLANG_FORMAT=clang-format -i +CLANG_FORMAT_EXCLUDE="test_vectors.cpp" CLANG_TIDY=OFF OPENSSL11_MANIFEST=alternatives/openssl_1.1 OPENSSL3_MANIFEST=alternatives/openssl_3 @@ -98,8 +99,8 @@ cclean: rm -rf ${BUILD_DIR} format: - find include -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT} - find src -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT} - find test -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT} - find cmd -iname "*.h" -or -iname "*.cpp" | xargs ${CLANG_FORMAT} - find lib -iname "*.h" -or -iname "*.cpp" | grep -v "test_vectors.cpp" | xargs ${CLANG_FORMAT} + for dir in include src test lib; \ + do \ + find $${dir} -iname "*.h" -or -iname "*.cpp" | grep -v ${CLANG_FORMAT_EXCLUDE} \ + | xargs ${CLANG_FORMAT}; \ + done diff --git a/include/mls/common.h b/include/mls/common.h index a94f0f49..6dff1b94 100644 --- a/include/mls/common.h +++ b/include/mls/common.h @@ -249,6 +249,13 @@ contains(const Container& c, const Value& val) return std::find(c.begin(), c.end(), val) != c.end(); } +template +auto +find(const Container& c, const Value& val) +{ + return std::find(c.begin(), c.end(), val); +} + template auto find_if(Container& c, const UnaryPredicate& pred) diff --git a/include/mls/key_schedule.h b/include/mls/key_schedule.h index 6cf5be2e..85cd7aba 100644 --- a/include/mls/key_schedule.h +++ b/include/mls/key_schedule.h @@ -106,6 +106,7 @@ struct KeyScheduleEpoch bytes epoch_authenticator; bytes external_secret; bytes confirmation_key; + bytes confirmation_tag; bytes membership_key; bytes resumption_psk; bytes init_secret; @@ -118,6 +119,7 @@ struct KeyScheduleEpoch static KeyScheduleEpoch joiner(CipherSuite suite_in, const bytes& joiner_secret, const std::vector& psks, + const bytes& confirmed_transcript_hash, const bytes& context); // Ciphersuite-only initializer, used by external joiner @@ -136,10 +138,10 @@ struct KeyScheduleEpoch KeyScheduleEpoch next(const bytes& commit_secret, const std::vector& psks, const std::optional& force_init_secret, + const bytes& confirmed_transcript_hash, const bytes& context) const; GroupKeySource encryption_keys(LeafCount size) const; - bytes confirmation_tag(const bytes& confirmed_transcript_hash) const; bytes do_export(const std::string& label, const bytes& context, size_t size) const; @@ -161,10 +163,12 @@ struct KeyScheduleEpoch const bytes& init_secret, const bytes& commit_secret, const bytes& psk_secret, + const bytes& confirmed_transcript_hash, const bytes& context); KeyScheduleEpoch next_raw(const bytes& commit_secret, const bytes& psk_secret, const std::optional& force_init_secret, + const bytes& confirmed_transcript_hash, const bytes& context) const; static bytes welcome_secret_raw(CipherSuite suite, const bytes& joiner_secret, @@ -174,6 +178,7 @@ struct KeyScheduleEpoch KeyScheduleEpoch(CipherSuite suite_in, const bytes& joiner_secret, const bytes& psk_secret, + const bytes& confirmed_transcript_hash, const bytes& context); }; @@ -194,10 +199,10 @@ struct TranscriptHash bytes confirmed_in, const bytes& confirmation_tag); - void update(const AuthenticatedContent& content_auth); - void update_confirmed(const AuthenticatedContent& content_auth); + // Updating hashes + bytes new_confirmed(const bytes& transcript_hash_input) const; + void set_confirmed(bytes confirmed_transcript_hash); void update_interim(const bytes& confirmation_tag); - void update_interim(const AuthenticatedContent& content_auth); }; bool diff --git a/include/mls/state.h b/include/mls/state.h index 5f736bea..d1d77ea7 100644 --- a/include/mls/state.h +++ b/include/mls/state.h @@ -321,17 +321,37 @@ class State const bytes& leaf_secret, const std::optional& opts, const MessageOpts& msg_opts, - CommitParams params); + const CommitParams& params); + + struct CommitMaterials; + CommitMaterials prepare_commit(const bytes& leaf_secret, + const std::optional& opts, + const CommitParams& params) const; + Welcome welcome(bool inline_tree, + const std::vector& psks, + const std::vector& joiners, + const std::vector>& path_secrets) const; - std::optional handle( - const MLSMessage& msg, - std::optional cached_state, - const std::optional& expected_params); std::optional handle( const ValidatedContent& val_content, std::optional cached_state, const std::optional& expected_params); + void handle_proposal(const AuthenticatedContent& content_auth); + State handle_commit(const AuthenticatedContent& content_auth, + std::optional cached_state, + const std::optional& expected_params) const; + + State ratchet(TreeKEMPublicKey new_tree, + LeafIndex committer, + const std::optional& path_secret_decrypt_node, + const std::optional& encrypted_path_secret, + ExtensionList extensions, + const std::vector& psks, + const std::optional& force_init_secret, + const bytes& confirmed_transcript_hash, + const bytes& confirmation_tag) const; + // Create an MLSMessage encapsulating some content template AuthenticatedContent sign(const Sender& sender, @@ -345,24 +365,26 @@ class State MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts); // Apply the changes requested by various messages - LeafIndex apply(const Add& add); - void apply(LeafIndex target, const Update& update); - void apply(LeafIndex target, - const Update& update, - const HPKEPrivateKey& leaf_priv); - LeafIndex apply(const Remove& remove); - void apply(const GroupContextExtensions& gce); - std::vector apply(const std::vector& proposals, - Proposal::Type required_type); - std::tuple, std::vector> apply( - const std::vector& proposals); + static LeafIndex apply(TreeKEMPublicKey& tree, const Add& add); + static void apply(TreeKEMPublicKey& tree, + LeafIndex target, + const Update& update); + static LeafIndex apply(TreeKEMPublicKey& tree, const Remove& remove); + static std::vector apply( + TreeKEMPublicKey& tree, + const std::vector& proposals, + Proposal::Type required_type); + std::tuple, + std::vector, + ExtensionList> + apply(const std::vector& proposals) const; // Verify that a specific key package or all members support a given set of // extensions bool extensions_supported(const ExtensionList& exts) const; // Extract proposals and PSKs from cache - void cache_proposal(AuthenticatedContent content_auth); std::optional resolve( const ProposalOrRef& id, std::optional sender_index) const; @@ -409,11 +431,6 @@ class State friend bool operator==(const State& lhs, const State& rhs); friend bool operator!=(const State& lhs, const State& rhs); - // Derive and set the secrets for an epoch, given some new entropy - void update_epoch_secrets(const bytes& commit_secret, - const std::vector& psks, - const std::optional& force_init_secret); - // Signature verification over a handshake message bool verify_internal(const AuthenticatedContent& content_auth) const; bool verify_external(const AuthenticatedContent& content_auth) const; @@ -425,8 +442,15 @@ class State // Convert a Roster entry into LeafIndex LeafIndex leaf_for_roster_entry(RosterIndex index) const; - // Create a draft successor state - State successor() const; + // Create a successor state + State successor(LeafIndex index, + TreeKEMPublicKey tree, + TreeKEMPrivateKey tree_priv, + ExtensionList extensions, + const bytes& confirmed_transcript_hash, + bool has_path, + const std::vector& psks, + const std::optional& force_init_secret) const; }; } // namespace MLS_NAMESPACE diff --git a/include/mls/treekem.h b/include/mls/treekem.h index 75a9550f..71e12335 100644 --- a/include/mls/treekem.h +++ b/include/mls/treekem.h @@ -94,6 +94,12 @@ struct TreeKEMPrivateKey const UpdatePath& path, const std::vector& except); + void decap(LeafIndex from, + const TreeKEMPublicKey& pub, + const bytes& context, + const NodeIndex& decrypt_node, + const HPKECiphertext& encrypted_path_secret); + void truncate(LeafCount size); bool consistent(const TreeKEMPrivateKey& other) const; @@ -150,6 +156,17 @@ struct TreeKEMPublicKey std::optional leaf_node(LeafIndex index) const; std::vector resolve(NodeIndex index) const; + struct DecapCoords + { + size_t ancestor_node_index; + size_t resolution_node_index; + NodeIndex resolution_node; + }; + DecapCoords decap_coords( + LeafIndex to, + LeafIndex from, + const std::vector& joiner_locations) const; + template bool all_leaves(const UnaryPredicate& pred) const { diff --git a/lib/hpke/src/base64.cpp b/lib/hpke/src/base64.cpp index 1b28cd02..453a32d1 100644 --- a/lib/hpke/src/base64.cpp +++ b/lib/hpke/src/base64.cpp @@ -57,7 +57,7 @@ to_base64url(const bytes& data) bytes from_base64(const std::string& enc) { - if (enc.length() == 0) { + if (enc.empty()) { return {}; } diff --git a/lib/hpke/src/group.cpp b/lib/hpke/src/group.cpp index 2d60dd33..6c3ff03e 100644 --- a/lib/hpke/src/group.cpp +++ b/lib/hpke/src/group.cpp @@ -728,7 +728,7 @@ struct ECKeyGroup : public EVPGroup } #endif - static inline int group_to_nid(Group::ID group_id) + static int group_to_nid(Group::ID group_id) { switch (group_id) { case Group::ID::P256: @@ -862,7 +862,7 @@ struct RawKeyGroup : public EVPGroup private: const int evp_type; - static inline int group_to_evp(Group::ID group_id) + static int group_to_evp(Group::ID group_id) { switch (group_id) { case Group::ID::X25519: diff --git a/lib/mls_vectors/src/mls_vectors.cpp b/lib/mls_vectors/src/mls_vectors.cpp index 5d31e226..a82849a2 100644 --- a/lib/mls_vectors/src/mls_vectors.cpp +++ b/lib/mls_vectors/src/mls_vectors.cpp @@ -13,8 +13,10 @@ using namespace MLS_NAMESPACE; /// Assertions for verifying test vectors /// +// For some reason, clang-tidy lints about C arrays are firing on this line. +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays) template, int> = 0> -std::ostream& +static std::ostream& operator<<(std::ostream& str, const T& obj) { auto u = static_cast>(obj); @@ -75,7 +77,7 @@ operator<<(std::ostream& str, const GroupContent::RawContent& obj) } template -inline std::enable_if_t +static inline std::enable_if_t operator<<(std::ostream& str, const T& obj) { return str << to_hex(tls::marshal(obj)); @@ -587,7 +589,11 @@ KeyScheduleTestVector::KeyScheduleTestVector(CipherSuite suite, // TODO(RLB) Add Test case for externally-driven epoch change auto commit_secret = epoch_prg.secret("commit_secret"); auto psk_secret = epoch_prg.secret("psk_secret"); - epoch = epoch.next_raw(commit_secret, psk_secret, std::nullopt, ctx); + epoch = epoch.next_raw(commit_secret, + psk_secret, + std::nullopt, + group_context.confirmed_transcript_hash, + ctx); auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw( cipher_suite, epoch.joiner_secret, psk_secret); @@ -645,8 +651,11 @@ KeyScheduleTestVector::verify() const auto ctx = tls::marshal(group_context); VERIFY_EQUAL("group context", ctx, tve.group_context); - epoch = - epoch.next_raw(tve.commit_secret, tve.psk_secret, std::nullopt, ctx); + epoch = epoch.next_raw(tve.commit_secret, + tve.psk_secret, + std::nullopt, + group_context.confirmed_transcript_hash, + ctx); // Verify the rest of the epoch VERIFY_EQUAL("joiner secret", epoch.joiner_secret, tve.joiner_secret); @@ -959,17 +968,17 @@ TranscriptTestVector::TranscriptTestVector(CipherSuite suite) auto group_id = prg.secret("group_id"); auto epoch = prg.uint64("epoch"); - auto group_context_obj = + auto group_context = GroupContext{ suite, group_id, epoch, prg.secret("tree_hash_before"), prg.secret("confirmed_transcript_hash_before"), {} }; - auto group_context = tls::marshal(group_context_obj); auto init_secret = prg.secret("init_secret"); - auto ks_epoch = KeyScheduleEpoch(suite, init_secret, group_context); + auto key_schedule_before = + KeyScheduleEpoch(suite, init_secret, tls::marshal(group_context)); auto sig_priv = prg.signature_key("sig_priv"); auto leaf_index = LeafIndex{ 0 }; @@ -980,17 +989,27 @@ TranscriptTestVector::TranscriptTestVector(CipherSuite suite) group_id, epoch, { MemberSender{ leaf_index } }, {}, Commit{} }, suite, sig_priv, - group_context_obj); + group_context); - transcript.update_confirmed(authenticated_content); + const auto new_confirmed = transcript.new_confirmed( + authenticated_content.confirmed_transcript_hash_input()); + transcript.set_confirmed(new_confirmed); + ; - const auto confirmation_tag = ks_epoch.confirmation_tag(transcript.confirmed); - authenticated_content.set_confirmation_tag(confirmation_tag); + group_context.confirmed_transcript_hash = transcript.confirmed; + auto key_schedule_after = + key_schedule_before.next(suite.zero(), + {}, + std::nullopt, + transcript.confirmed, + tls::marshal(group_context)); - transcript.update_interim(authenticated_content); + authenticated_content.set_confirmation_tag( + key_schedule_after.confirmation_tag); + transcript.update_interim(key_schedule_after.confirmation_tag); // Store the required data - confirmation_key = ks_epoch.confirmation_key; + confirmation_key = key_schedule_after.confirmation_key; confirmed_transcript_hash_after = transcript.confirmed; interim_transcript_hash_after = transcript.interim; } @@ -1001,7 +1020,14 @@ TranscriptTestVector::verify() const auto transcript = TranscriptHash(cipher_suite); transcript.interim = interim_transcript_hash_before; - transcript.update(authenticated_content); + const auto new_confirmed = transcript.new_confirmed( + authenticated_content.confirmed_transcript_hash_input()); + transcript.set_confirmed(new_confirmed); + + const auto input_confirmation_tag = + opt::get(authenticated_content.auth.confirmation_tag); + transcript.update_interim(input_confirmation_tag); + VERIFY_EQUAL( "confirmed", transcript.confirmed, confirmed_transcript_hash_after); VERIFY_EQUAL("interim", transcript.interim, interim_transcript_hash_after); @@ -1055,15 +1081,16 @@ WelcomeTestVector::WelcomeTestVector(CipherSuite suite) cipher_suite, group_id, epoch, tree_hash, confirmed_transcript_hash, {} }; - auto key_schedule = KeyScheduleEpoch::joiner( - cipher_suite, joiner_secret, {}, tls::marshal(group_context)); - auto confirmation_tag = - key_schedule.confirmation_tag(confirmed_transcript_hash); + auto key_schedule = KeyScheduleEpoch::joiner(cipher_suite, + joiner_secret, + {}, + confirmed_transcript_hash, + tls::marshal(group_context)); auto group_info = GroupInfo{ group_context, {}, - confirmation_tag, + key_schedule.confirmation_tag, }; group_info.sign(signer_index, signer_priv); @@ -1098,10 +1125,15 @@ WelcomeTestVector::verify() const // Verify confirmation tag const auto& group_context = group_info.group_context; - auto key_schedule = KeyScheduleEpoch::joiner( - cipher_suite, group_secrets.joiner_secret, {}, tls::marshal(group_context)); - auto confirmation_tag = - key_schedule.confirmation_tag(group_context.confirmed_transcript_hash); + auto key_schedule = + KeyScheduleEpoch::joiner(cipher_suite, + group_secrets.joiner_secret, + {}, + group_context.confirmed_transcript_hash, + tls::marshal(group_context)); + VERIFY_EQUAL("confirmation tag", + key_schedule.confirmation_tag, + group_info.confirmation_tag); return std::nullopt; } diff --git a/lib/mls_vectors/test/mls_vectors.cpp b/lib/mls_vectors/test/mls_vectors.cpp index 010f6c0c..df2f48ee 100644 --- a/lib/mls_vectors/test/mls_vectors.cpp +++ b/lib/mls_vectors/test/mls_vectors.cpp @@ -78,7 +78,7 @@ TEST_CASE("Welcome") } } -TEST_CASE("Tree Hashes") +TEST_CASE("Tree Hashes", "[.][all]") { for (auto suite : supported_suites) { for (auto structure : all_tree_structures) { @@ -97,7 +97,7 @@ TEST_CASE("Tree Operations") } } -TEST_CASE("TreeKEM") +TEST_CASE("TreeKEM", "[.][all]") { for (auto suite : supported_suites) { for (auto structure : treekem_test_tree_structures) { diff --git a/lib/tls_syntax/test/tls_syntax.cpp b/lib/tls_syntax/test/tls_syntax.cpp index 49bd3056..bf547fb2 100644 --- a/lib/tls_syntax/test/tls_syntax.cpp +++ b/lib/tls_syntax/test/tls_syntax.cpp @@ -97,8 +97,8 @@ class TLSSyntaxTest }; template -void -ostream_test(T val, const std::vector& enc) +static void +ostream_test(const T& val, const std::vector& enc) { MLS_NAMESPACE::tls::ostream w; // NOLINT(misc-const-correctness) w << val; @@ -127,8 +127,8 @@ TEST_CASE_METHOD(TLSSyntaxTest, "TLS ostream") } template -void -istream_test(T val, T& data, const std::vector& enc) +static void +istream_test(const T& val, T& data, const std::vector& enc) { MLS_NAMESPACE::tls::istream r(enc); // NOLINT(misc-const-correctness) r >> data; diff --git a/src/key_schedule.cpp b/src/key_schedule.cpp index a5587ca2..15169b2b 100644 --- a/src/key_schedule.cpp +++ b/src/key_schedule.cpp @@ -243,7 +243,7 @@ GroupKeySource::get(ContentType type, void GroupKeySource::erase(ContentType type, LeafIndex sender, uint32_t generation) { - return chain(type, sender).erase(generation); + chain(type, sender).erase(generation); } // struct { @@ -301,14 +301,20 @@ KeyScheduleEpoch KeyScheduleEpoch::joiner(CipherSuite suite_in, const bytes& joiner_secret, const std::vector& psks, + const bytes& confirmed_transcript_hash, const bytes& context) { - return { suite_in, joiner_secret, make_psk_secret(suite_in, psks), context }; + return { suite_in, + joiner_secret, + make_psk_secret(suite_in, psks), + confirmed_transcript_hash, + context }; } KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, const bytes& joiner_secret, const bytes& psk_secret, + const bytes& confirmed_transcript_hash, const bytes& context) : suite(suite_in) , joiner_secret(joiner_secret) @@ -325,6 +331,8 @@ KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, , init_secret(suite.derive_secret(epoch_secret, "init")) , external_priv(HPKEPrivateKey::derive(suite, external_secret)) { + confirmation_tag = + suite.digest().hmac(confirmation_key, confirmed_transcript_hash); } KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in) @@ -339,6 +347,7 @@ KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, suite_in, make_joiner_secret(suite_in, context, init_secret, suite_in.zero()), { /* no PSKs */ }, + { /* confirmed transcript hash is the zero-length octet string */ }, context) { } @@ -347,11 +356,13 @@ KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in, const bytes& init_secret, const bytes& commit_secret, const bytes& psk_secret, + const bytes& confirmed_transcript_hash, const bytes& context) : KeyScheduleEpoch( suite_in, make_joiner_secret(suite_in, context, init_secret, commit_secret), psk_secret, + confirmed_transcript_hash, context) { } @@ -377,16 +388,21 @@ KeyScheduleEpoch KeyScheduleEpoch::next(const bytes& commit_secret, const std::vector& psks, const std::optional& force_init_secret, + const bytes& confirmed_transcript_hash, const bytes& context) const { - return next_raw( - commit_secret, make_psk_secret(suite, psks), force_init_secret, context); + return next_raw(commit_secret, + make_psk_secret(suite, psks), + force_init_secret, + confirmed_transcript_hash, + context); } KeyScheduleEpoch KeyScheduleEpoch::next_raw(const bytes& commit_secret, const bytes& psk_secret, const std::optional& force_init_secret, + const bytes& confirmed_transcript_hash, const bytes& context) const { auto actual_init_secret = init_secret; @@ -394,7 +410,8 @@ KeyScheduleEpoch::next_raw(const bytes& commit_secret, actual_init_secret = opt::get(force_init_secret); } - return { suite, actual_init_secret, commit_secret, psk_secret, context }; + return { suite, actual_init_secret, commit_secret, + psk_secret, confirmed_transcript_hash, context }; } GroupKeySource @@ -403,12 +420,6 @@ KeyScheduleEpoch::encryption_keys(LeafCount size) const return { suite, size, encryption_secret }; } -bytes -KeyScheduleEpoch::confirmation_tag(const bytes& confirmed_transcript_hash) const -{ - return suite.digest().hmac(confirmation_key, confirmed_transcript_hash); -} - bytes KeyScheduleEpoch::do_export(const std::string& label, const bytes& context, @@ -539,34 +550,22 @@ TranscriptHash::TranscriptHash(CipherSuite suite_in, update_interim(confirmation_tag); } -void -TranscriptHash::update(const AuthenticatedContent& content_auth) +bytes +TranscriptHash::new_confirmed(const bytes& transcript_hash_input) const { - update_confirmed(content_auth); - update_interim(content_auth); + return suite.digest().hash(interim + transcript_hash_input); } void -TranscriptHash::update_confirmed(const AuthenticatedContent& content_auth) +TranscriptHash::set_confirmed(bytes confirmed_transcript_hash) { - const auto transcript = - interim + content_auth.confirmed_transcript_hash_input(); - confirmed = suite.digest().hash(transcript); + confirmed = std::move(confirmed_transcript_hash); } void TranscriptHash::update_interim(const bytes& confirmation_tag) { - const auto transcript = confirmed + tls::marshal(confirmation_tag); - interim = suite.digest().hash(transcript); -} - -void -TranscriptHash::update_interim(const AuthenticatedContent& content_auth) -{ - const auto transcript = - confirmed + content_auth.interim_transcript_hash_input(); - interim = suite.digest().hash(transcript); + interim = suite.digest().hash(confirmed + tls::marshal(confirmation_tag)); } bool diff --git a/src/state.cpp b/src/state.cpp index 44972eee..86cd9a31 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -42,8 +42,7 @@ State::State(bytes group_id, _keys = _key_schedule.encryption_keys(_tree.size); // Update the interim transcript hash with a virtual confirmation tag - _transcript_hash.update_interim( - _key_schedule.confirmation_tag(_transcript_hash.confirmed)); + _transcript_hash.update_interim(_key_schedule.confirmation_tag); } TreeKEMPublicKey @@ -268,14 +267,12 @@ State::State(const HPKEPrivateKey& init_priv, // Ratchet forward into the current epoch auto group_ctx = tls::marshal(group_context()); - _key_schedule = - KeyScheduleEpoch::joiner(_suite, secrets.joiner_secret, psks, group_ctx); + _key_schedule = KeyScheduleEpoch::joiner( + _suite, secrets.joiner_secret, psks, _transcript_hash.confirmed, group_ctx); _keys = _key_schedule.encryption_keys(_tree.size); // Verify the confirmation - const auto confirmation_tag = - _key_schedule.confirmation_tag(_transcript_hash.confirmed); - if (confirmation_tag != group_info.confirmation_tag) { + if (_key_schedule.confirmation_tag != group_info.confirmation_tag) { throw ProtocolError("Confirmation failed to verify"); } } @@ -634,18 +631,32 @@ State::commit(const bytes& leaf_secret, return commit(leaf_secret, opts, msg_opts, NormalCommitParams{}); } -std::tuple -State::commit(const bytes& leaf_secret, - const std::optional& opts, - const MessageOpts& msg_opts, - CommitParams params) -{ - // Construct a commit from cached proposals - // TODO(rlb) ignore some proposals: - // * Update after Update - // * Update after Remove - // * Remove after Remove - Commit commit; +struct State::CommitMaterials +{ + // To be used locally + LeafIndex index; + TreeKEMPublicKey new_tree; + TreeKEMPrivateKey new_tree_priv; + ExtensionList extensions; + std::optional force_init_secret; + + // To be sent to other members + std::vector proposals; + std::optional path; + + // To be used in forming Welcome messages + std::vector joiners; + std::vector> path_secrets; + std::vector psks; +}; + +State::CommitMaterials +State::prepare_commit(const bytes& leaf_secret, + const std::optional& opts, + const CommitParams& params) const +{ + // Construct a proposal list from cached proposals + auto proposals = std::vector{}; auto joiners = std::vector{}; for (const auto& cached : _pending_proposals) { if (var::holds_alternative(cached.proposal.content)) { @@ -653,7 +664,7 @@ State::commit(const bytes& leaf_secret, joiners.push_back(add.key_package); } - commit.proposals.push_back({ cached.ref }); + proposals.push_back({ cached.ref }); } // Add the extra proposals to those we had cached @@ -665,7 +676,7 @@ State::commit(const bytes& leaf_secret, joiners.push_back(add.key_package); } - commit.proposals.push_back({ proposal }); + proposals.push_back({ proposal }); } } @@ -681,58 +692,48 @@ State::commit(const bytes& leaf_secret, } // Apply proposals - State next = successor(); - - const auto proposals = must_resolve(commit.proposals, _index); - if (!valid(proposals, _index, params)) { + const auto cached_proposals = must_resolve(proposals, _index); + if (!valid(cached_proposals, _index, params)) { throw ProtocolError("Invalid proposal list"); } - const auto [joiner_locations, psks] = next.apply(proposals); + auto [new_tree, joiner_locations, psks, extensions] = apply(cached_proposals); + auto index = _index; if (external_commit) { const auto& leaf_node = opt::get(external_commit).joiner_key_package.leaf_node; - next._index = next._tree.add_leaf(leaf_node); - } - - // If this is an external commit, indicate it in the sender field - auto sender = Sender{ MemberSender{ _index } }; - if (external_commit) { - sender = Sender{ NewMemberCommitSender{} }; + index = new_tree.add_leaf(leaf_node); } // KEM new entropy to the group and the new joiners - auto commit_secret = _suite.zero(); + auto new_tree_priv = _tree_priv; + auto path = std::optional{}; auto path_secrets = std::vector>(joiner_locations.size()); auto force_path = opts && opt::get(opts).force_path; - if (force_path || path_required(proposals)) { + if (force_path || path_required(cached_proposals)) { auto leaf_node_opts = LeafNodeOptions{}; if (opts) { leaf_node_opts = opt::get(opts).leaf_node_opts; } - auto new_priv = next._tree.update( - next._index, leaf_secret, next._group_id, _identity_priv, leaf_node_opts); + new_tree_priv = new_tree.update( + index, leaf_secret, _group_id, _identity_priv, leaf_node_opts); auto ctx = tls::marshal(GroupContext{ - next._suite, - next._group_id, - next._epoch + 1, - next._tree.root_hash(), - next._transcript_hash.confirmed, - next._extensions, + _suite, + _group_id, + _epoch + 1, + new_tree.root_hash(), + _transcript_hash.confirmed, + extensions, }); - auto path = next._tree.encap(new_priv, ctx, joiner_locations); - - next._tree_priv = new_priv; - commit.path = path; - commit_secret = new_priv.update_secret; + path = new_tree.encap(new_tree_priv, ctx, joiner_locations); for (size_t i = 0; i < joiner_locations.size(); i++) { auto [overlap, shared_path_secret, ok] = - new_priv.shared_path_secret(joiner_locations[i]); + new_tree_priv.shared_path_secret(joiner_locations[i]); silence_unused(overlap); silence_unused(ok); @@ -740,47 +741,79 @@ State::commit(const bytes& leaf_secret, } } - // Create the Commit message and advance the transcripts / key schedule - auto commit_content_auth = - sign(sender, commit, msg_opts.authenticated_data, msg_opts.encrypt); + return { + index, new_tree, new_tree_priv, extensions, force_init_secret, + proposals, path, joiners, path_secrets, psks, + }; +} - next._transcript_hash.update_confirmed(commit_content_auth); - next._epoch += 1; - next.update_epoch_secrets(commit_secret, psks, force_init_secret); +Welcome +State::welcome(bool inline_tree, + const std::vector& psks, + const std::vector& joiners, + const std::vector>& path_secrets) const +{ + // TODO(RLB) Suppress external_pub in this GroupInfo + auto group_info_obj = group_info(inline_tree); - const auto confirmation_tag = - next._key_schedule.confirmation_tag(next._transcript_hash.confirmed); - commit_content_auth.set_confirmation_tag(confirmation_tag); + auto welcome = + Welcome{ _suite, _key_schedule.joiner_secret, psks, group_info_obj }; + for (size_t i = 0; i < joiners.size(); i++) { + welcome.encrypt(joiners[i], path_secrets[i]); + } - next._transcript_hash.update_interim(commit_content_auth); + return welcome; +} - auto commit_message = - protect(std::move(commit_content_auth), msg_opts.padding_size); +std::tuple +State::commit(const bytes& leaf_secret, + const std::optional& opts, + const MessageOpts& msg_opts, + const CommitParams& params) +{ + // Compute the new group state + auto commit_materials = prepare_commit(leaf_secret, opts, params); - // Complete the GroupInfo and form the Welcome - auto group_info = GroupInfo{ - { - next._suite, - next._group_id, - next._epoch, - next._tree.root_hash(), - next._transcript_hash.confirmed, - next._extensions, - }, - { /* No other extensions */ }, - { confirmation_tag }, + // Form the AuthenticatedContent (with signature, but not confirmation tag) + const auto commit = Commit{ + commit_materials.proposals, + commit_materials.path, }; - if (opts && opt::get(opts).inline_tree) { - group_info.extensions.add(RatchetTreeExtension{ next._tree }); - } - group_info.sign(next._tree, next._index, next._identity_priv); - auto welcome = - Welcome{ _suite, next._key_schedule.joiner_secret, psks, group_info }; - for (size_t i = 0; i < joiners.size(); i++) { - welcome.encrypt(joiners[i], path_secrets[i]); + auto sender = Sender{ MemberSender{ _index } }; + if (commit_materials.force_init_secret) { + sender = Sender{ NewMemberCommitSender{} }; } + auto preliminary_commit = + sign(sender, commit, msg_opts.authenticated_data, msg_opts.encrypt); + + // Update confirmed transcript hash and ratchet the key schedule forward + const auto confirmed_transcript_hash = _transcript_hash.new_confirmed( + preliminary_commit.confirmed_transcript_hash_input()); + + const auto next = successor(commit_materials.index, + std::move(commit_materials.new_tree), + std::move(commit_materials.new_tree_priv), + std::move(commit_materials.extensions), + confirmed_transcript_hash, + commit_materials.path.has_value(), + commit_materials.psks, + commit_materials.force_init_secret); + + // Complete the AuthenticatedContent and encapsulate as MLSMessage + const auto confirmation_tag = next._key_schedule.confirmation_tag; + preliminary_commit.set_confirmation_tag(confirmation_tag); + const auto commit_message = + protect(std::move(preliminary_commit), msg_opts.padding_size); + + // Create the welcome message + const auto inline_tree = opts && opt::get(opts).inline_tree; + const auto welcome = next.welcome(inline_tree, + commit_materials.psks, + commit_materials.joiners, + commit_materials.path_secrets); + return std::make_tuple(commit_message, welcome, next); } @@ -826,14 +859,6 @@ State::handle(const ValidatedContent& content_auth, return handle(content_auth, std::move(cached_state), std::nullopt); } -std::optional -State::handle(const MLSMessage& msg, - std::optional cached_state, - const std::optional& expected_params) -{ - return handle(unwrap(msg), std::move(cached_state), expected_params); -} - std::optional State::handle(const ValidatedContent& val_content, std::optional cached_state, @@ -841,22 +866,62 @@ State::handle(const ValidatedContent& val_content, { // Dispatch on content type const auto& content_auth = val_content.authenticated_content(); - const auto& content = content_auth.content; - switch (content.content_type()) { + switch (content_auth.content.content_type()) { // Proposals get queued, do not result in a state transition case ContentType::proposal: - cache_proposal(content_auth); + handle_proposal(content_auth); return std::nullopt; // Commits are handled in the remainder of this method case ContentType::commit: - break; + return handle_commit( + content_auth, std::move(cached_state), expected_params); // Any other content type in this method is an error default: throw InvalidParameterError("Invalid content type"); } +} + +void +State::handle_proposal(const AuthenticatedContent& content_auth) +{ + auto ref = _suite.ref(content_auth); + if (stdx::any_of(_pending_proposals, + [&](const auto& cached) { return cached.ref == ref; })) { + return; + } + auto sender_location = std::optional(); + if (content_auth.content.sender.sender_type() == SenderType::member) { + const auto& sender = content_auth.content.sender.sender; + sender_location = var::get(sender).sender; + } + + const auto& proposal = var::get(content_auth.content.content); + + if (content_auth.content.sender.sender_type() == SenderType::external && + !valid_external_proposal_type(proposal.proposal_type())) { + throw ProtocolError("Invalid external proposal"); + } + + if (!valid(sender_location, proposal)) { + throw ProtocolError("Invalid proposal"); + } + + _pending_proposals.push_back({ + _suite.ref(content_auth), + proposal, + sender_location, + }); +} + +State +State::handle_commit(const AuthenticatedContent& content_auth, + std::optional cached_state, + const std::optional& expected_params) const +{ + const auto& content = content_auth.content; switch (content.sender.sender_type()) { case SenderType::member: case SenderType::new_member_commit: @@ -886,22 +951,17 @@ State::handle(const ValidatedContent& val_content, throw InvalidParameterError("Handle own commits with caching"); } - // Apply the commit + // Unwrap the Commit itself const auto& commit = var::get(content.content); + + // Apply the proposals attached to the commit const auto proposals = must_resolve(commit.proposals, sender); + auto [new_tree, joiner_locations, psks, extensions] = apply(proposals); + // Determine what type of Commit this is const auto params = infer_commit_type(sender, proposals, expected_params); auto external_commit = var::holds_alternative(params); - // Check that a path is present when required - if (path_required(proposals) && !commit.path) { - throw ProtocolError("Path required but not present"); - } - - // Apply the proposals - auto next = successor(); - auto [joiner_locations, psks] = next.apply(proposals); - // If this is an external commit, add the joiner to the tree and note the // location where they were added. Also, compute the "externally forced" // value that we will use for the init_secret (as opposed to the init_secret @@ -912,7 +972,7 @@ State::handle(const ValidatedContent& val_content, sender_location = opt::get(sender); } else { // Find where the joiner will be added - sender_location = next._tree.allocate_leaf(); + sender_location = new_tree.allocate_leaf(); // Extract the forced init secret auto kem_output = commit.valid_external(); @@ -924,8 +984,14 @@ State::handle(const ValidatedContent& val_content, _key_schedule.receive_external_init(opt::get(kem_output)); } - // Decapsulate and apply the UpdatePath, if provided - auto commit_secret = _suite.zero(); + // Check that a path is present when required + if (path_required(proposals) && !commit.path) { + throw ProtocolError("Path required but not present"); + } + + // Identify the encrypted path secret and how to decrypt it + auto path_secret_decrypt_node = std::optional{}; + auto encrypted_path_secret = std::optional{}; if (commit.path) { const auto& path = opt::get(commit.path); @@ -933,35 +999,98 @@ State::handle(const ValidatedContent& val_content, throw ProtocolError("Commit path has invalid leaf node"); } - if (!next._tree.parent_hash_valid(sender_location, path)) { + if (!new_tree.parent_hash_valid(sender_location, path)) { throw ProtocolError("Commit path has invalid parent hash"); } - next._tree.merge(sender_location, path); + new_tree.merge(sender_location, path); + + const auto coords = + new_tree.decap_coords(_index, sender_location, joiner_locations); + path_secret_decrypt_node = coords.resolution_node; + encrypted_path_secret = + path.nodes.at(coords.ancestor_node_index) + .encrypted_path_secret.at(coords.resolution_node_index); + } + + // Update the transcript hash + const auto new_confirmed_transcript_hash = _transcript_hash.new_confirmed( + content_auth.confirmed_transcript_hash_input()); + const auto new_confirmation_tag = + opt::get(content_auth.auth.confirmation_tag); + + return ratchet(std::move(new_tree), + sender_location, + path_secret_decrypt_node, + encrypted_path_secret, + extensions, + psks, + force_init_secret, + new_confirmed_transcript_hash, + new_confirmation_tag); +} + +State +State::ratchet(TreeKEMPublicKey new_tree, + LeafIndex committer, + const std::optional& path_secret_decrypt_node, + const std::optional& encrypted_path_secret, + ExtensionList extensions, + const std::vector& psks, + const std::optional& force_init_secret, + const bytes& confirmed_transcript_hash, + const bytes& confirmation_tag) const +{ + // Update the TreeKEM private key to match the public key + auto new_tree_priv = _tree_priv; + new_tree_priv.truncate(new_tree.size); + + const auto my_leaf = opt::get(new_tree.leaf_node(_index)); + const auto my_priv = new_tree_priv.private_key_cache.at(NodeIndex(_index)); + if (my_leaf.encryption_key != my_priv.public_key) { + if (!_cached_update) { + throw ProtocolError("Self-update without cached update"); + } + + const auto cached_update = opt::get(_cached_update); + if (my_leaf != cached_update.proposal.leaf_node) { + throw ProtocolError("Self-update does not match cached leaf node"); + } + + new_tree_priv.set_leaf_priv(cached_update.update_priv); + } + // Compute the new TreeKEM private key + const auto has_path = path_secret_decrypt_node && encrypted_path_secret; + if (has_path) { auto ctx = tls::marshal(GroupContext{ - next._suite, - next._group_id, - next._epoch + 1, - next._tree.root_hash(), - next._transcript_hash.confirmed, - next._extensions, + _suite, + _group_id, + _epoch + 1, + new_tree.root_hash(), + _transcript_hash.confirmed, + extensions, }); - next._tree_priv.decap( - sender_location, next._tree, ctx, path, joiner_locations); - commit_secret = next._tree_priv.update_secret; + new_tree_priv.decap(committer, + new_tree, + ctx, + opt::get(path_secret_decrypt_node), + opt::get(encrypted_path_secret)); } // Update the transcripts and advance the key schedule - next._transcript_hash.update(content_auth); - next._epoch += 1; - next.update_epoch_secrets(commit_secret, { psks }, force_init_secret); + auto next = successor(_index, + std::move(new_tree), + std::move(new_tree_priv), + std::move(extensions), + confirmed_transcript_hash, + has_path, + psks, + force_init_secret); // Verify the confirmation MAC - const auto confirmation_tag = - next._key_schedule.confirmation_tag(next._transcript_hash.confirmed); - if (!content_auth.check_confirmation_tag(confirmation_tag)) { + if (next._key_schedule.confirmation_tag != confirmation_tag) { throw ProtocolError("Confirmation failed to verify"); } @@ -1183,87 +1312,125 @@ State::handle_reinit_commit(const MLSMessage& commit_msg) /// LeafIndex -State::apply(const Add& add) +State::apply(TreeKEMPublicKey& tree, const Add& add) { - return _tree.add_leaf(add.key_package.leaf_node); + return tree.add_leaf(add.key_package.leaf_node); } void -State::apply(LeafIndex target, const Update& update) +State::apply(TreeKEMPublicKey& tree, LeafIndex target, const Update& update) { - _tree.update_leaf(target, update.leaf_node); -} - -void -State::apply(LeafIndex target, - const Update& update, - const HPKEPrivateKey& leaf_priv) -{ - _tree.update_leaf(target, update.leaf_node); - _tree_priv.set_leaf_priv(leaf_priv); + tree.update_leaf(target, update.leaf_node); } LeafIndex -State::apply(const Remove& remove) +State::apply(TreeKEMPublicKey& tree, const Remove& remove) { - if (!_tree.has_leaf(remove.removed)) { + if (!tree.has_leaf(remove.removed)) { throw ProtocolError("Attempt to remove non-member"); } - _tree.blank_path(remove.removed); + tree.blank_path(remove.removed); return remove.removed; } -void -State::apply(const GroupContextExtensions& gce) +std::vector +State::apply(TreeKEMPublicKey& tree, + const std::vector& proposals, + Proposal::Type required_type) { - // TODO(RLB): Update spec to clarify that you MUST verify that the new - // extensions are compatible with all members. - if (!extensions_supported(gce.group_context_extensions)) { - throw ProtocolError("Unsupported extensions in GroupContextExtensions"); + auto locations = std::vector{}; + for (const auto& cached : proposals) { + auto proposal_type = cached.proposal.proposal_type(); + if (proposal_type != required_type) { + continue; + } + + switch (proposal_type) { + case ProposalType::add: { + const auto joiner_location = + apply(tree, var::get(cached.proposal.content)); + locations.push_back(joiner_location); + break; + } + + case ProposalType::update: { + const auto& update = var::get(cached.proposal.content); + + if (!cached.sender) { + throw ProtocolError("Update without target leaf"); + } + + auto target = opt::get(cached.sender); + apply(tree, target, update); + break; + } + + case ProposalType::remove: { + const auto& remove = var::get(cached.proposal.content); + apply(tree, remove); + break; + } + + default: + throw ProtocolError("Unsupported proposal type"); + } } - _extensions = gce.group_context_extensions; + return locations; } -bool -State::extensions_supported(const ExtensionList& exts) const +std::tuple, + std::vector, + ExtensionList> +State::apply(const std::vector& proposals) const { - return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { - return leaf_node.verify_extension_support(exts); - }); -} + auto tree = _tree; + apply(tree, proposals, ProposalType::update); + apply(tree, proposals, ProposalType::remove); + auto joiner_locations = apply(tree, proposals, ProposalType::add); -void -State::cache_proposal(AuthenticatedContent content_auth) -{ - auto ref = _suite.ref(content_auth); - if (stdx::any_of(_pending_proposals, - [&](const auto& cached) { return cached.ref == ref; })) { - return; - } + // Extract the GroupContextExtensions proposal, if present + auto extensions = _extensions; + for (const auto& cached : proposals) { + if (cached.proposal.proposal_type() != + ProposalType::group_context_extensions) { + continue; + } - auto sender_location = std::optional(); - if (content_auth.content.sender.sender_type() == SenderType::member) { - const auto& sender = content_auth.content.sender.sender; - sender_location = var::get(sender).sender; + const auto& proposal = + var::get(cached.proposal.content); + if (!extensions_supported(proposal.group_context_extensions)) { + throw ProtocolError("Unsupported extensions in GroupContextExtensions"); + } + + extensions = proposal.group_context_extensions; + break; } - const auto& proposal = var::get(content_auth.content.content); + // Extract the PSK proposals and look up the secrets + auto psk_ids = std::vector{}; + for (const auto& cached : proposals) { + if (cached.proposal.proposal_type() != ProposalType::psk) { + continue; + } - if (content_auth.content.sender.sender_type() == SenderType::external && - !valid_external_proposal_type(proposal.proposal_type())) { - throw ProtocolError("Invalid external proposal"); + const auto& proposal = var::get(cached.proposal.content); + psk_ids.push_back(proposal.psk); } + auto psks = resolve(psk_ids); - if (!valid(sender_location, proposal)) { - throw ProtocolError("Invalid proposal"); - } + tree.truncate(); + tree.set_hash_all(); + return { tree, joiner_locations, psks, extensions }; +} - _pending_proposals.push_back({ - _suite.ref(content_auth), - proposal, - sender_location, +bool +State::extensions_supported(const ExtensionList& exts) const +{ + return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + return leaf_node.verify_extension_support(exts); }); } @@ -1314,7 +1481,7 @@ State::resolve(const std::vector& psks) const }, [&](const ResumptionPSK& res_psk) { - if (res_psk.psk_epoch == _epoch) { + if (res_psk.psk_group_id == _group_id && res_psk.psk_epoch == _epoch) { return _key_schedule.resumption_psk; } @@ -1332,103 +1499,6 @@ State::resolve(const std::vector& psks) const }); } -std::vector -State::apply(const std::vector& proposals, - Proposal::Type required_type) -{ - auto locations = std::vector{}; - for (const auto& cached : proposals) { - auto proposal_type = cached.proposal.proposal_type(); - if (proposal_type != required_type) { - continue; - } - - switch (proposal_type) { - case ProposalType::add: { - locations.push_back(apply(var::get(cached.proposal.content))); - break; - } - - case ProposalType::update: { - const auto& update = var::get(cached.proposal.content); - - if (!cached.sender) { - throw ProtocolError("Update without target leaf"); - } - - auto target = opt::get(cached.sender); - if (target != _index) { - apply(target, update); - break; - } - - if (!_cached_update) { - throw ProtocolError("Self-update with no cached secret"); - } - - const auto& cached_update = opt::get(_cached_update); - if (update != cached_update.proposal) { - throw ProtocolError("Self-update does not match cached data"); - } - - apply(target, update, cached_update.update_priv); - locations.push_back(target); - break; - } - - case ProposalType::remove: { - const auto& remove = var::get(cached.proposal.content); - locations.push_back(apply(remove)); - break; - } - - case ProposalType::group_context_extensions: { - const auto& gce = - var::get(cached.proposal.content); - apply(gce); - break; - } - - default: - throw ProtocolError("Unsupported proposal type"); - } - } - - // The cached update needs to be reset after applying proposals, so that it is - // in a clean state for the next epoch. - _cached_update.reset(); - - return locations; -} - -std::tuple, std::vector> -State::apply(const std::vector& proposals) -{ - apply(proposals, ProposalType::update); - apply(proposals, ProposalType::remove); - auto joiner_locations = apply(proposals, ProposalType::add); - apply(proposals, ProposalType::group_context_extensions); - - // Extract the PSK proposals and look up the secrets - // TODO(RLB): Factor this out, and also factor the above methods into - // apply_update, apply_remove, etc. - auto psk_ids = std::vector{}; - for (const auto& cached : proposals) { - if (cached.proposal.proposal_type() != ProposalType::psk) { - continue; - } - - const auto& proposal = var::get(cached.proposal.content); - psk_ids.push_back(proposal.psk); - } - auto psks = resolve(psk_ids); - - _tree.truncate(); - _tree_priv.truncate(_tree.size); - _tree.set_hash_all(); - return { joiner_locations, psks }; -} - /// /// Message protection /// @@ -2021,24 +2091,6 @@ operator!=(const State& lhs, const State& rhs) return !(lhs == rhs); } -void -State::update_epoch_secrets(const bytes& commit_secret, - const std::vector& psks, - const std::optional& force_init_secret) -{ - auto ctx = tls::marshal(GroupContext{ - _suite, - _group_id, - _epoch, - _tree.root_hash(), - _transcript_hash.confirmed, - _extensions, - }); - _key_schedule = - _key_schedule.next(commit_secret, psks, force_init_secret, ctx); - _keys = _key_schedule.encryption_keys(_tree.size); -} - /// /// Message encryption and decryption /// @@ -2152,7 +2204,7 @@ State::group_info(bool inline_tree) const _extensions, }, { /* No other extensions */ }, - _key_schedule.confirmation_tag(_transcript_hash.confirmed), + _key_schedule.confirmation_tag, }; group_info.extensions.add( @@ -2172,7 +2224,7 @@ State::roster() const auto leaves = std::vector{}; leaves.reserve(_tree.size.val); - _tree.all_leaves([&](auto /* i */, auto leaf) { + _tree.all_leaves([&](auto /* i */, const auto& leaf) { leaves.push_back(leaf); return true; }); @@ -2205,15 +2257,45 @@ State::leaf_for_roster_entry(RosterIndex index) const } State -State::successor() const -{ - // Copy everything, then clear things that shouldn't be copied +State::successor(LeafIndex index, + TreeKEMPublicKey tree, + TreeKEMPrivateKey tree_priv, + ExtensionList extensions, + const bytes& confirmed_transcript_hash, + bool has_path, + const std::vector& psks, + const std::optional& force_init_secret) const +{ + // Initialize a clone with updates, clear things that shouldn't be copied auto next = *this; + next._epoch += 1; + next._index = index; + next._tree = std::move(tree); + next._tree_priv = std::move(tree_priv); + next._extensions = std::move(extensions); next._pending_proposals.clear(); // Copy forward a resumption PSK next.add_resumption_psk(_group_id, _epoch, _key_schedule.resumption_psk); + // Compute the commit secret + auto commit_secret = next._suite.zero(); + if (has_path) { + commit_secret = next._tree_priv.update_secret; + } + + // Ratchet forward the key schedule + next._transcript_hash.set_confirmed(confirmed_transcript_hash); + + const auto ctx = tls::marshal(next.group_context()); + next._key_schedule = _key_schedule.next(commit_secret, + psks, + force_init_secret, + next._transcript_hash.confirmed, + ctx); + next._keys = next._key_schedule.encryption_keys(next._tree.size); + next._transcript_hash.update_interim(next._key_schedule.confirmation_tag); + return next; } diff --git a/src/treekem.cpp b/src/treekem.cpp index db12582a..6c3fb2c2 100644 --- a/src/treekem.cpp +++ b/src/treekem.cpp @@ -243,6 +243,25 @@ TreeKEMPublicKey::dump() const } #endif +void +TreeKEMPrivateKey::decap(LeafIndex from, + const TreeKEMPublicKey& pub, + const bytes& context, + const NodeIndex& decrypt_node, + const HPKECiphertext& encrypted_path_secret) +{ + const auto overlap_node = from.ancestor(index); + const auto priv = opt::get(private_key(decrypt_node)); + const auto path_secret = priv.decrypt( + suite, encrypt_label::update_path_node, context, encrypted_path_secret); + implant(pub, overlap_node, path_secret); + + // Check that the resulting state is consistent with the public key + if (!consistent(pub)) { + throw ProtocolError("TreeKEMPublicKey inconsistent with TreeKEMPrivateKey"); + } +} + void TreeKEMPrivateKey::decap(LeafIndex from, const TreeKEMPublicKey& pub, @@ -586,6 +605,48 @@ TreeKEMPublicKey::resolve(NodeIndex index) const return l; } +TreeKEMPublicKey::DecapCoords +TreeKEMPublicKey::decap_coords( + LeafIndex to, + LeafIndex from, + const std::vector& joiner_locations) const +{ + const auto to_node = NodeIndex(to); + const auto from_node = NodeIndex(from); + + // Find the index of the common ancestor in the filtered direct path + const auto ancestor = to.ancestor(from); + const auto from_fdp = filtered_direct_path(from_node); + const auto ancestor_node_it = stdx::find_if(from_fdp, [&](const auto& pair) { + const auto& [node, _resolution] = pair; + return node == ancestor; + }); + const auto ancestor_node_index = + static_cast(ancestor_node_it - from_fdp.begin()); + + // Find the appropriate node in the copath resolution + auto copath_child = ancestor.left(); + if (!from_node.is_below(copath_child)) { + copath_child = ancestor.right(); + } + + auto resolution = std::get<1>(*ancestor_node_it); + for (const auto& j : joiner_locations) { + const auto it = stdx::find(resolution, NodeIndex(j)); + if (it != resolution.end()) { + resolution.erase(it); + } + } + + const auto resolution_node_it = stdx::find_if( + resolution, [&](const auto i) { return to_node.is_below(i); }); + const auto resolution_node_index = + static_cast(resolution_node_it - resolution.begin()); + const auto resolution_node = *resolution_node_it; + + return { ancestor_node_index, resolution_node_index, resolution_node }; +} + TreeKEMPublicKey::FilteredDirectPath TreeKEMPublicKey::filtered_direct_path(NodeIndex index) const { diff --git a/test/treekem.cpp b/test/treekem.cpp index 37601e53..5cb5ec89 100644 --- a/test/treekem.cpp +++ b/test/treekem.cpp @@ -266,7 +266,7 @@ TEST_CASE_METHOD(TreeKEMTest, "TreeKEM encap/decap") } } -TEST_CASE("TreeKEM Interop") +TEST_CASE("TreeKEM Interop", "[.][all]") { for (auto suite : all_supported_suites) { for (auto structure : treekem_test_tree_structures) {