From 5263760e697fc19d2d4d3fc9441f9fb36827340c Mon Sep 17 00:00:00 2001 From: Niket Naidu Date: Fri, 19 Jul 2024 22:37:01 -0700 Subject: [PATCH 1/3] Revert "Task Metadata instead of Droppable Future (#3)" This reverts commit 07eda0ef0978eb98d0996e44463aa530b7cf37c3. --- Cargo.toml | 1 + src/droppable_future.rs | 51 ++++++++++++++ src/lib.rs | 3 + src/ticked_async_executor.rs | 128 +++++++++++------------------------ 4 files changed, 93 insertions(+), 90 deletions(-) create mode 100644 src/droppable_future.rs diff --git a/Cargo.toml b/Cargo.toml index 1bf97de..f3fb95e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] async-task = "4.7" +pin-project = "1" [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/droppable_future.rs b/src/droppable_future.rs new file mode 100644 index 0000000..0ab6d57 --- /dev/null +++ b/src/droppable_future.rs @@ -0,0 +1,51 @@ +use std::{future::Future, pin::Pin}; + +use pin_project::{pin_project, pinned_drop}; + +#[pin_project(PinnedDrop)] +pub struct DroppableFuture +where + F: Future, + D: Fn(), +{ + #[pin] + future: F, + on_drop: D, +} + +impl DroppableFuture +where + F: Future, + D: Fn(), +{ + pub fn new(future: F, on_drop: D) -> Self { + Self { future, on_drop } + } +} + +impl Future for DroppableFuture +where + F: Future, + D: Fn(), +{ + type Output = F::Output; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.future.poll(cx) + } +} + +#[pinned_drop] +impl PinnedDrop for DroppableFuture +where + F: Future, + D: Fn(), +{ + fn drop(self: Pin<&mut Self>) { + (self.on_drop)(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 86ac703..1e5a011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,6 @@ +mod droppable_future; +use droppable_future::*; + mod task_identifier; pub use task_identifier::*; diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index 1ade368..6c06a78 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -6,7 +6,7 @@ use std::{ }, }; -use crate::TaskIdentifier; +use crate::{DroppableFuture, TaskIdentifier}; #[derive(Debug)] pub enum TaskState { @@ -16,37 +16,11 @@ pub enum TaskState { Drop(TaskIdentifier), } -pub type Task = async_task::Task>; -type TaskRunnable = async_task::Runnable>; -type Payload = (TaskIdentifier, TaskRunnable); +pub type Task = async_task::Task; +type Payload = (TaskIdentifier, async_task::Runnable); -/// Task Metadata associated with TickedAsyncExecutor -/// -/// Primarily used to track when the Task is completed/cancelled -pub struct TaskMetadata -where - O: Fn(TaskState) + Send + Sync + 'static, -{ - num_spawned_tasks: Arc, - identifier: TaskIdentifier, - observer: O, -} - -impl Drop for TaskMetadata -where - O: Fn(TaskState) + Send + Sync + 'static, -{ - fn drop(&mut self) { - self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); - (self.observer)(TaskState::Drop(self.identifier.clone())); - } -} - -pub struct TickedAsyncExecutor -where - O: Fn(TaskState) + Send + Sync + 'static, -{ - channel: (mpsc::Sender>, mpsc::Receiver>), +pub struct TickedAsyncExecutor { + channel: (mpsc::Sender, mpsc::Receiver), num_woken_tasks: Arc, num_spawned_tasks: Arc, @@ -79,22 +53,14 @@ where &self, identifier: impl Into, future: impl Future + Send + 'static, - ) -> Task + ) -> Task where T: Send + 'static, { let identifier = identifier.into(); - self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); - (self.observer)(TaskState::Spawn(identifier.clone())); - - let schedule = self.runnable_schedule_cb(identifier.clone()); - let (runnable, task) = async_task::Builder::new() - .metadata(TaskMetadata { - num_spawned_tasks: self.num_spawned_tasks.clone(), - identifier, - observer: self.observer.clone(), - }) - .spawn(|_m| future, schedule); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); + let (runnable, task) = async_task::spawn(future, schedule); runnable.schedule(); task } @@ -103,22 +69,14 @@ where &self, identifier: impl Into, future: impl Future + 'static, - ) -> Task + ) -> Task where T: 'static, { let identifier = identifier.into(); - self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); - (self.observer)(TaskState::Spawn(identifier.clone())); - - let schedule = self.runnable_schedule_cb(identifier.clone()); - let (runnable, task) = async_task::Builder::new() - .metadata(TaskMetadata { - num_spawned_tasks: self.num_spawned_tasks.clone(), - identifier, - observer: self.observer.clone(), - }) - .spawn_local(move |_m| future, schedule); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); + let (runnable, task) = async_task::spawn_local(future, schedule); runnable.schedule(); task } @@ -146,7 +104,29 @@ where .fetch_sub(num_woken_tasks, Ordering::Relaxed); } - fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable) { + fn droppable_future( + &self, + identifier: TaskIdentifier, + future: F, + ) -> DroppableFuture + where + F: Future, + { + let observer = self.observer.clone(); + + // Spawn Task + self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); + observer(TaskState::Spawn(identifier.clone())); + + // Droppable Future registering on_drop callback + let num_spawned_tasks = self.num_spawned_tasks.clone(); + DroppableFuture::new(future, move || { + num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); + observer(TaskState::Drop(identifier.clone())); + }) + } + + fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) { let sender = self.channel.0.clone(); let num_woken_tasks = self.num_woken_tasks.clone(); let observer = self.observer.clone(); @@ -165,7 +145,7 @@ mod tests { use super::*; #[test] - fn test_multiple_local_tasks() { + fn test_multiple_tasks() { let executor = TickedAsyncExecutor::default(); executor .spawn_local("A", async move { @@ -187,7 +167,7 @@ mod tests { } #[test] - fn test_local_tasks_cancellation() { + fn test_task_cancellation() { let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); let task1 = executor.spawn_local("A", async move { loop { @@ -217,36 +197,4 @@ mod tests { executor.tick(); } } - - #[test] - fn test_tasks_cancellation() { - let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); - let task1 = executor.spawn("A", async move { - loop { - tokio::task::yield_now().await; - } - }); - - let task2 = executor.spawn(format!("B"), async move { - loop { - tokio::task::yield_now().await; - } - }); - assert_eq!(executor.num_tasks(), 2); - executor.tick(); - - executor - .spawn_local("CancelTasks", async move { - let (t1, t2) = join!(task1.cancel(), task2.cancel()); - assert_eq!(t1, None); - assert_eq!(t2, None); - }) - .detach(); - assert_eq!(executor.num_tasks(), 3); - - // Since we have cancelled the tasks above, the loops should eventually end - while executor.num_tasks() != 0 { - executor.tick(); - } - } } From 545a1d851ff3aa9b6a5fd8f7430e30699b2c1417 Mon Sep 17 00:00:00 2001 From: Niket Naidu Date: Fri, 19 Jul 2024 22:46:22 -0700 Subject: [PATCH 2/3] Added unit tests for tokio join and tokio select APIs --- src/ticked_async_executor.rs | 82 ++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index 6c06a78..4888954 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -140,8 +140,6 @@ where #[cfg(test)] mod tests { - use tokio::join; - use super::*; #[test] @@ -185,7 +183,7 @@ mod tests { executor .spawn_local("CancelTasks", async move { - let (t1, t2) = join!(task1.cancel(), task2.cancel()); + let (t1, t2) = tokio::join!(task1.cancel(), task2.cancel()); assert_eq!(t1, None); assert_eq!(t2, None); }) @@ -197,4 +195,82 @@ mod tests { executor.tick(); } } + + #[test] + fn test_tokio_join() { + let executor = TickedAsyncExecutor::default(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); + let (tx2, mut rx2) = tokio::sync::mpsc::channel::(1); + executor + .spawn("ThreadedFuture", async move { + let (a, b) = tokio::join!(rx1.recv(), rx2.recv()); + assert_eq!(a.unwrap(), 10); + assert_eq!(b.unwrap(), 20); + }) + .detach(); + + let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); + let (tx4, mut rx4) = tokio::sync::mpsc::channel::(1); + executor + .spawn("LocalFuture", async move { + let (a, b) = tokio::join!(rx3.recv(), rx4.recv()); + assert_eq!(a.unwrap(), 10); + assert_eq!(b.unwrap(), 20); + }) + .detach(); + + tx1.try_send(10).unwrap(); + tx3.try_send(10).unwrap(); + for _ in 0..10 { + executor.tick(); + } + tx2.try_send(20).unwrap(); + tx4.try_send(20).unwrap(); + + while executor.num_tasks() != 0 { + executor.tick(); + } + } + + #[test] + fn test_tokio_select() { + let executor = TickedAsyncExecutor::default(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); + let (_tx2, mut rx2) = tokio::sync::mpsc::channel::(1); + executor + .spawn("ThreadedFuture", async move { + tokio::select! { + data = rx1.recv() => { + assert_eq!(data.unwrap(), 10); + } + _ = rx2.recv() => {} + } + }) + .detach(); + + let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); + let (_tx4, mut rx4) = tokio::sync::mpsc::channel::(1); + executor + .spawn("LocalFuture", async move { + tokio::select! { + data = rx3.recv() => { + assert_eq!(data.unwrap(), 10); + } + _ = rx4.recv() => {} + } + }) + .detach(); + + for _ in 0..10 { + executor.tick(); + } + + tx1.try_send(10).unwrap(); + tx3.try_send(10).unwrap(); + while executor.num_tasks() != 0 { + executor.tick(); + } + } } From 6cd106fe14f63134e61f378f6d2ba264bed1ea93 Mon Sep 17 00:00:00 2001 From: Niket Naidu Date: Fri, 19 Jul 2024 22:49:55 -0700 Subject: [PATCH 3/3] Added tokio integration tests --- src/ticked_async_executor.rs | 80 +----------------------------------- tests/tokio_tests.rs | 79 +++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 79 deletions(-) create mode 100644 tests/tokio_tests.rs diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index 4888954..70c3037 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -146,7 +146,7 @@ mod tests { fn test_multiple_tasks() { let executor = TickedAsyncExecutor::default(); executor - .spawn_local("A", async move { + .spawn("A", async move { tokio::task::yield_now().await; }) .detach(); @@ -195,82 +195,4 @@ mod tests { executor.tick(); } } - - #[test] - fn test_tokio_join() { - let executor = TickedAsyncExecutor::default(); - - let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); - let (tx2, mut rx2) = tokio::sync::mpsc::channel::(1); - executor - .spawn("ThreadedFuture", async move { - let (a, b) = tokio::join!(rx1.recv(), rx2.recv()); - assert_eq!(a.unwrap(), 10); - assert_eq!(b.unwrap(), 20); - }) - .detach(); - - let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); - let (tx4, mut rx4) = tokio::sync::mpsc::channel::(1); - executor - .spawn("LocalFuture", async move { - let (a, b) = tokio::join!(rx3.recv(), rx4.recv()); - assert_eq!(a.unwrap(), 10); - assert_eq!(b.unwrap(), 20); - }) - .detach(); - - tx1.try_send(10).unwrap(); - tx3.try_send(10).unwrap(); - for _ in 0..10 { - executor.tick(); - } - tx2.try_send(20).unwrap(); - tx4.try_send(20).unwrap(); - - while executor.num_tasks() != 0 { - executor.tick(); - } - } - - #[test] - fn test_tokio_select() { - let executor = TickedAsyncExecutor::default(); - - let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); - let (_tx2, mut rx2) = tokio::sync::mpsc::channel::(1); - executor - .spawn("ThreadedFuture", async move { - tokio::select! { - data = rx1.recv() => { - assert_eq!(data.unwrap(), 10); - } - _ = rx2.recv() => {} - } - }) - .detach(); - - let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); - let (_tx4, mut rx4) = tokio::sync::mpsc::channel::(1); - executor - .spawn("LocalFuture", async move { - tokio::select! { - data = rx3.recv() => { - assert_eq!(data.unwrap(), 10); - } - _ = rx4.recv() => {} - } - }) - .detach(); - - for _ in 0..10 { - executor.tick(); - } - - tx1.try_send(10).unwrap(); - tx3.try_send(10).unwrap(); - while executor.num_tasks() != 0 { - executor.tick(); - } - } } diff --git a/tests/tokio_tests.rs b/tests/tokio_tests.rs new file mode 100644 index 0000000..6b1db77 --- /dev/null +++ b/tests/tokio_tests.rs @@ -0,0 +1,79 @@ +use ticked_async_executor::TickedAsyncExecutor; + +#[test] +fn test_tokio_join() { + let executor = TickedAsyncExecutor::default(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); + let (tx2, mut rx2) = tokio::sync::mpsc::channel::(1); + executor + .spawn("ThreadedFuture", async move { + let (a, b) = tokio::join!(rx1.recv(), rx2.recv()); + assert_eq!(a.unwrap(), 10); + assert_eq!(b.unwrap(), 20); + }) + .detach(); + + let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); + let (tx4, mut rx4) = tokio::sync::mpsc::channel::(1); + executor + .spawn("LocalFuture", async move { + let (a, b) = tokio::join!(rx3.recv(), rx4.recv()); + assert_eq!(a.unwrap(), 10); + assert_eq!(b.unwrap(), 20); + }) + .detach(); + + tx1.try_send(10).unwrap(); + tx3.try_send(10).unwrap(); + for _ in 0..10 { + executor.tick(); + } + tx2.try_send(20).unwrap(); + tx4.try_send(20).unwrap(); + + while executor.num_tasks() != 0 { + executor.tick(); + } +} + +#[test] +fn test_tokio_select() { + let executor = TickedAsyncExecutor::default(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); + let (_tx2, mut rx2) = tokio::sync::mpsc::channel::(1); + executor + .spawn("ThreadedFuture", async move { + tokio::select! { + data = rx1.recv() => { + assert_eq!(data.unwrap(), 10); + } + _ = rx2.recv() => {} + } + }) + .detach(); + + let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); + let (_tx4, mut rx4) = tokio::sync::mpsc::channel::(1); + executor + .spawn("LocalFuture", async move { + tokio::select! { + data = rx3.recv() => { + assert_eq!(data.unwrap(), 10); + } + _ = rx4.recv() => {} + } + }) + .detach(); + + for _ in 0..10 { + executor.tick(); + } + + tx1.try_send(10).unwrap(); + tx3.try_send(10).unwrap(); + while executor.num_tasks() != 0 { + executor.tick(); + } +}