Skip to content

Commit

Permalink
refactor: use built-in tokio ctrl-c handling
Browse files Browse the repository at this point in the history
  • Loading branch information
pengowen123 committed Nov 3, 2022
1 parent 8d0f050 commit 5405814
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 102 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ url = "2.2.2"
# Missing a fix for macOS: https://github.com/serialport/serialport-rs/pull/58
# It will be included in serialport 4.2.1
tokio-serial = "5.4.3"
ctrlc = "3.2.3"
15 changes: 6 additions & 9 deletions src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ use std::{env, process};
use futures::{FutureExt, SinkExt, select};
use futures_util::{pin_mut, StreamExt};
use tokio::io::{AsyncReadExt};
use tokio::signal;
use tokio_tungstenite::{
connect_async,
tungstenite::protocol::Message,
};
use futures_channel::{mpsc, oneshot};
use futures_channel::mpsc;
use url::Url;

#[tokio::main]
Expand All @@ -36,12 +37,6 @@ async fn main() {
let (ws_stream, _) = connect_async(url).await.expect("Failed to connect");
println!("WebSocket handshake has been successfully completed");

let (ctrlc_tx, mut ctrlc_rx) = oneshot::channel();
let mut ctrlc_tx = Some(ctrlc_tx);
ctrlc::set_handler(move || {
ctrlc_tx.take().map(|c| c.send(()));
}).expect("Failed to set ctrl-c handler");

let (mut write, read) = ws_stream.split();

let stdin_to_ws = stdin_rx.map(|m| {
Expand All @@ -62,8 +57,10 @@ async fn main() {
select!(
_ = stdin_to_ws => {},
_ = ws_to_stdout => {},
_ = ctrlc_rx => {},

// Watch for ctrl-c
res = signal::ctrl_c().fuse() => if let Err(e) = res {
eprintln!("Failed to wait for ctrl-c signal: {}", e);
},
);

write.close().await.expect("Failed to close websocket");
Expand Down
147 changes: 55 additions & 92 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@ use std::{
sync::{
Arc,
Mutex,
mpsc::{SyncSender, Receiver}, atomic::{AtomicBool, Ordering},
}, thread,
},
thread,
};

use futures_channel::{mpsc, oneshot};
use futures_channel::{mpsc::{self, Sender}, oneshot};
use futures::{FutureExt, select};
use futures_util::{
stream::TryStreamExt,
future,
StreamExt,
pin_mut,
SinkExt
};

use tokio::{net::{TcpListener, TcpStream}, task::JoinHandle};
use tokio::{net::{TcpListener, TcpStream}, signal, sync::watch};
use tungstenite::{protocol::Message, Error};
use tokio_serial::{self, SerialPortBuilderExt, ErrorKind};

Expand All @@ -46,9 +45,9 @@ async fn handle_connection(
raw_stream: TcpStream,
addr: SocketAddr,
ryder_port: String,
force_disconnect_tx: SyncSender<()>,
force_disconnect_rx: Receiver<()>,
mut ctrlc_rx: watch::Receiver<()>,
mut ticket_rx: oneshot::Receiver<()>,
task_alive_token: Sender<()>,
) -> ServeNextInQueue {
println!("Incoming TCP connection from: {}", addr);

Expand Down Expand Up @@ -140,14 +139,16 @@ async fn handle_connection(

// For sending data to a dedicated thread that communicates with the serial device
let (tx_serial, mut rx_serial) = mpsc::unbounded();
// For closing the serial IO thread
let (close_serial_tx, close_serial_rx) = std::sync::mpsc::sync_channel(1);

// Set up message receiver for the WebSocket
let broadcast_incoming = incoming.try_for_each(|msg| {
let ws_receiver = incoming.try_for_each(|msg| {
async {
// If the client disconnected, stop listening and send a signal to close the serial
// port as well
if let Message::Close(_) = msg {
force_disconnect_tx.clone().send(()).unwrap();
close_serial_tx.send(()).unwrap();
return Err(Error::ConnectionClosed);
}

Expand All @@ -164,10 +165,10 @@ async fn handle_connection(
let (tx_ws, rx_ws) = mpsc::unbounded();

// Start thread to handle all serial port communication
let serial_io = std::thread::spawn(move || {
let serial_io = thread::spawn(move || {
loop {
// Watch for exit signal
if let Ok(()) = force_disconnect_rx.try_recv() {
if let Ok(()) = close_serial_rx.try_recv() {
return Ok::<(), tokio_serial::Error>(());
}

Expand Down Expand Up @@ -200,11 +201,16 @@ async fn handle_connection(
});

// Send responses to the WebSocket
let receive_from_others = rx_ws.map(Ok).forward(&mut outgoing);

// Wait for the client or the serial IO thread to end the connection
pin_mut!(broadcast_incoming, receive_from_others);
future::select(broadcast_incoming, receive_from_others).await;
let ws_sender = rx_ws.map(Ok).forward(&mut outgoing);

// Wait for a ctrl-c signal or for the client or serial IO thread to end the connection
pin_mut!(ws_receiver, ws_sender);
select! {
_ = ws_receiver => {},
_ = ws_sender => {},
// Close the serial IO thread on ctrl-c
_ = ctrlc_rx.changed().fuse() => close_serial_tx.send(()).unwrap(),
};

// Wait for the serial IO thread to exit
if let Ok(Err(e)) = serial_io.join() {
Expand All @@ -216,6 +222,9 @@ async fn handle_connection(
eprintln!("Failed to close WebSocket: {}", e);
}

// Signal that this task is completed
drop(task_alive_token);

ServeNextInQueue::Yes
}

Expand All @@ -233,72 +242,17 @@ async fn main() -> Result<(), IoError> {
let try_socket = TcpListener::bind(&addr).await;
let listener = try_socket.expect("Failed to bind");

// Set up ctrl-c handling infrastructure
struct ConnectionTerminator {
// A handle to the connection's task
handle: JoinHandle<()>,
// A sender to notify the task that it should terminate itself
sender: SyncSender<()>,
// Whether a termination signal has already been sent
signal_sent: bool,
}
let connections: Vec<ConnectionTerminator> = Vec::new();
let connections = Arc::new(Mutex::new(connections));
let (ctrlc_tx, ctrlc_rx) = std::sync::mpsc::sync_channel(1);

ctrlc::set_handler(move || {
ctrlc_tx.send(()).unwrap();
}).unwrap();

// Watch for ctrl-c inputs in a separate thread
let exiting = Arc::new(AtomicBool::new(false));
let exiting_thread = exiting.clone();
let (exit_tx, exit_rx) = oneshot::channel();
let connections_thread = connections.clone();
thread::spawn(move || {
loop {
if exiting_thread.load(Ordering::SeqCst) {
let mut connections = connections_thread.lock().unwrap();
let mut all_terminated = true;

for conn in &mut *connections {
if !conn.signal_sent {
// Ignore errors because the connection and its receiver may have already
// been dropped
let _ = conn.sender.send(());
conn.signal_sent = true;
}

if !conn.handle.is_finished() {
all_terminated = false;
}
}

// Exit once all remaining connections have been closed
if all_terminated {
exit_tx.send(()).unwrap();
break;
}
} else {
if let Ok(()) = ctrlc_rx.recv() {
exiting_thread.store(true, Ordering::SeqCst);
}
}
}
});

let queue = Arc::new(Mutex::new(ConnectionQueue::new()));
// Set up channel to wait for all tasks to finish
let (task_alive_token, mut tasks_finished_listener) = mpsc::channel(1);

let (ctrlc_tx, ctrlc_rx) = watch::channel(());
let ctrlc_rx_copy = ctrlc_rx.clone();

// Let's spawn the handling of each connection in a separate task.
let listen = async move {
let task_alive_token = task_alive_token;
while let Ok((stream, addr)) = listener.accept().await {
let mut connections = connections.lock().unwrap();

// Don't accept new connections while exiting
if exiting.load(Ordering::SeqCst) {
break;
}

// Add the connection to the queue
let queue_clone = queue.clone();
let (id, ticket_rx) = {
Expand All @@ -313,14 +267,13 @@ async fn main() -> Result<(), IoError> {

(id, rx)
};
let (force_disconnect_tx, force_disconnect_rx) = std::sync::mpsc::sync_channel(1);
let connection_handler = handle_connection(
stream,
addr,
ryder_port.clone(),
force_disconnect_tx.clone(),
force_disconnect_rx,
ctrlc_rx_copy.clone(),
ticket_rx,
task_alive_token.clone(),
).map(move |res| {
println!("{} disconnected", addr);

Expand All @@ -334,22 +287,32 @@ async fn main() -> Result<(), IoError> {
}
});

let handle = tokio::spawn(connection_handler);

// Register the connection so that it can be terminated properly on ctrl-c
let conn = ConnectionTerminator {
handle,
sender: force_disconnect_tx,
signal_sent: false,
};
connections.push(conn);
tokio::spawn(connection_handler);
}
};
}.fuse();

tokio::spawn(listen);
// Listen for new connections until ctrl-c is received
let mut ctrlc_rx = ctrlc_rx.clone();
let listen = tokio::spawn(async move {
pin_mut!(listen);
select! {
_ = listen => {},
_ = ctrlc_rx.changed().fuse() => {},
}
});

// Wait for ctrl-c
exit_rx.await.unwrap();
if let Err(e) = signal::ctrl_c().await {
eprintln!("Failed to wait for ctrl-c signal: {}", e);
}
ctrlc_tx.send(()).unwrap();

// Wait for all existing tasks to finish
listen.await.unwrap();
// This will return `None` when all `Sender`s (owned by the tasks) have been dropped
tasks_finished_listener.next().await;

println!("Shutting down");

Ok(())
}

0 comments on commit 5405814

Please sign in to comment.