diff --git a/tokio/src/blocking.rs b/tokio/src/blocking.rs index f172399d5ef..41cd6ec027b 100644 --- a/tokio/src/blocking.rs +++ b/tokio/src/blocking.rs @@ -1,8 +1,8 @@ cfg_rt! { + #[allow(unused_imports)] pub(crate) use crate::runtime::spawn_blocking; - cfg_fs! { - #[allow(unused_imports)] + cfg_io_blocking! { pub(crate) use crate::runtime::spawn_mandatory_blocking; } @@ -15,6 +15,7 @@ cfg_not_rt! { use std::pin::Pin; use std::task::{Context, Poll}; + #[allow(dead_code)] pub(crate) fn spawn_blocking(_f: F) -> JoinHandle where F: FnOnce() -> R + Send + 'static, @@ -24,7 +25,7 @@ cfg_not_rt! { panic!("requires the `rt` Tokio feature flag") } - cfg_fs! { + cfg_io_blocking! { pub(crate) fn spawn_mandatory_blocking(_f: F) -> Option> where F: FnOnce() -> R + Send + 'static, @@ -58,6 +59,7 @@ cfg_not_rt! { } } + #[allow(dead_code)] fn assert_send_sync() { } } diff --git a/tokio/src/io/blocking.rs b/tokio/src/io/blocking.rs index 1af5065456d..567367681f8 100644 --- a/tokio/src/io/blocking.rs +++ b/tokio/src/io/blocking.rs @@ -30,6 +30,7 @@ pub(crate) const DEFAULT_MAX_BUF_SIZE: usize = 2 * 1024 * 1024; enum State { Idle(Option), Busy(sys::Blocking<(io::Result, Buf, T)>), + Shutdown, } cfg_io_blocking! { @@ -73,11 +74,18 @@ where let mut inner = self.inner.take().unwrap(); let max_buf_size = cmp::min(dst.remaining(), DEFAULT_MAX_BUF_SIZE); - self.state = State::Busy(sys::run(move || { + + let rx = sys::run(move || { // SAFETY: the requirements are satisfied by `Blocking::new`. let res = unsafe { buf.read_from(&mut inner, max_buf_size) }; (res, buf, inner) - })); + }); + + self.state = if let Some(rx) = rx { + State::Busy(rx) + } else { + State::Shutdown + }; } State::Busy(ref mut rx) => { let (res, mut buf, inner) = ready!(Pin::new(rx).poll(cx))?; @@ -97,6 +105,9 @@ where } } } + State::Shutdown => { + return Poll::Ready(Err(gone())); + } } } } @@ -121,12 +132,19 @@ where let n = buf.copy_from(src, DEFAULT_MAX_BUF_SIZE); let mut inner = self.inner.take().unwrap(); - self.state = State::Busy(sys::run(move || { + let rx = sys::run(move || { let n = buf.len(); let res = buf.write_to(&mut inner).map(|()| n); (res, buf, inner) - })); + }); + + self.state = if let Some(rx) = rx { + State::Busy(rx) + } else { + State::Shutdown + }; + self.need_flush = true; return Poll::Ready(Ok(n)); @@ -139,6 +157,9 @@ where // If error, return res?; } + State::Shutdown => { + return Poll::Ready(Err(gone())); + } } } } @@ -153,10 +174,16 @@ where let buf = buf_cell.take().unwrap(); let mut inner = self.inner.take().unwrap(); - self.state = State::Busy(sys::run(move || { + let rx = sys::run(move || { let res = inner.flush().map(|()| 0); (res, buf, inner) - })); + }); + + self.state = if let Some(rx) = rx { + State::Busy(rx) + } else { + State::Shutdown + }; self.need_flush = false; } else { @@ -171,6 +198,9 @@ where // If error, return res?; } + State::Shutdown => { + return Poll::Ready(Err(gone())); + } } } } @@ -304,3 +334,10 @@ cfg_fs! { } } } + +fn gone() -> io::Error { + io::Error::new( + io::ErrorKind::Other, + crate::util::error::RUNTIME_SHUTTING_DOWN_ERROR, + ) +} diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index dc2c4309e66..c5415196647 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -289,7 +289,7 @@ cfg_io_blocking! { /// Types in this module can be mocked out in tests. mod sys { // TODO: don't rename - pub(crate) use crate::blocking::spawn_blocking as run; + pub(crate) use crate::blocking::spawn_mandatory_blocking as run; pub(crate) use crate::blocking::JoinHandle as Blocking; } } diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs index c42924be77d..7dc50058860 100644 --- a/tokio/src/runtime/blocking/mod.rs +++ b/tokio/src/runtime/blocking/mod.rs @@ -6,7 +6,7 @@ mod pool; pub(crate) use pool::{spawn_blocking, BlockingPool, Spawner}; -cfg_fs! { +cfg_io_blocking! { pub(crate) use pool::spawn_mandatory_blocking; } diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index 23180dc5245..1163cf51a72 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -185,7 +185,7 @@ where rt.spawn_blocking(func) } -cfg_fs! { +cfg_io_blocking! { #[cfg_attr(any( all(loom, not(test)), // the function is covered by loom tests test @@ -327,7 +327,7 @@ impl Spawner { } } - cfg_fs! { + cfg_io_blocking! { #[track_caller] #[cfg_attr(any( all(loom, not(test)), // the function is covered by loom tests diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 78a0114f48e..6d89182ff38 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -357,7 +357,7 @@ cfg_rt! { pub(crate) use blocking::Mandatory; } - cfg_fs! { + cfg_io_blocking! { pub(crate) use blocking::spawn_mandatory_blocking; }