diff --git a/core/src/layers/timeout.rs b/core/src/layers/timeout.rs index 6f322251ef00..ab60dbe348f5 100644 --- a/core/src/layers/timeout.rs +++ b/core/src/layers/timeout.rs @@ -165,7 +165,7 @@ pub struct TimeoutAccessor { } impl TimeoutAccessor { - async fn timeout>, T>(&self, op: Operation, fut: F) -> Result { + async fn timeout>, T>(&self, op: Operation, fut: F) -> Result { tokio::time::timeout(self.timeout, fut).await.map_err(|_| { Error::new(ErrorKind::Unexpected, "operation timeout reached") .with_operation(op) @@ -174,7 +174,7 @@ impl TimeoutAccessor { })? } - async fn io_timeout>, T>( + async fn io_timeout>, T>( &self, op: Operation, fut: F, @@ -191,7 +191,7 @@ impl TimeoutAccessor { } #[cfg_attr(not(target_arch = "wasm32"), async_trait)] -#[cfg_attr(target_arch = "wasm32", async_trait(?Send))] +#[cfg_attr(target_arch = "wasm32", async_trait(? Send))] impl LayeredAccessor for TimeoutAccessor { type Inner = A; type Reader = TimeoutWrapper; @@ -334,28 +334,70 @@ impl oio::Read for TimeoutWrapper { } impl oio::Write for TimeoutWrapper { - fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> Poll> { - self.poll_timeout(cx, WriteOperation::Write.into_static())?; + async fn poll_write(&mut self, cx: &mut Context<'_>, bs: &dyn oio::WriteBuf) -> Poll> { + let sleep = self + .sleep + .get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout))); - let v = ready!(self.inner.poll_write(cx, bs)); - self.sleep = None; - Poll::Ready(v) + tokio::select! { + _ = sleep.as_mut() => { + self.sleep = None; + Ready(Err( + Error::new(ErrorKind::Unexpected, "io operation timeout reached") + .with_operation(WriteOperation::Write.into_static()) + .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) + .set_temporary(), + )) + } + result = self.inner.poll_write(cx, bs) => { + self.sleep = None; + Ready(result) + } + } } - fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_timeout(cx, WriteOperation::Close.into_static())?; + async fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + let sleep = self + .sleep + .get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout))); - let v = ready!(self.inner.poll_close(cx)); - self.sleep = None; - Poll::Ready(v) + tokio::select! { + _ = sleep.as_mut() => { + self.sleep = None; + Ready(Err( + Error::new(ErrorKind::Unexpected, "io operation timeout reached") + .with_operation(WriteOperation::Write.into_static()) + .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) + .set_temporary(), + )) + } + result = self.inner.poll_close(cx) => { + self.sleep = None; + Ready(result) + } + } } - fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { - self.poll_timeout(cx, WriteOperation::Abort.into_static())?; + async fn poll_abort(&mut self, cx: &mut Context<'_>) -> Poll> { + let sleep = self + .sleep + .get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout))); - let v = ready!(self.inner.poll_abort(cx)); - self.sleep = None; - Poll::Ready(v) + tokio::select! { + _ = sleep.as_mut() => { + self.sleep = None; + Ready(Err( + Error::new(ErrorKind::Unexpected, "io operation timeout reached") + .with_operation(WriteOperation::Write.into_static()) + .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) + .set_temporary(), + )) + } + result = self.inner.poll_abort(cx) => { + self.sleep = None; + Ready(result) + } + } } } @@ -393,7 +435,7 @@ mod tests { struct MockService; #[cfg_attr(not(target_arch = "wasm32"), async_trait)] - #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] + #[cfg_attr(target_arch = "wasm32", async_trait(? Send))] impl Accessor for MockService { type Reader = MockReader; type Writer = ();