Skip to content

Commit

Permalink
Remove dependency on duct
Browse files Browse the repository at this point in the history
Remove the dependency on `duct` from `talpid-openvpn`, since we can use
`tokio` to spawn processes instead.
  • Loading branch information
MarkusPettersson98 committed Oct 10, 2023
1 parent 014c00e commit 4f15f07
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 111 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion talpid-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ edition.workspace = true
publish.workspace = true

[dependencies]
duct = "0.13"
err-derive = { workspace = true }
futures = "0.3.15"
ipnetwork = "0.16"
Expand Down Expand Up @@ -42,6 +41,7 @@ nftnl = { version = "0.6.2", features = ["nftnl-1-1-0"] }
mnl = { version = "0.2.2", features = ["mnl-1-0-4"] }
which = { version = "4.0", default-features = false }
talpid-dbus = { path = "../talpid-dbus" }
duct = "0.13"


[target.'cfg(target_os = "macos")'.dependencies]
Expand All @@ -51,6 +51,7 @@ trust-dns-server = { version = "0.23.0", features = ["resolver"] }
trust-dns-proto = "0.23.0"
subslice = "0.2"
async-trait = "0.1"
duct = "0.13"


[target.'cfg(windows)'.dependencies]
Expand Down
1 change: 0 additions & 1 deletion talpid-openvpn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ publish.workspace = true

[dependencies]
async-trait = "0.1"
duct = "0.13"
err-derive = { workspace = true }
futures = "0.3.15"
once_cell = { workspace = true }
Expand Down
130 changes: 66 additions & 64 deletions talpid-openvpn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ use std::{
process::ExitStatus,
sync::{
atomic::{AtomicBool, Ordering},
mpsc, Arc, Mutex,
mpsc, Arc,
},
thread,
time::Duration,
};
#[cfg(target_os = "linux")]
use talpid_routing::{self, RequiredRoute};
use talpid_tunnel::TunnelEvent;
use talpid_types::{net::openvpn, ErrorExt};
use tokio::task;
use tokio::{sync::Mutex, task};

#[cfg(windows)]
use widestring::U16CString;
Expand Down Expand Up @@ -437,16 +436,12 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
let close_handle = monitor.close_handle();
tokio::spawn(async move {
if tunnel_close_rx.await.is_ok() {
tokio::task::spawn_blocking(move || {
if let Err(error) = close_handle.close() {
log::error!(
"{}",
error.display_chain_with_msg("Failed to close the tunnel")
);
}
})
.await
.expect("close handle panic");
if let Err(error) = close_handle.close().await {
log::error!(
"{}",
error.display_chain_with_msg("Failed to close the tunnel")
);
}
}
});

Expand Down Expand Up @@ -491,17 +486,18 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
}

let handle = self.runtime.clone();

thread::spawn(move || {
tx_tunnel.send(Stopped::Tunnel(self.wait_tunnel())).unwrap();
handle.spawn(async move {
tx_tunnel
.send(Stopped::Tunnel(self.wait_tunnel().await))
.unwrap();
let _ = proxy_close_handle.close();
});

thread::spawn(move || {
handle.spawn(async move {
tx_proxy
.send(Stopped::Proxy(handle.block_on(proxy_monitor.wait())))
.send(Stopped::Proxy(proxy_monitor.wait().await))
.unwrap();
let _ = tunnel_close_handle.close();
tunnel_close_handle.close().await
});

let result = rx.recv().expect("wait got no result");
Expand All @@ -516,13 +512,19 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
}
} else {
// No proxy active, wait only for the tunnel.
self.wait_tunnel()
let handle = self.runtime.clone();
let (tx_tunnel, rx) = mpsc::channel();
handle.spawn(async move {
let x = self.wait_tunnel();
tx_tunnel.send(x.await).unwrap();
});
rx.recv().expect("wait_tunnel got no result")
}
}

/// Supplement `inner_wait_tunnel()` with logging and error handling.
fn wait_tunnel(self) -> Result<()> {
let result = self.inner_wait_tunnel();
async fn wait_tunnel(self) -> Result<()> {
let result = self.inner_wait_tunnel().await;
match result {
WaitResult::Preparation(result) => match result {
Err(error) => {
Expand Down Expand Up @@ -559,10 +561,12 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {

/// Waits for both the child process and the event dispatcher in parallel. After both have
/// returned this returns the earliest result.
fn inner_wait_tunnel(mut self) -> WaitResult {
async fn inner_wait_tunnel(mut self) -> WaitResult {
let child = match self
.runtime
.block_on(self.spawn_task.take().unwrap())
.spawn_task
.take()
.unwrap()
.await
.expect("spawn task panicked")
{
Ok(Ok(child)) => Arc::new(child),
Expand All @@ -574,41 +578,33 @@ impl<C: OpenVpnBuilder + Send + 'static> OpenVpnMonitor<C> {
};

if self.closed.load(Ordering::SeqCst) {
let _ = child.kill();
let _ = child.kill().await;
return WaitResult::Preparation(Ok(()));
}

{
self.child.lock().unwrap().replace(child.clone());
self.child.lock().await.replace(child.clone());
}

let closed_handle = self.closed.clone();
let child_close_handle = self.close_handle();

let (child_tx, rx) = mpsc::channel();
let dispatcher_tx = child_tx.clone();

let event_server_abort_tx = self.event_server_abort_tx.clone();

thread::spawn(move || {
let result = child.wait();
let closed = closed_handle.load(Ordering::SeqCst);
child_tx.send(WaitResult::Child(result, closed)).unwrap();
let kill_child = async move {
let result = child.wait().await;
let closed = self.closed.load(Ordering::SeqCst);
let result = WaitResult::Child(result, closed);
event_server_abort_tx.trigger();
});

let server_join_handle = self
.server_join_handle
.take()
.expect("No event server quit handle");
self.runtime.spawn(async move {
result
};
let kill_event_dispatcher = async move {
let server_join_handle = self
.server_join_handle
.take()
.expect("No event server quit handle");
let _ = server_join_handle.await;
dispatcher_tx.send(WaitResult::EventDispatcher).unwrap();
let _ = child_close_handle.close();
});
WaitResult::EventDispatcher
};

let result = rx.recv().expect("inner_wait_tunnel no result");
let _ = rx.recv().expect("inner_wait_tunnel no second result");
let (result, _) = tokio::join!(kill_child, kill_event_dispatcher);
result
}

Expand Down Expand Up @@ -733,11 +729,11 @@ pub struct OpenVpnCloseHandle<H: ProcessHandle = OpenVpnProcHandle> {

impl<H: ProcessHandle> OpenVpnCloseHandle<H> {
/// Kills the underlying OpenVPN process, making the `OpenVpnMonitor::wait` method return.
pub fn close(self) -> io::Result<()> {
pub async fn close(self) -> io::Result<()> {
if !self.closed.swap(true, Ordering::SeqCst) {
self.abort_spawn.abort();
if let Some(child) = self.child.lock().unwrap().as_ref() {
child.kill()
if let Some(child) = self.child.lock().await.as_ref() {
child.kill().await
} else {
Ok(())
}
Expand Down Expand Up @@ -775,12 +771,13 @@ pub trait OpenVpnBuilder {
}

/// Trait for types acting as handles to subprocesses for `OpenVpnMonitor`
#[async_trait::async_trait]
pub trait ProcessHandle: Send + Sync + 'static {
/// Block until the subprocess exits or there is an error in the wait syscall.
fn wait(&self) -> io::Result<ExitStatus>;
async fn wait(&self) -> io::Result<ExitStatus>;

/// Kill the subprocess.
fn kill(&self) -> io::Result<()>;
async fn kill(&self) -> io::Result<()>;
}

impl OpenVpnBuilder for OpenVpnCommand {
Expand All @@ -799,7 +796,7 @@ impl OpenVpnBuilder for OpenVpnCommand {
}

fn start(&self) -> io::Result<OpenVpnProcHandle> {
OpenVpnProcHandle::new(self.build())
OpenVpnProcHandle::new(&mut self.build())
}

#[cfg(target_os = "linux")]
Expand All @@ -809,13 +806,14 @@ impl OpenVpnBuilder for OpenVpnCommand {
}
}

#[async_trait::async_trait]
impl ProcessHandle for OpenVpnProcHandle {
fn wait(&self) -> io::Result<ExitStatus> {
self.inner.wait().map(|output| output.status)
async fn wait(&self) -> io::Result<ExitStatus> {
self.inner.lock().await.wait().await
}

fn kill(&self) -> io::Result<()> {
self.nice_kill(OPENVPN_DIE_TIMEOUT)
async fn kill(&self) -> io::Result<()> {
self.nice_kill(OPENVPN_DIE_TIMEOUT).await
}
}

Expand Down Expand Up @@ -1219,20 +1217,21 @@ mod tests {
#[derive(Debug, Copy, Clone)]
struct TestProcessHandle(i32);

#[async_trait::async_trait]
impl ProcessHandle for TestProcessHandle {
#[cfg(unix)]
fn wait(&self) -> io::Result<ExitStatus> {
async fn wait(&self) -> io::Result<ExitStatus> {
use std::os::unix::process::ExitStatusExt;
Ok(ExitStatus::from_raw(self.0))
}

#[cfg(windows)]
fn wait(&self) -> io::Result<ExitStatus> {
async fn wait(&self) -> io::Result<ExitStatus> {
use std::os::windows::process::ExitStatusExt;
Ok(ExitStatus::from_raw(self.0 as u32))
}

fn kill(&self) -> io::Result<()> {
async fn kill(&self) -> io::Result<()> {
Ok(())
}
}
Expand Down Expand Up @@ -1374,8 +1373,11 @@ mod tests {
})
.unwrap();

testee.close_handle().close().unwrap();
assert!(testee.wait().is_ok());
// TODO: Remove this?
runtime.block_on(testee.close_handle().close()).unwrap();
let result = testee.wait();
println!("[testee.wait(): {:?}]", result);
assert!(result.is_ok());
}

#[test]
Expand Down
Loading

0 comments on commit 4f15f07

Please sign in to comment.