From 97bd790206113f6e55778e1714dcaee617841ea7 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Sat, 16 Dec 2023 09:51:37 -0500 Subject: [PATCH] Expose group ID and allow inspection (#412) * Expose GroupID on MLSMessage * Provide handle() complementary to unwrap() * Add GroupInfo branch; remove redundant checks * clang-format * ValidatedContent * clang-format * clang-tidy * clang-tidy --- include/mls/messages.h | 24 ++++++++- include/mls/state.h | 16 ++++-- lib/mls_vectors/src/mls_vectors.cpp | 27 +++++----- src/messages.cpp | 44 +++++++++++++--- src/state.cpp | 81 ++++++++++++++++++----------- test/messages.cpp | 8 ++- 6 files changed, 144 insertions(+), 56 deletions(-) diff --git a/include/mls/messages.h b/include/mls/messages.h index a660aeb0..c7442a8f 100644 --- a/include/mls/messages.h +++ b/include/mls/messages.h @@ -570,17 +570,35 @@ struct AuthenticatedContent friend struct PrivateMessage; }; +struct ValidatedContent +{ + const AuthenticatedContent& authenticated_content() const; + + friend bool operator==(const ValidatedContent& lhs, + const ValidatedContent& rhs); + +private: + AuthenticatedContent content_auth; + + ValidatedContent(AuthenticatedContent content_auth_in); + + friend struct PublicMessage; + friend struct PrivateMessage; + friend class State; +}; + struct PublicMessage { PublicMessage() = default; + bytes get_group_id() const { return content.group_id; } epoch_t get_epoch() const { return content.epoch; } static PublicMessage protect(AuthenticatedContent content_auth, CipherSuite suite, const std::optional& membership_key, const std::optional& context); - std::optional unprotect( + std::optional unprotect( CipherSuite suite, const std::optional& membership_key, const std::optional& context) const; @@ -611,6 +629,7 @@ struct PrivateMessage { PrivateMessage() = default; + bytes get_group_id() const { return group_id; } epoch_t get_epoch() const { return epoch; } static PrivateMessage protect(AuthenticatedContent content_auth, @@ -618,7 +637,7 @@ struct PrivateMessage GroupKeySource& keys, const bytes& sender_data_secret, size_t padding_size); - std::optional unprotect( + std::optional unprotect( CipherSuite suite, GroupKeySource& keys, const bytes& sender_data_secret) const; @@ -649,6 +668,7 @@ struct MLSMessage var::variant message; + bytes group_id() const; epoch_t epoch() const; WireFormat wire_format() const; diff --git a/include/mls/state.h b/include/mls/state.h index 329aadb0..1741d10d 100644 --- a/include/mls/state.h +++ b/include/mls/state.h @@ -117,11 +117,16 @@ class State const MessageOpts& msg_opts); /// - /// Generic handshake message handler + /// Generic handshake message handlers /// std::optional handle(const MLSMessage& msg); std::optional handle(const MLSMessage& msg, std::optional cached_state); + + std::optional handle(const ValidatedContent& content_auth); + std::optional handle(const ValidatedContent& content_auth, + std::optional cached_state); + /// /// PSK management /// @@ -151,6 +156,11 @@ class State bytes epoch_authenticator() const; + /// + /// Unwrap messages so that applications can inspect them + /// + ValidatedContent unwrap(const MLSMessage& msg); + /// /// Application encryption and decryption /// @@ -318,7 +328,7 @@ class State std::optional cached_state, const std::optional& expected_params); std::optional handle( - const AuthenticatedContent& content_auth, + const ValidatedContent& val_content, std::optional cached_state, const std::optional& expected_params); @@ -334,8 +344,6 @@ class State template MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts); - AuthenticatedContent unprotect_to_content_auth(const MLSMessage& msg); - // Apply the changes requested by various messages LeafIndex apply(const Add& add); void apply(LeafIndex target, const Update& update); diff --git a/lib/mls_vectors/src/mls_vectors.cpp b/lib/mls_vectors/src/mls_vectors.cpp index ea44bdef..5d31e226 100644 --- a/lib/mls_vectors/src/mls_vectors.cpp +++ b/lib/mls_vectors/src/mls_vectors.cpp @@ -872,25 +872,26 @@ MessageProtectionTestVector::protect_priv( std::optional MessageProtectionTestVector::unprotect(const MLSMessage& message) { - auto do_unprotect = overloaded{ - [&](const PublicMessage& pt) { - return pt.unprotect(cipher_suite, membership_key, group_context()); - }, - [&](const PrivateMessage& ct) { - auto keys = group_keys(); - return ct.unprotect(cipher_suite, keys, sender_data_secret); - }, - [](const auto& /* other */) -> std::optional { - return std::nullopt; - } - }; + auto do_unprotect = + overloaded{ [&](const PublicMessage& pt) { + return pt.unprotect( + cipher_suite, membership_key, group_context()); + }, + [&](const PrivateMessage& ct) { + auto keys = group_keys(); + return ct.unprotect(cipher_suite, keys, sender_data_secret); + }, + [](const auto& /* other */) -> std::optional { + return std::nullopt; + } }; auto maybe_auth_content = var::visit(do_unprotect, message.message); if (!maybe_auth_content) { return std::nullopt; } - auto auth_content = opt::get(maybe_auth_content); + auto val_content = opt::get(maybe_auth_content); + const auto& auth_content = val_content.authenticated_content(); if (!auth_content.verify(cipher_suite, signature_pub, group_context())) { return std::nullopt; } diff --git a/src/messages.cpp b/src/messages.cpp index 37da6311..8d483193 100644 --- a/src/messages.cpp +++ b/src/messages.cpp @@ -464,6 +464,23 @@ AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in, { } +const AuthenticatedContent& +ValidatedContent::authenticated_content() const +{ + return content_auth; +} + +ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in) + : content_auth(std::move(content_auth_in)) +{ +} + +bool +operator==(const ValidatedContent& lhs, const ValidatedContent& rhs) +{ + return lhs.content_auth == rhs.content_auth; +} + struct GroupContentTBS { WireFormat wire_format = WireFormat::reserved; @@ -526,7 +543,7 @@ PublicMessage::protect(AuthenticatedContent content_auth, return pt; } -std::optional +std::optional PublicMessage::unprotect(CipherSuite suite, const std::optional& membership_key, const std::optional& context) const @@ -545,11 +562,11 @@ PublicMessage::unprotect(CipherSuite suite, break; } - return AuthenticatedContent{ + return { { AuthenticatedContent{ WireFormat::mls_public_message, content, auth, - }; + } } }; } bool @@ -756,7 +773,7 @@ PrivateMessage::protect(AuthenticatedContent content_auth, }; } -std::optional +std::optional PrivateMessage::unprotect(CipherSuite suite, GroupKeySource& keys, const bytes& sender_data_secret) const @@ -813,11 +830,11 @@ PrivateMessage::unprotect(CipherSuite suite, unmarshal_ciphertext_content(opt::get(content_pt), content, auth); - return AuthenticatedContent{ + return { { AuthenticatedContent{ WireFormat::mls_private_message, std::move(content), std::move(auth), - }; + } } }; } PrivateMessage::PrivateMessage(GroupContent content, @@ -832,6 +849,21 @@ PrivateMessage::PrivateMessage(GroupContent content, { } +bytes +MLSMessage::group_id() const +{ + return var::visit( + overloaded{ + [](const PublicMessage& pt) -> bytes { return pt.get_group_id(); }, + [](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); }, + [](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; }, + [](const auto& /* unused */) -> bytes { + throw InvalidParameterError("MLSMessage has no group_id"); + }, + }, + message); +} + epoch_t MLSMessage::epoch() const { diff --git a/src/state.cpp b/src/state.cpp index c45b6335..826503ba 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -407,15 +407,23 @@ State::protect(AuthenticatedContent&& content_auth, size_t padding_size) } } -AuthenticatedContent -State::unprotect_to_content_auth(const MLSMessage& msg) +ValidatedContent +State::unwrap(const MLSMessage& msg) { if (msg.version != ProtocolVersion::mls10) { throw InvalidParameterError("Unsupported version"); } const auto unprotect = overloaded{ - [&](const PublicMessage& pt) -> AuthenticatedContent { + [&](const PublicMessage& pt) -> ValidatedContent { + if (pt.get_group_id() != _group_id) { + throw ProtocolError("PublicMessage not for this group"); + } + + if (pt.get_epoch() != _epoch) { + throw ProtocolError("PublicMessage not for this epoch"); + } + auto maybe_content_auth = pt.unprotect(_suite, _key_schedule.membership_key, group_context()); if (!maybe_content_auth) { @@ -424,7 +432,15 @@ State::unprotect_to_content_auth(const MLSMessage& msg) return opt::get(maybe_content_auth); }, - [&](const PrivateMessage& ct) -> AuthenticatedContent { + [&](const PrivateMessage& ct) -> ValidatedContent { + if (ct.get_group_id() != _group_id) { + throw ProtocolError("PrivateMessage not for this group"); + } + + if (ct.get_epoch() != _epoch) { + throw ProtocolError("PrivateMessage not for this epoch"); + } + auto maybe_content_auth = ct.unprotect(_suite, _keys, _key_schedule.sender_data_secret); if (!maybe_content_auth) { @@ -433,12 +449,17 @@ State::unprotect_to_content_auth(const MLSMessage& msg) return opt::get(maybe_content_auth); }, - [](const auto& /* unused */) -> AuthenticatedContent { + [](const auto& /* unused */) -> ValidatedContent { throw ProtocolError("Invalid wire format"); }, }; - return var::visit(unprotect, msg.message); + auto val_content = var::visit(unprotect, msg.message); + if (!verify(val_content.content_auth)) { + throw InvalidParameterError("Message signature failed to verify"); + } + + return val_content; } Proposal @@ -783,13 +804,26 @@ State::group_context() const std::optional State::handle(const MLSMessage& msg) { - return handle(msg, std::nullopt, std::nullopt); + return handle(unwrap(msg), std::nullopt, std::nullopt); } std::optional State::handle(const MLSMessage& msg, std::optional cached_state) { - return handle(msg, std::move(cached_state), std::nullopt); + return handle(unwrap(msg), std::move(cached_state), std::nullopt); +} + +std::optional +State::handle(const ValidatedContent& content_auth) +{ + return handle(content_auth, std::nullopt, std::nullopt); +} + +std::optional +State::handle(const ValidatedContent& content_auth, + std::optional cached_state) +{ + return handle(content_auth, std::move(cached_state), std::nullopt); } std::optional @@ -797,30 +831,17 @@ State::handle(const MLSMessage& msg, std::optional cached_state, const std::optional& expected_params) { - auto content_auth = unprotect_to_content_auth(msg); - if (!verify(content_auth)) { - throw InvalidParameterError("Message signature failed to verify"); - } - - return handle(content_auth, std::move(cached_state), expected_params); + return handle(unwrap(msg), std::move(cached_state), expected_params); } std::optional -State::handle(const AuthenticatedContent& content_auth, +State::handle(const ValidatedContent& val_content, std::optional cached_state, const std::optional& expected_params) { - // Validate the GroupContent - const auto& content = content_auth.content; - if (content.group_id != _group_id) { - throw InvalidParameterError("GroupID mismatch"); - } - - if (content.epoch != _epoch) { - throw InvalidParameterError("Epoch mismatch"); - } - // Dispatch on content type + const auto& content_auth = val_content.authenticated_content(); + const auto& content = content_auth.content; switch (content.content_type()) { // Proposals get queued, do not result in a state transition case ContentType::proposal: @@ -1136,7 +1157,8 @@ State::Tombstone State::handle_reinit_commit(const MLSMessage& commit_msg) { // Verify the signature and process the commit - auto content_auth = unprotect_to_content_auth(commit_msg); + const auto val_content = unwrap(commit_msg); + const auto& content_auth = val_content.authenticated_content(); if (!verify(content_auth)) { throw InvalidParameterError("Message signature failed to verify"); } @@ -1417,7 +1439,8 @@ State::protect(const bytes& authenticated_data, std::tuple State::unprotect(const MLSMessage& ct) { - auto content_auth = unprotect_to_content_auth(ct); + const auto val_content = unwrap(ct); + const auto& content_auth = val_content.authenticated_content(); if (!verify(content_auth)) { throw InvalidParameterError("Message signature failed to verify"); @@ -1432,8 +1455,8 @@ State::unprotect(const MLSMessage& ct) } return { - std::move(content_auth.content.authenticated_data), - std::move(var::get(content_auth.content.content).data), + content_auth.content.authenticated_data, + var::get(content_auth.content.content).data, }; } diff --git a/test/messages.cpp b/test/messages.cpp index f3182dd6..a2f514fa 100644 --- a/test/messages.cpp +++ b/test/messages.cpp @@ -129,7 +129,9 @@ TEST_CASE_METHOD(MLSMessageTest, "PublicMessage Protect/Unprotect") auto pt = PublicMessage::protect( content_auth_original, suite, membership_key, context); - auto content_auth_unprotected = pt.unprotect(suite, membership_key, context); + auto val_auth_unprotected = pt.unprotect(suite, membership_key, context); + const auto& content_auth_unprotected = + opt::get(val_auth_unprotected).authenticated_content(); REQUIRE(content_auth_unprotected == content_auth_original); } @@ -145,7 +147,9 @@ TEST_CASE_METHOD(MLSMessageTest, "PrivateMessage Protect/Unprotect") auto ct = PrivateMessage::protect( content_auth_original, suite, keys, sender_data_secret, padding_size); - auto content_auth_unprotected = ct.unprotect(suite, keys, sender_data_secret); + auto val_auth_unprotected = ct.unprotect(suite, keys, sender_data_secret); + const auto& content_auth_unprotected = + opt::get(val_auth_unprotected).authenticated_content(); REQUIRE(content_auth_unprotected == content_auth_original); }