From 10e06167a8a193d6024b57939170509d5a026026 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sun, 17 Dec 2023 14:49:31 -0800 Subject: [PATCH] Replace dummy message with direct wake to trigger endpoint events --- quinn/src/connection.rs | 32 +++++++++++++++++++++----------- quinn/src/endpoint.rs | 32 +++++++++++++++++++++++++------- quinn/src/lib.rs | 1 - 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/quinn/src/connection.rs b/quinn/src/connection.rs index d503a41ac..a102f5c7d 100644 --- a/quinn/src/connection.rs +++ b/quinn/src/connection.rs @@ -10,6 +10,7 @@ use std::{ }; use crate::runtime::{AsyncTimer, AsyncUdpSocket, Runtime}; +use atomic_waker::AtomicWaker; use bytes::{Bytes, BytesMut}; use pin_project_lite::pin_project; use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId}; @@ -40,6 +41,7 @@ impl Connecting { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + endpoint_driver: Arc, conn_events: mpsc::UnboundedReceiver, socket: Arc, runtime: Arc, @@ -50,6 +52,7 @@ impl Connecting { handle, conn, endpoint_events, + endpoint_driver, conn_events, on_handshake_data_send, on_connected_send, @@ -233,7 +236,7 @@ impl Future for ConnectionDriver { // If a timer expires, there might be more to transmit. When we transmit something, we // might need to reset a timer. Hence, we must loop until neither happens. keep_going |= conn.drive_timer(cx); - conn.forward_endpoint_events(); + conn.forward_endpoint_events(&self.0.shared); conn.forward_app_events(&self.0.shared); if !conn.inner.is_drained() { @@ -759,6 +762,7 @@ impl ConnectionRef { handle: ConnectionHandle, conn: proto::Connection, endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, + endpoint_driver: Arc, conn_events: mpsc::UnboundedReceiver, on_handshake_data: oneshot::Sender<()>, on_connected: oneshot::Sender, @@ -786,7 +790,13 @@ impl ConnectionRef { socket, runtime, }), - shared: Shared::default(), + shared: Shared { + endpoint_driver, + stream_budget_available: Default::default(), + stream_incoming: Default::default(), + datagrams: Default::default(), + closed: Default::default(), + }, })) } @@ -831,7 +841,7 @@ pub(crate) struct ConnectionInner { pub(crate) shared: Shared, } -#[derive(Debug, Default)] +#[derive(Debug)] pub(crate) struct Shared { /// Notified when new streams may be locally initiated due to an increase in stream ID flow /// control budget @@ -840,6 +850,7 @@ pub(crate) struct Shared { stream_incoming: [Notify; 2], datagrams: Notify, closed: Notify, + endpoint_driver: Arc, } pub(crate) struct State { @@ -898,16 +909,15 @@ impl State { false } - fn forward_endpoint_events(&mut self) { + fn forward_endpoint_events(&mut self, shared: &Shared) { if self.inner.poll_endpoint_events() { + shared.endpoint_driver.wake(); + } + if self.inner.is_drained() { // If the endpoint driver is gone, noop. - let _ = self.endpoint_events.send(( - self.handle, - match self.inner.is_drained() { - false => EndpointEvent::Proto, - true => EndpointEvent::Drained, - }, - )); + let _ = self + .endpoint_events + .send((self.handle, EndpointEvent::Drained)); } } diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index f4ffca545..77a39fa47 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -198,9 +198,13 @@ impl Endpoint { .connect(Instant::now(), config, addr, server_name)?; let socket = endpoint.socket.clone(); - Ok(endpoint - .connections - .insert(ch, conn, socket, self.runtime.clone())) + Ok(endpoint.connections.insert( + ch, + conn, + socket, + self.runtime.clone(), + self.inner.shared.driver.clone(), + )) } /// Switch to a new UDP socket @@ -325,7 +329,7 @@ impl Future for EndpointDriver { let now = Instant::now(); let mut keep_going = false; - keep_going |= endpoint.drive_recv(cx, now)?; + keep_going |= endpoint.drive_recv(cx, now, &self.0.shared)?; keep_going |= endpoint.handle_events(cx, &self.0.shared); keep_going |= endpoint.drive_send(cx)?; @@ -393,7 +397,12 @@ pub(crate) struct Shared { } impl State { - fn drive_recv<'a>(&'a mut self, cx: &mut Context, now: Instant) -> Result { + fn drive_recv<'a>( + &'a mut self, + cx: &mut Context, + now: Instant, + shared: &Shared, + ) -> Result { self.recv_limiter.start_cycle(); let mut metas = [RecvMeta::default(); BATCH_SIZE]; let mut iovs = MaybeUninit::<[IoSliceMut<'a>; BATCH_SIZE]>::uninit(); @@ -431,6 +440,7 @@ impl State { conn, self.socket.clone(), self.runtime.clone(), + shared.driver.clone(), ); self.incoming.push_back(conn); } @@ -530,7 +540,6 @@ impl State { for _ in 0..IO_LOOP_BOUND { match self.events.poll_recv(cx) { Poll::Ready(Some((ch, event))) => match event { - Proto => {} Drained => { self.connections.senders.remove(&ch); if self.connections.is_empty() { @@ -617,6 +626,7 @@ impl ConnectionSet { conn: proto::Connection, socket: Arc, runtime: Arc, + driver: Arc, ) -> Connecting { let (send, recv) = mpsc::unbounded_channel(); if let Some((error_code, ref reason)) = self.close { @@ -627,7 +637,15 @@ impl ConnectionSet { .unwrap(); } self.senders.insert(handle, send); - Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime) + Connecting::new( + handle, + conn, + self.sender.clone(), + driver, + recv, + socket, + runtime, + ) } fn is_empty(&self) -> bool { diff --git a/quinn/src/lib.rs b/quinn/src/lib.rs index 974094862..e0b29d3ff 100644 --- a/quinn/src/lib.rs +++ b/quinn/src/lib.rs @@ -98,7 +98,6 @@ enum ConnectionEvent { #[derive(Debug)] enum EndpointEvent { - Proto, Drained, Transmit(proto::Transmit, Bytes), }