Skip to content

Commit 97bd790

Browse files
authored
Expose group ID and allow inspection (#412)
* Expose GroupID on MLSMessage * Provide handle() complementary to unwrap() * Add GroupInfo branch; remove redundant checks * clang-format * ValidatedContent * clang-format * clang-tidy * clang-tidy
1 parent 56f23eb commit 97bd790

File tree

6 files changed

+144
-56
lines changed

6 files changed

+144
-56
lines changed

include/mls/messages.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,17 +570,35 @@ struct AuthenticatedContent
570570
friend struct PrivateMessage;
571571
};
572572

573+
struct ValidatedContent
574+
{
575+
const AuthenticatedContent& authenticated_content() const;
576+
577+
friend bool operator==(const ValidatedContent& lhs,
578+
const ValidatedContent& rhs);
579+
580+
private:
581+
AuthenticatedContent content_auth;
582+
583+
ValidatedContent(AuthenticatedContent content_auth_in);
584+
585+
friend struct PublicMessage;
586+
friend struct PrivateMessage;
587+
friend class State;
588+
};
589+
573590
struct PublicMessage
574591
{
575592
PublicMessage() = default;
576593

594+
bytes get_group_id() const { return content.group_id; }
577595
epoch_t get_epoch() const { return content.epoch; }
578596

579597
static PublicMessage protect(AuthenticatedContent content_auth,
580598
CipherSuite suite,
581599
const std::optional<bytes>& membership_key,
582600
const std::optional<GroupContext>& context);
583-
std::optional<AuthenticatedContent> unprotect(
601+
std::optional<ValidatedContent> unprotect(
584602
CipherSuite suite,
585603
const std::optional<bytes>& membership_key,
586604
const std::optional<GroupContext>& context) const;
@@ -611,14 +629,15 @@ struct PrivateMessage
611629
{
612630
PrivateMessage() = default;
613631

632+
bytes get_group_id() const { return group_id; }
614633
epoch_t get_epoch() const { return epoch; }
615634

616635
static PrivateMessage protect(AuthenticatedContent content_auth,
617636
CipherSuite suite,
618637
GroupKeySource& keys,
619638
const bytes& sender_data_secret,
620639
size_t padding_size);
621-
std::optional<AuthenticatedContent> unprotect(
640+
std::optional<ValidatedContent> unprotect(
622641
CipherSuite suite,
623642
GroupKeySource& keys,
624643
const bytes& sender_data_secret) const;
@@ -649,6 +668,7 @@ struct MLSMessage
649668
var::variant<PublicMessage, PrivateMessage, Welcome, GroupInfo, KeyPackage>
650669
message;
651670

671+
bytes group_id() const;
652672
epoch_t epoch() const;
653673
WireFormat wire_format() const;
654674

include/mls/state.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,16 @@ class State
117117
const MessageOpts& msg_opts);
118118

119119
///
120-
/// Generic handshake message handler
120+
/// Generic handshake message handlers
121121
///
122122
std::optional<State> handle(const MLSMessage& msg);
123123
std::optional<State> handle(const MLSMessage& msg,
124124
std::optional<State> cached_state);
125+
126+
std::optional<State> handle(const ValidatedContent& content_auth);
127+
std::optional<State> handle(const ValidatedContent& content_auth,
128+
std::optional<State> cached_state);
129+
125130
///
126131
/// PSK management
127132
///
@@ -151,6 +156,11 @@ class State
151156

152157
bytes epoch_authenticator() const;
153158

159+
///
160+
/// Unwrap messages so that applications can inspect them
161+
///
162+
ValidatedContent unwrap(const MLSMessage& msg);
163+
154164
///
155165
/// Application encryption and decryption
156166
///
@@ -318,7 +328,7 @@ class State
318328
std::optional<State> cached_state,
319329
const std::optional<CommitParams>& expected_params);
320330
std::optional<State> handle(
321-
const AuthenticatedContent& content_auth,
331+
const ValidatedContent& val_content,
322332
std::optional<State> cached_state,
323333
const std::optional<CommitParams>& expected_params);
324334

@@ -334,8 +344,6 @@ class State
334344
template<typename Inner>
335345
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);
336346

337-
AuthenticatedContent unprotect_to_content_auth(const MLSMessage& msg);
338-
339347
// Apply the changes requested by various messages
340348
LeafIndex apply(const Add& add);
341349
void apply(LeafIndex target, const Update& update);

lib/mls_vectors/src/mls_vectors.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -872,25 +872,26 @@ MessageProtectionTestVector::protect_priv(
872872
std::optional<GroupContent>
873873
MessageProtectionTestVector::unprotect(const MLSMessage& message)
874874
{
875-
auto do_unprotect = overloaded{
876-
[&](const PublicMessage& pt) {
877-
return pt.unprotect(cipher_suite, membership_key, group_context());
878-
},
879-
[&](const PrivateMessage& ct) {
880-
auto keys = group_keys();
881-
return ct.unprotect(cipher_suite, keys, sender_data_secret);
882-
},
883-
[](const auto& /* other */) -> std::optional<AuthenticatedContent> {
884-
return std::nullopt;
885-
}
886-
};
875+
auto do_unprotect =
876+
overloaded{ [&](const PublicMessage& pt) {
877+
return pt.unprotect(
878+
cipher_suite, membership_key, group_context());
879+
},
880+
[&](const PrivateMessage& ct) {
881+
auto keys = group_keys();
882+
return ct.unprotect(cipher_suite, keys, sender_data_secret);
883+
},
884+
[](const auto& /* other */) -> std::optional<ValidatedContent> {
885+
return std::nullopt;
886+
} };
887887

888888
auto maybe_auth_content = var::visit(do_unprotect, message.message);
889889
if (!maybe_auth_content) {
890890
return std::nullopt;
891891
}
892892

893-
auto auth_content = opt::get(maybe_auth_content);
893+
auto val_content = opt::get(maybe_auth_content);
894+
const auto& auth_content = val_content.authenticated_content();
894895
if (!auth_content.verify(cipher_suite, signature_pub, group_context())) {
895896
return std::nullopt;
896897
}

src/messages.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,23 @@ AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
464464
{
465465
}
466466

467+
const AuthenticatedContent&
468+
ValidatedContent::authenticated_content() const
469+
{
470+
return content_auth;
471+
}
472+
473+
ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in)
474+
: content_auth(std::move(content_auth_in))
475+
{
476+
}
477+
478+
bool
479+
operator==(const ValidatedContent& lhs, const ValidatedContent& rhs)
480+
{
481+
return lhs.content_auth == rhs.content_auth;
482+
}
483+
467484
struct GroupContentTBS
468485
{
469486
WireFormat wire_format = WireFormat::reserved;
@@ -526,7 +543,7 @@ PublicMessage::protect(AuthenticatedContent content_auth,
526543
return pt;
527544
}
528545

529-
std::optional<AuthenticatedContent>
546+
std::optional<ValidatedContent>
530547
PublicMessage::unprotect(CipherSuite suite,
531548
const std::optional<bytes>& membership_key,
532549
const std::optional<GroupContext>& context) const
@@ -545,11 +562,11 @@ PublicMessage::unprotect(CipherSuite suite,
545562
break;
546563
}
547564

548-
return AuthenticatedContent{
565+
return { { AuthenticatedContent{
549566
WireFormat::mls_public_message,
550567
content,
551568
auth,
552-
};
569+
} } };
553570
}
554571

555572
bool
@@ -756,7 +773,7 @@ PrivateMessage::protect(AuthenticatedContent content_auth,
756773
};
757774
}
758775

759-
std::optional<AuthenticatedContent>
776+
std::optional<ValidatedContent>
760777
PrivateMessage::unprotect(CipherSuite suite,
761778
GroupKeySource& keys,
762779
const bytes& sender_data_secret) const
@@ -813,11 +830,11 @@ PrivateMessage::unprotect(CipherSuite suite,
813830

814831
unmarshal_ciphertext_content(opt::get(content_pt), content, auth);
815832

816-
return AuthenticatedContent{
833+
return { { AuthenticatedContent{
817834
WireFormat::mls_private_message,
818835
std::move(content),
819836
std::move(auth),
820-
};
837+
} } };
821838
}
822839

823840
PrivateMessage::PrivateMessage(GroupContent content,
@@ -832,6 +849,21 @@ PrivateMessage::PrivateMessage(GroupContent content,
832849
{
833850
}
834851

852+
bytes
853+
MLSMessage::group_id() const
854+
{
855+
return var::visit(
856+
overloaded{
857+
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
858+
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
859+
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
860+
[](const auto& /* unused */) -> bytes {
861+
throw InvalidParameterError("MLSMessage has no group_id");
862+
},
863+
},
864+
message);
865+
}
866+
835867
epoch_t
836868
MLSMessage::epoch() const
837869
{

0 commit comments

Comments
 (0)