diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index c8057e536b13..d665c54c5d33 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -1174,20 +1174,47 @@ mod tests { } #[derive(Debug, Copy, Clone)] - struct TestProcessHandle(i32); + struct TestProcessHandle { + exit_code: i32, + forever: bool, + } - #[async_trait::async_trait] - impl ProcessHandle for TestProcessHandle { - #[cfg(unix)] - async fn wait(&mut self) -> io::Result { - use std::os::unix::process::ExitStatusExt; - Ok(ExitStatus::from_raw(self.0)) + impl TestProcessHandle { + pub fn immediate(exit_code: i32) -> Self { + Self { + exit_code, + forever: false, + } } - #[cfg(windows)] + pub fn run_forever() -> Self { + Self { + exit_code: 0, + forever: true, + } + } + + fn status(&self) -> ExitStatus { + #[cfg(windows)] + { + use std::os::windows::process::ExitStatusExt; + ExitStatus::from_raw(self.exit_code as u32) + } + #[cfg(unix)] + { + use std::os::unix::process::ExitStatusExt; + ExitStatus::from_raw(self.exit_code) + } + } + } + + #[async_trait::async_trait] + impl ProcessHandle for TestProcessHandle { async fn wait(&mut self) -> io::Result { - use std::os::windows::process::ExitStatusExt; - Ok(ExitStatus::from_raw(self.0 as u32)) + if self.forever { + let _: () = futures::future::pending().await; + } + Ok(self.status()) } fn kill(&mut self) {} @@ -1253,7 +1280,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn exit_successfully() { let builder = TestOpenVpnBuilder { - process_handle: Some(TestProcessHandle(0)), + process_handle: Some(TestProcessHandle::immediate(0)), ..Default::default() }; let openvpn_init_args = create_init_args(); @@ -1271,7 +1298,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn exit_error() { let builder = TestOpenVpnBuilder { - process_handle: Some(TestProcessHandle(1)), + process_handle: Some(TestProcessHandle::immediate(1)), ..Default::default() }; let openvpn_init_args = create_init_args(); @@ -1286,10 +1313,11 @@ mod tests { assert!(testee.wait().await.is_err()); } - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + /// Test that the `OpenVpnMonitor` stops when the close handle closes it. + #[tokio::test(flavor = "current_thread", start_paused = true)] async fn wait_closed() { let builder = TestOpenVpnBuilder { - process_handle: Some(TestProcessHandle(1)), + process_handle: Some(TestProcessHandle::run_forever()), ..Default::default() }; let openvpn_init_args = create_init_args(); @@ -1303,9 +1331,11 @@ mod tests { .unwrap(); testee.close_handle().close(); - let result = testee.wait().await; - println!("[testee.wait(): {:?}]", result); - assert!(result.is_ok()); + + tokio::time::timeout(std::time::Duration::from_secs(10), testee.wait()) + .await + .expect("expected close handle to stop monitor") + .expect("expected successful result"); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)]