Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose group ID and allow inspection #412

Merged
merged 10 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions include/mls/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,17 +570,35 @@ struct AuthenticatedContent
friend struct PrivateMessage;
};

struct ValidatedContent
{
const AuthenticatedContent& authenticated_content() const;

friend bool operator==(const ValidatedContent& lhs,
const ValidatedContent& rhs);

private:
AuthenticatedContent content_auth;

ValidatedContent(AuthenticatedContent content_auth_in);

friend struct PublicMessage;
friend struct PrivateMessage;
friend class State;
};

struct PublicMessage
{
PublicMessage() = default;

bytes get_group_id() const { return content.group_id; }
epoch_t get_epoch() const { return content.epoch; }

static PublicMessage protect(AuthenticatedContent content_auth,
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context);
std::optional<AuthenticatedContent> unprotect(
std::optional<ValidatedContent> unprotect(
CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const;
Expand Down Expand Up @@ -611,14 +629,15 @@ struct PrivateMessage
{
PrivateMessage() = default;

bytes get_group_id() const { return group_id; }
epoch_t get_epoch() const { return epoch; }

static PrivateMessage protect(AuthenticatedContent content_auth,
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret,
size_t padding_size);
std::optional<AuthenticatedContent> unprotect(
std::optional<ValidatedContent> unprotect(
CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const;
Expand Down Expand Up @@ -649,6 +668,7 @@ struct MLSMessage
var::variant<PublicMessage, PrivateMessage, Welcome, GroupInfo, KeyPackage>
message;

bytes group_id() const;
epoch_t epoch() const;
WireFormat wire_format() const;

Expand Down
16 changes: 12 additions & 4 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,16 @@ class State
const MessageOpts& msg_opts);

///
/// Generic handshake message handler
/// Generic handshake message handlers
///
std::optional<State> handle(const MLSMessage& msg);
std::optional<State> handle(const MLSMessage& msg,
std::optional<State> cached_state);

std::optional<State> handle(const ValidatedContent& content_auth);
std::optional<State> handle(const ValidatedContent& content_auth,
std::optional<State> cached_state);

///
/// PSK management
///
Expand Down Expand Up @@ -151,6 +156,11 @@ class State

bytes epoch_authenticator() const;

///
/// Unwrap messages so that applications can inspect them
///
ValidatedContent unwrap(const MLSMessage& msg);

///
/// Application encryption and decryption
///
Expand Down Expand Up @@ -318,7 +328,7 @@ class State
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);
std::optional<State> handle(
const AuthenticatedContent& content_auth,
const ValidatedContent& val_content,
std::optional<State> cached_state,
const std::optional<CommitParams>& expected_params);

Expand All @@ -334,8 +344,6 @@ class State
template<typename Inner>
MLSMessage protect_full(Inner&& content, const MessageOpts& msg_opts);

AuthenticatedContent unprotect_to_content_auth(const MLSMessage& msg);

// Apply the changes requested by various messages
LeafIndex apply(const Add& add);
void apply(LeafIndex target, const Update& update);
Expand Down
27 changes: 14 additions & 13 deletions lib/mls_vectors/src/mls_vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,25 +872,26 @@ MessageProtectionTestVector::protect_priv(
std::optional<GroupContent>
MessageProtectionTestVector::unprotect(const MLSMessage& message)
{
auto do_unprotect = overloaded{
[&](const PublicMessage& pt) {
return pt.unprotect(cipher_suite, membership_key, group_context());
},
[&](const PrivateMessage& ct) {
auto keys = group_keys();
return ct.unprotect(cipher_suite, keys, sender_data_secret);
},
[](const auto& /* other */) -> std::optional<AuthenticatedContent> {
return std::nullopt;
}
};
auto do_unprotect =
overloaded{ [&](const PublicMessage& pt) {
return pt.unprotect(
cipher_suite, membership_key, group_context());
},
[&](const PrivateMessage& ct) {
auto keys = group_keys();
return ct.unprotect(cipher_suite, keys, sender_data_secret);
},
[](const auto& /* other */) -> std::optional<ValidatedContent> {
return std::nullopt;
} };

auto maybe_auth_content = var::visit(do_unprotect, message.message);
if (!maybe_auth_content) {
return std::nullopt;
}

auto auth_content = opt::get(maybe_auth_content);
auto val_content = opt::get(maybe_auth_content);
const auto& auth_content = val_content.authenticated_content();
if (!auth_content.verify(cipher_suite, signature_pub, group_context())) {
return std::nullopt;
}
Expand Down
44 changes: 38 additions & 6 deletions src/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,23 @@ AuthenticatedContent::AuthenticatedContent(WireFormat wire_format_in,
{
}

const AuthenticatedContent&
ValidatedContent::authenticated_content() const
{
return content_auth;
}

ValidatedContent::ValidatedContent(AuthenticatedContent content_auth_in)
: content_auth(std::move(content_auth_in))
{
}

bool
operator==(const ValidatedContent& lhs, const ValidatedContent& rhs)
{
return lhs.content_auth == rhs.content_auth;
}

struct GroupContentTBS
{
WireFormat wire_format = WireFormat::reserved;
Expand Down Expand Up @@ -526,7 +543,7 @@ PublicMessage::protect(AuthenticatedContent content_auth,
return pt;
}

std::optional<AuthenticatedContent>
std::optional<ValidatedContent>
PublicMessage::unprotect(CipherSuite suite,
const std::optional<bytes>& membership_key,
const std::optional<GroupContext>& context) const
Expand All @@ -545,11 +562,11 @@ PublicMessage::unprotect(CipherSuite suite,
break;
}

return AuthenticatedContent{
return { { AuthenticatedContent{
WireFormat::mls_public_message,
content,
auth,
};
} } };
}

bool
Expand Down Expand Up @@ -756,7 +773,7 @@ PrivateMessage::protect(AuthenticatedContent content_auth,
};
}

std::optional<AuthenticatedContent>
std::optional<ValidatedContent>
PrivateMessage::unprotect(CipherSuite suite,
GroupKeySource& keys,
const bytes& sender_data_secret) const
Expand Down Expand Up @@ -813,11 +830,11 @@ PrivateMessage::unprotect(CipherSuite suite,

unmarshal_ciphertext_content(opt::get(content_pt), content, auth);

return AuthenticatedContent{
return { { AuthenticatedContent{
WireFormat::mls_private_message,
std::move(content),
std::move(auth),
};
} } };
}

PrivateMessage::PrivateMessage(GroupContent content,
Expand All @@ -832,6 +849,21 @@ PrivateMessage::PrivateMessage(GroupContent content,
{
}

bytes
MLSMessage::group_id() const
{
return var::visit(
overloaded{
[](const PublicMessage& pt) -> bytes { return pt.get_group_id(); },
[](const PrivateMessage& ct) -> bytes { return ct.get_group_id(); },
[](const GroupInfo& gi) -> bytes { return gi.group_context.group_id; },
[](const auto& /* unused */) -> bytes {
throw InvalidParameterError("MLSMessage has no group_id");
},
},
message);
}

epoch_t
MLSMessage::epoch() const
{
Expand Down
Loading
Loading