diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index 290e51afc18a..f32a6fe6da02 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -17,17 +17,14 @@ use std::{ io::{self, Write}, path::{Path, PathBuf}, process::ExitStatus, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::{Arc, Mutex}, time::Duration, }; #[cfg(target_os = "linux")] use talpid_routing::{self, RequiredRoute}; use talpid_tunnel::TunnelEvent; use talpid_types::{net::openvpn, ErrorExt}; -use tokio::{sync::Mutex, task}; +use tokio::task; #[cfg(windows)] use widestring::U16CString; @@ -150,9 +147,7 @@ const OPENVPN_BIN_FILENAME: &str = "openvpn.exe"; pub struct OpenVpnMonitor { prepare_task: tokio::task::JoinHandle>, - child: Arc>>, proxy_monitor: Option>, - closed: Arc, /// Keep the `TempFile` for the user-pass file in the struct, so it's removed on drop. _user_pass_file: mktemp::TempFile, /// Keep the 'TempFile' for the proxy user-pass file in the struct, so it's removed on drop. @@ -161,6 +156,9 @@ pub struct OpenVpnMonitor { event_server_abort_tx: triggered::Trigger, server_join_handle: task::JoinHandle>, + monitor_abort_tx: Arc>>>, + monitor_abort_rx: oneshot::Receiver<()>, + #[cfg(windows)] _wintun: Arc>, } @@ -395,17 +393,20 @@ impl OpenVpnMonitor { wintun.clone(), )); + let (monitor_abort_tx, monitor_abort_rx) = oneshot::channel(); + let monitor = OpenVpnMonitor { prepare_task, - child: Arc::new(Mutex::new(None)), proxy_monitor, - closed: Arc::new(AtomicBool::new(false)), _user_pass_file: user_pass_file, _proxy_auth_file: proxy_auth_file, event_server_abort_tx, server_join_handle, + monitor_abort_tx: Arc::new(Mutex::new(Some(monitor_abort_tx))), + monitor_abort_rx, + #[cfg(windows)] _wintun: wintun, }; @@ -413,12 +414,7 @@ impl OpenVpnMonitor { let close_handle = monitor.close_handle(); tokio::spawn(async move { if tunnel_close_rx.await.is_ok() { - if let Err(error) = close_handle.close().await { - log::error!( - "{}", - error.display_chain_with_msg("Failed to close the tunnel") - ); - } + close_handle.close(); } }); @@ -441,11 +437,10 @@ impl OpenVpnMonitor { /// Creates a handle to this monitor, allowing the tunnel to be closed while some other /// thread is blocked in `wait`. - fn close_handle(&self) -> OpenVpnCloseHandle { + fn close_handle(&self) -> OpenVpnCloseHandle { OpenVpnCloseHandle { - child: self.child.clone(), + monitor_abort_tx: self.monitor_abort_tx.clone(), prepare_task: self.prepare_task.abort_handle(), - closed: self.closed.clone(), } } @@ -463,7 +458,7 @@ impl OpenVpnMonitor { let proxy_task = async move { let result = proxy_monitor.wait().await; - let _ = tunnel_close_handle.close().await; + let _ = tunnel_close_handle.close(); result.map_err(Error::ProxyError) }; @@ -487,8 +482,8 @@ impl OpenVpnMonitor { } _ => Ok(()), }, - WaitResult::Child(Ok(exit_status), closed) => { - if exit_status.success() || closed { + WaitResult::Child(Ok(exit_status)) => { + if exit_status.success() { log::debug!( "OpenVPN exited, as expected, with exit status: {}", exit_status @@ -499,7 +494,7 @@ impl OpenVpnMonitor { Err(Error::ChildProcessDied) } } - WaitResult::Child(Err(e), _) => { + WaitResult::Child(Err(e)) => { log::error!("OpenVPN process wait error: {}", e); Err(Error::ChildProcessError("Error when waiting", e)) } @@ -516,26 +511,26 @@ impl OpenVpnMonitor { let mut child = match self.prepare_task.await { Ok(Ok(child)) => child, Ok(Err(error)) => { - self.closed.swap(true, Ordering::SeqCst); return WaitResult::Preparation(Err(error)); } Err(_) => return WaitResult::Preparation(Ok(())), }; - if self.closed.load(Ordering::SeqCst) { - let _ = child.kill().await; - return WaitResult::Preparation(Ok(())); - } - - { - self.child.lock().await.replace(child); - } - let kill_child = async move { - let result = self.child.lock().await.take().unwrap().wait().await; - let closed = self.closed.load(Ordering::SeqCst); + let result = tokio::select! { + result = child.wait() => { + log::debug!("OpenVPN process exited"); + result + } + _ = self.monitor_abort_rx => { + log::debug!("Killing OpenVPN process"); + child.kill(); + child.wait().await + } + }; + self.event_server_abort_tx.trigger(); - WaitResult::Child(result, closed) + WaitResult::Child(result) }; let kill_event_dispatcher = async move { let _ = self.server_join_handle.await; @@ -658,24 +653,17 @@ impl OpenVpnMonitor { /// A handle to an `OpenVpnMonitor` for closing it. #[derive(Debug)] -pub struct OpenVpnCloseHandle { - child: Arc>>, +pub struct OpenVpnCloseHandle { + monitor_abort_tx: Arc>>>, prepare_task: tokio::task::AbortHandle, - closed: Arc, } -impl OpenVpnCloseHandle { - /// Kills the underlying OpenVPN process, making the `OpenVpnMonitor::wait` method return. - pub async fn close(self) -> io::Result<()> { - if !self.closed.swap(true, Ordering::SeqCst) { - self.prepare_task.abort(); - if let Some(child) = &mut *self.child.lock().await { - child.kill().await - } else { - Ok(()) - } - } else { - Ok(()) +impl OpenVpnCloseHandle { + /// Begin killing the OpenVPN monitor, making the `OpenVpnMonitor::wait` method return. + pub fn close(self) { + self.prepare_task.abort(); + if let Some(tx) = self.monitor_abort_tx.lock().unwrap().take() { + let _ = tx.send(()); } } } @@ -684,7 +672,7 @@ impl OpenVpnCloseHandle { #[derive(Debug)] enum WaitResult { Preparation(io::Result<()>), - Child(io::Result, bool), + Child(io::Result), EventDispatcher, } @@ -713,8 +701,8 @@ pub trait ProcessHandle: Send + Sync + 'static { /// Block until the subprocess exits or there is an error in the wait syscall. async fn wait(&mut self) -> io::Result; - /// Kill the subprocess. - async fn kill(&mut self) -> io::Result<()>; + /// Kill the subprocess without waiting for it to complete. + fn kill(&mut self); } impl OpenVpnBuilder for OpenVpnCommand { @@ -746,11 +734,11 @@ impl OpenVpnBuilder for OpenVpnCommand { #[async_trait::async_trait] impl ProcessHandle for OpenVpnProcHandle { async fn wait(&mut self) -> io::Result { - self.wait().await + OpenVpnProcHandle::wait(self).await } - async fn kill(&mut self) -> io::Result<()> { - self.nice_kill(OPENVPN_DIE_TIMEOUT).await + fn kill(&mut self) { + OpenVpnProcHandle::kill(self, OPENVPN_DIE_TIMEOUT) } } @@ -1188,9 +1176,7 @@ mod tests { Ok(ExitStatus::from_raw(self.0 as u32)) } - async fn kill(&mut self) -> io::Result<()> { - Ok(()) - } + fn kill(&mut self) {} } fn create_init_args_plugin_log( @@ -1304,7 +1290,7 @@ mod tests { ) .unwrap(); - testee.close_handle().close().await.unwrap(); + testee.close_handle().close(); let result = testee.wait().await; println!("[testee.wait(): {:?}]", result); assert!(result.is_ok()); diff --git a/talpid-openvpn/src/process/openvpn.rs b/talpid-openvpn/src/process/openvpn.rs index 70af744ec418..41c103c86107 100644 --- a/talpid-openvpn/src/process/openvpn.rs +++ b/talpid-openvpn/src/process/openvpn.rs @@ -1,12 +1,13 @@ -use os_pipe::{pipe, PipeWriter}; -use parking_lot::Mutex; +use futures::channel::oneshot; use shell_escape; use std::{ ffi::{OsStr, OsString}, fmt, io, path::{Path, PathBuf}, + process::Stdio, + time::Duration, }; -use talpid_types::{net, ErrorExt}; +use talpid_types::net; static BASE_ARGUMENTS: &[&[&str]] = &[ &["--client"], @@ -364,16 +365,8 @@ impl fmt::Display for OpenVpnCommand { /// Handle to a running OpenVPN process. pub struct OpenVpnProcHandle { - /// Handle to the child process running OpenVPN. - /// - /// This handle is acquired by calling [`OpenVpnCommand::build`] (or - /// [`tokio::process::Command::spawn`]). - pub inner: tokio::process::Child, - /// Pipe handle to stdin of the OpenVPN process. Our custom fork of OpenVPN - /// has been changed so that it exits cleanly when stdin is closed. This is a hack - /// solution to cleanly shut OpenVPN down without using the - /// management interface (which would be the correct thing to do). - pub stdin: Mutex>, + stop_tx: Option>, + proc: tokio::task::JoinHandle>, } impl OpenVpnProcHandle { @@ -390,81 +383,75 @@ impl OpenVpnProcHandle { cmd = cmd.stderr(std::process::Stdio::null()) } - let (reader, writer) = pipe()?; - let proc_handle = cmd.stdin(reader).spawn()?; + let mut proc_handle = cmd.stdin(Stdio::piped()).spawn()?; - Ok(Self { - inner: proc_handle, - stdin: Mutex::new(Some(writer)), - }) - } + let (stop_tx, mut stop_rx) = oneshot::channel(); - /// Attempts to stop the OpenVPN process gracefully in the given time - /// period, otherwise kills the process. - pub async fn nice_kill(&mut self, timeout: std::time::Duration) -> io::Result<()> { - log::debug!("Trying to stop child process gracefully"); - self.stop().await; - - // Wait for the process to die for a maximum of `timeout`. - let wait_result = tokio::time::timeout(timeout, self.wait()).await; - match wait_result { - Ok(_) => log::debug!("Child process terminated gracefully"), - Err(_) => { - log::warn!( - "Child process did not terminate gracefully within timeout, forcing termination" - ); - self.kill().await?; - } - } - Ok(()) - } + let proc = tokio::spawn(async move { + let stdin = proc_handle.stdin.take().expect("expected stdin handle"); - /// Waits for the child to exit completely, returning the status that it - /// exited with. See [tokio::process::Child::wait] for in-depth - /// documentation. - async fn wait(&mut self) -> io::Result { - self.inner.wait().await - } + tokio::select! { + timeout = &mut stop_rx => { + // Dropping our stdin handle so that it is closed once. Closing the handle should + // gracefully stop our OpenVPN child process. This only works because our OpenVPN + // fork expects this. + let _ = drop(stdin); - /// Kill the OpenVPN process and drop its stdin handle. - async fn stop(&mut self) { - // Dropping our stdin handle so that it is closed once. Closing the handle should - // gracefully stop our OpenVPN child process. - if self.stdin.lock().take().is_none() { - log::warn!("Tried to close OpenVPN stdin handle twice, this is a bug"); - } - self.clean_up().await - } + if let Ok(timeout) = timeout { + // + // Controlled shutdown using nice_kill() + // - async fn kill(&mut self) -> io::Result<()> { - log::warn!("Killing OpenVPN process"); - self.inner.kill().await?; - log::debug!("OpenVPN forcefully killed"); - Ok(()) - } + log::debug!("Trying to stop child process gracefully"); - async fn has_stopped(&mut self) -> io::Result { - let exit_status = self.inner.try_wait()?; - Ok(exit_status.is_some()) - } + match tokio::time::timeout(timeout, proc_handle.wait()).await { + Ok(_) => log::debug!("Child process terminated gracefully"), + Err(_) => { + log::warn!( + "Child process did not terminate gracefully within timeout, forcing termination" + ); + proc_handle.kill().await?; + } + } + } else { + // + // If the abort channel is just dropped, kill the process immediately. + // + log::debug!("Killing OpenVPN process forcefully"); + let _ = proc_handle.kill().await; + } + + proc_handle.wait().await + } - /// Try to kill the OpenVPN process. - async fn clean_up(&mut self) { - let result = match self.has_stopped().await { - Ok(false) => self.kill().await, - Err(e) => { - log::error!( - "{}", - e.display_chain_with_msg("Failed to check if OpenVPN is running") - ); - self.kill().await + // + // If the process exits on its own, we're also done. + // + result = proc_handle.wait() => { + log::debug!("OpenVPN process terminated"); + result + } } - _ => Ok(()), - }; - if let Err(error) = result { - log::error!("{}", error.display_chain_with_msg("Failed to kill OpenVPN")); + }); + + Ok(Self { + stop_tx: Some(stop_tx), + proc, + }) + } + + /// Begins to kill the process, causing `wait()` to return. This function does not wait for the operation + /// to complete. + pub fn kill(&mut self, timeout: std::time::Duration) { + if let Some(tx) = self.stop_tx.take() { + let _ = tx.send(timeout); } } + + /// Waits for the child to exit completely. + pub async fn wait(&mut self) -> io::Result { + (&mut self.proc).await.expect("openvpn task panicked") + } } #[cfg(test)]