Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose group ID and allow inspection #412

Merged
merged 10 commits into from
Dec 16, 2023
Merged
3 changes: 3 additions & 0 deletions include/mls/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ struct PublicMessage
{
PublicMessage() = default;

bytes get_group_id() const { return content.group_id; }
epoch_t get_epoch() const { return content.epoch; }

static PublicMessage protect(AuthenticatedContent content_auth,
Expand Down Expand Up @@ -611,6 +612,7 @@ struct PrivateMessage
{
PrivateMessage() = default;

bytes get_group_id() const { return group_id; }
epoch_t get_epoch() const { return epoch; }

static PrivateMessage protect(AuthenticatedContent content_auth,
Expand Down Expand Up @@ -649,6 +651,7 @@ struct MLSMessage
var::variant<PublicMessage, PrivateMessage, Welcome, GroupInfo, KeyPackage>
message;

bytes group_id() const;
epoch_t epoch() const;
WireFormat wire_format() const;

Expand Down
15 changes: 13 additions & 2 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ class State
std::optional<State> handle(const MLSMessage& msg);
std::optional<State> handle(const MLSMessage& msg,
std::optional<State> cached_state);

// In general, you should avoid these methods and prefer the MLSMessage
// variants. They are provided for cases where a message recipient wishes to
// unwrap and inspect a message before handling it.
std::optional<State> handle(const AuthenticatedContent& content_auth);
std::optional<State> handle(const AuthenticatedContent& content_auth,
std::optional<State> cached_state);

///
/// PSK management
///
Expand Down Expand Up @@ -151,6 +159,11 @@ class State

bytes epoch_authenticator() const;

///
/// Unwrap messages so that applications can inspect them
///
AuthenticatedContent unwrap(const MLSMessage& msg);

///
/// Application encryption and decryption
///
Expand Down Expand Up @@ -334,8 +347,6 @@ class State
template<typename Inner>
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);

AuthenticatedContent unprotect_to_content_auth(const MLSMessage& msg);

// Apply the changes requested by various messages
LeafIndex apply(const Add& add);
void apply(LeafIndex target, const Update& update);
Expand Down
15 changes: 15 additions & 0 deletions src/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,21 @@ PrivateMessage::PrivateMessage(GroupContent content,
{
}

bytes
MLSMessage::group_id() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
[](const auto& /* unused */) -> bytes {
throw InvalidParameterError("MLSMessage has no group_id");
},
},
message);
}

epoch_t
MLSMessage::epoch() const
{
Expand Down
66 changes: 45 additions & 21 deletions src/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,14 +408,22 @@ State::protect(AuthenticatedContent&& content_auth, size_t padding_size)
}

AuthenticatedContent
State::unprotect_to_content_auth(const MLSMessage& msg)
State::unwrap(const MLSMessage& msg)
{
if (msg.version != ProtocolVersion::mls10) {
throw InvalidParameterError("Unsupported version");
}

const auto unprotect = overloaded{
[&](const PublicMessage& pt) -> AuthenticatedContent {
if (pt.get_group_id() != _group_id) {
throw ProtocolError("PublicMessage not for this group");
}

if (pt.get_epoch() != _epoch) {
throw ProtocolError("PublicMessage not for this epoch");
}

auto maybe_content_auth =
pt.unprotect(_suite, _key_schedule.membership_key, group_context());
if (!maybe_content_auth) {
Expand All @@ -425,6 +433,14 @@ State::unprotect_to_content_auth(const MLSMessage& msg)
},

[&](const PrivateMessage& ct) -> AuthenticatedContent {
if (ct.get_group_id() != _group_id) {
throw ProtocolError("PrivateMessage not for this group");
}

if (ct.get_epoch() != _epoch) {
throw ProtocolError("PrivateMessage not for this epoch");
}

auto maybe_content_auth =
ct.unprotect(_suite, _keys, _key_schedule.sender_data_secret);
if (!maybe_content_auth) {
Expand All @@ -438,7 +454,12 @@ State::unprotect_to_content_auth(const MLSMessage& msg)
},
};

return var::visit(unprotect, msg.message);
const auto content_auth = var::visit(unprotect, msg.message);
if (!verify(content_auth)) {
throw InvalidParameterError("Message signature failed to verify");
}

return content_auth;
}

Proposal
Expand Down Expand Up @@ -783,44 +804,47 @@ State::group_context() const
std::optional<State>
State::handle(const MLSMessage& msg)
{
return handle(msg, std::nullopt, std::nullopt);
return handle(unwrap(msg), std::nullopt, std::nullopt);
}

std::optional<State>
State::handle(const MLSMessage& msg, std::optional<State> cached_state)
{
return handle(msg, std::move(cached_state), std::nullopt);
return handle(unwrap(msg), std::move(cached_state), std::nullopt);
}

std::optional<State>
State::handle(const AuthenticatedContent& content_auth)
{
return handle(content_auth, std::nullopt, std::nullopt);
}

std::optional<State>
State::handle(const AuthenticatedContent& content_auth,
std::optional<State> cached_state)
{
return handle(content_auth, std::move(cached_state), std::nullopt);
}

std::optional<State>
State::handle(const MLSMessage& msg,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params)
{
auto content_auth = unprotect_to_content_auth(msg);
if (!verify(content_auth)) {
throw InvalidParameterError("Message signature failed to verify");
}

return handle(content_auth, std::move(cached_state), expected_params);
return handle(unwrap(msg), std::move(cached_state), expected_params);
}

std::optional<State>
State::handle(const AuthenticatedContent& content_auth,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params)
{
// Validate the GroupContent
const auto& content = content_auth.content;
if (content.group_id != _group_id) {
throw InvalidParameterError("GroupID mismatch");
}

if (content.epoch != _epoch) {
throw InvalidParameterError("Epoch mismatch");
}
// XXX(RLB): We assume that the AuthenticatedContent has come to us by way of
// `unwrap()`, so that its authenticity has already been checked. This avoids
// duplicate signature verification.

// Dispatch on content type
const auto& content = content_auth.content;
switch (content.content_type()) {
// Proposals get queued, do not result in a state transition
case ContentType::proposal:
Expand Down Expand Up @@ -1136,7 +1160,7 @@ State::Tombstone
State::handle_reinit_commit(const MLSMessage& commit_msg)
{
// Verify the signature and process the commit
auto content_auth = unprotect_to_content_auth(commit_msg);
auto content_auth = unwrap(commit_msg);
if (!verify(content_auth)) {
throw InvalidParameterError("Message signature failed to verify");
}
Expand Down Expand Up @@ -1417,7 +1441,7 @@ State::protect(const bytes& authenticated_data,
std::tuple<bytes, bytes>
State::unprotect(const MLSMessage& ct)
{
auto content_auth = unprotect_to_content_auth(ct);
auto content_auth = unwrap(ct);

if (!verify(content_auth)) {
throw InvalidParameterError("Message signature failed to verify");
Expand Down
Loading