Skip to content

Commit

Permalink
Expose group ID and allow inspection (#412)
Browse files Browse the repository at this point in the history
* Expose GroupID on MLSMessage

* Provide handle() complementary to unwrap()

* Add GroupInfo branch; remove redundant checks

* clang-format

* ValidatedContent

* clang-format

* clang-tidy

* clang-tidy
  • Loading branch information
bifurcation committed Dec 16, 2023
1 parent 56f23eb commit 97bd790
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 56 deletions.
24 changes: 22 additions & 2 deletions include/mls/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,17 +570,35 @@ struct AuthenticatedContent
friend struct PrivateMessage;
};

struct ValidatedContent
{
const AuthenticatedContent& authenticated_content() const;

friend bool operator==(const ValidatedContent& lhs,
const ValidatedContent& rhs);

private:
AuthenticatedContent content_auth;

ValidatedContent(AuthenticatedContent content_auth_in);

friend struct PublicMessage;
friend struct PrivateMessage;
friend class State;
};

struct PublicMessage
{
PublicMessage() = default;

bytes get_group_id() const { return content.group_id; }
epoch_t get_epoch() const { return content.epoch; }

static PublicMessage protect(AuthenticatedContent content_auth,
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context);
std::optional<AuthenticatedContent> unprotect(
std::optional<ValidatedContent> unprotect(
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const;
Expand Down Expand Up @@ -611,14 +629,15 @@ struct PrivateMessage
{
PrivateMessage() = default;

bytes get_group_id() const { return group_id; }
epoch_t get_epoch() const { return epoch; }

static PrivateMessage protect(AuthenticatedContent content_auth,
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret,
size_t padding_size);
std::optional<AuthenticatedContent> unprotect(
std::optional<ValidatedContent> unprotect(
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const;
Expand Down Expand Up @@ -649,6 +668,7 @@ struct MLSMessage
var::variant<PublicMessage, PrivateMessage, Welcome, GroupInfo, KeyPackage>
message;

bytes group_id() const;
epoch_t epoch() const;
WireFormat wire_format() const;

Expand Down
16 changes: 12 additions & 4 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,16 @@ class State
const MessageOpts& msg_opts);

///
/// Generic handshake message handler
/// Generic handshake message handlers
///
std::optional<State> handle(const MLSMessage& msg);
std::optional<State> handle(const MLSMessage& msg,
std::optional<State> cached_state);

std::optional<State> handle(const ValidatedContent& content_auth);
std::optional<State> handle(const ValidatedContent& content_auth,
std::optional<State> cached_state);

///
/// PSK management
///
Expand Down Expand Up @@ -151,6 +156,11 @@ class State

bytes epoch_authenticator() const;

///
/// Unwrap messages so that applications can inspect them
///
ValidatedContent unwrap(const MLSMessage& msg);

///
/// Application encryption and decryption
///
Expand Down Expand Up @@ -318,7 +328,7 @@ class State
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);
std::optional<State> handle(
const AuthenticatedContent& content_auth,
const ValidatedContent& val_content,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);

Expand All @@ -334,8 +344,6 @@ class State
template<typename Inner>
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);

AuthenticatedContent unprotect_to_content_auth(const MLSMessage& msg);

// Apply the changes requested by various messages
LeafIndex apply(const Add& add);
void apply(LeafIndex target, const Update& update);
Expand Down
27 changes: 14 additions & 13 deletions lib/mls_vectors/src/mls_vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,25 +872,26 @@ MessageProtectionTestVector::protect_priv(
std::optional<GroupContent>
MessageProtectionTestVector::unprotect(const MLSMessage& message)
{
auto do_unprotect = overloaded{
[&](const PublicMessage& pt) {
return pt.unprotect(cipher_suite, membership_key, group_context());
},
[&](const PrivateMessage& ct) {
auto keys = group_keys();
return ct.unprotect(cipher_suite, keys, sender_data_secret);
},
[](const auto& /* other */) -> std::optional<AuthenticatedContent> {
return std::nullopt;
}
};
auto do_unprotect =
overloaded{ [&](const PublicMessage& pt) {
return pt.unprotect(
cipher_suite, membership_key, group_context());
},
[&](const PrivateMessage& ct) {
auto keys = group_keys();
return ct.unprotect(cipher_suite, keys, sender_data_secret);
},
[](const auto& /* other */) -> std::optional<ValidatedContent> {
return std::nullopt;
} };

auto maybe_auth_content = var::visit(do_unprotect, message.message);
if (!maybe_auth_content) {
return std::nullopt;
}

auto auth_content = opt::get(maybe_auth_content);
auto val_content = opt::get(maybe_auth_content);
const auto& auth_content = val_content.authenticated_content();
if (!auth_content.verify(cipher_suite, signature_pub, group_context())) {
return std::nullopt;
}
Expand Down
44 changes: 38 additions & 6 deletions src/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,23 @@ AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
{
}

const AuthenticatedContent&
ValidatedContent::authenticated_content() const
{
return content_auth;
}

ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in)
: content_auth(std::move(content_auth_in))
{
}

bool
operator==(const ValidatedContent& lhs, const ValidatedContent& rhs)
{
return lhs.content_auth == rhs.content_auth;
}

struct GroupContentTBS
{
WireFormat wire_format = WireFormat::reserved;
Expand Down Expand Up @@ -526,7 +543,7 @@ PublicMessage::protect(AuthenticatedContent content_auth,
return pt;
}

std::optional<AuthenticatedContent>
std::optional<ValidatedContent>
PublicMessage::unprotect(CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const
Expand All @@ -545,11 +562,11 @@ PublicMessage::unprotect(CipherSuite suite,
break;
}

return AuthenticatedContent{
return { { AuthenticatedContent{
WireFormat::mls_public_message,
content,
auth,
};
} } };
}

bool
Expand Down Expand Up @@ -756,7 +773,7 @@ PrivateMessage::protect(AuthenticatedContent content_auth,
};
}

std::optional<AuthenticatedContent>
std::optional<ValidatedContent>
PrivateMessage::unprotect(CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const
Expand Down Expand Up @@ -813,11 +830,11 @@ PrivateMessage::unprotect(CipherSuite suite,

unmarshal_ciphertext_content(opt::get(content_pt), content, auth);

return AuthenticatedContent{
return { { AuthenticatedContent{
WireFormat::mls_private_message,
std::move(content),
std::move(auth),
};
} } };
}

PrivateMessage::PrivateMessage(GroupContent content,
Expand All @@ -832,6 +849,21 @@ PrivateMessage::PrivateMessage(GroupContent content,
{
}

bytes
MLSMessage::group_id() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
[](const auto& /* unused */) -> bytes {
throw InvalidParameterError("MLSMessage has no group_id");
},
},
message);
}

epoch_t
MLSMessage::epoch() const
{
Expand Down
Loading

0 comments on commit 97bd790

Please sign in to comment.