Skip to content

Commit

Permalink
feat: added scoped threads
Browse files Browse the repository at this point in the history
  • Loading branch information
erwanvivien committed Mar 9, 2023
1 parent 34867b2 commit 3a90a90
Showing 1 changed file with 115 additions and 1 deletion.
116 changes: 115 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@ use std::any::Any;
use std::fmt;
use std::mem;

use std::sync::Arc;
use std::sync::Mutex;
pub use std::thread::{current, sleep, Result, Thread, ThreadId};
use std::{
marker::PhantomData,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
time::Duration,
};

use wasm_bindgen::prelude::*;
use wasm_bindgen::*;
Expand Down Expand Up @@ -176,6 +184,21 @@ impl Builder {
unsafe { self.spawn_unchecked(f) }
}

pub fn spawn_scoped<'scope, 'env, F, T>(
self,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
unsafe { self.spawn_unchecked(f) }?,
PhantomData,
))
}

/// Spawns a new thread without any lifetime restrictions by taking ownership
/// of the `Builder`, and returns an [`io::Result`] to its [`JoinHandle`].
///
Expand Down Expand Up @@ -358,3 +381,94 @@ where
{
Builder::new().spawn(f).expect("failed to spawn thread")
}

use core::num::NonZeroUsize;
pub fn available_parallelism() -> std::io::Result<NonZeroUsize> {
// TODO: Use [Navigator::hardware_concurrency](https://rustwasm.github.io/wasm-bindgen/api/web_sys/struct.Navigator.html#method.hardware_concurrency)
Ok(NonZeroUsize::new(8).unwrap())
}

pub struct ScopeData {
num_running_threads: AtomicUsize,
a_thread_panicked: AtomicBool,
main_thread: Thread,
}

pub struct Scope<'scope, 'env: 'scope> {
data: Arc<ScopeData>,
/// Invariance over 'scope, to make sure 'scope cannot shrink,
/// which is necessary for soundness.
///
/// Without invariance, this would compile fine but be unsound:
///
/// ```compile_fail,E0373
/// std::thread::scope(|s| {
/// s.spawn(|| {
/// let a = String::from("abcd");
/// s.spawn(|| println!("{a:?}")); // might run after `a` is dropped
/// });
/// });
/// ```
scope: PhantomData<&'scope mut &'scope ()>,
env: PhantomData<&'env mut &'env ()>,
}

pub fn scope<'env, F, T>(f: F) -> T
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
{
// We put the `ScopeData` into an `Arc` so that other threads can finish their
// `decrement_num_running_threads` even after this function returns.
let scope = Scope {
data: Arc::new(ScopeData {
num_running_threads: AtomicUsize::new(0),
main_thread: current(),
a_thread_panicked: AtomicBool::new(false),
}),
env: PhantomData,
scope: PhantomData,
};

// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
let result = catch_unwind(AssertUnwindSafe(|| f(&scope)));

// Wait until all the threads are finished.
while scope.data.num_running_threads.load(Ordering::Acquire) != 0 {
// park();
// TODO: Replaced by a wasm-friendly version of park()
sleep(Duration::from_millis(1));
}

// Throw any panic from `f`, or the return value of `f` if no thread panicked.
match result {
Err(e) => resume_unwind(e),
Ok(_) if scope.data.a_thread_panicked.load(Ordering::Relaxed) => {
panic!("a scoped thread panicked")
}
Ok(result) => result,
}
}

pub struct ScopedJoinHandle<'scope, T>(crate::JoinHandle<T>, PhantomData<&'scope ()>);
impl<'scope, T> ScopedJoinHandle<'scope, T> {
pub fn join(self) -> std::io::Result<T> {
self.0
.join()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, ""))
}
}

pub fn spawn_scoped<'scope, 'env, F, T>(
builder: crate::Builder,
scope: &'scope Scope<'scope, 'env>,
f: F,
) -> std::io::Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
Ok(ScopedJoinHandle(
unsafe { builder.spawn_unchecked(f) }?,
PhantomData,
))
}

0 comments on commit 3a90a90

Please sign in to comment.