diff --git a/core/src/layers/timeout.rs b/core/src/layers/timeout.rs index 4eec45bd87d0..6f322251ef00 100644 --- a/core/src/layers/timeout.rs +++ b/core/src/layers/timeout.rs @@ -288,22 +288,21 @@ impl TimeoutWrapper { #[inline] fn poll_timeout(&mut self, cx: &mut Context<'_>, op: &'static str) -> Result<()> { - if let Some(sleep) = self.sleep.as_mut() { - match sleep.as_mut().poll(cx) { - Poll::Pending => Ok(()), - Poll::Ready(_) => { - self.sleep = None; - Err( - Error::new(ErrorKind::Unexpected, "io operation timeout reached") - .with_operation(op) - .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) - .set_temporary(), - ) - } + let sleep = self + .sleep + .get_or_insert_with(|| Box::pin(tokio::time::sleep(self.timeout))); + + match sleep.as_mut().poll(cx) { + Poll::Pending => Ok(()), + Poll::Ready(_) => { + self.sleep = None; + Err( + Error::new(ErrorKind::Unexpected, "io operation timeout reached") + .with_operation(op) + .with_context("io_timeout", self.timeout.as_secs_f64().to_string()) + .set_temporary(), + ) } - } else { - self.sleep = Some(Box::pin(tokio::time::sleep(self.timeout))); - Ok(()) } } } @@ -380,6 +379,7 @@ mod tests { use async_trait::async_trait; use bytes::Bytes; + use futures::StreamExt; use tokio::time::sleep; use tokio::time::timeout; @@ -397,7 +397,7 @@ mod tests { impl Accessor for MockService { type Reader = MockReader; type Writer = (); - type Lister = (); + type Lister = MockLister; type BlockingReader = (); type BlockingWriter = (); type BlockingLister = (); @@ -424,6 +424,10 @@ mod tests { Ok(RpDelete::default()) } + + async fn list(&self, _: &str, _: OpList) -> Result<(RpList, Self::Lister)> { + Ok((RpList::default(), MockLister)) + } } #[derive(Debug, Clone, Default)] @@ -443,6 +447,15 @@ mod tests { } } + #[derive(Debug, Clone, Default)] + struct MockLister; + + impl oio::List for MockLister { + fn poll_next(&mut self, _: &mut Context<'_>) -> Poll>> { + Poll::Pending + } + } + #[tokio::test] async fn test_operation_timeout() { let acc = Arc::new(TypeEraseLayer.layer(MockService)) as FusedAccessor; @@ -468,18 +481,50 @@ mod tests { let op = Operator::from_inner(acc) .layer(TimeoutLayer::new().with_io_timeout(Duration::from_secs(1))); - let fut = async { - let mut reader = op.reader("test").await.unwrap(); + let mut reader = op.reader("test").await.unwrap(); - let res = reader.read(&mut [0; 4]).await; - assert!(res.is_err()); - let err = res.unwrap_err(); - assert_eq!(err.kind(), ErrorKind::Unexpected); - assert!(err.to_string().contains("timeout")) - }; + let res = reader.read(&mut [0; 4]).await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.to_string().contains("timeout")) + } - timeout(Duration::from_secs(2), fut) + #[tokio::test] + async fn test_list_timeout() { + let acc = Arc::new(TypeEraseLayer.layer(MockService)) as FusedAccessor; + let op = Operator::from_inner(acc).layer( + TimeoutLayer::new() + .with_timeout(Duration::from_secs(1)) + .with_io_timeout(Duration::from_secs(1)), + ); + + let mut lister = op.lister("test").await.unwrap(); + + let res = lister.next().await.unwrap(); + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.to_string().contains("timeout")) + } + + #[tokio::test] + async fn test_list_timeout_raw() { + let acc = MockService; + let timeout_layer = TimeoutLayer::new() + .with_timeout(Duration::from_secs(1)) + .with_io_timeout(Duration::from_secs(1)); + let timeout_acc = timeout_layer.layer(acc); + + let (_, mut lister) = Accessor::list(&timeout_acc, "test", OpList::default()) .await - .expect("this test should not exceed 2 seconds") + .unwrap(); + + use oio::ListExt; + let res = lister.next().await; + assert!(res.is_err()); + let err = res.unwrap_err(); + assert_eq!(err.kind(), ErrorKind::Unexpected); + assert!(err.to_string().contains("timeout")); } }