From 843bfdea880010781008c44765b6b79db8d661f3 Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Fri, 17 Feb 2023 10:48:36 +0100 Subject: [PATCH] Abort on drop --- src/invoker/src/invocation_task.rs | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/invoker/src/invocation_task.rs b/src/invoker/src/invocation_task.rs index a67b824868..187c22fcd0 100644 --- a/src/invoker/src/invocation_task.rs +++ b/src/invoker/src/invocation_task.rs @@ -1,5 +1,8 @@ use std::cmp; use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use bytes::Bytes; use common::types::{PartitionLeaderEpoch, ServiceInvocationId}; @@ -17,6 +20,7 @@ use opentelemetry::sdk::propagation::TraceContextPropagator; use opentelemetry_http::HeaderInjector; use tokio::sync::mpsc; use tokio::task::JoinError; +use tokio::task::JoinHandle; use tracing::{debug, trace}; use super::message::{ @@ -266,7 +270,10 @@ where // Because the body sender blocks on waiting for the request body buffer to be available, // we need to spawn the request initiation separately, otherwise the loop below // will deadlock on the journal entry write. - let mut req_fut = tokio::task::spawn(client.request(req)); + // This task::spawn won't be required by hyper 1.0, as the connection will be driven by a task + // spawned somewhere else (perhaps in the connection pool). + // See: https://github.com/restatedev/restate/issues/96 and https://github.com/restatedev/restate/issues/76 + let mut req_fut = AbortOnDrop(tokio::task::spawn(client.request(req))); let mut http_stream_rx_res = None; @@ -531,3 +538,24 @@ where Ok(()) } } + +/// This wrapper makes sure we abort the task when the JoinHandle is dropped, +/// but it doesn't wait for the task to complete, because we simply don't have async drops! +/// For more: https://github.com/tokio-rs/tokio/issues/2596 +/// Inspired by: https://github.com/cyb0124/abort-on-drop +#[derive(Debug)] +struct AbortOnDrop(JoinHandle); + +impl Future for AbortOnDrop { + type Output = as Future>::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl Drop for AbortOnDrop { + fn drop(&mut self) { + self.0.abort() + } +}