Skip to content

Commit

Permalink
Abort on drop
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Feb 17, 2023
1 parent 92672fd commit 843bfde
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion src/invoker/src/invocation_task.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::cmp;
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use bytes::Bytes;
use common::types::{PartitionLeaderEpoch, ServiceInvocationId};
Expand All @@ -17,6 +20,7 @@ use opentelemetry::sdk::propagation::TraceContextPropagator;
use opentelemetry_http::HeaderInjector;
use tokio::sync::mpsc;
use tokio::task::JoinError;
use tokio::task::JoinHandle;
use tracing::{debug, trace};

use super::message::{
Expand Down Expand Up @@ -266,7 +270,10 @@ where
// Because the body sender blocks on waiting for the request body buffer to be available,
// we need to spawn the request initiation separately, otherwise the loop below
// will deadlock on the journal entry write.
let mut req_fut = tokio::task::spawn(client.request(req));
// This task::spawn won't be required by hyper 1.0, as the connection will be driven by a task
// spawned somewhere else (perhaps in the connection pool).
// See: https://github.com/restatedev/restate/issues/96 and https://github.com/restatedev/restate/issues/76
let mut req_fut = AbortOnDrop(tokio::task::spawn(client.request(req)));

let mut http_stream_rx_res = None;

Expand Down Expand Up @@ -531,3 +538,24 @@ where
Ok(())
}
}

/// This wrapper makes sure we abort the task when the JoinHandle is dropped,
/// but it doesn't wait for the task to complete, because we simply don't have async drops!
/// For more: https://github.com/tokio-rs/tokio/issues/2596
/// Inspired by: https://github.com/cyb0124/abort-on-drop
#[derive(Debug)]
struct AbortOnDrop<T>(JoinHandle<T>);

impl<T> Future for AbortOnDrop<T> {
type Output = <JoinHandle<T> as Future>::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx)
}
}

impl<T> Drop for AbortOnDrop<T> {
fn drop(&mut self) {
self.0.abort()
}
}

0 comments on commit 843bfde

Please sign in to comment.