Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement light MLS #422

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/mls/core_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ struct ExtensionType
static constexpr Extension::Type external_pub = 4;
static constexpr Extension::Type external_senders = 5;

static constexpr Extension::Type flags = 6;
static constexpr Extension::Type membership_proof = 7;

// XXX(RLB) There is no IANA-registered type for this extension yet, so we use
// a value from the vendor-specific space
static constexpr Extension::Type sframe_parameters = 0xff02;
Expand Down
35 changes: 35 additions & 0 deletions include/mls/messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ struct RatchetTreeExtension
TLS_SERIALIZABLE(tree)
};

struct MembershipProofExtension
{
std::vector<TreeSlice> slices;

static const uint16_t type;
TLS_SERIALIZABLE(slices)
};

struct ExternalSender
{
SignaturePublicKey signature_key;
Expand All @@ -43,6 +51,20 @@ struct ExternalSendersExtension
TLS_SERIALIZABLE(senders);
};

struct FlagsExtension
{
std::vector<uint8_t> flag_data;

void set(size_t pos);
void unset(size_t pos);
bool get(size_t pos) const;

static const uint16_t type;

// XXX(RLB): This should check for extra zero bytes on deserialize.
TLS_SERIALIZABLE(flag_data);
};

struct SFrameParameters
{
uint16_t cipher_suite;
Expand Down Expand Up @@ -257,6 +279,15 @@ struct Welcome
const std::vector<PSKWithSecret>& psks);
};

struct LightCommit
{
GroupContext group_context;
bytes confirmation_tag;
TreeSlice sender_membership_proof;
std::optional<HPKECiphertext> encrypted_path_secret;
std::optional<NodeIndex> decryption_node_index;
};

///
/// Proposals & Commit
///
Expand Down Expand Up @@ -623,6 +654,10 @@ struct PublicMessage
bytes membership_mac(CipherSuite suite,
const bytes& membership_key,
const std::optional<GroupContext>& context) const;

// XXX(RLB) This is a hack to avoid unwrapping across epochs. We should do
// something more elegant, like unchecked_content()
friend class State;
};

struct PrivateMessage
Expand Down
22 changes: 20 additions & 2 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,20 @@ struct RosterIndex : public UInt32

struct CommitOpts
{
// Include these proposals in the commit by value
std::vector<Proposal> extra_proposals;
bool inline_tree;
bool force_path;

// Send a ratchet_tree extension in the Welcome
bool inline_tree = false;

// Send an UpdatePath even if none is required
bool force_path = false;

// Send a membership_proof extension in the Welcome covering the committer and
// the new joiners
bool membership_proof = false;

// Update the committer's LeafNode in the following way
LeafNodeOptions leaf_node_opts;
};

Expand Down Expand Up @@ -127,6 +138,12 @@ class State
std::optional<State> handle(const ValidatedContent& content_auth,
std::optional<State> cached_state);

///
/// Light MLS
///
LightCommit lighten_for(LeafIndex leaf, const MLSMessage& commit) const;
State handle(const LightCommit& light_commit) const;

///
/// PSK management
///
Expand All @@ -145,6 +162,7 @@ class State
const ExtensionList& extensions() const { return _extensions; }
const TreeKEMPublicKey& tree() const { return _tree; }
const bytes& resumption_psk() const { return _key_schedule.resumption_psk; }
bool is_full_client() const { return _tree.is_complete(); }

bytes do_export(const std::string& label,
const bytes& context,
Expand Down
4 changes: 2 additions & 2 deletions include/mls/tree_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ struct NodeIndex : public UInt32
// of `ancestor` that is not in the direct path of this node.
NodeIndex sibling(NodeIndex ancestor) const;

std::vector<NodeIndex> dirpath(LeafCount n);
std::vector<NodeIndex> copath(LeafCount n);
std::vector<NodeIndex> dirpath(LeafCount n) const;
std::vector<NodeIndex> copath(LeafCount n) const;

uint32_t level() const;
};
Expand Down
28 changes: 27 additions & 1 deletion include/mls/treekem.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ struct OptionalNode
TLS_SERIALIZABLE(node)
};

struct TreeSlice
{
LeafIndex leaf_index;
LeafCount n_leaves;
std::vector<OptionalNode> direct_path_nodes;
std::vector<bytes> copath_hashes;

bytes tree_hash(CipherSuite suite) const;
void add(const TreeSlice& other);

TLS_SERIALIZABLE(leaf_index, n_leaves, direct_path_nodes, copath_hashes);
};

struct TreeKEMPublicKey;

struct TreeKEMPrivateKey
Expand Down Expand Up @@ -107,15 +120,19 @@ struct TreeKEMPrivateKey
void implant(const TreeKEMPublicKey& pub,
NodeIndex start,
const bytes& path_secret);
void implant_matching(const TreeKEMPublicKey& pub,
NodeIndex start,
const bytes& path_secret);
};

struct TreeKEMPublicKey
{
CipherSuite suite;
LeafCount size{ 0 };
std::vector<OptionalNode> nodes;
std::map<NodeIndex, OptionalNode> nodes;

explicit TreeKEMPublicKey(CipherSuite suite);
TreeKEMPublicKey(CipherSuite suite, const TreeSlice& slice);

TreeKEMPublicKey() = default;
TreeKEMPublicKey(const TreeKEMPublicKey& other) = default;
Expand Down Expand Up @@ -144,12 +161,19 @@ struct TreeKEMPublicKey

bool parent_hash_valid(LeafIndex from, const UpdatePath& path) const;
bool parent_hash_valid() const;
bool is_complete() const;

bool has_leaf(LeafIndex index) const;
std::optional<LeafIndex> find(const LeafNode& leaf) const;
std::optional<LeafNode> leaf_node(LeafIndex index) const;
std::vector<NodeIndex> resolve(NodeIndex index) const;

TreeSlice extract_slice(LeafIndex leaf) const;
void implant_slice(const TreeSlice& slice);
std::tuple<HPKECiphertext, NodeIndex> slice_path(UpdatePath path,
LeafIndex from,
LeafIndex to) const;

template<typename UnaryPredicate>
bool all_leaves(const UnaryPredicate& pred) const
{
Expand Down Expand Up @@ -228,6 +252,8 @@ struct TreeKEMPublicKey
bool exists_in_tree(const SignaturePublicKey& key,
std::optional<LeafIndex> except) const;

void implant_slice_unchecked(const TreeSlice& slice);

OptionalNode blank_node;

friend struct TreeKEMPrivateKey;
Expand Down
59 changes: 59 additions & 0 deletions src/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,71 @@ namespace MLS_NAMESPACE {

const Extension::Type ExternalPubExtension::type = ExtensionType::external_pub;
const Extension::Type RatchetTreeExtension::type = ExtensionType::ratchet_tree;
const Extension::Type MembershipProofExtension::type =
ExtensionType::membership_proof;
const Extension::Type ExternalSendersExtension::type =
ExtensionType::external_senders;
const Extension::Type FlagsExtension::type = ExtensionType::flags;
const Extension::Type SFrameParameters::type = ExtensionType::sframe_parameters;
const Extension::Type SFrameCapabilities::type =
ExtensionType::sframe_parameters;

void
FlagsExtension::set(size_t pos)
{
const auto byte_pos = pos >> 3;
const auto bit_pos = pos & 0x07;

// Ensure space
if (byte_pos >= flag_data.size()) {
flag_data.resize(byte_pos + 1);
}

// Set the bit
flag_data.at(byte_pos) |= uint8_t(1 << bit_pos);
}

void
FlagsExtension::unset(size_t pos)
{
const auto byte_pos = pos >> 3;
const auto bit_pos = pos & 0x07;

if (byte_pos >= flag_data.size()) {
return;
}

// Unset the bit
flag_data.at(byte_pos) &= ~uint8_t(1 << bit_pos);

// Trim any zero bytes
auto cut = flag_data.size() - 1;
while (cut > 0 && flag_data.at(cut) == 0) {
cut -= 1;
}

if (flag_data.at(cut) == 0) {
flag_data.clear();
return;
}

flag_data.resize(cut + 1);
}

bool
FlagsExtension::get(size_t pos) const
{
const auto byte_pos = pos >> 3;
const auto bit_pos = pos & 0x07;

if (byte_pos >= flag_data.size()) {
return false;
}

const auto bit = (flag_data.at(byte_pos) >> bit_pos) & 0x01;
return bit == 1;
}

bool
SFrameCapabilities::compatible(const SFrameParameters& params) const
{
Expand Down
6 changes: 4 additions & 2 deletions src/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,10 @@ Session::commit()
{
auto commit_secret = inner->fresh_secret();
auto encrypt = inner->encrypt_handshake;
auto [commit, welcome, new_state] = inner->history.front().commit(
commit_secret, CommitOpts{ {}, true, encrypt, {} }, { encrypt, {}, 0 });
auto [commit, welcome, new_state] =
inner->history.front().commit(commit_secret,
CommitOpts{ {}, true, encrypt, false, {} },
{ encrypt, {}, 0 });

auto commit_msg = tls::marshal(commit);
auto welcome_msg = tls::marshal(welcome);
Expand Down
Loading
Loading