diff --git a/futures-util/src/future/future/shared.rs b/futures-util/src/future/future/shared.rs index b4d9bff89..d2ca43f2a 100644 --- a/futures-util/src/future/future/shared.rs +++ b/futures-util/src/future/future/shared.rs @@ -9,7 +9,7 @@ use std::pin::Pin; use std::ptr; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{Acquire, SeqCst}; -use std::sync::{Arc, Mutex, Weak}; +use std::sync::{Arc, Mutex, MutexGuard, Weak}; /// Future for the [`shared`](super::FutureExt::shared) method. #[must_use = "futures do nothing unless you `.await` or poll them"] @@ -81,6 +81,7 @@ const IDLE: usize = 0; const POLLING: usize = 1; const COMPLETE: usize = 2; const POISONED: usize = 3; +const AWAKEN_DURING_POLLING: usize = 4; const NULL_WAKER_KEY: usize = usize::MAX; @@ -197,36 +198,47 @@ where } } -impl Inner -where - Fut: Future, - Fut::Output: Clone, -{ - /// Registers the current task to receive a wakeup when we are awoken. - fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) { - let mut wakers_guard = self.notifier.wakers.lock().unwrap(); - - let wakers_mut = wakers_guard.as_mut(); - - let wakers = match wakers_mut { - Some(wakers) => wakers, - None => return, - }; - - let new_waker = cx.waker(); +/// Registers the current task to receive a wakeup when we are awoken. +fn record_waker( + wakers_guard: &mut MutexGuard<'_, Option>>>, + waker_key: &mut usize, + cx: &mut Context<'_>, +) { + let wakers = match wakers_guard.as_mut() { + Some(wakers) => wakers, + None => return, + }; + + let new_waker = cx.waker(); + + if *waker_key == NULL_WAKER_KEY { + *waker_key = wakers.insert(Some(new_waker.clone())); + } else { + match wakers[*waker_key] { + Some(ref old_waker) if new_waker.will_wake(old_waker) => {} + // Could use clone_from here, but Waker doesn't specialize it. + ref mut slot => *slot = Some(new_waker.clone()), + } + } + debug_assert!(*waker_key != NULL_WAKER_KEY); +} - if *waker_key == NULL_WAKER_KEY { - *waker_key = wakers.insert(Some(new_waker.clone())); - } else { - match wakers[*waker_key] { - Some(ref old_waker) if new_waker.will_wake(old_waker) => {} - // Could use clone_from here, but Waker doesn't specialize it. - ref mut slot => *slot = Some(new_waker.clone()), +/// Wakes all tasks that are registered to be woken. +fn wake_all(waker_guard: &mut MutexGuard<'_, Option>>>) { + if let Some(wakers) = waker_guard.as_mut() { + for (_key, opt_waker) in wakers { + if let Some(waker) = opt_waker.take() { + waker.wake(); } } - debug_assert!(*waker_key != NULL_WAKER_KEY); } +} +impl Inner +where + Fut: Future, + Fut::Output: Clone, +{ /// Safety: callers must first ensure that `inner.state` /// is `COMPLETE` unsafe fn take_or_clone_output(self: Arc) -> Fut::Output { @@ -268,18 +280,22 @@ where return unsafe { Poll::Ready(inner.take_or_clone_output()) }; } - inner.record_waker(&mut this.waker_key, cx); + // Guard the state transition with mutex too + let mut wakers_guard = inner.notifier.wakers.lock().unwrap(); + record_waker(&mut wakers_guard, &mut this.waker_key, cx); - match inner + let prev = inner .notifier .state .compare_exchange(IDLE, POLLING, SeqCst, SeqCst) - .unwrap_or_else(|x| x) - { + .unwrap_or_else(|x| x); + drop(wakers_guard); + + match prev { IDLE => { // Lock acquired, fall through } - POLLING => { + POLLING | AWAKEN_DURING_POLLING => { // Another task is currently polling, at this point we just want // to ensure that the waker for this task is registered this.inner = Some(inner); @@ -324,15 +340,21 @@ where match poll_result { Poll::Pending => { - if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok() - { - // Success - drop(reset); - this.inner = Some(inner); - return Poll::Pending; - } else { - unreachable!() + match inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst) { + Ok(POLLING) => {} // success + Err(AWAKEN_DURING_POLLING) => { + // waker has been called inside future.poll, need to wake any new wakers registered + let mut wakers = inner.notifier.wakers.lock().unwrap(); + wake_all(&mut wakers); + let prev = inner.notifier.state.swap(IDLE, SeqCst); + assert_eq!(prev, AWAKEN_DURING_POLLING); + drop(wakers); + } + _ => unreachable!(), } + drop(reset); + this.inner = Some(inner); + return Poll::Pending; } Poll::Ready(output) => output, } @@ -387,14 +409,9 @@ where impl ArcWake for Notifier { fn wake_by_ref(arc_self: &Arc) { - let wakers = &mut *arc_self.wakers.lock().unwrap(); - if let Some(wakers) = wakers.as_mut() { - for (_key, opt_waker) in wakers { - if let Some(waker) = opt_waker.take() { - waker.wake(); - } - } - } + let mut wakers = arc_self.wakers.lock().unwrap(); + let _ = arc_self.state.compare_exchange(POLLING, AWAKEN_DURING_POLLING, SeqCst, SeqCst); + wake_all(&mut wakers); } } diff --git a/futures/tests/future_shared.rs b/futures/tests/future_shared.rs index bd69c1d7c..177b995a0 100644 --- a/futures/tests/future_shared.rs +++ b/futures/tests/future_shared.rs @@ -3,9 +3,11 @@ use futures::executor::{block_on, LocalPool}; use futures::future::{self, FutureExt, LocalFutureObj, TryFutureExt}; use futures::task::LocalSpawn; use std::cell::{Cell, RefCell}; +use std::future::Future; use std::panic::AssertUnwindSafe; +use std::pin::Pin; use std::rc::Rc; -use std::task::Poll; +use std::task::{Context, Poll}; use std::thread; struct CountClone(Rc>); @@ -271,3 +273,52 @@ fn poll_while_panic() { let _s = S {}; panic!("test_marker"); } + +#[test] +fn shared_futures_woken_during_polling() { + async fn yield_now() { + /// Yield implementation + struct YieldNow { + yielded: bool, + } + + impl Future for YieldNow { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + if self.yielded { + return Poll::Ready(()); + } + + self.yielded = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + + YieldNow { yielded: false }.await + } + fn test() { + let f1 = yield_now().shared(); + let f2 = f1.clone(); + let x1 = thread::spawn(move || { + block_on(async move { + f1.now_or_never(); + }) + }); + let x2 = thread::spawn(move || { + block_on(async move { + f2.await; + }) + }); + x1.join().ok(); + x2.join().ok(); + } + + for _ in 0..10 { + print!("."); + for _ in 0..10000 { + test(); + } + } +}