diff --git a/core/src/layers/timeout.rs b/core/src/layers/timeout.rs index 5812b79a4a26..6f322251ef00 100644 --- a/core/src/layers/timeout.rs +++ b/core/src/layers/timeout.rs @@ -288,22 +288,21 @@ impl TimeoutWrapper { #[inline] fn poll_timeout(&mut self, cx: &mut Context<'_>, op: &'static str) -> Result<()> { - if let Some(sleep) = self.sleep.as_mut() { - match sleep.as_mut().poll(cx) { - Poll::Pending => Ok(()), - Poll::Ready(_) => { - self.sleep = None; - Err( - Error::new(ErrorKind::Unexpected, "io operation timeout reached") - .with_operation(op) - .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) - .set_temporary(), - ) - } + let sleep = self + .sleep + .get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout))); + + match sleep.as_mut().poll(cx) { + Poll::Pending => Ok(()), + Poll::Ready(_) => { + self.sleep = None; + Err( + Error::new(ErrorKind::Unexpected, "io operation timeout reached") + .with_operation(op) + .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) + .set_temporary(), + ) } - } else { - self.sleep = Some(Box::pin(tokio::time::sleep(self.timeout))); - Ok(()) } } } @@ -482,19 +481,13 @@ mod tests { let op = Operator::from_inner(acc) .layer(TimeoutLayer::new().with_io_timeout(Duration::from_secs(1))); - let fut = async { - let mut reader = op.reader("test").await.unwrap(); - - let res = reader.read(&mut [0; 4]).await; - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(err.kind(), ErrorKind::Unexpected); - assert!(err.to_string().contains("timeout")) - }; + let mut reader = op.reader("test").await.unwrap(); - timeout(Duration::from_secs(2), fut) - .await - .expect("this test should not exceed 2 seconds") + let res = reader.read(&mut [0; 4]).await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.to_string().contains("timeout")) } #[tokio::test] @@ -506,18 +499,32 @@ mod tests { .with_io_timeout(Duration::from_secs(1)), ); - let fut = async { - let mut lister = op.lister("test").await.unwrap(); + let mut lister = op.lister("test").await.unwrap(); - let res = lister.next().await.unwrap(); - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(err.kind(), ErrorKind::Unexpected); - assert!(err.to_string().contains("timeout")) - }; + let res = lister.next().await.unwrap(); + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.to_string().contains("timeout")) + } - timeout(Duration::from_secs(2), fut) + #[tokio::test] + async fn test_list_timeout_raw() { + let acc = MockService; + let timeout_layer = TimeoutLayer::new() + .with_timeout(Duration::from_secs(1)) + .with_io_timeout(Duration::from_secs(1)); + let timeout_acc = timeout_layer.layer(acc); + + let (_, mut lister) = Accessor::list(&timeout_acc, "test", OpList::default()) .await - .expect("this test should not exceed 2 seconds") + .unwrap(); + + use oio::ListExt; + let res = lister.next().await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.to_string().contains("timeout")); } }