Skip to content

Commit

Permalink
refactor(hydro_deploy)!: replace some uses of tokio::sync::RwLock w…
Browse files Browse the repository at this point in the history
…ith `std::sync::Mutex` #430 (3/3) (#1339)

`std::sync::Mutex` should be used as long as the lock is not held across
any `.await` yield points.

STACK: #1338 ~~#1341~~

#1356 next
  • Loading branch information
MingweiSamuel authored Jul 18, 2024
1 parent 12b8ba5 commit 141eae1
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 84 deletions.
2 changes: 1 addition & 1 deletion hydro_deploy/core/src/hydroflow_crate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ mod tests {

deployment.deploy().await.unwrap();

let stdout = service.try_read().unwrap().stdout().await;
let stdout = service.try_read().unwrap().stdout();

deployment.start().await.unwrap();

Expand Down
30 changes: 11 additions & 19 deletions hydro_deploy/core/src/hydroflow_crate/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,16 +139,16 @@ impl HydroflowCrateService {
}
}

pub async fn stdout(&self) -> Receiver<String> {
self.launched_binary.as_deref().unwrap().stdout().await
pub fn stdout(&self) -> Receiver<String> {
self.launched_binary.as_ref().unwrap().stdout()
}

pub async fn stderr(&self) -> Receiver<String> {
self.launched_binary.as_deref().unwrap().stderr().await
pub fn stderr(&self) -> Receiver<String> {
self.launched_binary.as_ref().unwrap().stderr()
}

pub async fn exit_code(&self) -> Option<i32> {
self.launched_binary.as_deref().unwrap().exit_code().await
pub fn exit_code(&self) -> Option<i32> {
self.launched_binary.as_ref().unwrap().exit_code()
}

fn build(&self) -> impl Future<Output = Result<&'static BuildOutput, BuildError>> {
Expand Down Expand Up @@ -244,11 +244,10 @@ impl Service for HydroflowCrateService {
serde_json::to_string::<InitConfig>(&(bind_config, self.meta.clone())).unwrap();

// request stdout before sending config so we don't miss the "ready" response
let stdout_receiver = binary.cli_stdout().await;
let stdout_receiver = binary.cli_stdout();

binary
.stdin()
.await
.send(format!("{formatted_bind_config}\n"))
.await?;

Expand Down Expand Up @@ -284,18 +283,12 @@ impl Service for HydroflowCrateService {

let formatted_defns = serde_json::to_string(&sink_ports).unwrap();

let stdout_receiver = self
.launched_binary
.as_deref_mut()
.unwrap()
.cli_stdout()
.await;
let stdout_receiver = self.launched_binary.as_ref().unwrap().cli_stdout();

self.launched_binary
.as_deref_mut()
.as_ref()
.unwrap()
.stdin()
.await
.send(format!("start: {formatted_defns}\n"))
.await
.unwrap();
Expand All @@ -315,14 +308,13 @@ impl Service for HydroflowCrateService {

async fn stop(&mut self) -> Result<()> {
self.launched_binary
.as_deref_mut()
.as_ref()
.unwrap()
.stdin()
.await
.send("stop\n".to_string())
.await?;

self.launched_binary.as_deref_mut().unwrap().wait().await;
self.launched_binary.as_mut().unwrap().wait().await;

Ok(())
}
Expand Down
10 changes: 5 additions & 5 deletions hydro_deploy/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,18 @@ pub struct ResourceResult {

#[async_trait]
pub trait LaunchedBinary: Send + Sync {
async fn stdin(&self) -> Sender<String>;
fn stdin(&self) -> Sender<String>;

/// Provides a oneshot channel for the CLI to handshake with the binary,
/// with the guarantee that as long as the CLI is holding on
/// to a handle, none of the messages will also be broadcast
/// to the user-facing [`LaunchedBinary::stdout`] channel.
async fn cli_stdout(&self) -> tokio::sync::oneshot::Receiver<String>;
fn cli_stdout(&self) -> tokio::sync::oneshot::Receiver<String>;

async fn stdout(&self) -> Receiver<String>;
async fn stderr(&self) -> Receiver<String>;
fn stdout(&self) -> Receiver<String>;
fn stderr(&self) -> Receiver<String>;

async fn exit_code(&self) -> Option<i32>;
fn exit_code(&self) -> Option<i32>;

async fn wait(&mut self) -> Option<i32>;
}
Expand Down
39 changes: 19 additions & 20 deletions hydro_deploy/core/src/localhost/launched_binary.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
#[cfg(unix)]
use std::os::unix::process::ExitStatusExt;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use async_channel::{Receiver, Sender};
use async_trait::async_trait;
use futures::io::BufReader;
use futures::{AsyncBufReadExt, AsyncWriteExt, StreamExt};
use tokio::sync::RwLock;

use crate::util::prioritized_broadcast;
use crate::LaunchedBinary;

pub struct LaunchedLocalhostBinary {
child: RwLock<async_process::Child>,
child: Mutex<async_process::Child>,
stdin_sender: Sender<String>,
stdout_cli_receivers: Arc<RwLock<Option<tokio::sync::oneshot::Sender<String>>>>,
stdout_receivers: Arc<RwLock<Vec<Sender<String>>>>,
stderr_receivers: Arc<RwLock<Vec<Sender<String>>>>,
stdout_cli_receivers: Arc<Mutex<Option<tokio::sync::oneshot::Sender<String>>>>,
stdout_receivers: Arc<Mutex<Vec<Sender<String>>>>,
stderr_receivers: Arc<Mutex<Vec<Sender<String>>>>,
}

#[cfg(unix)]
impl Drop for LaunchedLocalhostBinary {
fn drop(&mut self) {
let mut child = self.child.try_write().unwrap();
let mut child = self.child.lock().unwrap();

if let Ok(Some(_)) = child.try_status() {
return;
Expand Down Expand Up @@ -63,7 +62,7 @@ impl LaunchedLocalhostBinary {
);

Self {
child: RwLock::new(child),
child: Mutex::new(child),
stdin_sender,
stdout_cli_receivers,
stdout_receivers,
Expand All @@ -74,12 +73,12 @@ impl LaunchedLocalhostBinary {

#[async_trait]
impl LaunchedBinary for LaunchedLocalhostBinary {
async fn stdin(&self) -> Sender<String> {
fn stdin(&self) -> Sender<String> {
self.stdin_sender.clone()
}

async fn cli_stdout(&self) -> tokio::sync::oneshot::Receiver<String> {
let mut receivers = self.stdout_cli_receivers.write().await;
fn cli_stdout(&self) -> tokio::sync::oneshot::Receiver<String> {
let mut receivers = self.stdout_cli_receivers.lock().unwrap();

if receivers.is_some() {
panic!("Only one CLI stdout receiver is allowed at a time");
Expand All @@ -90,24 +89,24 @@ impl LaunchedBinary for LaunchedLocalhostBinary {
receiver
}

async fn stdout(&self) -> Receiver<String> {
let mut receivers = self.stdout_receivers.write().await;
fn stdout(&self) -> Receiver<String> {
let mut receivers = self.stdout_receivers.lock().unwrap();
let (sender, receiver) = async_channel::unbounded::<String>();
receivers.push(sender);
receiver
}

async fn stderr(&self) -> Receiver<String> {
let mut receivers = self.stderr_receivers.write().await;
fn stderr(&self) -> Receiver<String> {
let mut receivers = self.stderr_receivers.lock().unwrap();
let (sender, receiver) = async_channel::unbounded::<String>();
receivers.push(sender);
receiver
}

async fn exit_code(&self) -> Option<i32> {
fn exit_code(&self) -> Option<i32> {
self.child
.write()
.await
.lock()
.unwrap()
.try_status()
.ok()
.flatten()
Expand All @@ -120,7 +119,7 @@ impl LaunchedBinary for LaunchedLocalhostBinary {
}

async fn wait(&mut self) -> Option<i32> {
let _ = self.child.get_mut().status().await;
self.exit_code().await
let _ = self.child.get_mut().unwrap().status().await;
self.exit_code()
}
}
27 changes: 13 additions & 14 deletions hydro_deploy/core/src/ssh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use anyhow::{Context, Result};
Expand All @@ -15,7 +15,6 @@ use futures::{AsyncBufReadExt, AsyncWriteExt, StreamExt};
use hydroflow_cli_integration::ServerBindConfig;
use nanoid::nanoid;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::RwLock;

use super::progress::ProgressTracker;
use super::util::async_retry;
Expand All @@ -28,19 +27,19 @@ struct LaunchedSshBinary {
session: Option<AsyncSession<TcpStream>>,
channel: AsyncChannel<TcpStream>,
stdin_sender: Sender<String>,
stdout_receivers: Arc<RwLock<Vec<Sender<String>>>>,
stdout_cli_receivers: Arc<RwLock<Option<tokio::sync::oneshot::Sender<String>>>>,
stderr_receivers: Arc<RwLock<Vec<Sender<String>>>>,
stdout_receivers: Arc<Mutex<Vec<Sender<String>>>>,
stdout_cli_receivers: Arc<Mutex<Option<tokio::sync::oneshot::Sender<String>>>>,
stderr_receivers: Arc<Mutex<Vec<Sender<String>>>>,
}

#[async_trait]
impl LaunchedBinary for LaunchedSshBinary {
async fn stdin(&self) -> Sender<String> {
fn stdin(&self) -> Sender<String> {
self.stdin_sender.clone()
}

async fn cli_stdout(&self) -> tokio::sync::oneshot::Receiver<String> {
let mut receivers = self.stdout_cli_receivers.write().await;
fn cli_stdout(&self) -> tokio::sync::oneshot::Receiver<String> {
let mut receivers = self.stdout_cli_receivers.lock().unwrap();

if receivers.is_some() {
panic!("Only one CLI stdout receiver is allowed at a time");
Expand All @@ -51,21 +50,21 @@ impl LaunchedBinary for LaunchedSshBinary {
receiver
}

async fn stdout(&self) -> Receiver<String> {
let mut receivers = self.stdout_receivers.write().await;
fn stdout(&self) -> Receiver<String> {
let mut receivers = self.stdout_receivers.lock().unwrap();
let (sender, receiver) = async_channel::unbounded::<String>();
receivers.push(sender);
receiver
}

async fn stderr(&self) -> Receiver<String> {
let mut receivers = self.stderr_receivers.write().await;
fn stderr(&self) -> Receiver<String> {
let mut receivers = self.stderr_receivers.lock().unwrap();
let (sender, receiver) = async_channel::unbounded::<String>();
receivers.push(sender);
receiver
}

async fn exit_code(&self) -> Option<i32> {
fn exit_code(&self) -> Option<i32> {
// until the program exits, the exit status is meaningless
if self.channel.eof() {
self.channel.exit_status().ok()
Expand All @@ -76,7 +75,7 @@ impl LaunchedBinary for LaunchedSshBinary {

async fn wait(&mut self) -> Option<i32> {
self.channel.wait_eof().await.unwrap();
let ret = self.exit_code().await;
let ret = self.exit_code();
self.channel.wait_close().await.unwrap();
ret
}
Expand Down
46 changes: 26 additions & 20 deletions hydro_deploy/core/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::io;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::time::Duration;

use anyhow::Result;
use async_channel::Sender;
use futures::future::join_all;
use futures::{Future, StreamExt};
use futures_core::Stream;
use tokio::sync::RwLock;

pub async fn async_retry<T, F: Future<Output = Result<T>>>(
mut thunk: impl FnMut() -> F,
Expand All @@ -26,60 +26,66 @@ pub async fn async_retry<T, F: Future<Output = Result<T>>>(
}

type PriorityBroadcacst = (
Arc<RwLock<Option<tokio::sync::oneshot::Sender<String>>>>,
Arc<RwLock<Vec<Sender<String>>>>,
Arc<Mutex<Option<tokio::sync::oneshot::Sender<String>>>>,
Arc<Mutex<Vec<Sender<String>>>>,
);

pub fn prioritized_broadcast<T: Stream<Item = io::Result<String>> + Send + Unpin + 'static>(
mut lines: T,
default: impl Fn(String) + Send + 'static,
) -> PriorityBroadcacst {
let priority_receivers = Arc::new(RwLock::new(None::<tokio::sync::oneshot::Sender<String>>));
let receivers = Arc::new(RwLock::new(Vec::<Sender<String>>::new()));
let priority_receivers = Arc::new(Mutex::new(None::<tokio::sync::oneshot::Sender<String>>));
let receivers = Arc::new(Mutex::new(Vec::<Sender<String>>::new()));

let weak_priority_receivers = Arc::downgrade(&priority_receivers);
let weak_receivers = Arc::downgrade(&receivers);

tokio::spawn(async move {
'line_loop: while let Some(Result::Ok(line)) = lines.next().await {
while let Some(Result::Ok(line)) = lines.next().await {
if let Some(cli_receivers) = weak_priority_receivers.upgrade() {
let mut cli_receivers = cli_receivers.write().await;
let mut cli_receivers = cli_receivers.lock().unwrap();

let successful_send = if let Some(r) = cli_receivers.take() {
r.send(line.clone()).is_ok()
} else {
false
};
drop(cli_receivers);

if successful_send {
continue 'line_loop;
continue;
}
}

if let Some(receivers) = weak_receivers.upgrade() {
let mut receivers = receivers.write().await;
let mut successful_send = false;
for r in receivers.iter() {
successful_send |= r.send(line.clone()).await.is_ok();
}

receivers.retain(|r| !r.is_closed());
let send_all = {
let mut receivers = receivers.lock().unwrap();
receivers.retain(|receiver| !receiver.is_closed());
join_all(receivers.iter().map(|receiver| {
// Create a future which doesn't need to hold the `receivers` lock.
let receiver = receiver.clone();
let line = &line;
async move { receiver.send(line.clone()).await }
}))
// Do not `.await` while holding onto the `std::sync::Mutex` `receivers` lock.
};

let successful_send = send_all.await.into_iter().any(|result| result.is_ok());
if !successful_send {
default(line);
(default)(line);
}
} else {
break;
}
}

if let Some(cli_receivers) = weak_priority_receivers.upgrade() {
let mut cli_receivers = cli_receivers.write().await;
let mut cli_receivers = cli_receivers.lock().unwrap();
drop(cli_receivers.take());
}

if let Some(receivers) = weak_receivers.upgrade() {
let mut receivers = receivers.write().await;
let mut receivers = receivers.lock().unwrap();
receivers.clear();
}
});
Expand All @@ -98,7 +104,7 @@ mod test {

let (tx2, mut rx2) = async_channel::unbounded::<_>();

receivers.try_write().unwrap().push(tx2);
receivers.lock().unwrap().push(tx2);

tx.send(Ok("hello".to_string())).await.unwrap();
assert_eq!(rx2.next().await, Some("hello".to_string()));
Expand Down
Loading

0 comments on commit 141eae1

Please sign in to comment.