diff --git a/talpid-wireguard/build.rs b/talpid-wireguard/build.rs index ab3500330c26..0a7360569d70 100644 --- a/talpid-wireguard/build.rs +++ b/talpid-wireguard/build.rs @@ -14,6 +14,14 @@ fn main() { // Enable DAITA by default on desktop and android println!("cargo::rustc-check-cfg=cfg(daita)"); println!("cargo::rustc-cfg=daita"); + + // Ensure that the WireGuard tunnel works before exchanging ephemeral peers. + // This is useful after updating the WireGuard config, to force a WireGuard handshake. This + // should reduce the number of PQ timeouts. + println!("cargo::rustc-check-cfg=cfg(force_wireguard_handshake)"); + if matches!(target_os.as_str(), "linux" | "macos" | "windows") { + println!("cargo::rustc-cfg=force_wireguard_handshake"); + } } fn declare_libs_dir(base: &str) { diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs index 512d8715f17d..25190a702039 100644 --- a/talpid-wireguard/src/connectivity/mod.rs +++ b/talpid-wireguard/src/connectivity/mod.rs @@ -6,7 +6,6 @@ mod mock; mod monitor; mod pinger; -#[cfg(target_os = "android")] pub use check::Cancellable; pub use check::Check; pub use error::Error; diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index 31f3957253e9..8ca235c78aee 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -1,6 +1,8 @@ //! This module takes care of obtaining ephemeral peers, updating the WireGuard configuration and //! restarting obfuscation and WG tunnels when necessary. +#[cfg(force_wireguard_handshake)] +use super::connectivity; #[cfg(target_os = "android")] // On Android, the Tunnel trait is not imported by default. use super::Tunnel; use super::{config::Config, obfuscation::ObfuscatorHandle, CloseMsg, Error, TunnelType}; @@ -31,6 +33,9 @@ pub async fn config_ephemeral_peers( retry_attempt: u32, obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, + #[cfg(force_wireguard_handshake)] connectivity: &mut connectivity::Check< + connectivity::Cancellable, + >, ) -> std::result::Result<(), CloseMsg> { let iface_name = { let tunnel = tunnel.lock().await; @@ -44,8 +49,16 @@ pub async fn config_ephemeral_peers( log::trace!("Temporarily lowering tunnel MTU before ephemeral peer config"); try_set_ipv4_mtu(&iface_name, talpid_tunnel::MIN_IPV4_MTU); - config_ephemeral_peers_inner(tunnel, config, retry_attempt, obfuscator, close_obfs_sender) - .await?; + config_ephemeral_peers_inner( + tunnel, + config, + retry_attempt, + obfuscator, + close_obfs_sender, + #[cfg(force_wireguard_handshake)] + connectivity, + ) + .await?; log::trace!("Resetting tunnel MTU"); try_set_ipv4_mtu(&iface_name, config.mtu); @@ -75,6 +88,9 @@ pub async fn config_ephemeral_peers( retry_attempt: u32, obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, + #[cfg(force_wireguard_handshake)] connectivity: &mut connectivity::Check< + connectivity::Cancellable, + >, #[cfg(target_os = "android")] tun_provider: Arc>, ) -> Result<(), CloseMsg> { config_ephemeral_peers_inner( @@ -83,6 +99,8 @@ pub async fn config_ephemeral_peers( retry_attempt, obfuscator, close_obfs_sender, + #[cfg(force_wireguard_handshake)] + connectivity, #[cfg(target_os = "android")] tun_provider, ) @@ -95,8 +113,14 @@ async fn config_ephemeral_peers_inner( retry_attempt: u32, obfuscator: Arc>>, close_obfs_sender: sync_mpsc::Sender, + #[cfg(force_wireguard_handshake)] connectivity: &mut connectivity::Check< + connectivity::Cancellable, + >, #[cfg(target_os = "android")] tun_provider: Arc>, ) -> Result<(), CloseMsg> { + #[cfg(force_wireguard_handshake)] + establish_tunnel_connection(tunnel, connectivity).await?; + let ephemeral_private_key = PrivateKey::new_from_random(); let close_obfs_sender = close_obfs_sender.clone(); @@ -134,6 +158,10 @@ async fn config_ephemeral_peers_inner( &tun_provider, ) .await?; + + #[cfg(force_wireguard_handshake)] + establish_tunnel_connection(tunnel, connectivity).await?; + let entry_ephemeral_peer = request_ephemeral_peer( retry_attempt, &entry_config, @@ -214,7 +242,6 @@ async fn reconfigure_tunnel( *obfs_guard = super::obfuscation::apply_obfuscation_config( &mut config, close_obfs_sender, - #[cfg(target_os = "android")] tun_provider.clone(), ) .await @@ -268,6 +295,37 @@ async fn reconfigure_tunnel( Ok(config) } +/// Ensure that the WireGuard tunnel works. This is useful after updating the WireGuard config, to +/// force a WireGuard handshake. This should reduce the number of PQ timeouts. +#[cfg(force_wireguard_handshake)] +async fn establish_tunnel_connection( + tunnel: &Arc>>, + connectivity: &mut connectivity::Check, +) -> Result<(), CloseMsg> { + use talpid_types::ErrorExt; + + let mut shared_tunnel = tunnel.lock().await; + let tunnel = shared_tunnel.take().expect("tunnel was None"); + let ping_result = connectivity.establish_connectivity(&tunnel); + *shared_tunnel = Some(tunnel); + drop(shared_tunnel); + + match ping_result { + Ok(true) => Ok(()), + Ok(false) => { + log::warn!("Timeout while checking tunnel connection"); + Err(CloseMsg::PingErr) + } + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check tunnel connection") + ); + Err(CloseMsg::PingErr) + } + } +} + async fn request_ephemeral_peer( retry_attempt: u32, config: &Config, diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 2d282c6315c6..cf0a57e41084 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -274,6 +274,8 @@ impl WireguardMonitor { args.retry_attempt, obfuscator.clone(), ephemeral_obfs_sender, + #[cfg(force_wireguard_handshake)] + &mut connectivity_monitor, ) .await?;