diff --git a/crates/node/src/abortable.rs b/crates/node/src/abortable.rs index dca8e5a409..2a80ed084d 100644 --- a/crates/node/src/abortable.rs +++ b/crates/node/src/abortable.rs @@ -290,3 +290,123 @@ impl AborterStatus { matches!(self, AborterStatus::ChildProcessTerminated) } } + +#[cfg(test)] +mod abortale_spawner_tests { + use std::sync::{Arc, Mutex}; + + use super::*; + + /// Test panicking a non-pinned task shouldn't cause the entire spawner to + /// come crashing down. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_abortable_spawner_panic_non_pinned_task() { + let mut spawner = AbortableSpawner::new(); + + spawner + .abortable("TestTask", |_aborter| async { + panic!(); + }) + .spawn(); + + spawner.run_to_completion().await; + } + + /// Test panicking a pinned task must cause the entire spawner to come + /// crashing down. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + #[should_panic = "AbortableSpawnerPanic"] + async fn test_abortable_spawner_panic_pinned_task() { + let mut spawner = AbortableSpawner::new(); + + spawner + .abortable("TestTask", |_aborter| async { + panic!("AbortableSpawnerPanic"); + }) + .pin() + .spawn(); + + spawner.run_to_completion().await; + } + + /// Test that cleanup jobs get triggered. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_cleanup_job() { + let mut spawner = AbortableSpawner::new(); + + struct Slot { + task_data: [String; 3], + } + + let slot = Arc::new(Mutex::new(Slot { + task_data: [String::new(), String::new(), String::new()], + })); + + let task_ids = ["TestTask#1", "TestTask#2", "TestTask#3"]; + + for (task_no, &id) in task_ids.iter().enumerate() { + let slot = Arc::clone(&slot); + + spawner + .abortable(id, |aborter| async move { + drop(aborter); + Ok(()) + }) + .with_cleanup(async move { + slot.lock().unwrap().task_data[task_no] = id.into(); + }) + .spawn(); + } + + spawner.run_to_completion().await; + + let slot_handle = slot.lock().unwrap(); + assert_eq!(slot_handle.task_data[0].as_str(), task_ids[0]); + assert_eq!(slot_handle.task_data[1].as_str(), task_ids[1]); + assert_eq!(slot_handle.task_data[2].as_str(), task_ids[2]); + } + + /// Test blocking jobs. + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_blocking_spawn() { + let (bing_tx, bing_rx) = tokio::sync::oneshot::channel(); + let (bong_tx, bong_rx) = tokio::sync::oneshot::channel(); + + let mut spawner = AbortableSpawner::new(); + spawner + .abortable("Bing", move |aborter| { + bing_rx.blocking_recv().unwrap(); + drop(aborter); + Ok(()) + }) + .spawn_blocking(); + spawner + .abortable("Bong", move |aborter| { + bong_rx.blocking_recv().unwrap(); + drop(aborter); + Ok(()) + }) + .spawn_blocking(); + + let spawner_run_fut = Box::pin(spawner.run_to_completion()); + let select_result = + futures::future::select(spawner_run_fut, std::future::ready(())) + .await; + let spawner_run_fut = match select_result { + futures::future::Either::Left(_) => unreachable!("Test failed"), + futures::future::Either::Right(((), fut)) => fut, + }; + + bing_tx.send(()).unwrap(); + let select_result = + futures::future::select(spawner_run_fut, std::future::ready(())) + .await; + let spawner_run_fut = match select_result { + futures::future::Either::Left(_) => unreachable!("Test failed"), + futures::future::Either::Right(((), fut)) => fut, + }; + + bong_tx.send(()).unwrap(); + spawner_run_fut.await; + } +}