Skip to content

Commit

Permalink
Simplify process handle in talpid-openvpn
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 12, 2023
1 parent 52e68f0 commit b77b02f
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 138 deletions.
106 changes: 46 additions & 60 deletions talpid-openvpn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@ use std::{
io::{self, Write},
path::{Path, PathBuf},
process::ExitStatus,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
sync::{Arc, Mutex},
time::Duration,
};
#[cfg(target_os = "linux")]
use talpid_routing::{self, RequiredRoute};
use talpid_tunnel::TunnelEvent;
use talpid_types::{net::openvpn, ErrorExt};
use tokio::{sync::Mutex, task};
use tokio::task;

#[cfg(windows)]
use widestring::U16CString;
Expand Down Expand Up @@ -150,9 +147,7 @@ const OPENVPN_BIN_FILENAME: &str = "openvpn.exe";
pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> {
prepare_task: tokio::task::JoinHandle<io::Result<C::ProcessHandle>>,

child: Arc<Mutex<Option<C::ProcessHandle>>>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
closed: Arc<AtomicBool>,
/// Keep the `TempFile` for the user-pass file in the struct, so it's removed on drop.
_user_pass_file: mktemp::TempFile,
/// Keep the 'TempFile' for the proxy user-pass file in the struct, so it's removed on drop.
Expand All @@ -161,6 +156,9 @@ pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> {
event_server_abort_tx: triggered::Trigger,
server_join_handle: task::JoinHandle<std::result::Result<(), event_server::Error>>,

monitor_abort_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
monitor_abort_rx: oneshot::Receiver<()>,

#[cfg(windows)]
_wintun: Arc<Box<dyn WintunContext>>,
}
Expand Down Expand Up @@ -395,30 +393,28 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
wintun.clone(),
));

let (monitor_abort_tx, monitor_abort_rx) = oneshot::channel();

let monitor = OpenVpnMonitor {
prepare_task,
child: Arc::new(Mutex::new(None)),
proxy_monitor,
closed: Arc::new(AtomicBool::new(false)),
_user_pass_file: user_pass_file,
_proxy_auth_file: proxy_auth_file,

event_server_abort_tx,
server_join_handle,

monitor_abort_tx: Arc::new(Mutex::new(Some(monitor_abort_tx))),
monitor_abort_rx,

#[cfg(windows)]
_wintun: wintun,
};

let close_handle = monitor.close_handle();
tokio::spawn(async move {
if tunnel_close_rx.await.is_ok() {
if let Err(error) = close_handle.close().await {
log::error!(
"{}",
error.display_chain_with_msg("Failed to close the tunnel")
);
}
close_handle.close();
}
});

Expand All @@ -441,11 +437,10 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {

/// Creates a handle to this monitor, allowing the tunnel to be closed while some other
/// thread is blocked in `wait`.
fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> {
fn close_handle(&self) -> OpenVpnCloseHandle {
OpenVpnCloseHandle {
child: self.child.clone(),
monitor_abort_tx: self.monitor_abort_tx.clone(),
prepare_task: self.prepare_task.abort_handle(),
closed: self.closed.clone(),
}
}

Expand All @@ -463,7 +458,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {

let proxy_task = async move {
let result = proxy_monitor.wait().await;
let _ = tunnel_close_handle.close().await;
let _ = tunnel_close_handle.close();
result.map_err(Error::ProxyError)
};

Expand All @@ -487,8 +482,8 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
}
_ => Ok(()),
},
WaitResult::Child(Ok(exit_status), closed) => {
if exit_status.success() || closed {
WaitResult::Child(Ok(exit_status)) => {
if exit_status.success() {
log::debug!(
"OpenVPN exited, as expected, with exit status: {}",
exit_status
Expand All @@ -499,7 +494,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
Err(Error::ChildProcessDied)
}
}
WaitResult::Child(Err(e), _) => {
WaitResult::Child(Err(e)) => {
log::error!("OpenVPN process wait error: {}", e);
Err(Error::ChildProcessError("Error when waiting", e))
}
Expand All @@ -516,26 +511,26 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
let mut child = match self.prepare_task.await {
Ok(Ok(child)) => child,
Ok(Err(error)) => {
self.closed.swap(true, Ordering::SeqCst);
return WaitResult::Preparation(Err(error));
}
Err(_) => return WaitResult::Preparation(Ok(())),
};

if self.closed.load(Ordering::SeqCst) {
let _ = child.kill().await;
return WaitResult::Preparation(Ok(()));
}

{
self.child.lock().await.replace(child);
}

let kill_child = async move {
let result = self.child.lock().await.take().unwrap().wait().await;
let closed = self.closed.load(Ordering::SeqCst);
let result = tokio::select! {
result = child.wait() => {
log::debug!("OpenVPN process exited");
result
}
_ = self.monitor_abort_rx => {
log::debug!("Killing OpenVPN process");
child.kill();
child.wait().await
}
};

self.event_server_abort_tx.trigger();
WaitResult::Child(result, closed)
WaitResult::Child(result)
};
let kill_event_dispatcher = async move {
let _ = self.server_join_handle.await;
Expand Down Expand Up @@ -658,24 +653,17 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {

/// A handle to an `OpenVpnMonitor` for closing it.
#[derive(Debug)]
pub struct OpenVpnCloseHandle<H: ProcessHandle = OpenVpnProcHandle> {
child: Arc<Mutex<Option<H>>>,
pub struct OpenVpnCloseHandle {
monitor_abort_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
prepare_task: tokio::task::AbortHandle,
closed: Arc<AtomicBool>,
}

impl<H: ProcessHandle> OpenVpnCloseHandle<H> {
/// Kills the underlying OpenVPN process, making the `OpenVpnMonitor::wait` method return.
pub async fn close(self) -> io::Result<()> {
if !self.closed.swap(true, Ordering::SeqCst) {
self.prepare_task.abort();
if let Some(child) = &mut *self.child.lock().await {
child.kill().await
} else {
Ok(())
}
} else {
Ok(())
impl OpenVpnCloseHandle {
/// Begin killing the OpenVPN monitor, making the `OpenVpnMonitor::wait` method return.
pub fn close(self) {
self.prepare_task.abort();
if let Some(tx) = self.monitor_abort_tx.lock().unwrap().take() {
let _ = tx.send(());
}
}
}
Expand All @@ -684,7 +672,7 @@ impl<H: ProcessHandle> OpenVpnCloseHandle<H> {
#[derive(Debug)]
enum WaitResult {
Preparation(io::Result<()>),
Child(io::Result<ExitStatus>, bool),
Child(io::Result<ExitStatus>),
EventDispatcher,
}

Expand Down Expand Up @@ -713,8 +701,8 @@ pub trait ProcessHandle: Send + Sync + 'static {
/// Block until the subprocess exits or there is an error in the wait syscall.
async fn wait(&mut self) -> io::Result<ExitStatus>;

/// Kill the subprocess.
async fn kill(&mut self) -> io::Result<()>;
/// Kill the subprocess without waiting for it to complete.
fn kill(&mut self);
}

impl OpenVpnBuilder for OpenVpnCommand {
Expand Down Expand Up @@ -746,11 +734,11 @@ impl OpenVpnBuilder for OpenVpnCommand {
#[async_trait::async_trait]
impl ProcessHandle for OpenVpnProcHandle {
async fn wait(&mut self) -> io::Result<ExitStatus> {
self.wait().await
OpenVpnProcHandle::wait(self).await
}

async fn kill(&mut self) -> io::Result<()> {
self.nice_kill(OPENVPN_DIE_TIMEOUT).await
fn kill(&mut self) {
OpenVpnProcHandle::kill(self, OPENVPN_DIE_TIMEOUT)
}
}

Expand Down Expand Up @@ -1188,9 +1176,7 @@ mod tests {
Ok(ExitStatus::from_raw(self.0 as u32))
}

async fn kill(&mut self) -> io::Result<()> {
Ok(())
}
fn kill(&mut self) {}
}

fn create_init_args_plugin_log(
Expand Down Expand Up @@ -1304,7 +1290,7 @@ mod tests {
)
.unwrap();

testee.close_handle().close().await.unwrap();
testee.close_handle().close();
let result = testee.wait().await;
println!("[testee.wait(): {:?}]", result);
assert!(result.is_ok());
Expand Down
Loading

0 comments on commit b77b02f

Please sign in to comment.