diff --git a/linkerd/stack/src/gate.rs b/linkerd/stack/src/gate.rs index 6f3411dcb7..8c0bee9968 100644 --- a/linkerd/stack/src/gate.rs +++ b/linkerd/stack/src/gate.rs @@ -173,11 +173,12 @@ where type Future = S::Future; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.permit.is_ready() { - return Poll::Ready(Ok(())); - } + // If we previously polled to ready and acquired a permit, clear it so + // we can reestablish readiness without holding it. + self.permit = Poll::Pending; let permit = ready!(self.poll_acquire(cx)); ready!(self.inner.poll_ready(cx))?; + tracing::trace!("Acquired permit"); self.permit = Poll::Ready(permit); Poll::Ready(Ok(())) } @@ -227,6 +228,7 @@ impl Gate { #[cfg(test)] mod tests { use super::*; + use std::sync::atomic::AtomicBool; use tokio_test::{assert_pending, assert_ready, task}; #[tokio::test] @@ -262,6 +264,47 @@ mod tests { assert_ready!(gate.poll_ready()).expect("ok"); } + #[tokio::test] + async fn gate_repolls_back_to_pending() { + let (tx, rx) = channel(); + let pending = Arc::new(AtomicBool::new(false)); + let (mut gate, mut handle) = { + struct Svc(S, Arc); + impl> Service for Svc { + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.1.load(std::sync::atomic::Ordering::Relaxed) { + return Poll::Pending; + } + self.0.poll_ready(cx) + } + fn call(&mut self, req: Req) -> Self::Future { + self.0.call(req) + } + } + + let pending = pending.clone(); + tower_test::mock::spawn_with::<(), (), _, _>(move |inner| { + Gate::new(rx.clone(), Svc(inner, pending.clone())) + }) + }; + + tx.open(); + handle.allow(1); + assert_ready!(gate.poll_ready()).expect("ok"); + + pending.store(true, std::sync::atomic::Ordering::Relaxed); + assert_pending!(gate.poll_ready()); + + pending.store(false, std::sync::atomic::Ordering::Relaxed); + assert_ready!(gate.poll_ready()).expect("ok"); + } + #[tokio::test] async fn notifies_on_open() { let (tx, rx) = channel();