From 2fbbed647cfb130dd1734832e249bff164be9f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kalle=20Lindstr=C3=B6m?= Date: Mon, 18 Nov 2024 14:18:19 +0100 Subject: [PATCH] Hide logic for WgGoTunnel connectivty check --- talpid-wireguard/src/connectivity_check.rs | 243 +++++++++++++-------- talpid-wireguard/src/ephemeral.rs | 31 +-- talpid-wireguard/src/lib.rs | 36 +-- talpid-wireguard/src/wireguard_go/mod.rs | 36 +-- 4 files changed, 194 insertions(+), 152 deletions(-) diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs index 6cfabbcd13e8..438b599541cf 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity_check.rs @@ -1,8 +1,6 @@ #[cfg(target_os = "android")] use super::Tunnel; use super::{TunnelError, TunnelType}; -#[cfg(target_os = "android")] -use crate::wireguard_go::WgGoTunnel; use crate::{ ping_monitor::{new_pinger, Pinger}, stats::StatsMap, @@ -75,36 +73,50 @@ pub enum Error { /// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`. pub struct ConnectivityMonitor { conn_state: ConnState, - initial_ping_timestamp: Option, - num_pings_sent: u32, - pinger: Box, - close_receiver: mpsc::Receiver<()>, + ping_state: PingState, + close_receiver: Option>, } impl ConnectivityMonitor { pub(super) fn new( addr: Ipv4Addr, #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, - close_receiver: mpsc::Receiver<()>, ) -> Result { - let pinger = new_pinger( - addr, - #[cfg(any(target_os = "macos", target_os = "linux"))] - interface, - ) - .map_err(Error::PingError)?; - - let now = Instant::now(); - Ok(Self { - conn_state: ConnState::new(now, Default::default()), - initial_ping_timestamp: None, - num_pings_sent: 0, - pinger, - close_receiver, + conn_state: ConnState::new(Instant::now(), Default::default()), + ping_state: PingState::new( + addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] + interface, + )?, + close_receiver: None, }) } + pub(super) fn with_close_receiver(self, close_receiver: mpsc::Receiver<()>) -> Self { + Self { + close_receiver: Some(close_receiver), + ..self + } + } + + /// Returns true if monitor should be shut down + fn should_shut_down(&mut self, timeout: Duration) -> bool { + let Some(close_receiver) = self.close_receiver.as_ref() else { + return false; + }; + + match close_receiver.recv_timeout(timeout) { + Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, + Err(mpsc::RecvTimeoutError::Timeout) => false, + } + } + + fn reset(&mut self, current_iteration: Instant) { + self.ping_state.reset(); + self.conn_state.reset_after_suspension(current_iteration); + } + // checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is // successful at the start of a connection. pub(super) fn establish_connectivity( @@ -113,7 +125,10 @@ impl ConnectivityMonitor { tunnel_handle: &TunnelType, ) -> Result { // Send initial ping to prod WireGuard into connecting. - self.pinger.send_icmp().map_err(Error::PingError)?; + self.ping_state + .pinger + .send_icmp() + .map_err(Error::PingError)?; self.establish_connectivity_inner( retry_attempt, ESTABLISH_TIMEOUT, @@ -152,59 +167,6 @@ impl ConnectivityMonitor { Ok(false) } - pub(super) fn run( - &mut self, - tunnel_handle: Weak>>, - ) -> Result<(), Error> { - self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle) - } - - /// Returns true if monitor should be shut down - fn should_shut_down(&mut self, timeout: Duration) -> bool { - match self.close_receiver.recv_timeout(timeout) { - Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, - Err(mpsc::RecvTimeoutError::Timeout) => false, - } - } - - fn wait_loop( - &mut self, - iter_delay: Duration, - tunnel_handle: Weak>>, - ) -> Result<(), Error> { - let mut last_iteration = Instant::now(); - while !self.should_shut_down(iter_delay) { - let mut current_iteration = Instant::now(); - let time_slept = current_iteration - last_iteration; - if time_slept < (iter_delay * 2) { - let Some(tunnel) = tunnel_handle.upgrade() else { - return Ok(()); - }; - let lock = tunnel.blocking_lock(); - let Some(tunnel) = lock.as_ref() else { - return Ok(()); - }; - - if !self.check_connectivity(Instant::now(), tunnel)? { - return Ok(()); - } - drop(lock); - - let end = Instant::now(); - if end - current_iteration > Duration::from_secs(1) { - current_iteration = end; - } - } else { - // Loop was suspended for too long, so it's safer to assume that the host still has - // connectivity. - self.reset_pinger(); - self.conn_state.reset_after_suspension(current_iteration); - } - last_iteration = current_iteration; - } - Ok(()) - } - /// Returns true if connection is established fn check_connectivity( &mut self, @@ -227,12 +189,12 @@ impl ConnectivityMonitor { let new_stats = new_stats?; if self.conn_state.update(now, new_stats) { - self.reset_pinger(); + self.ping_state.reset(); return Ok(true); } self.maybe_send_ping(now)?; - Ok(!self.ping_timed_out(timeout) && self.conn_state.connected()) + Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected()) } } } @@ -258,33 +220,25 @@ impl ConnectivityMonitor { // 3 seconds. if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out()) && self + .ping_state .initial_ping_timestamp .map(|initial_ping_timestamp| { - initial_ping_timestamp.elapsed() / self.num_pings_sent < SECONDS_PER_PING + initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent + < SECONDS_PER_PING }) .unwrap_or(true) { - self.pinger.send_icmp().map_err(Error::PingError)?; - if self.initial_ping_timestamp.is_none() { - self.initial_ping_timestamp = Some(now); + self.ping_state + .pinger + .send_icmp() + .map_err(Error::PingError)?; + if self.ping_state.initial_ping_timestamp.is_none() { + self.ping_state.initial_ping_timestamp = Some(now); } - self.num_pings_sent += 1; + self.ping_state.num_pings_sent += 1; } Ok(()) } - - fn ping_timed_out(&self, timeout: Duration) -> bool { - self.initial_ping_timestamp - .map(|initial_ping_timestamp| initial_ping_timestamp.elapsed() > timeout) - .unwrap_or(false) - } - - /// Reset timeouts - assume that the last time bytes were received is now. - fn reset_pinger(&mut self) { - self.initial_ping_timestamp = None; - self.num_pings_sent = 0; - self.pinger.reset(); - } } enum ConnState { @@ -300,6 +254,45 @@ enum ConnState { }, } +struct PingState { + initial_ping_timestamp: Option, + num_pings_sent: u32, + pinger: Box, +} + +impl PingState { + pub(super) fn new( + addr: Ipv4Addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, + ) -> Result { + let pinger = new_pinger( + addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] + interface, + ) + .map_err(Error::PingError)?; + + Ok(Self { + initial_ping_timestamp: None, + num_pings_sent: 0, + pinger, + }) + } + + fn ping_timed_out(&self, timeout: Duration) -> bool { + self.initial_ping_timestamp + .map(|initial_ping_timestamp| initial_ping_timestamp.elapsed() > timeout) + .unwrap_or(false) + } + + /// Reset timeouts - assume that the last time bytes were received is now. + fn reset(&mut self) { + self.initial_ping_timestamp = None; + self.num_pings_sent = 0; + self.pinger.reset(); + } +} + impl ConnState { pub fn new(start: Instant, stats: StatsMap) -> Self { ConnState::Connecting { @@ -418,6 +411,66 @@ impl ConnState { } } +pub struct ConnectivityMonitorLoop { + connectivity_monitor: ConnectivityMonitor, +} + +impl ConnectivityMonitorLoop { + pub(super) fn new(connectivity_monitor: ConnectivityMonitor) -> Self { + debug_assert!( + connectivity_monitor.close_receiver.is_some(), + "Close receiver must be set" + ); + Self { + connectivity_monitor, + } + } + + pub(super) fn run(self, tunnel_handle: Weak>>) -> Result<(), Error> { + self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle) + } + + fn wait_loop( + mut self, + iter_delay: Duration, + tunnel_handle: Weak>>, + ) -> Result<(), Error> { + let mut last_iteration = Instant::now(); + while !self.connectivity_monitor.should_shut_down(iter_delay) { + let mut current_iteration = Instant::now(); + let time_slept = current_iteration - last_iteration; + if time_slept < (iter_delay * 2) { + let Some(tunnel) = tunnel_handle.upgrade() else { + return Ok(()); + }; + let lock = tunnel.blocking_lock(); + let Some(tunnel) = lock.as_ref() else { + return Ok(()); + }; + + if !self + .connectivity_monitor + .check_connectivity(Instant::now(), tunnel)? + { + return Ok(()); + } + drop(lock); + + let end = Instant::now(); + if end - current_iteration > Duration::from_secs(1) { + current_iteration = end; + } + } else { + // Loop was suspended for too long, so it's safer to assume that the host still has + // connectivity. + self.connectivity_monitor.reset(current_iteration); + } + last_iteration = current_iteration; + } + Ok(()) + } +} + #[cfg(test)] mod test { use futures::Future; diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index f7f74173ea42..a9283fcb2e2b 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -15,8 +15,6 @@ use std::{ #[cfg(target_os = "android")] use talpid_tunnel::tun_provider::TunProvider; -#[cfg(target_os = "android")] -use crate::connectivity_check::ConnectivityMonitor; use ipnetwork::IpNetwork; use talpid_types::net::wireguard::{PresharedKey, PrivateKey, PublicKey}; use tokio::sync::Mutex as AsyncMutex; @@ -77,7 +75,6 @@ pub async fn config_ephemeral_peers( obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, #[cfg(target_os = "android")] tun_provider: Arc>, - #[cfg(target_os = "android")] connectivity_monitor: Arc>, ) -> Result<(), CloseMsg> { config_ephemeral_peers_inner( tunnel, @@ -87,8 +84,6 @@ pub async fn config_ephemeral_peers( close_obfs_sender, #[cfg(target_os = "android")] tun_provider, - #[cfg(target_os = "android")] - connectivity_monitor, ) .await } @@ -100,7 +95,6 @@ async fn config_ephemeral_peers_inner( obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, #[cfg(target_os = "android")] tun_provider: Arc>, - #[cfg(target_os = "android")] connectivity_monitor: Arc>, ) -> Result<(), CloseMsg> { let ephemeral_private_key = PrivateKey::new_from_random(); let close_obfs_sender = close_obfs_sender.clone(); @@ -134,8 +128,6 @@ async fn config_ephemeral_peers_inner( close_obfs_sender, #[cfg(target_os = "android")] &tun_provider, - #[cfg(target_os = "android")] - &connectivity_monitor, ) .await?; @@ -168,8 +160,6 @@ async fn config_ephemeral_peers_inner( close_obfs_sender, #[cfg(target_os = "android")] &tun_provider, - #[cfg(target_os = "android")] - &connectivity_monitor, ) .await?; @@ -197,7 +187,6 @@ async fn reconfigure_tunnel( obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, tun_provider: &Arc>, - connectivity_monitor: &Arc>, ) -> Result { let mut obfs_guard = obfuscator.lock().await; if let Some(obfuscator_handle) = obfs_guard.take() { @@ -211,19 +200,17 @@ async fn reconfigure_tunnel( .await .map_err(CloseMsg::ObfuscatorFailed)?; } + { + let mut shared_tunnel = tunnel.lock().await; + let tunnel = shared_tunnel.take().expect("tunnel was None"); - let mut lock = tunnel.lock().await; - - let tunnel = lock.take().expect("tunnel was None"); - - let mut connectivity_monitor = connectivity_monitor.lock().unwrap(); - - let new_tunnel = tunnel - .set_config(&config, &mut connectivity_monitor) - .map_err(Error::TunnelError) - .map_err(CloseMsg::SetupError)?; + let updated_tunnel = tunnel + .set_config(&config) + .map_err(Error::TunnelError) + .map_err(CloseMsg::SetupError)?; - *lock = Some(new_tunnel); + *shared_tunnel = Some(updated_tunnel); + } Ok(config) } diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index ff88e5e829e6..61f54ef1c24c 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -26,8 +26,7 @@ use talpid_routing::{self, RequiredRoute}; use talpid_tunnel::tun_provider; use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; -#[cfg(target_os = "android")] -use crate::connectivity_check::ConnectivityMonitor; +use crate::connectivity_check::{ConnectivityMonitor, ConnectivityMonitorLoop}; use talpid_types::net::wireguard::TunnelParameters; use talpid_types::{ net::{AllowedTunnelTraffic, Endpoint, TransportProtocol}, @@ -230,13 +229,13 @@ impl WireguardMonitor { }; let gateway = config.ipv4_gateway; - let mut connectivity_monitor = connectivity_check::ConnectivityMonitor::new( + let mut connectivity_monitor = ConnectivityMonitor::new( gateway, #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name.clone(), - pinger_rx, ) - .map_err(Error::ConnectivityMonitorError)?; + .map_err(Error::ConnectivityMonitorError)? + .with_close_receiver(pinger_rx); let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); @@ -324,7 +323,7 @@ impl WireguardMonitor { let cloned_tunnel = Arc::clone(&tunnel); - let mut connectivity_monitor = tokio::task::spawn_blocking(move || { + let connectivity_monitor = tokio::task::spawn_blocking(move || { let lock = cloned_tunnel.blocking_lock(); let Some(tunnel) = lock.as_ref() else { @@ -360,7 +359,9 @@ impl WireguardMonitor { let weak_tunnel = Arc::downgrade(&tunnel); tokio::task::spawn_blocking(move || { - if let Err(error) = connectivity_monitor.run(weak_tunnel) { + if let Err(error) = + ConnectivityMonitorLoop::new(connectivity_monitor).run(weak_tunnel) + { log::error!( "{}", error.display_chain_with_msg("Connectivity monitor failed") @@ -433,10 +434,6 @@ impl WireguardMonitor { let (pinger_tx, pinger_rx) = sync_mpsc::channel(); - let mut connectivity_monitor = - connectivity_check::ConnectivityMonitor::new(config.ipv4_gateway, pinger_rx) - .map_err(Error::ConnectivityMonitorError)?; - let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; let tunnel = Self::open_wireguard_go_tunnel( &config, @@ -447,11 +444,14 @@ impl WireguardMonitor { // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. should_negotiate_ephemeral_peer, - &mut connectivity_monitor, )?; let iface_name = tunnel.get_interface_name(); let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); + let connectivity_monitor = ConnectivityMonitor::new(config.ipv4_gateway) + .map_err(Error::ConnectivityMonitorError)? + .with_close_receiver(pinger_rx); + let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::clone(&tunnel), @@ -466,7 +466,6 @@ impl WireguardMonitor { let tunnel_fut = async move { let close_obfs_sender: sync_mpsc::Sender = moved_close_obfs_sender; let obfuscator = moved_obfuscator; - let connectivity_monitor = Arc::new(Mutex::new(connectivity_monitor)); let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); @@ -483,7 +482,6 @@ impl WireguardMonitor { obfuscator.clone(), ephemeral_obfs_sender, args.tun_provider, - connectivity_monitor.clone(), ) .await?; @@ -500,7 +498,7 @@ impl WireguardMonitor { tokio::task::spawn_blocking(move || { let tunnel = Arc::downgrade(&tunnel); - if let Err(error) = connectivity_monitor.lock().unwrap().run(tunnel) { + if let Err(error) = ConnectivityMonitorLoop::new(connectivity_monitor).run(tunnel) { log::error!( "{}", error.display_chain_with_msg("Connectivity monitor failed") @@ -721,7 +719,6 @@ impl WireguardMonitor { #[cfg(daita)] resource_dir: &Path, tun_provider: Arc>, #[cfg(target_os = "android")] gateway_only: bool, - #[cfg(target_os = "android")] connectivity_monitor: &mut ConnectivityMonitor, ) -> Result { let routes = config .get_tunnel_destinations() @@ -760,7 +757,6 @@ impl WireguardMonitor { routes, #[cfg(daita)] resource_dir, - connectivity_monitor, ) .map_err(Error::TunnelError)? } else { @@ -772,7 +768,6 @@ impl WireguardMonitor { routes, #[cfg(daita)] resource_dir, - connectivity_monitor, ) .map_err(Error::TunnelError)? }; @@ -1057,6 +1052,11 @@ pub enum TunnelError { #[error("Failed to configure Wireguard sockets to bypass the tunnel")] BypassError(#[source] tun_provider::Error), + /// TODO + #[cfg(target_os = "android")] + #[error("Failed to set up a working tunnel")] + TunnelUp, + /// Invalid tunnel interface name. #[error("Invalid tunnel interface name")] InterfaceNameError(#[source] std::ffi::NulError), diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index dd8462469c1e..a27d49c2fd8b 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -12,6 +12,8 @@ use crate::logging::{clean_up_logging, initialize_logging}; use ipnetwork::IpNetwork; #[cfg(daita)] use once_cell::sync::OnceCell; +#[cfg(target_os = "android")] +use std::net::Ipv4Addr; #[cfg(daita)] use std::{ffi::CString, fs, path::PathBuf}; use std::{ @@ -105,11 +107,7 @@ impl WgGoTunnel { } } - pub fn set_config( - self, - config: &Config, - connectivity_monitor: &mut ConnectivityMonitor, - ) -> Result { + pub fn set_config(self, config: &Config) -> Result { let state = self.as_state(); let log_path = state._logging_context.path.clone(); let tun_provider = Arc::clone(&state.tun_provider); @@ -126,7 +124,6 @@ impl WgGoTunnel { tun_provider, routes, &resource_dir, - connectivity_monitor, ) } WgGoTunnel::Singlehop(state) if config.is_multihop() => { @@ -138,7 +135,6 @@ impl WgGoTunnel { tun_provider, routes, &resource_dir, - connectivity_monitor, ) } WgGoTunnel::Singlehop(mut state) => { @@ -296,7 +292,6 @@ impl WgGoTunnel { tun_provider: Arc>, routes: impl Iterator, #[cfg(daita)] resource_dir: &Path, - connectivity_monitor: &mut ConnectivityMonitor, ) -> Result { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -331,10 +326,7 @@ impl WgGoTunnel { config: config.clone(), }); - // TODO: explain - connectivity_monitor - .establish_connectivity(0, &tunnel) - .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?; + tunnel.ensure_tunnel_is_running(config.ipv4_gateway)?; Ok(tunnel) } @@ -346,7 +338,6 @@ impl WgGoTunnel { tun_provider: Arc>, routes: impl Iterator, #[cfg(daita)] resource_dir: &Path, - connectivity_monitor: &mut ConnectivityMonitor, ) -> Result { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -397,10 +388,7 @@ impl WgGoTunnel { config: config.clone(), }); - // TODO: explain - connectivity_monitor - .establish_connectivity(0, &tunnel) - .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?; + tunnel.ensure_tunnel_is_running(config.ipv4_gateway)?; Ok(tunnel) } @@ -417,6 +405,20 @@ impl WgGoTunnel { Ok(()) } + + fn ensure_tunnel_is_running(&self, addr: Ipv4Addr) -> Result<()> { + let connection_established = ConnectivityMonitor::new(addr) + .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))? + .establish_connectivity(0, self) + .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?; + + // Timed out + if !connection_established { + Err(TunnelError::TunnelUp) + } else { + Ok(()) + } + } } impl Tunnel for WgGoTunnel {