Skip to content

Commit

Permalink
Make OpenVPN monitor fully async
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 12, 2023
1 parent cfdfcb1 commit 52e68f0
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 119 deletions.
4 changes: 3 additions & 1 deletion talpid-core/src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,11 @@ enum InternalTunnelMonitor {

impl InternalTunnelMonitor {
fn wait(self) -> Result<()> {
let handle = tokio::runtime::Handle::current();

match self {
#[cfg(not(target_os = "android"))]
InternalTunnelMonitor::OpenVpn(tun) => tun.wait()?,
InternalTunnelMonitor::OpenVpn(tun) => handle.block_on(tun.wait())?,
InternalTunnelMonitor::Wireguard(tun) => tun.wait()?,
}

Expand Down
174 changes: 67 additions & 107 deletions talpid-openvpn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::{
process::ExitStatus,
sync::{
atomic::{AtomicBool, Ordering},
mpsc, Arc,
Arc,
},
time::Duration,
};
Expand Down Expand Up @@ -109,19 +109,8 @@ pub enum Error {
CredentialsWriteError(#[error(source)] io::Error),

/// Failures related to the proxy service.
#[error(display = "Unable to start the proxy service")]
StartProxyError(#[error(source)] proxy::Error),

/// Error while monitoring proxy service
#[error(display = "Error while monitoring proxy service")]
MonitorProxyError(#[error(source)] io::Error),

/// The proxy exited unexpectedly
#[error(
display = "The proxy exited unexpectedly providing these details: {}",
_0
)]
ProxyExited(String),
#[error(display = "Proxy service failed")]
ProxyError(#[error(source)] proxy::Error),

/// The map is missing 'dev'
#[cfg(target_os = "linux")]
Expand Down Expand Up @@ -159,12 +148,7 @@ const OPENVPN_BIN_FILENAME: &str = "openvpn.exe";
/// Struct for monitoring an OpenVPN process.
#[derive(Debug)]
pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> {
spawn_task: Option<
tokio::task::JoinHandle<
std::result::Result<io::Result<C::ProcessHandle>, futures::future::Aborted>,
>,
>,
abort_spawn: futures::future::AbortHandle,
prepare_task: tokio::task::JoinHandle<io::Result<C::ProcessHandle>>,

child: Arc<Mutex<Option<C::ProcessHandle>>>,
proxy_monitor: Option<Box<dyn ProxyMonitor>>,
Expand All @@ -174,9 +158,8 @@ pub struct OpenVpnMonitor<C: OpenVpnBuilder = OpenVpnCommand> {
/// Keep the 'TempFile' for the proxy user-pass file in the struct, so it's removed on drop.
_proxy_auth_file: Option<mktemp::TempFile>,

runtime: tokio::runtime::Handle,
event_server_abort_tx: triggered::Trigger,
server_join_handle: Option<task::JoinHandle<std::result::Result<(), event_server::Error>>>,
server_join_handle: task::JoinHandle<std::result::Result<(), event_server::Error>>,

#[cfg(windows)]
_wintun: Arc<Box<dyn WintunContext>>,
Expand Down Expand Up @@ -406,25 +389,22 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {

cmd.plugin(plugin_path, vec![ipc_path])
.log(log_path.as_deref());
let (spawn_task, abort_spawn) = futures::future::abortable(Self::prepare_process(
let prepare_task = tokio::spawn(Self::prepare_process(
cmd,
#[cfg(windows)]
wintun.clone(),
));
let spawn_task = tokio::spawn(spawn_task);

let monitor = OpenVpnMonitor {
spawn_task: Some(spawn_task),
abort_spawn,
prepare_task,
child: Arc::new(Mutex::new(None)),
proxy_monitor,
closed: Arc::new(AtomicBool::new(false)),
_user_pass_file: user_pass_file,
_proxy_auth_file: proxy_auth_file,

runtime: tokio::runtime::Handle::current(),
event_server_abort_tx,
server_join_handle: Some(server_join_handle),
server_join_handle,

#[cfg(windows)]
_wintun: wintun,
Expand Down Expand Up @@ -464,65 +444,39 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
fn close_handle(&self) -> OpenVpnCloseHandle<C::ProcessHandle> {
OpenVpnCloseHandle {
child: self.child.clone(),
abort_spawn: self.abort_spawn.clone(),
prepare_task: self.prepare_task.abort_handle(),
closed: self.closed.clone(),
}
}

/// Consumes the monitor and waits for both proxy and tunnel, as applicable.
pub fn wait(mut self) -> Result<()> {
pub async fn wait(mut self) -> Result<()> {
if let Some(mut proxy_monitor) = self.proxy_monitor.take() {
let (tx_tunnel, rx) = mpsc::channel();
let tx_proxy = tx_tunnel.clone();
let tunnel_close_handle = self.close_handle();
let proxy_close_handle = proxy_monitor.close_handle();

enum Stopped {
Tunnel(Result<()>),
Proxy(proxy::Result<()>),
}

let handle = self.runtime.clone();
handle.spawn(async move {
tx_tunnel
.send(Stopped::Tunnel(self.wait_tunnel().await))
.unwrap();
let tunnel_task = async move {
let result = self.wait_tunnel().await;
let _ = proxy_close_handle.close();
});

handle.spawn(async move {
tx_proxy
.send(Stopped::Proxy(proxy_monitor.wait().await))
.unwrap();
tunnel_close_handle.close().await
});

let result = rx.recv().expect("wait got no result");
let _ = rx.recv();

match result {
Stopped::Tunnel(tunnel_result) => tunnel_result,
Stopped::Proxy(proxy_result) => proxy_result.map_err(|error| match error {
proxy::Error::UnexpectedExit(details) => Error::ProxyExited(details),
proxy::Error::Io(error) => Error::MonitorProxyError(error),
}),
}
result
};

let proxy_task = async move {
let result = proxy_monitor.wait().await;
let _ = tunnel_close_handle.close().await;
result.map_err(Error::ProxyError)
};

join_return_first(tunnel_task, proxy_task).await
} else {
// No proxy active, wait only for the tunnel.
let handle = self.runtime.clone();
let (tx_tunnel, rx) = mpsc::channel();
handle.spawn(async move {
let x = self.wait_tunnel();
tx_tunnel.send(x.await).unwrap();
});
rx.recv().expect("wait_tunnel got no result")
self.wait_tunnel().await
}
}

/// Supplement `inner_wait_tunnel()` with logging and error handling.
async fn wait_tunnel(self) -> Result<()> {
let result = self.inner_wait_tunnel().await;
match result {
match self.inner_wait_tunnel().await {
WaitResult::Preparation(result) => match result {
Err(error) => {
log::debug!(
Expand Down Expand Up @@ -558,14 +512,8 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {

/// Waits for both the child process and the event dispatcher in parallel. After both have
/// returned this returns the earliest result.
async fn inner_wait_tunnel(mut self) -> WaitResult {
let child = match self
.spawn_task
.take()
.unwrap()
.await
.expect("spawn task panicked")
{
async fn inner_wait_tunnel(self) -> WaitResult {
let mut child = match self.prepare_task.await {
Ok(Ok(child)) => child,
Ok(Err(error)) => {
self.closed.swap(true, Ordering::SeqCst);
Expand All @@ -583,26 +531,18 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
self.child.lock().await.replace(child);
}

let event_server_abort_tx = self.event_server_abort_tx.clone();

let kill_child = async move {
let result = self.child.lock().await.as_ref().unwrap().wait().await;
let result = self.child.lock().await.take().unwrap().wait().await;
let closed = self.closed.load(Ordering::SeqCst);
let result = WaitResult::Child(result, closed);
event_server_abort_tx.trigger();
result
self.event_server_abort_tx.trigger();
WaitResult::Child(result, closed)
};
let kill_event_dispatcher = async move {
let server_join_handle = self
.server_join_handle
.take()
.expect("No event server quit handle");
let _ = server_join_handle.await;
let _ = self.server_join_handle.await;
WaitResult::EventDispatcher
};

let (result, _) = tokio::join!(kill_child, kill_event_dispatcher);
result
join_return_first(kill_child, kill_event_dispatcher).await
}

fn create_proxy_auth_file(
Expand All @@ -627,7 +567,7 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
if let Some(ref settings) = proxy_settings {
let proxy_monitor = proxy::start_proxy(settings, proxy_resources)
.await
.map_err(Error::StartProxyError)?;
.map_err(Error::ProxyError)?;
return Ok(Some(proxy_monitor));
}
Ok(None)
Expand Down Expand Up @@ -717,19 +657,19 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
}

/// A handle to an `OpenVpnMonitor` for closing it.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct OpenVpnCloseHandle<H: ProcessHandle = OpenVpnProcHandle> {
child: Arc<Mutex<Option<H>>>,
abort_spawn: futures::future::AbortHandle,
prepare_task: tokio::task::AbortHandle,
closed: Arc<AtomicBool>,
}

impl<H: ProcessHandle> OpenVpnCloseHandle<H> {
/// Kills the underlying OpenVPN process, making the `OpenVpnMonitor::wait` method return.
pub async fn close(self) -> io::Result<()> {
if !self.closed.swap(true, Ordering::SeqCst) {
self.abort_spawn.abort();
if let Some(child) = self.child.lock().await.as_ref() {
self.prepare_task.abort();
if let Some(child) = &mut *self.child.lock().await {
child.kill().await
} else {
Ok(())
Expand Down Expand Up @@ -771,10 +711,10 @@ pub trait OpenVpnBuilder {
#[async_trait::async_trait]
pub trait ProcessHandle: Send + Sync + 'static {
/// Block until the subprocess exits or there is an error in the wait syscall.
async fn wait(&self) -> io::Result<ExitStatus>;
async fn wait(&mut self) -> io::Result<ExitStatus>;

/// Kill the subprocess.
async fn kill(&self) -> io::Result<()>;
async fn kill(&mut self) -> io::Result<()>;
}

impl OpenVpnBuilder for OpenVpnCommand {
Expand Down Expand Up @@ -805,15 +745,35 @@ impl OpenVpnBuilder for OpenVpnCommand {

#[async_trait::async_trait]
impl ProcessHandle for OpenVpnProcHandle {
async fn wait(&self) -> io::Result<ExitStatus> {
async fn wait(&mut self) -> io::Result<ExitStatus> {
self.wait().await
}

async fn kill(&self) -> io::Result<()> {
async fn kill(&mut self) -> io::Result<()> {
self.nice_kill(OPENVPN_DIE_TIMEOUT).await
}
}

/// Join two futures and return the result of the first one to complete.
async fn join_return_first<R>(
future1: impl std::future::Future<Output = R>,
future2: impl std::future::Future<Output = R>,
) -> R {
futures::pin_mut!(future1);
futures::pin_mut!(future2);

match futures::future::select(future1, future2).await {
futures::future::Either::Left((result, other)) => {
let _ = other.await;
result
}
futures::future::Either::Right((result, other)) => {
let _ = other.await;
result
}
}
}

mod event_server {
use futures::stream::TryStreamExt;
use parity_tokio_ipc::Endpoint as IpcEndpoint;
Expand Down Expand Up @@ -1217,18 +1177,18 @@ mod tests {
#[async_trait::async_trait]
impl ProcessHandle for TestProcessHandle {
#[cfg(unix)]
async fn wait(&self) -> io::Result<ExitStatus> {
async fn wait(&mut self) -> io::Result<ExitStatus> {
use std::os::unix::process::ExitStatusExt;
Ok(ExitStatus::from_raw(self.0))
}

#[cfg(windows)]
async fn wait(&self) -> io::Result<ExitStatus> {
async fn wait(&mut self) -> io::Result<ExitStatus> {
use std::os::windows::process::ExitStatusExt;
Ok(ExitStatus::from_raw(self.0 as u32))
}

async fn kill(&self) -> io::Result<()> {
async fn kill(&mut self) -> io::Result<()> {
Ok(())
}
}
Expand Down Expand Up @@ -1307,7 +1267,7 @@ mod tests {
Box::new(TestWintunContext {}),
)
.unwrap();
assert!(testee.wait().is_ok());
assert!(testee.wait().await.is_ok());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand All @@ -1325,7 +1285,7 @@ mod tests {
Box::new(TestWintunContext {}),
)
.unwrap();
assert!(testee.wait().is_err());
assert!(testee.wait().await.is_err());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand All @@ -1345,7 +1305,7 @@ mod tests {
.unwrap();

testee.close_handle().close().await.unwrap();
let result = testee.wait();
let result = testee.wait().await;
println!("[testee.wait(): {:?}]", result);
assert!(result.is_ok());
}
Expand All @@ -1362,7 +1322,7 @@ mod tests {
Box::new(TestWintunContext {}),
)
.unwrap();
match result.wait() {
match result.wait().await {
Err(Error::StartProcessError) => (),
_ => panic!("Wrong error"),
}
Expand Down
Loading

0 comments on commit 52e68f0

Please sign in to comment.