diff --git a/src/pool.rs b/src/pool.rs index 7a09a3e..0a58f1d 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -281,29 +281,29 @@ impl PoolInner { } }; + // Cancel safety: All branches of this select! statement are + // cancel-safe (mpsc::Receiver::recv, tokio::time::sleep_until, + // monitoring the tokio::sync::watch::Receivers) + // + // Futurelock safety: All select arms are queried concurrently. No + // awaiting happens outside this concurrent polling. tokio::select! { // Handle requests from clients request = self.rx.recv() => { match request { Some(Request::Claim { id, tx }) => { - self.claim_or_enqueue(id, tx).await + self.claim_or_enqueue(id, tx) } // The caller has explicitly asked us to terminate, and // we should respond to them once we've stopped doing // work. - Some(Request::Terminate) => { - self.terminate().await; - return; - }, + Some(Request::Terminate) => break, // The caller has abandoned their connection to the pool. // // We stop handling new requests, but have no one to // notify. Given that the caller no longer needs the // pool, we choose to terminate to avoid leaks. - None => { - self.terminate().await; - return; - } + None => break, } } // Timeout old requests from clients @@ -326,28 +326,36 @@ impl PoolInner { // Periodically rebalance the allocation of slots to backends _ = rebalance_interval.tick() => { event!(Level::INFO, "Rebalancing: timer tick"); - self.rebalance().await; + self.rebalance(); } // If any of the slots change state, update their allocations. - Some((name, status)) = &mut backend_status_stream.next(), if !backend_status_stream.is_empty() => { + Some((name, status)) = backend_status_stream.next(), if !backend_status_stream.is_empty() => { event!(Level::INFO, name = ?name, status = ?status, "Rebalancing: Backend has new status"); rebalance_interval.reset(); - self.rebalance().await; + self.rebalance(); if matches!(status, slot::SetState::Online { has_unclaimed_slots: true }) { - self.try_claim_from_queue().await; + self.try_claim_from_queue(); } }, } } + + // Out of an abundance of caution, to avoid futurelock: drop all + // possible unpolled futures before invoking terminate. + drop(rebalance_interval); + drop(backend_status_stream); + drop(resolver_stream); + + self.terminate().await; } - async fn claim_or_enqueue( + fn claim_or_enqueue( &mut self, id: ClaimId, tx: oneshot::Sender, Error>>, ) { - let result = self.claim(id).await; + let result = self.claim(id); if result.is_ok() { let _ = tx.send(result); return; @@ -364,13 +372,13 @@ impl PoolInner { }); } - async fn try_claim_from_queue(&mut self) { + fn try_claim_from_queue(&mut self) { loop { let Some(request) = self.request_queue.pop_front() else { return; }; - let result = self.claim(request.id).await; + let result = self.claim(request.id); if result.is_ok() { let _ = request.tx.send(result); } else { @@ -394,16 +402,16 @@ impl PoolInner { } #[instrument(skip(self), name = "PoolInner::rebalance")] - async fn rebalance(&mut self) { + fn rebalance(&mut self) { #[cfg(feature = "probes")] probes::rebalance__start!(|| self.name.as_str()); - self.rebalance_inner().await; + self.rebalance_inner(); #[cfg(feature = "probes")] probes::rebalance__done!(|| self.name.as_str()); } - async fn rebalance_inner(&mut self) { + fn rebalance_inner(&mut self) { let mut questionable_backend_count = 0; let mut usable_backends = vec![]; @@ -412,7 +420,7 @@ impl PoolInner { for (name, slot_set) in iter { match slot_set.get_state() { slot::SetState::Offline => { - let _ = slot_set.set_wanted_count(1).await; + let _ = slot_set.set_wanted_count(1); questionable_backend_count += 1; } slot::SetState::Online { .. } => { @@ -442,7 +450,7 @@ impl PoolInner { let Some(slot_set) = self.slots.get_mut(&name) else { continue; }; - let _ = slot_set.set_wanted_count(slots_wanted_per_backend).await; + let _ = slot_set.set_wanted_count(slots_wanted_per_backend); } let mut new_priority_list = PriorityList::new(); @@ -472,7 +480,7 @@ impl PoolInner { self.priority_list = new_priority_list; } - async fn claim(&mut self, id: ClaimId) -> Result, Error> { + fn claim(&mut self, id: ClaimId) -> Result, Error> { let mut attempted_backend = vec![]; let mut result = Err(Error::NoBackends); @@ -504,7 +512,7 @@ impl PoolInner { // // Either way, put this backend back in the priority list after // we're done with it. - let Ok(claim) = set.claim(id).await else { + let Ok(claim) = set.claim(id) else { event!(Level::DEBUG, "Failed to actually get claim for backend"); rebalancer::claimed_err(&mut weighted_backend); attempted_backend.push(weighted_backend); @@ -525,7 +533,7 @@ impl PoolInner { Err(_) => probes::pool__claim__failed!(|| (self.name.as_str(), id.0)), } - self.priority_list.extend(attempted_backend.into_iter()); + self.priority_list.extend(attempted_backend); result } } diff --git a/src/resolvers/dns.rs b/src/resolvers/dns.rs index 030c691..ebc7d06 100644 --- a/src/resolvers/dns.rs +++ b/src/resolvers/dns.rs @@ -15,6 +15,7 @@ use hickory_resolver::config::ResolverOpts; use hickory_resolver::error::{ResolveError, ResolveErrorKind}; use hickory_resolver::TokioAsyncResolver; use std::collections::{BTreeMap, HashMap}; +use std::future::Future; use std::net::SocketAddr; use std::net::SocketAddrV6; use std::sync::atomic::{AtomicBool, Ordering}; @@ -171,22 +172,40 @@ impl DnsResolverWorker { .or_insert_with(|| Client::new(&self.config, address, failure_window)); } + // This function is cancel-safe. + async fn tick_and_query_dns(&mut self, query_interval: &mut tokio::time::Interval) { + // We want to wait for "query_interval"'s timeout to pass before + // starting to query DNS. However, if we're partway through "query_dns" + // and we are cancelled, we'd like to resume immediately. + // + // To accomplish this: + // - After we tick once, we "reset_immediately" so tick will fire + // again immediately if this future is dropped and re-created. + // - Once we finish "query_dns", we reset the query interval to + // actually respect the "tick period" of time. + query_interval.tick().await; + query_interval.reset_immediately(); + + self.query_dns().await; + if self.backends.is_empty() { + query_interval.reset_after(self.config.query_retry_if_no_records_found); + } else { + query_interval.reset(); + } + } + async fn run(mut self, mut terminate_rx: tokio::sync::oneshot::Receiver<()>) { let mut query_interval = tokio::time::interval(self.config.query_interval); loop { - let next_tick = query_interval.tick(); let next_backend_expiration = self.sleep_until_next_backend_expiration(); + // Cancel safety: All branches are cancel-safe. + // + // Futurelock safety: All select arms are queried concurrently. No + // awaiting happens outside this concurrent polling. tokio::select! { - _ = &mut terminate_rx => { - return; - }, - _ = next_tick => { - self.query_dns().await; - if self.backends.is_empty() { - query_interval.reset_after(self.config.query_retry_if_no_records_found); - } - }, + _ = &mut terminate_rx => return, + _ = self.tick_and_query_dns(&mut query_interval) => {}, backend_name = next_backend_expiration => { if self.backends.remove(&backend_name).is_some() { self.watch_tx.send_modify(|backends| { @@ -342,7 +361,8 @@ impl DnsResolverWorker { }); } - async fn sleep_until_next_backend_expiration(&self) -> backend::Name { + // This function is cancel-safe. + fn sleep_until_next_backend_expiration(&self) -> impl Future { let next_expiration = self.backends.iter().reduce(|soonest, backend| { let Some(backend_expiration) = backend.1.expires_at else { return soonest; @@ -364,20 +384,21 @@ impl DnsResolverWorker { } }); - let Some(( - name, - BackendRecord { - expires_at: Some(deadline), - .. - }, - )) = next_expiration - else { - let () = futures::future::pending().await; - unreachable!(); + let (name, deadline) = match next_expiration { + Some(( + name, + BackendRecord { + expires_at: Some(deadline), + .. + }, + )) => (name.clone(), *deadline), + _ => return futures::future::Either::Left(futures::future::pending()), }; - tokio::time::sleep_until((*deadline).into()).await; - name.clone() + futures::future::Either::Right(async move { + tokio::time::sleep_until(deadline.into()).await; + name + }) } } diff --git a/src/slot.rs b/src/slot.rs index 6e33f87..017133c 100644 --- a/src/slot.rs +++ b/src/slot.rs @@ -14,7 +14,7 @@ use derive_where::derive_where; use std::collections::BTreeMap; use std::sync::{Arc, Mutex}; use thiserror::Error; -use tokio::sync::{mpsc, oneshot, watch, Notify}; +use tokio::sync::{mpsc, watch}; use tokio::task::JoinHandle; use tokio::time::{interval, Duration}; use tracing::{event, instrument, span, Instrument, Level}; @@ -121,8 +121,8 @@ struct SlotInner { // All fields of the slot which need to be guarded behind a mutex guarded: Mutex>, - // A notification channel indicating that the slot needs recyling - recycling_needed: Notify, + // A watch channel indicating that the slot needs recycling + recycling_needed_tx: tokio::sync::watch::Sender, // This is wrapped in an "Arc" because it's shared with all slots in the slot set. stats: Arc>, @@ -284,33 +284,53 @@ impl Slot { backend: &Backend, terminate_rx: &mut tokio::sync::oneshot::Receiver<()>, ) -> bool { - let mut retry_duration = config.min_connection_backoff.add_spread(config.spread); + // Cancel safety: We don't care about the state of the slot if we're terminating. + // + // Futurelock safety: Both select arms are queried concurrently. No awaiting happens + // outside this concurrent polling. + tokio::select! { + biased; + _ = &mut *terminate_rx => { + false + } + _ = self.try_connect_forever(slot_id, config, connector, backend) => { + true + } + } + } + async fn try_connect_forever( + &self, + #[cfg_attr(not(feature = "probes"), allow(unused_variables))] slot_id: SlotId, + config: &SetConfig, + connector: &SharedConnector, + backend: &Backend, + ) { + let mut retry_duration = config.min_connection_backoff.add_spread(config.spread); loop { - tokio::select! { - biased; - _ = &mut *terminate_rx => { - return false; + let result = self.do_connect(slot_id, connector, backend).await; + match result { + Ok(conn) => { + self.inner.state_transition( + slot_id, + backend, + self.inner.guarded.lock().unwrap(), + State::ConnectedUnclaimed(DebugIgnore(conn)), + ); + return; } - result = self.do_connect(slot_id, connector, backend) => { - match result { - Ok(conn) => { - self.inner.state_transition( - slot_id, - backend, - self.inner.guarded.lock().unwrap(), - State::ConnectedUnclaimed(DebugIgnore(conn)), - ); - return true; - } - Err(err) => { - event!(Level::WARN, pool_name = self.inner.pool_name.as_str(), ?err, ?backend, "Failed to connect"); - self.inner.failure_window.add(1); - retry_duration = - retry_duration.exponential_backoff(config.max_connection_backoff); - tokio::time::sleep(retry_duration).await; - } - } + Err(err) => { + event!( + Level::WARN, + pool_name = self.inner.pool_name.as_str(), + ?err, + ?backend, + "Failed to connect" + ); + self.inner.failure_window.add(1); + retry_duration = + retry_duration.exponential_backoff(config.max_connection_backoff); + tokio::time::sleep(retry_duration).await; } } } @@ -348,6 +368,37 @@ impl Slot { res } + // Returns "true" if we should terminate, returns "false" otherwise. + async fn recycle_if_needed_or_terminate( + &self, + slot_id: SlotId, + connector: &SharedConnector, + timeout: Duration, + backend: &Backend, + terminate_rx: &mut tokio::sync::oneshot::Receiver<()>, + ) -> bool { + // Cancel safety: If we're terminating, we don't care about the state of the slot. If we + // finish recycling, "terminate_rx" is borrowed, and not cancelled. + // + // Futurelock safety: Both select arms are queried concurrently. No awaiting happens + // outside this concurrent polling. + tokio::select! { + biased; + _ = terminate_rx => { + event!( + Level::TRACE, + slot_id = slot_id.as_u64(), + "Terminating while recycling" + ); + true + }, + _ = self.recycle_if_needed(slot_id, connector, timeout, backend) => { + false + } + } + } + + // Recycles the connection, transitioning to "unclaimed" or "connecting". #[instrument( level = "trace", skip(self, connector), @@ -441,6 +492,37 @@ impl Slot { } } + // Returns "true" if we should terminate, returns "false" otherwise. + async fn validate_health_if_connected_or_terminate( + &self, + slot_id: SlotId, + connector: &SharedConnector, + timeout: Duration, + backend: &Backend, + terminate_rx: &mut tokio::sync::oneshot::Receiver<()>, + ) -> bool { + // Cancel safety: If we're terminating, we don't care about the state of the slot. If we + // finish recycling, "terminate_rx" is borrowed, and not cancelled. + // + // Futurelock safety: Both select arms are queried concurrently. No awaiting happens + // outside this concurrent polling. + tokio::select! { + biased; + _ = terminate_rx => { + event!( + Level::TRACE, + slot_id = slot_id.as_u64(), + "Terminating while validating health" + ); + true + }, + _ = self.validate_health_if_connected(slot_id, connector, timeout, backend) => { + false + } + } + } + + // Queries "is_valid" on a connection, updating the state based on health. #[instrument( level = "trace", skip(self, connector), @@ -642,233 +724,91 @@ impl Stats { } } -enum SetRequest { - Claim { - id: ClaimId, - tx: oneshot::Sender, Error>>, - }, - SetWantedCount { - count: usize, - }, -} - -// Owns and runs work on behalf of a [Set]. -struct SetWorker { +// Provides direct access to all underlying slots +// +// Shared by both a [`SetWorker`] and [`Set`] +struct Slots { pool_name: pool::Name, name: backend::Name, backend: Backend, - config: SetConfig, - - wanted_count: usize, - // Interface for actually connecting to backends backend_connector: SharedConnector, + config: SetConfig, - // Interface for receiving client requests - rx: mpsc::Receiver>, + // The actual underlying slots, by ID. + slots: BTreeMap>, + // The desired number of slots + wanted_count: usize, + next_slot_id: SlotId, - // Identifies that the set worker should terminate immediately - terminate_rx: tokio::sync::oneshot::Receiver<()>, + // If "true", new requests are rejected + terminating: bool, // Interface for communicating backend status status_tx: watch::Sender, - - // Sender and receiver for returning old handles. - // - // This is to guarantee a size, and to vend out permits to claim::Handles so they can be sure - // that their connections can return to the set without error. + // Sender for returning old handles. slot_tx: mpsc::Sender>, - slot_rx: mpsc::Receiver>, - - // The actual slots themselves. - slots: BTreeMap>, // Summary information about the health of all slots. // // Must be kept in lockstep with "Self::slots" stats: Arc>, - failure_window: Arc, - - next_slot_id: SlotId, } -impl SetWorker { - #[allow(clippy::too_many_arguments)] - fn new( - pool_name: pool::Name, - set_id: u16, - name: backend::Name, - rx: mpsc::Receiver>, - terminate_rx: tokio::sync::oneshot::Receiver<()>, - status_tx: watch::Sender, - config: SetConfig, - wanted_count: usize, - backend: Backend, - backend_connector: SharedConnector, - stats: Arc>, - failure_window: Arc, - ) -> Self { - let (slot_tx, slot_rx) = mpsc::channel(config.max_count); - let mut set = Self { - pool_name, - name, - backend, - config, - wanted_count, - backend_connector, - stats, - failure_window, - rx, - terminate_rx, - status_tx, - slot_tx, - slot_rx, - slots: BTreeMap::new(), - next_slot_id: SlotId::first(set_id), - }; - set.set_wanted_count(wanted_count); - set - } - - // Creates a new Slot, which always starts as "Connecting", and spawn a task - // to actually connect to the backend and monitor slot health. +impl Slots { #[instrument( - skip(self) - fields(pool_name = %self.pool_name), - name = "SetWorker::create_slot" + level = "trace", + skip(self), + err, + name = "Slots::claim", + fields(name = ?self.name), )] - fn create_slot(&mut self, slot_id: SlotId) { - let (terminate_tx, mut terminate_rx) = tokio::sync::oneshot::channel(); - let slot = Slot { - inner: Arc::new(SlotInner { - guarded: Mutex::new(SlotInnerGuarded { - state: State::Connecting, - status_tx: self.status_tx.clone(), - terminate_tx: Some(terminate_tx), - handle: None, - }), - recycling_needed: Notify::new(), - stats: self.stats.clone(), - failure_window: self.failure_window.clone(), - pool_name: self.pool_name.clone(), - }), - }; - let slot = self.slots.entry(slot_id).or_insert(slot).clone(); - self.stats - .lock() - .unwrap() - .enter_state(&State::::Connecting); + fn claim(&self, id: ClaimId) -> Result, Error> { + #[cfg(feature = "probes")] + probes::slot__set__claim__start!(|| ( + self.pool_name.as_str(), + id.0, + self.backend.address.to_string() + )); - slot.inner.guarded.lock().unwrap().handle = Some(tokio::task::spawn({ - let slot = slot.clone(); - let config = self.config.clone(); - let connector = self.backend_connector.clone(); - let backend = self.backend.clone(); - async move { - let mut interval = interval(config.health_interval); + // Before we vend out the slot's connection to a client, make sure that + // we have space to take it back once they're done with it. + let Ok(permit) = self.slot_tx.clone().try_reserve_owned() else { + event!(Level::TRACE, "Could not reserve slot_tx permit"); - loop { - event!( - Level::TRACE, - slot_id = slot_id.as_u64(), - "Starting Slot work loop" - ); - enum Work { - DoConnect, - DoMonitor, - } + #[cfg(feature = "probes")] + probes::slot__set__claim__failed!(|| ( + self.pool_name.as_str(), + id.0, + "Could not reserve slot_tx permit; all slots used" + )); - // We're deciding what work to do, based on the state, - // within an isolated scope. This is due to: - // https://github.com/rust-lang/rust/issues/69663 - // - // Even if we drop the MutexGuard before `.await` points, - // rustc still sees something "non-Send" held across an - // `.await`. - let work = { - let guarded = slot.inner.guarded.lock().unwrap(); - match &guarded.state { - State::Connecting => Work::DoConnect, - State::ConnectedUnclaimed(_) - | State::ConnectedRecycling(_) - | State::ConnectedChecking - | State::ConnectedClaimed => Work::DoMonitor, - State::Terminated => return, - } - }; + // This is more of an "all slots in-use" error, + // but it should look the same to clients. + return Err(Error::NoSlotsReady); + }; - match work { - Work::DoConnect => { - let span = span!( - Level::TRACE, - "Slot worker connecting", - slot_id = slot_id.as_u64() - ); - let connected = async { - if !slot - .loop_until_connected( - slot_id, - &config, - &connector, - &backend, - &mut terminate_rx, - ) - .await - { - // The slot was instructed to exit - // before it connected. Bail. - event!( - Level::TRACE, - slot_id = slot_id.as_u64(), - "Terminating instead of connecting" - ); - return false; - } - interval.reset_after(interval.period().add_spread(config.spread)); - true - } - .instrument(span) - .await; + let Some(handle) = self.take_connected_unclaimed_slot(permit, id) else { + event!(Level::TRACE, "Failed to take unclaimed slot"); - if !connected { - return; - } - } - Work::DoMonitor => { - tokio::select! { - biased; - _ = &mut terminate_rx => { - // If we've been instructed to bail out, - // do that immediately. - event!(Level::TRACE, slot_id = slot_id.as_u64(), "Terminating while monitoring"); - return; - }, - _ = interval.tick() => { - slot.validate_health_if_connected( - slot_id, - &connector, - config.health_check_timeout, - &backend, - ) - .await; - interval - .reset_after(interval.period().add_spread(config.spread)); - }, - _ = slot.inner.recycling_needed.notified() => { - slot.recycle_if_needed( - slot_id, - &connector, - config.health_check_timeout, - &backend, - ).await; - }, - } - } - } - } - } - })); + #[cfg(feature = "probes")] + probes::slot__set__claim__failed!(|| ( + self.pool_name.as_str(), + id.0, + "No unclaimed slots" + )); + return Err(Error::NoSlotsReady); + }; + + #[cfg(feature = "probes")] + probes::slot__set__claim__done!(|| ( + self.pool_name.as_str(), + id.0, + handle.slot_id().as_u64() + )); + + return Ok(handle); } // Borrows a connection out of the first unclaimed slot. @@ -878,14 +818,14 @@ impl SetWorker { #[instrument( skip(self, permit) fields(pool_name = %self.pool_name), - name = "SetWorker::take_connected_unclaimed_slot" + name = "Slots::take_connected_unclaimed_slot" )] fn take_connected_unclaimed_slot( - &mut self, + &self, permit: mpsc::OwnedPermit>, claim_id: ClaimId, ) -> Option> { - for (id, slot) in &mut self.slots { + for (id, slot) in &self.slots { let guarded = slot.inner.guarded.lock().unwrap(); event!(Level::TRACE, slot_id = id.as_u64(), state = ?guarded.state, "Considering slot"); if matches!(guarded.state, State::ConnectedUnclaimed(_)) { @@ -927,68 +867,11 @@ impl SetWorker { None } - // Takes back borrowed slots from clients who dropped their claim handles. - #[instrument( - level = "trace", - skip(self, borrowed_conn), - fields( - slot_id = borrowed_conn.id.as_u64(), - name = ?self.name, - ), - name = "SetWorker::recycle_connection" - )] - fn recycle_connection(&mut self, borrowed_conn: BorrowedConnection) { - let slot_id = borrowed_conn.id; - #[cfg(feature = "probes")] - crate::probes::handle__returned!(|| (self.pool_name.as_str(), slot_id.as_u64())); - let inner = self - .slots - .get_mut(&slot_id) - .expect( - "A borrowed connection was returned to this\ - pool, and it should reference a slot that \ - cannot be removed while borrowed", - ) - .inner - .clone(); - { - let guarded = inner.guarded.lock().unwrap(); - assert!( - matches!(guarded.state, State::ConnectedClaimed), - "Unexpected slot state {:?}", - guarded.state - ); - inner.state_transition( - slot_id, - &self.backend, - guarded, - State::ConnectedRecycling(DebugIgnore(borrowed_conn.conn)), - ); - } - - // If we tried to shrink the slot count while too many connections were - // in-use, it's possible there's more work to do. Try to conform the - // slot count after recycling each connection. - self.conform_slot_count(); - - inner.recycling_needed.notify_one(); - } - fn set_wanted_count(&mut self, count: usize) { self.wanted_count = std::cmp::min(count, self.config.max_count); self.conform_slot_count(); } - // Makes the number of slots as close to "desired_count" as we can get. - #[instrument( - level = "trace", - skip(self), - fields( - wanted_count = self.wanted_count, - name = ?self.name, - ), - name = "SetWorker::conform_slot_count" - )] fn conform_slot_count(&mut self) { let desired = self.wanted_count; @@ -1067,80 +950,275 @@ impl SetWorker { for slot_id in self.next_slot_id.0..self.next_slot_id.0 + new_slots { self.create_slot(SlotId(slot_id)); } - self.next_slot_id.0 += new_slots; + self.next_slot_id.0 += new_slots; + } + Equal => {} + } + } + + fn create_slot(&mut self, slot_id: SlotId) { + let (terminate_tx, mut terminate_rx) = tokio::sync::oneshot::channel(); + let (recycling_needed_tx, recycling_needed_rx) = tokio::sync::watch::channel(false); + let slot = Slot { + inner: Arc::new(SlotInner { + guarded: Mutex::new(SlotInnerGuarded { + state: State::Connecting, + status_tx: self.status_tx.clone(), + terminate_tx: Some(terminate_tx), + handle: None, + }), + recycling_needed_tx, + stats: self.stats.clone(), + failure_window: self.failure_window.clone(), + pool_name: self.pool_name.clone(), + }), + }; + let slot = self.slots.entry(slot_id).or_insert(slot).clone(); + self.stats + .lock() + .unwrap() + .enter_state(&State::::Connecting); + + slot.inner.guarded.lock().unwrap().handle = Some(tokio::task::spawn({ + let slot = slot.clone(); + let config = self.config.clone(); + let connector = self.backend_connector.clone(); + let backend = self.backend.clone(); + let mut recycling_needed_rx = recycling_needed_rx; + async move { + let mut interval = interval(config.health_interval); + + loop { + event!( + Level::TRACE, + slot_id = slot_id.as_u64(), + "Starting Slot work loop" + ); + enum Work { + DoConnect, + DoMonitor, + } + + // We're deciding what work to do, based on the state, + // within an isolated scope. This is due to: + // https://github.com/rust-lang/rust/issues/69663 + // + // Even if we drop the MutexGuard before `.await` points, + // rustc still sees something "non-Send" held across an + // `.await`. + let work = { + let guarded = slot.inner.guarded.lock().unwrap(); + match &guarded.state { + State::Connecting => Work::DoConnect, + State::ConnectedUnclaimed(_) + | State::ConnectedRecycling(_) + | State::ConnectedChecking + | State::ConnectedClaimed => Work::DoMonitor, + State::Terminated => return, + } + }; + + match work { + Work::DoConnect => { + let span = span!( + Level::TRACE, + "Slot worker connecting", + slot_id = slot_id.as_u64() + ); + let connected = async { + if !slot + .loop_until_connected( + slot_id, + &config, + &connector, + &backend, + &mut terminate_rx, + ) + .await + { + // The slot was instructed to exit + // before it connected. Bail. + event!( + Level::TRACE, + slot_id = slot_id.as_u64(), + "Terminating instead of connecting" + ); + return false; + } + interval.reset_after(interval.period().add_spread(config.spread)); + true + } + .instrument(span) + .await; + + if !connected { + return; + } + } + Work::DoMonitor => { + // Cancel safety: "tick()" and "changed()" are both cancel-safe; if we + // take one branch, the other is not consumed. "terminate_rx" is also + // borrowed, and will not be cancelled. + // + // Futurelock safety: The only borrowed future is "terminate_rx", which + // we will continue to poll in the other select arms. The other futures + // are dropped immediately. + tokio::select! { + biased; + _ = &mut terminate_rx => { + event!( + Level::TRACE, + slot_id = slot_id.as_u64(), + "Terminating while monitoring" + ); + return; + }, + _ = interval.tick() => { + if slot.validate_health_if_connected_or_terminate( + slot_id, + &connector, + config.health_check_timeout, + &backend, + &mut terminate_rx, + ) + .await { return; } + interval.reset_after(interval.period().add_spread(config.spread)); + }, + _ = recycling_needed_rx.changed() => { + if *recycling_needed_rx.borrow_and_update() { + if slot.recycle_if_needed_or_terminate( + slot_id, + &connector, + config.health_check_timeout, + &backend, + &mut terminate_rx, + ).await { return; } + let _ = slot.inner.recycling_needed_tx.send(false); + } + }, + } + } + } + } } - Equal => {} - } + })); } +} - #[instrument( - level = "trace", - skip(self), - err, - name = "SetWorker::claim", - fields(name = ?self.name), - )] - fn claim(&mut self, id: ClaimId) -> Result, Error> { - #[cfg(feature = "probes")] - probes::slot__set__claim__start!(|| ( - self.pool_name.as_str(), - id.0, - self.backend.address.to_string() - )); - - // Before we vend out the slot's connection to a client, make sure that - // we have space to take it back once they're done with it. - let Ok(permit) = self.slot_tx.clone().try_reserve_owned() else { - event!(Level::TRACE, "Could not reserve slot_tx permit"); +// Owns and runs work on behalf of a [Set]. +struct SetWorker { + pool_name: pool::Name, + name: backend::Name, + backend: Backend, - #[cfg(feature = "probes")] - probes::slot__set__claim__failed!(|| ( - self.pool_name.as_str(), - id.0, - "Could not reserve slot_tx permit; all slots used" - )); + // Identifies that the set worker should terminate immediately + terminate_rx: tokio::sync::oneshot::Receiver<()>, - // This is more of an "all slots in-use" error, - // but it should look the same to clients. - return Err(Error::NoSlotsReady); - }; + // Sender and receiver for returning old handles. + // + // This is to guarantee a size, and to vend out permits to claim::Handles so + // they can be sure that their connections can return to the set without + // error. + slot_rx: mpsc::Receiver>, - let Some(handle) = self.take_connected_unclaimed_slot(permit, id) else { - event!(Level::TRACE, "Failed to take unclaimed slot"); + // The actual slots themselves. + slots: Arc>>, +} - #[cfg(feature = "probes")] - probes::slot__set__claim__failed!(|| ( - self.pool_name.as_str(), - id.0, - "No unclaimed slots" - )); - return Err(Error::NoSlotsReady); +impl SetWorker { + #[allow(clippy::too_many_arguments)] + fn new( + pool_name: pool::Name, + set_id: u16, + name: backend::Name, + terminate_rx: tokio::sync::oneshot::Receiver<()>, + status_tx: watch::Sender, + config: SetConfig, + wanted_count: usize, + backend: Backend, + backend_connector: SharedConnector, + stats: Arc>, + failure_window: Arc, + ) -> Self { + let (slot_tx, slot_rx) = mpsc::channel(config.max_count); + let set = Self { + pool_name: pool_name.clone(), + name: name.clone(), + backend: backend.clone(), + terminate_rx, + slot_rx, + slots: Arc::new(Mutex::new(Slots { + pool_name, + name, + backend, + backend_connector, + config, + slots: BTreeMap::new(), + wanted_count, + next_slot_id: SlotId::first(set_id), + terminating: false, + status_tx, + slot_tx, + stats, + failure_window, + })), }; + set.set_wanted_count(wanted_count); + set + } + // Takes back borrowed slots from clients who dropped their claim handles. + #[instrument( + level = "trace", + skip(self, borrowed_conn), + fields( + slot_id = borrowed_conn.id.as_u64(), + name = ?self.name, + ), + name = "SetWorker::recycle_connection" + )] + fn recycle_connection(&self, borrowed_conn: BorrowedConnection) { + let slot_id = borrowed_conn.id; #[cfg(feature = "probes")] - probes::slot__set__claim__done!(|| ( - self.pool_name.as_str(), - id.0, - handle.slot_id().as_u64() - )); + crate::probes::handle__returned!(|| (self.pool_name.as_str(), slot_id.as_u64())); - return Ok(handle); + let mut slots = self.slots.lock().unwrap(); + let inner = slots + .slots + .get_mut(&slot_id) + .expect( + "A borrowed connection was returned to this\ + pool, and it should reference a slot that \ + cannot be removed while borrowed", + ) + .inner + .clone(); + { + let guarded = inner.guarded.lock().unwrap(); + assert!( + matches!(guarded.state, State::ConnectedClaimed), + "Unexpected slot state {:?}", + guarded.state + ); + inner.state_transition( + slot_id, + &self.backend, + guarded, + State::ConnectedRecycling(DebugIgnore(borrowed_conn.conn)), + ); + } + + // If we tried to shrink the slot count while too many connections were + // in-use, it's possible there's more work to do. Try to conform the + // slot count after recycling each connection. + slots.conform_slot_count(); + + let _ = inner.recycling_needed_tx.send(true); } - // Note that this function is not asynchronous. - // - // This is intentional: We should not be await-ing in the SetWorker - // task when servicing client requests. - fn handle_client_request(&mut self, request: SetRequest) { - match request { - SetRequest::Claim { id, tx } => { - let result = self.claim(id); - let _ = tx.send(result); - } - SetRequest::SetWantedCount { count } => { - self.set_wanted_count(count); - } - } + fn set_wanted_count(&self, count: usize) { + let mut slots = self.slots.lock().unwrap(); + slots.set_wanted_count(count); } #[instrument( @@ -1151,6 +1229,11 @@ impl SetWorker { )] async fn run(&mut self) { loop { + // Cancel safety: "recv()" is cancel-safe, but we also don't care about it if + // we're terminating. "terminate_rx" is borrowed, and will not be cancelled. + // + // Futurelock safety: Both select arms are queried concurrently. No awaiting happens + // outside this concurrent polling. tokio::select! { biased; // If we should exit, terminate immediately. @@ -1168,22 +1251,16 @@ impl SetWorker { }, } }, - // Handle requests from clients - request = self.rx.recv() => { - if let Some(request) = request { - self.handle_client_request(request); - } else { - // All clients have gone away, so terminate the set. - // Break out of the loop rather than return, so that the - // termination code runs. - break; - } - } } } // If we have exited from the run loop, tear down the background tasks - while let Some((_id, slot)) = self.slots.pop_first() { + let mut slots = { + let mut slots_lock = self.slots.lock().unwrap(); + slots_lock.terminating = true; + std::mem::take(&mut slots_lock.slots) + }; + while let Some((_id, slot)) = slots.pop_first() { let handle = { let mut lock = slot.inner.guarded.lock().unwrap(); @@ -1221,7 +1298,7 @@ pub(crate) enum SetState { /// A set of slots for a particular backend. pub(crate) struct Set { - tx: mpsc::Sender>, + slots: Arc>>, status_rx: watch::Receiver, @@ -1249,37 +1326,35 @@ impl Set { backend: Backend, backend_connector: SharedConnector, ) -> Self { - let (tx, rx) = mpsc::channel(1); let (terminate_tx, terminate_rx) = tokio::sync::oneshot::channel(); let (status_tx, status_rx) = watch::channel(SetState::Offline); let failure_duration = config.max_connection_backoff * 2; let stats = Arc::new(Mutex::new(Stats::default())); let failure_window = Arc::new(WindowedCounter::new(failure_duration)); + + let mut worker = SetWorker::new( + pool_name, + set_id, + name.clone(), + terminate_rx, + status_tx, + config, + wanted_count, + backend, + backend_connector, + stats.clone(), + failure_window.clone(), + ); + let slots = worker.slots.clone(); + let handle = tokio::task::spawn({ - let stats = stats.clone(); - let failure_window = failure_window.clone(); - let name = name.clone(); async move { - let mut worker = SetWorker::new( - pool_name, - set_id, - name, - rx, - terminate_rx, - status_tx, - config, - wanted_count, - backend, - backend_connector, - stats, - failure_window, - ); worker.run().await; } }); Self { - tx, + slots, status_rx, name, stats, @@ -1321,15 +1396,12 @@ impl Set { name = "Set::claim", fields(name = ?self.name), )] - pub(crate) async fn claim(&mut self, id: ClaimId) -> Result, Error> { - let (tx, rx) = oneshot::channel(); - - self.tx - .send(SetRequest::Claim { id, tx }) - .await - .map_err(|_| Error::SlotWorkerTerminated)?; - - rx.await.map_err(|_| Error::SlotWorkerTerminated)? + pub(crate) fn claim(&self, id: ClaimId) -> Result, Error> { + let slots = self.slots.lock().unwrap(); + if slots.terminating { + return Err(Error::SlotWorkerTerminated); + } + slots.claim(id) } /// Updates the number of "wanted" slots within the slot set. @@ -1342,11 +1414,12 @@ impl Set { name = "Set::set_wanted_count", fields(name = ?self.name), )] - pub(crate) async fn set_wanted_count(&mut self, count: usize) -> Result<(), Error> { - self.tx - .send(SetRequest::SetWantedCount { count }) - .await - .map_err(|_| Error::SlotWorkerTerminated)?; + pub(crate) fn set_wanted_count(&self, count: usize) -> Result<(), Error> { + let mut slots = self.slots.lock().unwrap(); + if slots.terminating { + return Err(Error::SlotWorkerTerminated); + } + slots.set_wanted_count(count); Ok(()) } @@ -1531,7 +1604,7 @@ mod test { #[tokio::test] async fn test_one_claim() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1547,13 +1620,13 @@ mod test { .await .unwrap(); - let _conn = set.claim(ClaimId::new()).await.unwrap(); + let _conn = set.claim(ClaimId::new()).unwrap(); } #[tokio::test] async fn test_drain_slots() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1570,8 +1643,8 @@ mod test { .unwrap(); // Grab a connection, then set the "Wanted" count to zero. - let conn = set.claim(ClaimId::new()).await.unwrap(); - set.set_wanted_count(0).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); + set.set_wanted_count(0).unwrap(); // Let the connections drain loop { @@ -1598,7 +1671,7 @@ mod test { #[tokio::test] async fn test_no_slots_add_some_later() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1610,13 +1683,12 @@ mod test { // We start with nothing available set.claim(ClaimId::new()) - .await .map(|_| ()) .expect_err("Should not be able to get claims yet"); assert_eq!(set.get_state(), SetState::Offline); // We can later adjust the count of desired slots - set.set_wanted_count(3).await.unwrap(); + set.set_wanted_count(3).unwrap(); // Let the connections fill up set.monitor() @@ -1625,13 +1697,13 @@ mod test { .unwrap(); // When this completes, the connections may be claimed - let _conn = set.claim(ClaimId::new()).await.unwrap(); + let _conn = set.claim(ClaimId::new()).unwrap(); } #[tokio::test] async fn test_all_claims() { setup_tracing_subscriber(); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig::default(), @@ -1652,12 +1724,11 @@ mod test { } } - let _conn1 = set.claim(ClaimId::new()).await.unwrap(); - let _conn2 = set.claim(ClaimId::new()).await.unwrap(); - let conn3 = set.claim(ClaimId::new()).await.unwrap(); + let _conn1 = set.claim(ClaimId::new()).unwrap(); + let _conn2 = set.claim(ClaimId::new()).unwrap(); + let conn3 = set.claim(ClaimId::new()).unwrap(); set.claim(ClaimId::new()) - .await .map(|_| ()) .expect_err("We should fail to acquire a 4th claim from 3 slot set"); @@ -1675,7 +1746,7 @@ mod test { } } - let _conn4 = set.claim(ClaimId::new()).await.unwrap(); + let _conn4 = set.claim(ClaimId::new()).unwrap(); } #[tokio::test] @@ -1684,7 +1755,7 @@ mod test { let connector = Arc::new(TestConnector::new()); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig { @@ -1728,14 +1799,13 @@ mod test { )); // Grab three connections - let _claim1 = set.claim(ClaimId::new()).await.expect("Failed to claim"); - let _claim2 = set.claim(ClaimId::new()).await.expect("Failed to claim"); - let claim3 = set.claim(ClaimId::new()).await.expect("Failed to claim"); + let _claim1 = set.claim(ClaimId::new()).expect("Failed to claim"); + let _claim2 = set.claim(ClaimId::new()).expect("Failed to claim"); + let claim3 = set.claim(ClaimId::new()).expect("Failed to claim"); // Cannot claim the fourth connection, this slot set is all used. assert!(matches!( set.claim(ClaimId::new()) - .await .map(|_| ()) .expect_err("Should have reached claim capacity"), Error::NoSlotsReady, @@ -1842,7 +1912,7 @@ mod test { let wanted_count = 5; let connector = Arc::new(TestConnector::new()); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig { @@ -1869,7 +1939,7 @@ mod test { } // Grab one of the slots. Inspect the state, validating it is connected. - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); let raw_conn = conn.clone(); assert_eq!(raw_conn.get_state(), TestConnectionState::Connected); drop(conn); @@ -1891,7 +1961,7 @@ mod test { assert_eq!(raw_conn.get_state(), TestConnectionState::Recycled); connector.set_recyclable(false); - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); let raw_conn = conn.clone(); assert_eq!(raw_conn.get_state(), TestConnectionState::Recycled); drop(conn); @@ -1923,7 +1993,7 @@ mod test { let wanted_count = 5; let connector = Arc::new(TestConnector::new()); let health_interval = Duration::from_millis(1); - let mut set = Set::new( + let set = Set::new( 0, pool::Name::new("my-pool"), SetConfig { @@ -1946,7 +2016,7 @@ mod test { // // This means no new connections, and existing connections will die // when health checked. - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); connector.set_connectable(false); connector.set_valid(false); let raw_conn = conn.clone(); @@ -1991,19 +2061,21 @@ mod test { .await .unwrap(); - let conn = set.claim(ClaimId::new()).await.unwrap(); + let conn = set.claim(ClaimId::new()).unwrap(); // We should be able to terminate, even with a claim out. set.terminate().await; - assert!(matches!( - set.claim(ClaimId::new()).await.map(|_| ()).unwrap_err(), - Error::SlotWorkerTerminated, - )); - assert!(matches!( - set.set_wanted_count(1).await.unwrap_err(), - Error::SlotWorkerTerminated - )); + let err = set.claim(ClaimId::new()).map(|_| ()).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated,), + "Unexpected error: {err}" + ); + let err = set.set_wanted_count(1).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated), + "Unexpected error: {err}" + ); drop(conn); } @@ -2106,14 +2178,14 @@ mod test { }) .await .unwrap(); - handles.push(set.claim(ClaimId::new()).await.unwrap()); + handles.push(set.claim(ClaimId::new()).unwrap()); } // All future connections should be slow! connector.stall(); // This should start making new connections... - set.set_wanted_count(config.new_wanted).await.unwrap(); + set.set_wanted_count(config.new_wanted).unwrap(); set.terminate().await; @@ -2121,10 +2193,130 @@ mod test { drop(handles); - assert!(matches!( - set.claim(ClaimId::new()).await.map(|_| ()).unwrap_err(), - Error::SlotWorkerTerminated, - )); + let err = set.claim(ClaimId::new()).map(|_| ()).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated,), + "Unexpected err: {err}" + ); } } + + #[tokio::test] + async fn test_terminate_during_health_validation() { + setup_tracing_subscriber(); + + let connector = Arc::new(crate::test_utils::SlowConnector::new()); + let mut set = Set::new( + 0, + pool::Name::new("my-pool"), + SetConfig { + // Fast health checks so they run frequently + health_interval: Duration::from_millis(1), + spread: Duration::ZERO, + // Explicit timeout for health checks + health_check_timeout: Duration::from_secs(5), + ..Default::default() + }, + 3, + backend::Name::new("Test set"), + backend::Backend { address: BACKEND }, + connector.clone(), + ); + + // Wait for the connections to come online + set.monitor() + .wait_for(|state| matches!(state, SetState::Online { .. })) + .await + .unwrap(); + + // Make health validation take forever + connector.stall(); + + // Give the health check time to start (it runs every 1ms) + tokio::time::sleep(Duration::from_millis(10)).await; + + // Terminate should complete quickly, even though health checks are blocked + let start = std::time::Instant::now(); + set.terminate().await; + let elapsed = start.elapsed(); + + // Termination should be fast (less than 1 second) + // If it's slow, it means we waited for the stalled health check + assert!( + elapsed < Duration::from_secs(1), + "Termination took too long: {:?}", + elapsed + ); + + // After termination, connector should not be accessed + connector.panic_on_access(); + + let err = set.claim(ClaimId::new()).map(|_| ()).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated,), + "Unexpected err: {err}" + ); + } + + #[tokio::test] + async fn test_terminate_during_recycling() { + setup_tracing_subscriber(); + + let connector = Arc::new(crate::test_utils::SlowConnector::new()); + let mut set = Set::new( + 0, + pool::Name::new("my-pool"), + SetConfig { + // Make health checks slow so they don't interfere + health_interval: Duration::from_secs(1000), + // Explicit timeout for recycling operations + health_check_timeout: Duration::from_secs(5), + ..Default::default() + }, + 3, + backend::Name::new("Test set"), + backend::Backend { address: BACKEND }, + connector.clone(), + ); + + // Wait for the connections to come online + set.monitor() + .wait_for(|state| matches!(state, SetState::Online { .. })) + .await + .unwrap(); + + // Claim a connection to trigger recycling later + let handle = set.claim(ClaimId::new()).unwrap(); + + // Make recycling take forever + connector.stall(); + + // Return the handle to trigger recycling + drop(handle); + + // Give recycling time to start + tokio::time::sleep(Duration::from_millis(10)).await; + + // Terminate should complete quickly, even though recycling is blocked + let start = std::time::Instant::now(); + set.terminate().await; + let elapsed = start.elapsed(); + + // Termination should be fast (less than 1 second) + // If it's slow, it means we waited for the stalled recycling + assert!( + elapsed < Duration::from_secs(1), + "Termination took too long: {:?}", + elapsed + ); + + // After termination, connector should not be accessed + connector.panic_on_access(); + + let err = set.claim(ClaimId::new()).map(|_| ()).unwrap_err(); + assert!( + matches!(err, Error::SlotWorkerTerminated,), + "Unexpected err: {err}" + ); + } }