diff --git a/Cargo.lock b/Cargo.lock index b7503bf74f..833af50ffd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1797,6 +1797,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "linkerd-proxy-pool" +version = "0.1.0" +dependencies = [ + "futures", + "linkerd-error", + "linkerd-proxy-core", + "linkerd-stack", + "linkerd-tracing", + "parking_lot", + "pin-project", + "thiserror", + "tokio", + "tokio-stream", + "tokio-test", + "tokio-util", + "tower-test", + "tracing", +] + [[package]] name = "linkerd-proxy-resolve" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 118ce889c2..a00b13f1f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ members = [ "linkerd/proxy/dns-resolve", "linkerd/proxy/http", "linkerd/proxy/identity-client", + "linkerd/proxy/pool", "linkerd/proxy/resolve", "linkerd/proxy/server-policy", "linkerd/proxy/tap", diff --git a/linkerd/proxy/pool/Cargo.toml b/linkerd/proxy/pool/Cargo.toml new file mode 100644 index 0000000000..ab1f8ac7eb --- /dev/null +++ b/linkerd/proxy/pool/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "linkerd-proxy-pool" +version = "0.1.0" +authors = ["Linkerd Developers "] +license = "Apache-2.0" +edition = "2021" +publish = false + +[dependencies] +futures = { version = "0.3", default-features = false } +linkerd-error = { path = "../../error" } +# linkerd-metrics = { path = "../../metrics" } +linkerd-proxy-core = { path = "../core" } +linkerd-stack = { path = "../../stack" } +parking_lot = "0.12" +pin-project = "1" +# rand = "0.8" +thiserror = "1" +tokio = { version = "1", features = ["rt", "sync", "time"] } +# tokio-stream = { version = "0.1", features = ["sync"] } +tokio-util = "0.7" +tracing = "0.1" + +[dev-dependencies] +linkerd-tracing = { path = "../../tracing" } +tokio-stream = { version = "0.1", features = ["sync"] } +tokio-test = "0.4" +tower-test = "0.4" diff --git a/linkerd/proxy/pool/src/error.rs b/linkerd/proxy/pool/src/error.rs new file mode 100644 index 0000000000..511089449a --- /dev/null +++ b/linkerd/proxy/pool/src/error.rs @@ -0,0 +1,69 @@ +//! Error types for the `Buffer` middleware. + +use linkerd_error::Error; +use std::{fmt, sync::Arc}; + +/// A shareable, terminal error produced by either a service or discovery +/// resolution. +/// +/// [`Service`]: crate::Service +/// [`Buffer`]: crate::buffer::Buffer +#[derive(Clone, Debug)] +pub struct TerminalFailure { + inner: Arc, +} + +/// An error produced when the a buffer's worker closes unexpectedly. +pub struct Closed { + _p: (), +} + +// ===== impl ServiceError ===== + +impl TerminalFailure { + pub(crate) fn new(inner: Error) -> TerminalFailure { + let inner = Arc::new(inner); + TerminalFailure { inner } + } + + // Private to avoid exposing `Clone` trait as part of the public API + pub(crate) fn clone(&self) -> TerminalFailure { + TerminalFailure { + inner: self.inner.clone(), + } + } +} + +impl fmt::Display for TerminalFailure { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "buffered service failed: {}", self.inner) + } +} + +impl std::error::Error for TerminalFailure { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(&**self.inner) + } +} + +// ===== impl Closed ===== + +impl Closed { + pub(crate) fn new() -> Self { + Closed { _p: () } + } +} + +impl fmt::Debug for Closed { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_tuple("Closed").finish() + } +} + +impl fmt::Display for Closed { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("buffer's worker closed unexpectedly") + } +} + +impl std::error::Error for Closed {} diff --git a/linkerd/proxy/pool/src/failfast.rs b/linkerd/proxy/pool/src/failfast.rs new file mode 100644 index 0000000000..c2bc51ffa2 --- /dev/null +++ b/linkerd/proxy/pool/src/failfast.rs @@ -0,0 +1,130 @@ +use linkerd_stack::gate; +use std::pin::Pin; +use tokio::time; + +/// Manages the failfast state for a pool. +#[derive(Debug)] +pub(super) struct Failfast { + timeout: time::Duration, + sleep: Pin>, + state: Option, + gate: gate::Tx, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(super) enum State { + Waiting { since: time::Instant }, + Failfast { since: time::Instant }, +} + +// === impl Failfast === + +impl Failfast { + pub(super) fn new(timeout: time::Duration, gate: gate::Tx) -> Self { + Self { + timeout, + sleep: Box::pin(time::sleep(time::Duration::MAX)), + state: None, + gate, + } + } + + pub(super) fn duration(&self) -> time::Duration { + self.timeout + } + + /// Returns true if we are currently in a failfast state. + pub(super) fn is_active(&self) -> bool { + matches!(self.state, Some(State::Failfast { .. })) + } + + /// Clears any waiting or failfast state. + pub(super) fn set_ready(&mut self) -> Option { + let state = self.state.take()?; + if matches!(state, State::Failfast { .. }) { + tracing::trace!("Exiting failfast"); + let _ = self.gate.open(); + } + Some(state) + } + + /// Waits for the failfast timeout to expire and enters the failfast state. + pub(super) async fn timeout(&mut self) { + let since = match self.state { + // If we're already in failfast, then we don't need to wait. + Some(State::Failfast { .. }) => { + return; + } + + // Ensure that the timer's been initialized. + Some(State::Waiting { since }) => since, + None => { + let now = time::Instant::now(); + self.sleep.as_mut().reset(now + self.timeout); + self.state = Some(State::Waiting { since: now }); + now + } + }; + + // Wait for the failfast timer to expire. + tracing::trace!("Waiting for failfast timeout"); + self.sleep.as_mut().await; + tracing::trace!("Entering failfast"); + + // Once we enter failfast, shut the upstream gate so that we can + // advertise backpressure past the queue. + self.state = Some(State::Failfast { since }); + let _ = self.gate.shut(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use futures::prelude::*; + + #[tokio::test(flavor = "current_thread", start_paused = true)] + async fn failfast() { + let (tx, gate_rx) = gate::channel(); + let dur = time::Duration::from_secs(1); + let mut failfast = Failfast::new(dur, tx); + + assert_eq!(dur, failfast.duration()); + assert!(gate_rx.is_open()); + + // The failfast timeout should not be initialized until the first + // request is received. + assert!(!failfast.is_active(), "failfast should be active"); + + failfast.timeout().await; + assert!(failfast.is_active(), "failfast should be active"); + assert!(gate_rx.is_shut(), "gate should be shut"); + + failfast + .timeout() + .now_or_never() + .expect("timeout must return immediately when in failfast"); + assert!(failfast.is_active(), "failfast should be active"); + assert!(gate_rx.is_shut(), "gate should be shut"); + + failfast.set_ready(); + assert!(!failfast.is_active(), "failfast should be inactive"); + assert!(gate_rx.is_open(), "gate should be open"); + + tokio::select! { + _ = time::sleep(time::Duration::from_millis(10)) => {} + _ = failfast.timeout() => unreachable!("timed out too quick"), + } + assert!(!failfast.is_active(), "failfast should be inactive"); + assert!(gate_rx.is_open(), "gate should be open"); + + assert!( + matches!(failfast.state, Some(State::Waiting { .. })), + "failfast should be waiting" + ); + + failfast.timeout().await; + assert!(failfast.is_active(), "failfast should be active"); + assert!(gate_rx.is_shut(), "gate should be shut"); + } +} diff --git a/linkerd/proxy/pool/src/future.rs b/linkerd/proxy/pool/src/future.rs new file mode 100644 index 0000000000..ce5bd24a13 --- /dev/null +++ b/linkerd/proxy/pool/src/future.rs @@ -0,0 +1,77 @@ +//! Future types for the [`Buffer`] middleware. +//! +//! [`Buffer`]: crate::buffer::Buffer + +use super::{error::Closed, message}; +use futures::ready; +use linkerd_error::Error; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +#[pin_project] +/// Future that completes when the buffered service eventually services the submitted request. +#[derive(Debug)] +pub struct ResponseFuture { + #[pin] + state: ResponseState, +} + +#[pin_project(project = ResponseStateProj)] +#[derive(Debug)] +enum ResponseState { + Failed { + error: Option, + }, + Rx { + #[pin] + rx: message::Rx, + }, + Poll { + #[pin] + fut: T, + }, +} + +impl ResponseFuture { + pub(crate) fn new(rx: message::Rx) -> Self { + ResponseFuture { + state: ResponseState::Rx { rx }, + } + } + + pub(crate) fn failed(err: Error) -> Self { + ResponseFuture { + state: ResponseState::Failed { error: Some(err) }, + } + } +} + +impl Future for ResponseFuture +where + F: Future>, + E: Into, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + loop { + match this.state.as_mut().project() { + ResponseStateProj::Failed { error } => { + return Poll::Ready(Err(error.take().expect("polled after error"))); + } + ResponseStateProj::Rx { rx } => match ready!(rx.poll(cx)) { + Ok(Ok(fut)) => this.state.set(ResponseState::Poll { fut }), + Ok(Err(e)) => return Poll::Ready(Err(e)), + Err(_) => return Poll::Ready(Err(Closed::new().into())), + }, + ResponseStateProj::Poll { fut } => return fut.poll(cx).map_err(Into::into), + } + } + } +} diff --git a/linkerd/proxy/pool/src/lib.rs b/linkerd/proxy/pool/src/lib.rs new file mode 100644 index 0000000000..44ca4604cb --- /dev/null +++ b/linkerd/proxy/pool/src/lib.rs @@ -0,0 +1,39 @@ +//! Adapted from [`tower::buffer`][buffer]. +//! +//! [buffer]: https://github.com/tower-rs/tower/tree/bf4ea948346c59a5be03563425a7d9f04aadedf2/tower/src/buffer +// +// Copyright (c) 2019 Tower Contributors + +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +mod error; +mod failfast; +mod future; +mod message; +mod service; +#[cfg(test)] +mod tests; +mod worker; + +pub use self::service::PoolQueue; +pub use linkerd_proxy_core::Update; + +use linkerd_stack::Service; + +/// A collection of services updated from a resolution. +pub trait Pool: Service { + /// Updates the pool's endpoints. + fn update_pool(&mut self, update: Update); + + /// Polls pending endpoints to ready. + /// + /// This is complementary to [`Service::poll_ready`], which may also handle + /// driving endpoints to ready. Unlike [`Service::poll_ready`], which + /// returns ready when *at least one* inner endpoint is ready, + /// [`Pool::poll_pool`] returns ready when *all* inner endpoints are ready. + fn poll_pool( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll>; +} diff --git a/linkerd/proxy/pool/src/message.rs b/linkerd/proxy/pool/src/message.rs new file mode 100644 index 0000000000..60bb4d77a5 --- /dev/null +++ b/linkerd/proxy/pool/src/message.rs @@ -0,0 +1,26 @@ +use linkerd_error::Result; +use tokio::{sync::oneshot, time}; + +/// Message sent over buffer +#[derive(Debug)] +pub(crate) struct Message { + pub(crate) req: Req, + pub(crate) tx: Tx, + pub(crate) span: tracing::Span, + pub(crate) t0: time::Instant, +} + +/// Response sender +type Tx = oneshot::Sender>; + +/// Response receiver +pub(crate) type Rx = oneshot::Receiver>; + +impl Message { + pub(crate) fn channel(req: Req) -> (Self, Rx) { + let (tx, rx) = oneshot::channel(); + let t0 = time::Instant::now(); + let span = tracing::Span::current(); + (Message { req, span, tx, t0 }, rx) + } +} diff --git a/linkerd/proxy/pool/src/service.rs b/linkerd/proxy/pool/src/service.rs new file mode 100644 index 0000000000..9901d5baf5 --- /dev/null +++ b/linkerd/proxy/pool/src/service.rs @@ -0,0 +1,120 @@ +use crate::{error, future::ResponseFuture, message::Message, worker, Pool}; +use futures::TryStream; +use linkerd_error::{Error, Result}; +use linkerd_proxy_core::Update; +use linkerd_stack::{gate, Service}; +use parking_lot::Mutex; +use std::{ + future::Future, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{sync::mpsc, time}; +use tokio_util::sync::PollSender; + +/// A shareable service backed by a dynamic endpoint. +#[derive(Debug)] +pub struct PoolQueue { + tx: PollSender>, + terminal: Arc>>, +} + +/// Provides a copy of the terminal failure error to all handles. +#[derive(Clone, Debug, Default)] +pub(crate) struct Terminate { + inner: Arc>>, +} + +// === impl SharedTerminalFailure === + +impl Terminate { + pub(crate) fn send(self, error: error::TerminalFailure) { + *self.inner.lock() = Some(error); + } +} + +impl PoolQueue +where + Req: Send + 'static, + F: Send + 'static, +{ + pub fn spawn( + capacity: usize, + failfast: time::Duration, + resolution: R, + pool: P, + ) -> gate::Gate + where + T: Clone + Eq + std::fmt::Debug + Send, + R: TryStream> + Send + Unpin + 'static, + R::Error: Into + Send, + P: Pool + Send + 'static, + P::Error: Into + Send + Sync, + Req: Send + 'static, + { + let (gate_tx, gate_rx) = gate::channel(); + let (tx, rx) = mpsc::channel(capacity); + let inner = Self::new(tx); + let terminate = Terminate { + inner: inner.terminal.clone(), + }; + worker::spawn(rx, failfast, gate_tx, terminate, resolution, pool); + gate::Gate::new(gate_rx, inner) + } + + fn new(tx: mpsc::Sender>) -> Self { + Self { + tx: PollSender::new(tx), + terminal: Default::default(), + } + } + + #[inline] + fn error_or_closed(&self) -> Error { + (*self.terminal.lock()) + .clone() + .map(Into::into) + .unwrap_or_else(|| error::Closed::new().into()) + } +} + +impl Service for PoolQueue +where + Req: Send + 'static, + F: Future> + Send + 'static, + E: Into, +{ + type Response = Rsp; + type Error = Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + let poll = self.tx.poll_reserve(cx).map_err(|_| self.error_or_closed()); + tracing::trace!(?poll); + poll + } + + fn call(&mut self, req: Req) -> Self::Future { + tracing::trace!("Sending request to worker"); + let (msg, rx) = Message::channel(req); + if self.tx.send_item(msg).is_err() { + // The channel closed since poll_ready was called, so propagate the + // failure in the response future. + return ResponseFuture::failed(self.error_or_closed()); + } + ResponseFuture::new(rx) + } +} + +impl Clone for PoolQueue +where + Req: Send + 'static, + F: Send + 'static, +{ + fn clone(&self) -> Self { + Self { + terminal: self.terminal.clone(), + tx: self.tx.clone(), + } + } +} diff --git a/linkerd/proxy/pool/src/tests.rs b/linkerd/proxy/pool/src/tests.rs new file mode 100644 index 0000000000..0233381316 --- /dev/null +++ b/linkerd/proxy/pool/src/tests.rs @@ -0,0 +1,343 @@ +use crate::PoolQueue; +use futures::prelude::*; +use linkerd_proxy_core::Update; +use linkerd_stack::{Service, ServiceExt}; +use tokio::{sync::mpsc, time}; +use tokio_stream::wrappers::ReceiverStream; + +mod mock; + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn processes_requests() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(1); + assert!(poolq.ready().now_or_never().expect("ready").is_ok()); + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + call.await.expect("response"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn processes_requests_cloned() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq0 = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + let mut poolq1 = poolq0.clone(); + + handle.svc.allow(2); + assert!(poolq0.ready().now_or_never().expect("ready").is_ok()); + assert!(poolq1.ready().now_or_never().expect("ready").is_ok()); + let call0 = poolq0.call(()); + let call1 = poolq1.call(()); + + let ((), respond0) = handle.svc.next_request().await.expect("request"); + respond0.send_response(()); + call0.await.expect("response"); + + let ((), respond1) = handle.svc.next_request().await.expect("request"); + respond1.send_response(()); + call1.await.expect("response"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn limits_request_capacity() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq0 = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + let mut poolq1 = poolq0.clone(); + + handle.svc.allow(0); + assert!(poolq0.ready().now_or_never().expect("ready").is_ok()); + let mut _call0 = poolq0.call(()); + + assert!( + poolq0.ready().now_or_never().is_none(), + "poolq must not be ready when at capacity" + ); + assert!( + poolq1.ready().now_or_never().is_none(), + "poolq must not be ready when at capacity" + ); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn waits_for_endpoints() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + + updates + .try_send(Ok(Update::Reset(vec![( + "192.168.1.44:80".parse().unwrap(), + (), + )]))) + .ok() + .expect("send update"); + handle.set_poll(std::task::Poll::Pending); + tokio::task::yield_now().await; + + handle.set_poll(std::task::Poll::Ready(Ok(()))); + handle.svc.allow(1); + tokio::task::yield_now().await; + + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + call.await.expect("response"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn updates_while_idle() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut _poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + updates + .try_send(Ok(Update::Reset(vec![( + "192.168.1.44:80".parse().unwrap(), + (), + )]))) + .ok() + .expect("send update"); + + tokio::task::yield_now().await; + assert_eq!( + handle.rx.try_recv().expect("must receive update"), + Update::Reset(vec![("192.168.1.44:80".parse().unwrap(), (),)]) + ); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn complete_resolution() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 1, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + // When we drop the update stream, everything continues to work as long as + // the pool is ready. + handle.svc.allow(1); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + drop(updates); + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok()); + + handle.svc.allow(1); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok()); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn error_resolution() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call0 = poolq.call(()); + + updates + .try_send(Err(mock::ResolutionError)) + .ok() + .expect("send update"); + + call0.await.expect_err("response should fail"); + + assert!( + poolq.ready().await.is_err(), + "poolq must error after failed resolution" + ); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn error_pool_while_pending() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + handle.set_poll(std::task::Poll::Pending); + + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + handle.set_poll(std::task::Poll::Ready(Err(mock::PoolError))); + call.await.expect_err("response should fail"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn error_after_ready() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + updates + .try_send(Err(mock::ResolutionError)) + .ok() + .expect("send update"); + tokio::task::yield_now().await; + poolq.call(()).await.expect_err("response should fail"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn terminates() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + drop(poolq); + assert!( + call.await.is_err(), + "call should fail when queue is dropped" + ); + assert!(updates.is_closed()); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn failfast() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + time::sleep(time::Duration::from_secs(1)).await; + assert!(call.await.is_err(), "call should failfast"); + if let Ok(_) = time::timeout(time::Duration::from_secs(1), poolq.ready()).await { + panic!("queue should not be ready while in failfast"); + } + + handle.svc.allow(1); + tokio::task::yield_now().await; + tracing::info!("Waiting for poolq to exit failfast"); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + // A delay doesn't impact failfast behavior when the pool is ready. + time::sleep(time::Duration::from_secs(1)).await; + let call = poolq.call(()); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok(), "call should not failfast"); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); +} + +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn failfast_interrupted() { + let _trace = linkerd_tracing::test::with_default_filter("linkerd=trace"); + + let (pool, mut handle) = mock::pool::<(), (), ()>(); + let (_updates, u) = mpsc::channel::, mock::ResolutionError>>(1); + let mut poolq = PoolQueue::spawn( + 10, + time::Duration::from_secs(1), + ReceiverStream::from(u), + pool, + ); + + handle.svc.allow(0); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); + let call = poolq.call(()); + // Wait for half a failfast timeout and then allow the request to be + // processed. + time::sleep(time::Duration::from_secs_f64(0.5)).await; + handle.svc.allow(1); + let ((), respond) = handle.svc.next_request().await.expect("request"); + respond.send_response(()); + assert!(call.await.is_ok(), "call should not failfast"); + assert!(poolq.ready().await.is_ok(), "poolq must be ready"); +} diff --git a/linkerd/proxy/pool/src/tests/mock.rs b/linkerd/proxy/pool/src/tests/mock.rs new file mode 100644 index 0000000000..f8317532ff --- /dev/null +++ b/linkerd/proxy/pool/src/tests/mock.rs @@ -0,0 +1,90 @@ +use linkerd_error::Error; +use linkerd_proxy_core::Update; +use parking_lot::Mutex; +use std::{ + sync::Arc, + task::{Context, Poll, Waker}, +}; +use tokio::sync::mpsc; +use tower_test::mock; + +pub fn pool() -> (MockPool, PoolHandle) { + let state = Arc::new(Mutex::new(State { + poll: Poll::Ready(Ok(())), + waker: None, + })); + let (updates_tx, updates_rx) = mpsc::unbounded_channel(); + let (mock, svc) = mock::pair(); + let h = PoolHandle { + rx: updates_rx, + state: state.clone(), + svc, + }; + let p = MockPool { + tx: updates_tx, + state, + svc: mock, + }; + (p, h) +} + +pub struct MockPool { + tx: mpsc::UnboundedSender>, + state: Arc>, + svc: mock::Mock, +} + +pub struct PoolHandle { + state: Arc>, + pub rx: mpsc::UnboundedReceiver>, + pub svc: mock::Handle, +} + +struct State { + poll: Poll>, + waker: Option, +} + +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[error("mock pool error")] +pub struct PoolError; + +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[error("mock resolution error")] +pub struct ResolutionError; + +impl crate::Pool for MockPool { + fn update_pool(&mut self, update: Update) { + self.tx.send(update).ok().unwrap(); + } + + fn poll_pool(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut s = self.state.lock(); + s.waker.replace(cx.waker().clone()); + s.poll.map_err(Into::into) + } +} + +impl linkerd_stack::Service for MockPool { + type Response = Rsp; + type Error = Error; + type Future = mock::future::ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.svc.poll_ready(cx) + } + + fn call(&mut self, req: Req) -> Self::Future { + self.svc.call(req) + } +} + +impl PoolHandle { + pub fn set_poll(&self, poll: Poll>) { + let mut s = self.state.lock(); + s.poll = poll; + if let Some(w) = s.waker.take() { + w.wake(); + } + } +} diff --git a/linkerd/proxy/pool/src/worker.rs b/linkerd/proxy/pool/src/worker.rs new file mode 100644 index 0000000000..f0e44a26aa --- /dev/null +++ b/linkerd/proxy/pool/src/worker.rs @@ -0,0 +1,349 @@ +use std::future::poll_fn; + +use crate::{ + error::TerminalFailure, + failfast::{self, Failfast}, + message::Message, + service::Terminate, + Pool, +}; +use futures::{future, TryStream, TryStreamExt}; +use linkerd_error::{Error, Result}; +use linkerd_proxy_core::Update; +use linkerd_stack::{gate, FailFastError, ServiceExt}; +use tokio::{sync::mpsc, task::JoinHandle, time}; +use tracing::{debug_span, Instrument}; + +#[derive(Debug)] +struct Worker { + pool: PoolDriver

, + discovery: Discovery, +} + +/// Manages the pool's readiness state, handling failfast timeouts. +#[derive(Debug)] +struct PoolDriver

{ + pool: P, + failfast: Failfast, +} + +/// Processes endpoint updates from service discovery. +#[derive(Debug)] +struct Discovery { + resolution: R, + closed: bool, +} + +/// Spawns a task that simultaneously updates a pool of services from a +/// discovery stream and dispatches requests to it. +/// +/// If the pool service does not become ready within the failfast timeout, then +/// request are failed with a FailFastError until the pool becomes ready. While +/// in the failfast state, the provided gate is shut so that the caller may +/// exert backpressure to eliminate requests from being added to the queue. +pub(crate) fn spawn( + mut reqs_rx: mpsc::Receiver>, + failfast: time::Duration, + gate: gate::Tx, + terminate: Terminate, + updates_rx: R, + pool: P, +) -> JoinHandle> +where + Req: Send + 'static, + T: Clone + Eq + std::fmt::Debug + Send, + R: TryStream> + Unpin + Send + 'static, + R::Error: Into + Send, + P: Pool + Send + 'static, + P::Future: Send + 'static, + P::Error: Into + Send, +{ + let mut terminate = Some(terminate); + let mut terminal_failure = None; + tokio::spawn( + async move { + let mut worker = Worker { + pool: PoolDriver::new(pool, Failfast::new(failfast, gate)), + discovery: Discovery::new(updates_rx), + }; + + loop { + // Drive the pool with discovery updates while waiting for a + // request. + // + // NOTE: We do NOT require that pool become ready before + // processing a request, so this technically means that the + // queue supports capacity + 1 items. This behavior is + // inherrited from tower::buffer. Correcting this is not worth + // the complexity. + let Message { req, tx, span, t0 } = tokio::select! { + biased; + + // If either the discovery stream or the pool fail, close + // the request stream and process any remaining requests. + e = worker.discover_while_awaiting_requests(), if terminal_failure.is_none() => { + let err = TerminalFailure::new(e); + terminate.take().expect("must not fail twice").send(err.clone()); + reqs_rx.close(); + tracing::trace!("Closed"); + terminal_failure = Some(err); + continue; + } + + msg = reqs_rx.recv() => match msg { + Some(msg) => msg, + None => { + tracing::debug!("Requests channel closed"); + return Ok(()); + } + }, + }; + + // Preserve the original request's tracing context. + let _enter = span.enter(); + + // Wait for the pool to have at least one ready endpoint. + if terminal_failure.is_none() { + tracing::trace!("Waiting for pool"); + if let Err(e) = worker.ready_pool().await { + let err = TerminalFailure::new(e); + terminate + .take() + .expect("must not fail twice") + .send(err.clone()); + reqs_rx.close(); + terminal_failure = Some(err); + tracing::trace!("Closed"); + } else { + tracing::trace!("Pool ready"); + } + } + + // Process requests, either by dispatching them to the pool or + // by serving errors directly. + let _ = if let Some(e) = terminal_failure.clone() { + tx.send(Err(e.into())) + } else { + tx.send(worker.pool.call(req)) + }; + + // TODO(ver) track histogram from t0 until the request is dispatched. + tracing::debug!( + latency = (time::Instant::now() - t0).as_secs_f64(), + "Dispatched" + ); + } + } + .instrument(debug_span!("pool")), + ) +} + +// === impl Worker === + +impl Worker +where + T: Clone + Eq + std::fmt::Debug, + R: TryStream> + Unpin, + R::Error: Into, +{ + /// Attempts to update the pool with discovery updates. + /// + /// Additionally, this attempts to drive the pool to ready if it is + /// currently in failfast. + /// + /// If the discovery stream is closed, this never returns. + async fn discover_while_awaiting_requests(&mut self) -> Error + where + P: Pool, + P::Error: Into, + { + tracing::trace!("Discovering while awaiting requests"); + + loop { + let update = tokio::select! { + e = self.pool.drive() => return e, + res = self.discovery.discover() => match res { + Err(e) => return e, + Ok(up) => up, + }, + }; + + tracing::debug!(?update, "Discovered"); + self.pool.pool.update_pool(update); + } + } + + /// Wait for the pool to have at least one ready endpoint, while also + /// processing service discovery updates (e.g. to provide new available + /// endpoints). + async fn ready_pool(&mut self) -> Result<(), Error> + where + P: Pool, + P::Error: Into, + { + loop { + tokio::select! { + // Tests, especially, depend on discovery updates being + // processed before ready returning. + biased; + + // If the pool updated, continue waiting for the pool to be + // ready. + res = self.discovery.discover() => { + let update = res?; + tracing::debug!(?update, "Discovered"); + self.pool.pool.update_pool(update); + } + + // When the pool is ready, clear any failfast state we may have + // set before returning. + res = self.pool.ready() => { + tracing::trace!(ready.ok = res.is_ok()); + return res; + } + } + } + } +} + +// === impl Discovery === + +impl Discovery +where + T: Clone + Eq + std::fmt::Debug, + R: TryStream> + Unpin, + R::Error: Into, +{ + fn new(resolution: R) -> Self { + Self { + resolution, + closed: false, + } + } + + /// Await the next service discovery update. + /// + /// If the discovery stream has closed, this never returns. + async fn discover(&mut self) -> Result, Error> { + if self.closed { + // Never returns. + return futures::future::pending().await; + } + + match self.resolution.try_next().await { + Ok(Some(up)) => Ok(up), + + Ok(None) => { + tracing::debug!("Resolution stream closed"); + self.closed = true; + // Never returns. + futures::future::pending().await + } + + Err(e) => { + let error = e.into(); + tracing::debug!(%error, "Resolution stream failed"); + self.closed = true; + Err(error) + } + } + } +} + +// === impl PoolDriver === + +impl

PoolDriver

{ + fn new(pool: P, failfast: Failfast) -> Self { + Self { pool, failfast } + } + + /// Drives all current endpoints to ready. + /// + /// If the service is in failfast, this clears the failfast state on readiness. + /// + /// If any endpoint fails, the error + /// is returned. + async fn drive(&mut self) -> Error + where + P: Pool, + P::Error: Into, + { + if self.failfast.is_active() { + tracing::trace!("Waiting to leave failfast"); + let res = self.pool.ready().await; + match self.failfast.set_ready() { + Some(failfast::State::Failfast { since }) => { + tracing::info!( + elapsed = (time::Instant::now() - since).as_secs_f64(), + "Available; exited failfast" + ); + } + _ => unreachable!("must be in failfast"), + } + if let Err(e) = res { + return e.into(); + } + } + + tracing::trace!("Driving pending endpoints"); + if let Err(e) = poll_fn(|cx| self.pool.poll_pool(cx)).await { + return e.into(); + } + + tracing::trace!("Driven"); + future::pending().await + } + + async fn ready(&mut self) -> Result<(), Error> + where + P: Pool, + P::Error: Into, + { + tokio::select! { + biased; + + res = self.pool.ready() => { + match self.failfast.set_ready() { + None => tracing::trace!("Ready"), + Some(failfast::State::Waiting { since }) => { + tracing::debug!( + elapsed = (time::Instant::now() - since).as_secs_f64(), + "Available" + ); + } + Some(failfast::State::Failfast { since }) => { + tracing::info!( + elapsed = (time::Instant::now() - since).as_secs_f64(), + "Available; exited failfast" + ); + } + } + if let Err(e) = res { + return Err(e.into()); + } + } + + () = self.failfast.timeout() => { + tracing::info!( + timeout = self.failfast.duration().as_secs_f64(), "Unavailable; entering failfast", + ); + } + } + + Ok(()) + } + + fn call(&mut self, req: Req) -> Result + where + P: Pool, + P::Error: Into, + { + // If we've tripped failfast, fail the request. + if self.failfast.is_active() { + return Err(FailFastError::default().into()); + } + + // Otherwise dispatch the request to the pool. + Ok(self.pool.call(req)) + } +} diff --git a/linkerd/stack/src/failfast.rs b/linkerd/stack/src/failfast.rs index 0f26dc02f2..fb96bd7dcd 100644 --- a/linkerd/stack/src/failfast.rs +++ b/linkerd/stack/src/failfast.rs @@ -44,7 +44,7 @@ pub struct FailFast { } /// An error representing that an operation timed out. -#[derive(Debug, Error)] +#[derive(Debug, Default, Error)] #[error("service in fail-fast")] pub struct FailFastError(());