Skip to content

Commit

Permalink
Replace dummy message with direct wake to trigger endpoint events
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralith committed Dec 17, 2023
1 parent a1ac819 commit 10e0616
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 19 deletions.
32 changes: 21 additions & 11 deletions quinn/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
};

use crate::runtime::{AsyncTimer, AsyncUdpSocket, Runtime};
use atomic_waker::AtomicWaker;
use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project;
use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
Expand Down Expand Up @@ -40,6 +41,7 @@ impl Connecting {
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
endpoint_driver: Arc<AtomicWaker>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
socket: Arc<dyn AsyncUdpSocket>,
runtime: Arc<dyn Runtime>,
Expand All @@ -50,6 +52,7 @@ impl Connecting {
handle,
conn,
endpoint_events,
endpoint_driver,
conn_events,
on_handshake_data_send,
on_connected_send,
Expand Down Expand Up @@ -233,7 +236,7 @@ impl Future for ConnectionDriver {
// If a timer expires, there might be more to transmit. When we transmit something, we
// might need to reset a timer. Hence, we must loop until neither happens.
keep_going |= conn.drive_timer(cx);
conn.forward_endpoint_events();
conn.forward_endpoint_events(&self.0.shared);
conn.forward_app_events(&self.0.shared);

if !conn.inner.is_drained() {
Expand Down Expand Up @@ -759,6 +762,7 @@ impl ConnectionRef {
handle: ConnectionHandle,
conn: proto::Connection,
endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
endpoint_driver: Arc<AtomicWaker>,
conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
on_handshake_data: oneshot::Sender<()>,
on_connected: oneshot::Sender<bool>,
Expand Down Expand Up @@ -786,7 +790,13 @@ impl ConnectionRef {
socket,
runtime,
}),
shared: Shared::default(),
shared: Shared {
endpoint_driver,
stream_budget_available: Default::default(),
stream_incoming: Default::default(),
datagrams: Default::default(),
closed: Default::default(),
},
}))
}

Expand Down Expand Up @@ -831,7 +841,7 @@ pub(crate) struct ConnectionInner {
pub(crate) shared: Shared,
}

#[derive(Debug, Default)]
#[derive(Debug)]
pub(crate) struct Shared {
/// Notified when new streams may be locally initiated due to an increase in stream ID flow
/// control budget
Expand All @@ -840,6 +850,7 @@ pub(crate) struct Shared {
stream_incoming: [Notify; 2],
datagrams: Notify,
closed: Notify,
endpoint_driver: Arc<AtomicWaker>,
}

pub(crate) struct State {
Expand Down Expand Up @@ -898,16 +909,15 @@ impl State {
false
}

fn forward_endpoint_events(&mut self) {
fn forward_endpoint_events(&mut self, shared: &Shared) {
if self.inner.poll_endpoint_events() {
shared.endpoint_driver.wake();
}
if self.inner.is_drained() {
// If the endpoint driver is gone, noop.
let _ = self.endpoint_events.send((
self.handle,
match self.inner.is_drained() {
false => EndpointEvent::Proto,
true => EndpointEvent::Drained,
},
));
let _ = self
.endpoint_events
.send((self.handle, EndpointEvent::Drained));
}
}

Expand Down
32 changes: 25 additions & 7 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,13 @@ impl Endpoint {
.connect(Instant::now(), config, addr, server_name)?;

let socket = endpoint.socket.clone();
Ok(endpoint
.connections
.insert(ch, conn, socket, self.runtime.clone()))
Ok(endpoint.connections.insert(
ch,
conn,
socket,
self.runtime.clone(),
self.inner.shared.driver.clone(),
))
}

/// Switch to a new UDP socket
Expand Down Expand Up @@ -325,7 +329,7 @@ impl Future for EndpointDriver {

let now = Instant::now();
let mut keep_going = false;
keep_going |= endpoint.drive_recv(cx, now)?;
keep_going |= endpoint.drive_recv(cx, now, &self.0.shared)?;
keep_going |= endpoint.handle_events(cx, &self.0.shared);
keep_going |= endpoint.drive_send(cx)?;

Expand Down Expand Up @@ -393,7 +397,12 @@ pub(crate) struct Shared {
}

impl State {
fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
fn drive_recv<'a>(
&'a mut self,
cx: &mut Context,
now: Instant,
shared: &Shared,
) -> Result<bool, io::Error> {
self.recv_limiter.start_cycle();
let mut metas = [RecvMeta::default(); BATCH_SIZE];
let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit();
Expand Down Expand Up @@ -431,6 +440,7 @@ impl State {
conn,
self.socket.clone(),
self.runtime.clone(),
shared.driver.clone(),
);
self.incoming.push_back(conn);
}
Expand Down Expand Up @@ -530,7 +540,6 @@ impl State {
for _ in 0..IO_LOOP_BOUND {
match self.events.poll_recv(cx) {
Poll::Ready(Some((ch, event))) => match event {
Proto => {}
Drained => {
self.connections.senders.remove(&ch);
if self.connections.is_empty() {
Expand Down Expand Up @@ -617,6 +626,7 @@ impl ConnectionSet {
conn: proto::Connection,
socket: Arc<dyn AsyncUdpSocket>,
runtime: Arc<dyn Runtime>,
driver: Arc<AtomicWaker>,
) -> Connecting {
let (send, recv) = mpsc::unbounded_channel();
if let Some((error_code, ref reason)) = self.close {
Expand All @@ -627,7 +637,15 @@ impl ConnectionSet {
.unwrap();
}
self.senders.insert(handle, send);
Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
Connecting::new(
handle,
conn,
self.sender.clone(),
driver,
recv,
socket,
runtime,
)
}

fn is_empty(&self) -> bool {
Expand Down
1 change: 0 additions & 1 deletion quinn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ enum ConnectionEvent {

#[derive(Debug)]
enum EndpointEvent {
Proto,
Drained,
Transmit(proto::Transmit, Bytes),
}
Expand Down

0 comments on commit 10e0616

Please sign in to comment.