Skip to content

Commit 33553c9

Browse files
authored
Fix termination bug (#57)
1 parent 665d8ca commit 33553c9

File tree

3 files changed

+123
-4
lines changed

3 files changed

+123
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ tracing-subscriber = "0.3"
3333

3434
[package]
3535
name = "qorb"
36-
version = "0.1.1"
36+
version = "0.1.2"
3737
edition = "2021"
3838
description = "Connection Pooling"
3939
license = "MPL-2.0"

src/pool.rs

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ mod test {
542542
use async_trait::async_trait;
543543
use std::collections::BTreeMap;
544544
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
545-
use std::sync::atomic::{AtomicUsize, Ordering};
545+
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
546546

547547
#[derive(Clone)]
548548
struct TestResolver {
@@ -802,4 +802,112 @@ mod test {
802802
Error::Terminated,
803803
));
804804
}
805+
806+
struct SlowConnector {
807+
delay_ms: AtomicU64,
808+
}
809+
810+
impl SlowConnector {
811+
fn new() -> Self {
812+
Self {
813+
delay_ms: AtomicU64::new(1),
814+
}
815+
}
816+
817+
async fn go_slow(&self) {
818+
let delay_ms = self.delay_ms.load(Ordering::SeqCst);
819+
tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
820+
}
821+
}
822+
823+
#[async_trait]
824+
impl Connector for SlowConnector {
825+
type Connection = ();
826+
827+
async fn connect(&self, _backend: &Backend) -> Result<Self::Connection, backend::Error> {
828+
self.go_slow().await;
829+
Ok(())
830+
}
831+
832+
async fn is_valid(&self, _: &mut Self::Connection) -> Result<(), backend::Error> {
833+
self.go_slow().await;
834+
Ok(())
835+
}
836+
837+
async fn on_acquire(&self, _: &mut Self::Connection) -> Result<(), backend::Error> {
838+
self.go_slow().await;
839+
Ok(())
840+
}
841+
}
842+
843+
fn setup_tracing_subscriber() {
844+
use tracing_subscriber::fmt::format::FmtSpan;
845+
tracing_subscriber::fmt()
846+
.with_thread_names(true)
847+
.with_span_events(FmtSpan::ENTER)
848+
.with_max_level(tracing::Level::TRACE)
849+
.with_test_writer()
850+
.init();
851+
}
852+
853+
#[tokio::test]
854+
async fn test_terminate_with_slow_active_claim() {
855+
setup_tracing_subscriber();
856+
857+
let resolver = Box::new(TestResolver::new());
858+
let connector = Arc::new(SlowConnector::new());
859+
let address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
860+
861+
resolver.replace(BTreeMap::from([(
862+
backend::Name::new("aaa"),
863+
Backend::new(address),
864+
)]));
865+
866+
let pool = Pool::new(resolver, connector.clone(), Policy::default());
867+
let _handle = pool.claim().await.expect("Failed to get claim");
868+
869+
// This delay is enormous, but the point is that termination should not
870+
// get stuck behind any ongoing operations that might happen.
871+
connector.delay_ms.store(99999999, Ordering::SeqCst);
872+
873+
pool.terminate().await.unwrap();
874+
assert!(matches!(
875+
pool.terminate().await.unwrap_err(),
876+
Error::Terminated,
877+
));
878+
assert!(matches!(
879+
pool.claim().await.map(|_| ()).unwrap_err(),
880+
Error::Terminated,
881+
));
882+
}
883+
884+
#[tokio::test]
885+
async fn test_terminate_with_slow_setup() {
886+
setup_tracing_subscriber();
887+
888+
let resolver = Box::new(TestResolver::new());
889+
let connector = Arc::new(SlowConnector::new());
890+
let address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
891+
892+
resolver.replace(BTreeMap::from([(
893+
backend::Name::new("aaa"),
894+
Backend::new(address),
895+
)]));
896+
897+
// This delay is enormous, but the point is that termination should not
898+
// get stuck behind any ongoing operations that might happen.
899+
connector.delay_ms.store(99999999, Ordering::SeqCst);
900+
901+
let pool = Pool::new(resolver, connector.clone(), Policy::default());
902+
903+
pool.terminate().await.unwrap();
904+
assert!(matches!(
905+
pool.terminate().await.unwrap_err(),
906+
Error::Terminated,
907+
));
908+
assert!(matches!(
909+
pool.claim().await.map(|_| ()).unwrap_err(),
910+
Error::Terminated,
911+
));
912+
}
805913
}

src/slot.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ impl<Conn: Connection> SetWorker<Conn> {
578578
match work {
579579
Work::DoConnect => {
580580
let span = span!(Level::TRACE, "Slot worker connecting", slot_id);
581-
async {
581+
let connected = async {
582582
if !slot
583583
.loop_until_connected(
584584
&config,
@@ -590,19 +590,30 @@ impl<Conn: Connection> SetWorker<Conn> {
590590
{
591591
// The slot was instructed to exit
592592
// before it connected. Bail.
593-
return;
593+
event!(
594+
Level::TRACE,
595+
slot_id,
596+
"Terminating instead of connecting"
597+
);
598+
return false;
594599
}
595600
interval.reset_after(interval.period().add_spread(config.spread));
601+
true
596602
}
597603
.instrument(span)
598604
.await;
605+
606+
if !connected {
607+
return;
608+
}
599609
}
600610
Work::DoMonitor => {
601611
tokio::select! {
602612
biased;
603613
_ = &mut terminate_rx => {
604614
// If we've been instructed to bail out,
605615
// do that immediately.
616+
event!(Level::TRACE, slot_id, "Terminating while monitoring");
606617
return;
607618
},
608619
_ = interval.tick() => {

0 commit comments

Comments
 (0)