Skip to content

Commit

Permalink
server: task manager update
Browse files Browse the repository at this point in the history
  • Loading branch information
pnmadelaine committed Oct 11, 2023
1 parent 4394127 commit 4ba4a14
Showing 1 changed file with 50 additions and 43 deletions.
93 changes: 50 additions & 43 deletions typhon/src/tasks.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::collections::HashMap;
use std::future::Future;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;

use std::collections::HashMap;
use std::future::Future;

#[derive(Debug)]
pub enum Error {
Expand All @@ -20,83 +18,91 @@ impl std::fmt::Display for Error {
enum Msg<Id> {
Cancel(Id),
Finish(Id),
Run(Id, oneshot::Sender<()>, JoinHandle<()>),
Run(Id, oneshot::Sender<mpsc::Sender<()>>, oneshot::Sender<()>),
Shutdown,
Wait(Id, oneshot::Sender<()>),
}

struct TaskHandle {
canceler: Option<oneshot::Sender<()>>,
handle: JoinHandle<()>,
waiters: Vec<oneshot::Sender<()>>,
}

pub struct Tasks<Id> {
handle: Mutex<Option<JoinHandle<()>>>,
sender: mpsc::Sender<Msg<Id>>,
msg_send: mpsc::Sender<Msg<Id>>,
shutdown_recv: Mutex<Option<oneshot::Receiver<()>>>,
}

impl<Id: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Send + Sync + 'static> Tasks<Id> {
pub fn new() -> Self {
let (sender, mut receiver) = mpsc::channel(256);
let handle = tokio::spawn(async move {
let (msg_send, mut msg_recv) = mpsc::channel(256);
let (shutdown_send, shutdown_recv) = oneshot::channel();
tokio::spawn(async move {
let _shutdown_send = shutdown_send;
let (finish_send, mut finish_recv) = mpsc::channel(1);
let mut tasks: HashMap<Id, TaskHandle> = HashMap::new();
while let Some(msg) = receiver.recv().await {
match msg {
Msg::Cancel(id) => {
let mut shutdown = false;
while let Some(msg) = msg_recv.recv().await {
match (shutdown, msg) {
(false, Msg::Cancel(id)) => {
let _ = tasks
.get_mut(&id)
.map(|task| task.canceler.take().map(|send| send.send(())));
}
Msg::Finish(id) => {
(_, Msg::Finish(id)) => {
if let Some(task) = tasks.remove(&id) {
let _ = task.handle.await;
for send in task.waiters {
let _ = send.send(());
}
}
if shutdown && tasks.is_empty() {
break;
}
}
Msg::Run(id, sender, handle) => {
(false, Msg::Run(id, finish_send_send, cancel_send)) => {
let _ = finish_send_send.send(finish_send.clone());
let task = TaskHandle {
canceler: Some(sender),
handle,
canceler: Some(cancel_send),
waiters: Vec::new(),
};
tasks.insert(id, task);
}
Msg::Shutdown => {
(false, Msg::Shutdown) => {
shutdown = true;
let ids: Vec<_> = tasks.keys().cloned().collect();
for id in ids.iter() {
tasks
.get_mut(&id)
.get_mut(id)
.map(|task| task.canceler.take().map(|sender| sender.send(())));
}
for id in ids {
if let Some(mut task) = tasks.remove(&id) {
let _ = task.handle.await;
let _ = task.waiters.drain(..).map(|sender| sender.send(()));
}
if tasks.is_empty() {
break;
}
break;
}
Msg::Wait(id, sender) => match tasks.get_mut(&id) {
(_, Msg::Wait(id, sender)) => match tasks.get_mut(&id) {
Some(task) => {
task.waiters.push(sender);
}
None => {
let _ = sender.send(());
}
},
_ => (),
}
}
drop(finish_send);
let _ = finish_recv.recv().await;
});
let handle = Mutex::new(Some(handle));
Self { handle, sender }
let shutdown_recv = Mutex::new(Some(shutdown_recv));
Self {
msg_send,
shutdown_recv,
}
}

pub async fn wait(&self, id: &Id) -> () {
let (sender, receiver) = oneshot::channel();
let _ = self.sender.send(Msg::Wait(id.clone(), sender)).await;
let _ = self.msg_send.send(Msg::Wait(id.clone(), sender)).await;
let _ = receiver.await;
}

Expand All @@ -112,32 +118,33 @@ impl<Id: std::cmp::Eq + std::hash::Hash + std::clone::Clone + Send + Sync + 'sta
task: O,
finish: F,
) {
let (send, recv) = oneshot::channel::<()>();
let sender_self = self.sender.clone();
let (cancel_send, cancel_recv) = oneshot::channel::<()>();
let (finish_send_send, finish_send_recv) = oneshot::channel::<mpsc::Sender<()>>();
let sender_self = self.msg_send.clone();
let id_bis = id.clone();
let handle = tokio::spawn(async move {
tokio::spawn(async move {
let r = tokio::select! {
_ = recv => None,
_ = cancel_recv => None,
r = task => Some(r),
};
finish(r).await;
let _ = sender_self.send(Msg::Finish(id_bis)).await;
let _ = finish_send_recv.await;
});
let _ = self.sender.send(Msg::Run(id, send, handle)).await;
let _ = self
.msg_send
.send(Msg::Run(id, finish_send_send, cancel_send))
.await;
}

pub async fn cancel(&self, id: Id) {
let _ = self.sender.send(Msg::Cancel(id)).await;
let _ = self.msg_send.send(Msg::Cancel(id)).await;
}

pub async fn shutdown(&'static self) {
let handle = self.handle.lock().await.take();
if let Some(handle) = handle {
if self.sender.send(Msg::Shutdown).await.is_ok() {
let _ = handle.await;
} else {
handle.abort();
}
if let Some(shutdown_recv) = self.shutdown_recv.lock().await.take() {
let _ = self.msg_send.send(Msg::Shutdown).await;
let _ = shutdown_recv.await;
}
}
}

0 comments on commit 4ba4a14

Please sign in to comment.