Skip to content

Commit

Permalink
stream revert on early data rejected, instead of close
Browse files Browse the repository at this point in the history
  • Loading branch information
dr7ana committed Sep 11, 2024
1 parent 060c58b commit 79bc67b
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 67 deletions.
12 changes: 9 additions & 3 deletions include/oxen/quic/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <map>
#include <memory>
#include <optional>
#include <set>
#include <stdexcept>

#include "connection_ids.hpp"
Expand Down Expand Up @@ -236,6 +237,7 @@ namespace oxen::quic
{
friend class TestHelper;
friend struct rotating_buffer;
friend struct connection_callbacks;

public:
// Non-movable/non-copyable; you must always hold a Connection in a shared_ptr
Expand Down Expand Up @@ -336,8 +338,6 @@ namespace oxen::quic
connection_established_callback conn_established_cb;
connection_closed_callback conn_closed_cb;

void early_data_rejected();

void set_remote_addr(const ngtcp2_addr& new_remote);

void store_associated_cid(const quic_cid& cid);
Expand Down Expand Up @@ -418,6 +418,12 @@ namespace oxen::quic

ustring remote_pubkey;

std::set<int64_t> _early_streams;

void make_early_streams(ngtcp2_conn* connptr);

void revert_early_streams();

struct connection_deleter
{
inline void operator()(ngtcp2_conn* c) const { ngtcp2_conn_del(c); }
Expand Down Expand Up @@ -503,7 +509,7 @@ namespace oxen::quic
void stream_execute_close(Stream& s, uint64_t app_code);
void stream_closed(int64_t id, uint64_t app_code);
void close_all_streams();
void check_pending_streams(uint64_t available);
void check_pending_streams(uint64_t available, bool is_early_stream = false);
int recv_datagram(bstring_view data, bool fin);
int ack_datagram(uint64_t dgram_id);
int recv_token(const uint8_t* token, size_t tokenlen);
Expand Down
5 changes: 3 additions & 2 deletions include/oxen/quic/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,10 @@ namespace oxen::quic
std::vector<ustring> inbound_alpns;
std::chrono::nanoseconds handshake_timeout{DEFAULT_HANDSHAKE_TIMEOUT};

std::unordered_map<ustring_view, gtls_session_ticket> session_tickets;
std::unordered_map<ustring_view, gtls_ticket_ptr> _session_tickets;
std::unordered_map<ustring_view, gtls_ticket_ptr> session_tickets;

std::unordered_map<ustring, ustring> encoded_transport_params;

std::unordered_map<ustring, ustring> path_validation_tokens;

const std::shared_ptr<event_base>& get_loop() { return net._loop->loop(); }
Expand Down
3 changes: 0 additions & 3 deletions include/oxen/quic/gnutls_crypto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,6 @@ namespace oxen::quic

struct gtls_session_ticket
{
friend class Endpoint;
friend class Connection;

std::vector<unsigned char> _key;
std::vector<unsigned char> _ticket;
gnutls_datum_t _data;
Expand Down
3 changes: 3 additions & 0 deletions include/oxen/quic/stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ namespace oxen::quic
size_t unsent_impl() const override;

private:
// Called if 0-RTT early data was rejected; marks all sent data as unacked
void revert_stream();

std::vector<ngtcp2_vec> pending() override;

size_t _unacked_size{0};
Expand Down
4 changes: 2 additions & 2 deletions src/btstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ namespace oxen::quic

if (auto type = msg.type(); type == message::TYPE_REPLY || type == message::TYPE_ERROR)
{
log::trace(log_cat, "Looking for request with req_id={}", msg.req_id);
log::debug(bp_cat, "Looking for request with req_id={}", msg.req_id);
// Iterate using forward iterators, s.t. we go highest (newest) rids to lowest (oldest) rids.
// As a result, our comparator checks if the sent request ID is greater thanthan the target rid
auto itr = std::lower_bound(
Expand All @@ -149,7 +149,7 @@ namespace oxen::quic

if (itr != sent_reqs.end())
{
log::debug(bp_cat, "Successfully matched response to sent request!");
log::debug(bp_cat, "Successfully matched response (req_id={}) to sent request!", msg.req_id);
auto req = std::move(*itr);
sent_reqs.erase(itr);
try
Expand Down
108 changes: 66 additions & 42 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ namespace oxen::quic
auto& conn = *static_cast<Connection*>(user_data);
assert(_conn == conn);

log::warning(log_cat, "Client dropping connection; early data rejected by server!");
conn.endpoint().drop_connection(conn, io_error{CONN_EARLY_DATA_REJECTED});
log::info(log_cat, "Client resetting early streams; 0-rtt rejected by server!");
conn.revert_early_streams();

return 0;
}
Expand Down Expand Up @@ -383,36 +383,32 @@ namespace oxen::quic

int Connection::client_handshake_completed()
{
/** TODO:
This section will be uncommented and finished upon completion of 0RTT and session resumption capabilities.
- If early data is NOT ACCEPTED, then the call to ngtcp2_conn_tls_early_data_rejected must be invoked
to reset aspects of connection state prior to early data rejection.
- If early data is ACCEPTED, then we can open streams and start doing things immediately. At that point,
we should encode and store 0RTT transport parameters.
Moreover, decoding and setting 0RTT transport parameters must be handled in connection creation. Both that
location and the required callbacks are comment-blocked in the relevant location.
*/
if (_0rtt_enabled and not tls_session->get_early_data_accepted())
if (_0rtt_enabled)
{
log::info(log_cat, "Early data was rejected by server!");

if (auto rv = ngtcp2_conn_tls_early_data_rejected(conn.get()); rv != 0)
if (not tls_session->get_early_data_accepted())
{
log::error(log_cat, "ngtcp2_conn_tls_early_data_rejected: {}", ngtcp2_strerror(rv));
return -1;
log::info(log_cat, "Early data was rejected by server!");

if (auto rv = ngtcp2_conn_tls_early_data_rejected(conn.get()); rv != 0)
{
log::error(log_cat, "ngtcp2_conn_tls_early_data_rejected failed: {}", ngtcp2_strerror(rv));
return -1;
}
}
}

ustring data;
data.resize(256);
ustring data;
data.resize(256);

if (auto len = ngtcp2_conn_encode_0rtt_transport_params(conn.get(), data.data(), data.size()); len > 0)
{
_endpoint.store_0rtt_transport_params(remote_pubkey, std::move(data));
log::info(log_cat, "Client encoded and stored 0rtt transport params");
if (auto len = ngtcp2_conn_encode_0rtt_transport_params(conn.get(), data.data(), data.size()); len > 0)
{
_endpoint.store_0rtt_transport_params(remote_pubkey, std::move(data));
log::info(log_cat, "Client successfully encoded and stored 0rtt transport params");
}
else
log::warning(log_cat, "Client could not encode 0-RTT transport parameters: {}", ngtcp2_strerror(len));
}
else
log::warning(log_cat, "Client could not encode 0-RTT transport parameters: {}", ngtcp2_strerror(len));

log::debug(log_cat, "Client handshake completed!");

return 0;
}
Expand Down Expand Up @@ -446,6 +442,8 @@ namespace oxen::quic
return -1;
}

log::debug(log_cat, "Server successfully submitted regular token on handshake completion...");

return 0;
}

Expand All @@ -462,11 +460,6 @@ namespace oxen::quic
return datagrams->recv_buffer.last_cleared;
}

void Connection::early_data_rejected()
{
close_connection();
}

void Connection::set_remote_addr(const ngtcp2_addr& new_remote)
{
_endpoint.call([this, new_remote]() { _path.set_new_remote(new_remote); });
Expand Down Expand Up @@ -524,6 +517,36 @@ namespace oxen::quic
_endpoint.close_connection(*this, io_error{error_code});
}

void Connection::make_early_streams(ngtcp2_conn* connptr)
{
log::debug(log_cat, "Client making streams to attempt 0-rtt early data!");

if (auto remaining = ngtcp2_conn_get_streams_bidi_left(connptr); remaining > 0)
{
log::debug(log_cat, "Client has room to promote {} streams for early data!", remaining);
check_pending_streams(remaining, true);
}
}

void Connection::revert_early_streams()
{
_endpoint.call([&]() {
log::debug(log_cat, "Client reverting {} 0-rtt streams", _early_streams.size());

for (auto& _id : _early_streams)
{
if (auto it = _streams.find(_id); it != _streams.end())
{
log::trace(log_cat, "Reverting stream (ID:{})...", _id);
it->second->revert_stream();
}
else
log::warning(log_cat, "Could not find early stream (ID:{}) to revert!", _id);
}
_early_streams.clear();
});
}

void Connection::handle_conn_packet(const Packet& pkt)
{
if (auto rv = ngtcp2_conn_in_closing_period(*this); rv != 0)
Expand Down Expand Up @@ -622,7 +645,7 @@ namespace oxen::quic
// so, we move them to the streams map, where they will get picked up by flush_streams and dump
// their buffers. If none are ready, we keep chugging along and make another stream as usual. Though
// if none of the pending streams are ready, the new stream really shouldn't be ready, but here we are
void Connection::check_pending_streams(uint64_t available)
void Connection::check_pending_streams(uint64_t available, bool is_early_stream)
{
log::trace(log_cat, "{} called", __PRETTY_FUNCTION__);
uint64_t popped = 0;
Expand All @@ -633,11 +656,15 @@ namespace oxen::quic

if (int rv = ngtcp2_conn_open_bidi_stream(conn.get(), &str->_stream_id, str.get()); rv == 0)
{
log::debug(log_cat, "Stream [ID:{}] ready for broadcast, moving out of pending streams", str->_stream_id);
auto _id = str->_stream_id;
log::debug(log_cat, "Stream [ID:{}] ready for broadcast, moving out of pending streams", _id);
str->set_ready();
popped += 1;
_streams[str->_stream_id] = std::move(str);
_streams[_id] = std::move(str);
pending_streams.pop_front();

if (_0rtt_enabled and is_early_stream)
_early_streams.emplace_hint(_early_streams.end(), _id);
}
else
return;
Expand Down Expand Up @@ -1638,7 +1665,9 @@ namespace oxen::quic
callbacks.handshake_confirmed = connection_callbacks::on_handshake_confirmed;
callbacks.recv_retry = ngtcp2_crypto_recv_retry_cb;
callbacks.recv_new_token = connection_callbacks::on_recv_token;
callbacks.tls_early_data_rejected = connection_callbacks::on_early_data_rejected;

if (_0rtt_enabled)
callbacks.tls_early_data_rejected = connection_callbacks::on_early_data_rejected;

// Clients should be the ones providing a remote pubkey here. This way we can emplace it into
// the gnutlssession object to be verified. Servers should be verifying via callback
Expand Down Expand Up @@ -1694,13 +1723,8 @@ namespace oxen::quic
}
else
{
if (connection_callbacks::extend_max_local_streams_bidi(nullptr, 0, this) != 0)
{
log::warning(
log_cat,
"Client failed open streams to send early data; disabling 0rtt and proceeding...");
_0rtt_enabled = false;
}
make_early_streams(connptr);
log::info(log_cat, "Client encoded and set 0rtt params, ready to attempt early data!");
}
}
}
Expand Down
20 changes: 9 additions & 11 deletions src/endpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ namespace oxen::quic
_validate_0rtt_ticket = rtt._check ? std::move(rtt._check) : [this](gtls_ticket_ptr ticket, time_t current) -> bool {
auto key = ticket->key();

// auto [it, b] = _session_tickets.try_emplace(key, ticket);

if (auto it = _session_tickets.find(key); it != _session_tickets.end())
if (auto it = session_tickets.find(key); it != session_tickets.end())
{
if (auto exp = gnutls_db_check_entry_expire_time(*it->second); current < exp)
{
Expand All @@ -107,16 +105,16 @@ namespace oxen::quic
log::debug(log_cat, "Found expired anti-replay ticket for incoming connection");
}

_session_tickets[std::move(key)] = std::move(ticket);
session_tickets[std::move(key)] = std::move(ticket);
return 0;
};

_get_session_ticket = rtt._fetch ? std::move(rtt._fetch) : [this](ustring_view key) -> gtls_ticket_ptr {
gtls_ticket_ptr ret = nullptr;
if (auto it = _session_tickets.find(key); it != _session_tickets.end())
if (auto it = session_tickets.find(key); it != session_tickets.end())
{
ret = std::move(it->second);
_session_tickets.erase(it);
session_tickets.erase(it);
log::debug(log_cat, "Found session ticket for remote; entry extracted and returned...");
}
else
Expand All @@ -127,7 +125,7 @@ namespace oxen::quic

_put_session_ticket = rtt._put ? std::move(rtt._put) : [this](gtls_ticket_ptr ticket, time_t /* exp */) {
auto key = ticket->key();
auto [_, b] = _session_tickets.insert_or_assign(std::move(key), std::move(ticket));
auto [_, b] = session_tickets.insert_or_assign(std::move(key), std::move(ticket));

log::debug(
log_cat, "Stored anti-replay ticket for connection to remote{}!", b ? "" : "; old ticket overwritten");
Expand Down Expand Up @@ -501,25 +499,25 @@ namespace oxen::quic

void Endpoint::store_session_ticket(gtls_session_ticket ticket)
{
log::trace(log_cat, "Storing session ticket...");
return _put_session_ticket(gtls_session_ticket::make(std::move(ticket)), 0);
auto key = ticket.key();
auto [_, b] = session_tickets.insert_or_assign(std::move(key), std::move(ticket));

log::debug(log_cat, "Stored anti-replay ticket for connection to remote{}!", b ? "" : "; old ticket overwritten");
}

gtls_ticket_ptr Endpoint::get_session_ticket(const ustring_view& remote_pk)
{
log::trace(log_cat, "Fetching session ticket (remote key: {})...", buffer_printer{remote_pk});
return _get_session_ticket(remote_pk);
}

void Endpoint::store_0rtt_transport_params(ustring remote_pk, ustring encoded_params)
{
log::trace(log_cat, "Storing 0rtt tranpsport params...");
encoded_transport_params.insert_or_assign(std::move(remote_pk), std::move(encoded_params));
}

std::optional<ustring> Endpoint::get_0rtt_transport_params(const ustring& remote_pk)
{
log::trace(log_cat, "Fetching 0rtt transport params...");
if (auto itr = encoded_transport_params.find(remote_pk); itr != encoded_transport_params.end())
return itr->second;

Expand Down
1 change: 1 addition & 0 deletions src/gnutls_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace oxen::quic
{
if (htype == GNUTLS_HANDSHAKE_NEW_SESSION_TICKET)
{
log::debug(log_cat, "Client received new session ticket from server!");
auto* conn = get_connection_from_gnutls(session);
auto remote_key = conn->remote_key();
auto& ep = conn->endpoint();
Expand Down
8 changes: 8 additions & 0 deletions src/stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,14 @@ namespace oxen::quic
return std::make_pair(std::move(it), offset);
}

void Stream::revert_stream()
{
assert(endpoint.in_event_loop());
log::trace(log_cat, "Stream (ID:{}) reverting after early data rejected...", _stream_id);
_unacked_size = 0;
log::debug(log_cat, "Stream (ID:{}) has {}B in buffer, 0B unacacked...", _stream_id, size());
}

std::vector<ngtcp2_vec> Stream::pending()
{
log::trace(log_cat, "{} called", __PRETTY_FUNCTION__);
Expand Down
8 changes: 4 additions & 4 deletions tests/010-migration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ namespace oxen::quic::test
}

// Uncomment this when NGTCP2 releases v1.2.0
// SECTION("Immediate migration")
// {
// TestHelper::migrate_connection_immediate(conn, client_secondary);
// }
SECTION("Immediate migration")
{
TestHelper::migrate_connection_immediate(conn, client_secondary);
}

address_flipped = true;
conn_promise_a.set_value();
Expand Down

0 comments on commit 79bc67b

Please sign in to comment.