From 608d47ee32d815bc8c2a428d7dde3d1955496114 Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Mon, 5 Feb 2024 17:31:43 +0100 Subject: [PATCH 1/8] Fix typos --- talpid-tunnel/src/lib.rs | 2 +- talpid-wireguard/src/connectivity_check.rs | 2 +- talpid-wireguard/src/ping_monitor/icmp.rs | 2 ++ talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs index 8a916c668d4b..81cb85b6da67 100644 --- a/talpid-tunnel/src/lib.rs +++ b/talpid-tunnel/src/lib.rs @@ -19,7 +19,7 @@ pub struct TunnelArgs<'a, L> where L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static, { - /// Toktio runtime handle. + /// Tokio runtime handle. pub runtime: tokio::runtime::Handle, /// Resource directory path. pub resource_dir: &'a Path, diff --git a/talpid-wireguard/src/connectivity_check.rs b/talpid-wireguard/src/connectivity_check.rs index 8f44b9b1c6e6..a7b09778d934 100644 --- a/talpid-wireguard/src/connectivity_check.rs +++ b/talpid-wireguard/src/connectivity_check.rs @@ -62,7 +62,7 @@ pub enum Error { /// /// The connectivity monitor will start sending pings and start the countdown to `PING_TIMEOUT` in /// the following cases: -/// - In case that we have observed a bump in the outgoing traffic but no coressponding incoming +/// - In case that we have observed a bump in the outgoing traffic but no corresponding incoming /// traffic for longer than `BYTES_RX_TIMEOUT`, then the monitor will start pinging. /// - In case that no increase in outgoing or incoming traffic has been observed for longer than /// `TRAFFIC_TIMEOUT`, then the monitor will start pinging as well. diff --git a/talpid-wireguard/src/ping_monitor/icmp.rs b/talpid-wireguard/src/ping_monitor/icmp.rs index a0afc8a98c50..ad31349799ad 100644 --- a/talpid-wireguard/src/ping_monitor/icmp.rs +++ b/talpid-wireguard/src/ping_monitor/icmp.rs @@ -1,6 +1,7 @@ use byteorder::{NetworkEndian, WriteBytesExt}; use rand::Rng; use socket2::{Domain, Protocol, Socket, Type}; + use std::{ io::{self, Write}, net::{Ipv4Addr, SocketAddr}, @@ -59,6 +60,7 @@ pub struct Pinger { } impl Pinger { + /// Creates a new `Pinger`. pub fn new( addr: Ipv4Addr, #[cfg(not(target_os = "windows"))] interface_name: String, diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs index b1159bb6de97..579bcde65aa0 100644 --- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs @@ -122,7 +122,7 @@ impl Tunnel for NetlinkTunnel { wg.set_config(interface_index, &config) .await .map_err(|err| { - log::error!("Failed to fetch WireGuard device config: {}", err); + log::error!("Failed to set WireGuard device config: {}", err); TunnelError::SetConfigError }) }) From 948aa31e7c47a692999ebb1895c32bce18ab642b Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Mon, 5 Feb 2024 17:37:07 +0100 Subject: [PATCH 2/8] Move constants --- talpid-tunnel/src/lib.rs | 13 +++++++++++++ talpid-wireguard/src/lib.rs | 10 ++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs index 81cb85b6da67..8ce3dd2d0d50 100644 --- a/talpid-tunnel/src/lib.rs +++ b/talpid-tunnel/src/lib.rs @@ -14,6 +14,19 @@ use talpid_routing::RouteManagerHandle; use talpid_types::net::AllowedTunnelTraffic; use tun_provider::TunProvider; +/// Size of IPv4 header in bytes +pub const IPV4_HEADER_SIZE: u16 = 20; +/// Size of IPv6 header in bytes +pub const IPV6_HEADER_SIZE: u16 = 40; +/// Size of wireguard header in bytes +pub const WIREGUARD_HEADER_SIZE: u16 = 40; +/// Size of ICMP header in bytes +pub const ICMP_HEADER_SIZE: u16 = 8; +/// Smallest allowed MTU for IPv4 in bytes +pub const MIN_IPV4_MTU: u16 = 576; +/// Smallest allowed MTU for IPv6 in bytes +pub const MIN_IPV6_MTU: u16 = 1280; + /// Arguments for creating a tunnel. pub struct TunnelArgs<'a, L> where diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index e8a63d0b1b17..35cfb58dce17 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -42,6 +42,9 @@ use tunnel_obfuscation::{ create_obfuscator, Error as ObfuscationError, Settings as ObfuscationSettings, Udp2TcpSettings, }; +#[cfg(any(target_os = "linux", target_os = "macos"))] +use talpid_tunnel::{IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, WIREGUARD_HEADER_SIZE}; + /// WireGuard config data-types pub mod config; mod connectivity_check; @@ -898,16 +901,11 @@ impl WireguardMonitor { } else { // Set route MTU by subtracting the WireGuard overhead from the tunnel MTU. Plus // some margin to make room for padding bytes. - // TODO: Move consts to shared location - const IPV4_HEADER_SIZE: u16 = 20; - const IPV6_HEADER_SIZE: u16 = 40; - const WIREGUARD_HEADER_SIZE: u16 = 40; - const PADDING_BYTES_MARGIN: u16 = 15; - let ip_overhead = match route.prefix.is_ipv4() { true => IPV4_HEADER_SIZE, false => IPV6_HEADER_SIZE, }; + const PADDING_BYTES_MARGIN: u16 = 15; let mtu = config.mtu - ip_overhead - WIREGUARD_HEADER_SIZE - PADDING_BYTES_MARGIN; route.mtu(mtu) From 3c57574b6613073fdbb51747995905d0cb270435 Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Mon, 5 Feb 2024 17:39:20 +0100 Subject: [PATCH 3/8] Add `set_mtu` for linux --- talpid-wireguard/src/lib.rs | 2 ++ talpid-wireguard/src/unix.rs | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 talpid-wireguard/src/unix.rs diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 35cfb58dce17..662eed8be453 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -51,6 +51,8 @@ mod connectivity_check; mod logging; mod ping_monitor; mod stats; +#[cfg(target_os = "linux")] +mod unix; #[cfg(wireguard_go)] mod wireguard_go; #[cfg(target_os = "linux")] diff --git a/talpid-wireguard/src/unix.rs b/talpid-wireguard/src/unix.rs new file mode 100644 index 000000000000..bef057a04270 --- /dev/null +++ b/talpid-wireguard/src/unix.rs @@ -0,0 +1,41 @@ +use std::{io, os::fd::AsRawFd}; + +use socket2::Domain; +use talpid_types::ErrorExt; + +pub fn set_mtu(interface_name: &str, mtu: u16) -> Result<(), io::Error> { + debug_assert_ne!( + interface_name, "eth0", + "Should be name of mullvad tunnel interface, e.g. 'wg0-mullvad'" + ); + + let sock = socket2::Socket::new( + Domain::IPV4, + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + + let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() }; + if interface_name.len() >= ifr.ifr_name.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Interface name too long", + )); + } + + unsafe { + std::ptr::copy_nonoverlapping( + interface_name.as_ptr() as *const i8, + &mut ifr.ifr_name as *mut _, + interface_name.len(), + ) + }; + ifr.ifr_ifru.ifru_mtu = mtu as i32; + + if unsafe { libc::ioctl(sock.as_raw_fd(), libc::SIOCSIFMTU, &ifr) } < 0 { + let e = std::io::Error::last_os_error(); + log::error!("{}", e.display_chain_with_msg("SIOCSIFMTU failed")); + return Err(e); + } + Ok(()) +} From 62b803248405eeb17068e41ced9b8998c5d20095 Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Mon, 5 Feb 2024 17:45:45 +0100 Subject: [PATCH 4/8] Add automatic MTU detection --- Cargo.lock | 71 +++++++++++++++++++++++++ talpid-wireguard/Cargo.toml | 1 + talpid-wireguard/src/lib.rs | 103 ++++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 1481155dc6cd..070d7c9c1a2b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1164,6 +1164,12 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.3.24" @@ -2188,6 +2194,12 @@ dependencies = [ "libc", ] +[[package]] +name = "no-std-net" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43794a0ace135be66a25d3ae77d41b91615fb68ae937f904090203e81f755b65" + [[package]] name = "notify" version = "6.1.1" @@ -2534,6 +2546,48 @@ version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4503fa043bf02cee09a9582e9554b4c6403b2ef55e4612e96561d294419429f8" +[[package]] +name = "pnet_base" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "872e46346144ebf35219ccaa64b1dffacd9c6f188cd7d012bd6977a2a838f42e" +dependencies = [ + "no-std-net", +] + +[[package]] +name = "pnet_macros" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a780e80005c2e463ec25a6e9f928630049a10b43945fea83207207d4a7606f4" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 1.0.109", +] + +[[package]] +name = "pnet_macros_support" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d932134f32efd7834eb8b16d42418dac87086347d1bc7d142370ef078582bc" +dependencies = [ + "pnet_base", +] + +[[package]] +name = "pnet_packet" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bde678bbd85cb1c2d99dc9fc596e57f03aa725f84f3168b0eaf33eeccb41706" +dependencies = [ + "glob", + "pnet_base", + "pnet_macros", + "pnet_macros_support", +] + [[package]] name = "poly1305" version = "0.8.0" @@ -3362,6 +3416,22 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "surge-ping" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af341b2be485d647b5dc4cfb2da99efac35b5c95748a08fb7233480fedc5ead3" +dependencies = [ + "hex", + "parking_lot", + "pnet_packet", + "rand 0.8.5", + "socket2 0.5.3", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "syn" version = "1.0.109" @@ -3659,6 +3729,7 @@ dependencies = [ "rand 0.8.5", "rtnetlink", "socket2 0.5.3", + "surge-ping", "talpid-dbus", "talpid-routing", "talpid-tunnel", diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index 5c5a14a3fad2..c6f3669f8b27 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -28,6 +28,7 @@ chrono = { workspace = true, features = ["clock"] } tokio = { workspace = true, features = ["process", "rt-multi-thread", "fs"] } tunnel-obfuscation = { path = "../tunnel-obfuscation" } rand = "0.8.5" +surge-ping = "0.8.0" [target.'cfg(target_os="android")'.dependencies] duct = "0.13" diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 662eed8be453..41033173e942 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -74,6 +74,14 @@ pub enum Error { #[error(display = "Failed to setup routing")] SetupRoutingError(#[error(source)] talpid_routing::Error), + /// Failed to set MTU + #[error(display = "Failed to detect MTU because every ping was dropped.")] + MtuDetectionAllDropped, + + /// Failed to set MTU + #[error(display = "Failed to detect MTU because of unexpected ping error.")] + MtuDetectionPingError(#[error(source)] surge_ping::SurgeError), + /// Tunnel timed out #[error(display = "Tunnel timed out")] TimeoutError, @@ -949,6 +957,101 @@ impl WireguardMonitor { } } +/// Detects the maximum MTU that does not cause dropped packets. +/// +/// The detection works by sending evenly spread out range of pings between 576 and the given +/// current tunnel MTU, and returning the maximum packet size that was returned within a timeout. +#[cfg(target_os = "linux")] +async fn auto_mtu_detection( + gateway: std::net::Ipv4Addr, + #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name: String, + current_mtu: u16, +) -> Result { + use futures::{future, stream::FuturesUnordered, TryStreamExt}; + use surge_ping::{Client, Config, PingIdentifier, PingSequence, SurgeError}; + use talpid_tunnel::{ICMP_HEADER_SIZE, MIN_IPV4_MTU}; + use tokio_stream::StreamExt; + + /// Max time to wait for any ping, when this expires, we give up and throw an error. + const PING_TIMEOUT: Duration = Duration::from_secs(10); + /// Max time to wait after the first ping arrives. Every ping after this timeout is considered + /// dropped, so we return the largest collected packet size. + const PING_OFFSET_TIMEOUT: Duration = Duration::from_secs(2); + + let config_builder = Config::builder().kind(surge_ping::ICMP::V4); + #[cfg(any(target_os = "macos", target_os = "linux"))] + let config_builder = config_builder.interface(&iface_name); + let client = Client::new(&config_builder.build()).unwrap(); + + let step_size = 20; + let linspace = mtu_spacing(MIN_IPV4_MTU, current_mtu, step_size); + + let payload_buf = vec![0; current_mtu as usize]; + + let mut ping_stream = linspace + .iter() + .enumerate() + .map(|(i, &mtu)| { + let client = client.clone(); + let payload_size = (mtu - IPV4_HEADER_SIZE - ICMP_HEADER_SIZE) as usize; + let payload = &payload_buf[0..payload_size]; + async move { + log::trace!("Sending ICMP ping of total size {mtu}"); + client + .pinger(IpAddr::V4(gateway), PingIdentifier(0)) + .await + .timeout(PING_TIMEOUT) + .ping(PingSequence(i as u16), payload) + .await + } + }) + .collect::>() + .map_ok(|(packet, _rtt)| { + let surge_ping::IcmpPacket::V4(packet) = packet else { + unreachable!("ICMP ping response was not of IPv4 type"); + }; + let size = packet.get_size() as u16 + IPV4_HEADER_SIZE; + log::trace!("Got ICMP ping response of total size {size}"); + debug_assert_eq!(size, linspace[packet.get_sequence().0 as usize]); + size + }); + + let first_ping_size = ping_stream + .next() + .await + .expect("At least one pings should be sent") + // Short-circuit and return on error + .map_err(|e| match e { + // If the first ping we get back timed out, then all of them did + SurgeError::Timeout { .. } => Error::MtuDetectionAllDropped, + // Unexpected error type + e => Error::MtuDetectionPingError(e), + })?; + + ping_stream + .timeout(PING_OFFSET_TIMEOUT) // Start a new, shorter, timeout + .map_while(|res| res.ok()) // Stop waiting for pings after this timeout + .try_fold(first_ping_size, |acc, mtu| future::ready(Ok(acc.max(mtu)))) // Get largest ping + .await + .map_err(Error::MtuDetectionPingError) +} + +/// Creates a linear spacing of MTU values with the given step size. Always includes the given end +/// points. +#[cfg(target_os = "linux")] +fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec { + if mtu_min > mtu_max { + panic!("Invalid MTU detection range: `mtu_min`={mtu_min}, `mtu_max`={mtu_max}."); + } + let second_mtu = mtu_min.next_multiple_of(step_size); + let in_between = (second_mtu..mtu_max).step_by(step_size as usize); + let mut ret = Vec::with_capacity(((mtu_max - second_mtu).div_ceil(step_size) + 2) as usize); + ret.push(mtu_min); + ret.extend(in_between); + ret.push(mtu_max); + ret +} + #[derive(Debug)] enum CloseMsg { Stop, From 58dc910acf6c7459b0fafed7e9b611056fc927cd Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Mon, 5 Feb 2024 17:48:02 +0100 Subject: [PATCH 5/8] Refactor default MTU calculation --- talpid-core/src/tunnel/mod.rs | 72 +++++++++++++++++----------------- talpid-wireguard/src/config.rs | 18 ++++----- 2 files changed, 43 insertions(+), 47 deletions(-) diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 3e6df61f76a1..8e9a51986b61 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -17,6 +17,14 @@ use talpid_wireguard; const OPENVPN_LOG_FILENAME: &str = "openvpn.log"; const WIREGUARD_LOG_FILENAME: &str = "wireguard.log"; +/// Set the MTU to the lowest possible whilst still allowing for IPv6 to help with wireless +/// carriers that do a lot of encapsulation. +const DEFAULT_MTU: u16 = if cfg!(target_os = "android") { + 1280 +} else { + 1380 +}; + /// Results from operations in the tunnel module. pub type Result = std::result::Result; @@ -154,13 +162,25 @@ impl TunnelMonitor { + Clone + 'static, { + let default_mtu = DEFAULT_MTU; + #[cfg(any(target_os = "linux", target_os = "windows"))] - args.runtime - .block_on(Self::assign_mtu(&args.route_manager, params)); - let config = talpid_wireguard::config::Config::from_parameters(params)?; + // Detects the MTU of the device and sets the default tunnel MTU to that minus headers and + // the safety margin + let default_mtu = args + .runtime + .block_on( + args.route_manager + .get_mtu_for_route(params.connection.peer.endpoint.ip()), + ) + .map(|mtu| Self::clamp_mtu(params, mtu)) + .unwrap_or(default_mtu); + let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?; let monitor = talpid_wireguard::WireguardMonitor::start( config, params.options.quantum_resistant, + #[cfg(target_os = "linux")] + detect_mtu, log.as_deref(), args, )?; @@ -169,58 +189,36 @@ impl TunnelMonitor { }) } - /// Set the MTU in the tunnel parameters based on the inputted device MTU and some - /// calculations. `peer_mtu` is the detected device MTU. + /// Calculates and appropriate tunnel MTU based on the given peer MTU minus header sizes #[cfg(any(target_os = "linux", target_os = "windows"))] - fn set_mtu(params: &mut wireguard_types::TunnelParameters, peer_mtu: u16) { + fn clamp_mtu(params: &wireguard_types::TunnelParameters, peer_mtu: u16) -> u16 { + use talpid_tunnel::{ + IPV4_HEADER_SIZE, IPV6_HEADER_SIZE, MIN_IPV4_MTU, MIN_IPV6_MTU, WIREGUARD_HEADER_SIZE, + }; // Some users experience fragmentation issues even when we take the interface MTU and // subtract the header sizes. This is likely due to some program that they use which does // not change the interface MTU but adds its own header onto the outgoing packets. For this // reason we subtract some extra bytes from our MTU in order to give other programs some // safety margin. const MTU_SAFETY_MARGIN: u16 = 60; - const IPV4_HEADER_SIZE: u16 = 20; - const IPV6_HEADER_SIZE: u16 = 40; - const WIREGUARD_HEADER_SIZE: u16 = 40; + let total_header_size = WIREGUARD_HEADER_SIZE + match params.connection.peer.endpoint.is_ipv6() { false => IPV4_HEADER_SIZE, true => IPV6_HEADER_SIZE, }; + // The largest peer MTU that we allow - const MAX_PEER_MTU: u16 = 1500 - MTU_SAFETY_MARGIN; - // The minimum allowed MTU size for our tunnel in IPv6 is 1280 and 576 for IPv4 - const MIN_IPV4_MTU: u16 = 576; - const MIN_IPV6_MTU: u16 = 1280; + let max_peer_mtu: u16 = 1500 - MTU_SAFETY_MARGIN - total_header_size; + let min_mtu = match params.generic_options.enable_ipv6 { false => MIN_IPV4_MTU, true => MIN_IPV6_MTU, }; - let tunnel_mtu = peer_mtu - .saturating_sub(total_header_size) - .clamp(min_mtu, MAX_PEER_MTU - total_header_size); - params.options.mtu = Some(tunnel_mtu); - } - /// Detects the MTU of the device, calculates what the virtual device MTU should be and sets - /// that in the tunnel parameters. - #[cfg(any(target_os = "linux", target_os = "windows"))] - async fn assign_mtu( - route_manager: &RouteManagerHandle, - params: &mut wireguard_types::TunnelParameters, - ) { - // Only calculate the mtu automatically if the user has not set any - if params.options.mtu.is_none() { - match route_manager - .get_mtu_for_route(params.connection.peer.endpoint.ip()) - .await - { - Ok(mtu) => Self::set_mtu(params, mtu), - Err(e) => { - log::error!("Could not get the MTU for route {}", e); - } - } - } + peer_mtu + .saturating_sub(total_header_size) + .clamp(min_mtu, max_peer_mtu) } #[cfg(not(target_os = "android"))] diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index 0e462102b244..b30e9053fc89 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -30,14 +30,6 @@ pub struct Config { pub obfuscator_config: Option, } -/// Set the MTU to the lowest possible whilst still allowing for IPv6 to help with wireless -/// carriers that do a lot of encapsulation. -const DEFAULT_MTU: u16 = if cfg!(target_os = "android") { - 1280 -} else { - 1380 -}; - /// Configuration errors #[derive(err_derive::Error, Debug)] pub enum Error { @@ -52,12 +44,16 @@ pub enum Error { impl Config { /// Constructs a Config from parameters - pub fn from_parameters(params: &wireguard::TunnelParameters) -> Result { + pub fn from_parameters( + params: &wireguard::TunnelParameters, + default_mtu: u16, + ) -> Result { Self::new( ¶ms.connection, ¶ms.options, ¶ms.generic_options, ¶ms.obfuscation, + default_mtu, ) } @@ -67,9 +63,11 @@ impl Config { wg_options: &wireguard::TunnelOptions, generic_options: &GenericTunnelOptions, obfuscator_config: &Option, + default_mtu: u16, ) -> Result { let mut tunnel = connection.tunnel.clone(); - let mtu = wg_options.mtu.unwrap_or(DEFAULT_MTU); + + let mtu = wg_options.mtu.unwrap_or(default_mtu); if tunnel.addresses.is_empty() { return Err(Error::InvalidTunnelIpError); From 9963325e5779774b1da08b2720b50dff8280be81 Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Mon, 5 Feb 2024 17:48:42 +0100 Subject: [PATCH 6/8] Enable automatic MTU detection on linux --- talpid-core/src/tunnel/mod.rs | 4 ++++ talpid-wireguard/src/lib.rs | 30 ++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 8e9a51986b61..17ad2915d854 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -175,6 +175,10 @@ impl TunnelMonitor { ) .map(|mtu| Self::clamp_mtu(params, mtu)) .unwrap_or(default_mtu); + + #[cfg(target_os = "linux")] + let detect_mtu = params.options.mtu.is_none(); + let config = talpid_wireguard::config::Config::from_parameters(params, default_mtu)?; let monitor = talpid_wireguard::WireguardMonitor::start( config, diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 41033173e942..0875ad0a4917 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -270,6 +270,7 @@ impl WireguardMonitor { >( mut config: Config, psk_negotiation: bool, + #[cfg(target_os = "linux")] detect_mtu: bool, log_path: Option<&Path>, args: TunnelArgs<'_, F>, ) -> Result { @@ -388,7 +389,36 @@ impl WireguardMonitor { ) .await?; } + #[cfg(target_os = "linux")] + if detect_mtu { + let iface_name_clone = iface_name.clone(); + tokio::task::spawn(async move { + log::debug!("Starting MTU detection"); + let verified_mtu = match auto_mtu_detection( + gateway, + #[cfg(any(target_os = "macos", target_os = "linux"))] + iface_name_clone.clone(), + config.mtu, + ) + .await + { + Ok(mtu) => mtu, + Err(e) => { + log::error!("{}", e.display_chain_with_msg("Failed to detect MTU")); + return; + } + }; + if verified_mtu != config.mtu { + log::warn!("Lowering MTU from {} to {verified_mtu}", config.mtu); + if let Err(e) = unix::set_mtu(&iface_name_clone, verified_mtu) { + log::error!("{}", e.display_chain_with_msg("Failed to set MTU")) + }; + } else { + log::debug!("MTU {verified_mtu} verified to not drop packets"); + } + }); + } let mut connectivity_monitor = tokio::task::spawn_blocking(move || { match connectivity_monitor.establish_connectivity(args.retry_attempt) { Ok(true) => Ok(connectivity_monitor), From 058a3e99dedcdccfe71b45b75ad060b8968cbc63 Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Tue, 6 Feb 2024 14:56:06 +0100 Subject: [PATCH 7/8] Add `proptest` dependency --- Cargo.lock | 89 ++++++++++++++++++++++++++++++++++++- talpid-wireguard/Cargo.toml | 3 ++ talpid-wireguard/src/lib.rs | 32 ++++++++++--- 3 files changed, 117 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 070d7c9c1a2b..cb8c92e08b0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -289,6 +289,21 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -1635,6 +1650,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -2226,6 +2247,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -2670,6 +2692,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags 2.4.0", + "lazy_static", + "num-traits", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rand_xorshift", + "regex-syntax 0.8.2", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "prost" version = "0.12.0" @@ -2840,6 +2882,15 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core 0.6.4", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -2878,7 +2929,7 @@ dependencies = [ "aho-corasick", "memchr", "regex-automata", - "regex-syntax", + "regex-syntax 0.7.5", ] [[package]] @@ -2889,7 +2940,7 @@ checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.7.5", ] [[package]] @@ -2898,6 +2949,12 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + [[package]] name = "resolv-conf" version = "0.7.0" @@ -3028,6 +3085,18 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "ryu" version = "1.0.15" @@ -3726,6 +3795,7 @@ dependencies = [ "nix 0.23.2", "once_cell", "parking_lot", + "proptest", "rand 0.8.5", "rtnetlink", "socket2 0.5.3", @@ -4222,6 +4292,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-bidi" version = "0.3.13" @@ -4299,6 +4375,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "walkdir" version = "2.4.0" diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index c6f3669f8b27..d28ccca2aefe 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -79,3 +79,6 @@ features = [ "Win32_UI_Shell", "Win32_UI_WindowsAndMessaging", ] + +[dev-dependencies] +proptest = "1.4" diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 0875ad0a4917..09a0fc929a5b 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -1070,18 +1070,40 @@ async fn auto_mtu_detection( /// points. #[cfg(target_os = "linux")] fn mtu_spacing(mtu_min: u16, mtu_max: u16, step_size: u16) -> Vec { - if mtu_min > mtu_max { - panic!("Invalid MTU detection range: `mtu_min`={mtu_min}, `mtu_max`={mtu_max}."); - } - let second_mtu = mtu_min.next_multiple_of(step_size); + assert!(mtu_min < mtu_max); + assert!(step_size < mtu_max); + assert_ne!(step_size, 0); + + let second_mtu = (mtu_min + 1).next_multiple_of(step_size); let in_between = (second_mtu..mtu_max).step_by(step_size as usize); - let mut ret = Vec::with_capacity(((mtu_max - second_mtu).div_ceil(step_size) + 2) as usize); + + let mut ret = Vec::with_capacity(in_between.clone().count() + 2); ret.push(mtu_min); ret.extend(in_between); ret.push(mtu_max); ret } +#[cfg(all(test, target_os = "linux"))] +mod tests { + use crate::mtu_spacing; + use proptest::prelude::*; + + proptest! { + #[test] + fn test_mtu_spacing(mtu_min in 0..800u16, mtu_max in 800..2000u16, step_size in 1..800u16) { + let mtu_spacing = mtu_spacing(mtu_min, mtu_max, step_size); + + prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_min).count(), 1); + prop_assert_eq!(mtu_spacing.iter().filter(|mtu| mtu == &&mtu_max).count(), 1); + prop_assert_eq!(mtu_spacing.capacity(), mtu_spacing.len()); + let mut diffs = mtu_spacing.windows(2).map(|win| win[1]-win[0]); + prop_assert!(diffs.all(|diff| diff <= step_size)); + + } + } +} + #[derive(Debug)] enum CloseMsg { Stop, From eeca931f565dd7699e8d7f1209c65b1520834373 Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Tue, 6 Feb 2024 14:59:01 +0100 Subject: [PATCH 8/8] Remove `quicktest` dependency, replace usages with `proptest`. --- Cargo.lock | 23 +---------------------- talpid-core/Cargo.toml | 3 +-- talpid-core/src/future_retry.rs | 17 ++++++++++------- 3 files changed, 12 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb8c92e08b0e..5f37db970b07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2782,26 +2782,6 @@ dependencies = [ "serde", ] -[[package]] -name = "quickcheck" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" -dependencies = [ - "rand 0.8.5", -] - -[[package]] -name = "quickcheck_macros" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b22a693222d716a9587786f37ac3f6b4faedb5b80c23914e7303ff5a1d8016e9" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "quote" version = "1.0.33" @@ -3584,8 +3564,7 @@ dependencies = [ "once_cell", "parking_lot", "pfctl", - "quickcheck", - "quickcheck_macros", + "proptest", "rand 0.8.5", "resolv-conf", "subslice", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 4d2a54fc3fb5..1ec70876ffcf 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -94,6 +94,5 @@ features = [ tonic-build = { workspace = true, default-features = false, features = ["transport", "prost"] } [dev-dependencies] -quickcheck = { version = "1.0", default-features = false } -quickcheck_macros = "1.0" +proptest = "1.4" tokio = { workspace = true, features = [ "test-util" ] } diff --git a/talpid-core/src/future_retry.rs b/talpid-core/src/future_retry.rs index 197042e353d7..ee23de312fcf 100644 --- a/talpid-core/src/future_retry.rs +++ b/talpid-core/src/future_retry.rs @@ -153,6 +153,7 @@ fn apply_jitter(duration: Duration, jitter: f64) -> Duration { #[cfg(test)] mod test { use super::*; + use proptest::prelude::*; #[test] fn test_constant_interval() { @@ -220,13 +221,15 @@ mod test { assert_eq!(apply_jitter(second, 1.0), second); } - #[quickcheck_macros::quickcheck] - fn test_jitter(millis: u64, jitter: u64) { - let max_num = 2u64.checked_pow(f64::MANTISSA_DIGITS).unwrap(); - let jitter = (jitter % max_num) as f64 / (max_num as f64); - let unjittered_duration = Duration::from_millis(millis); - let jittered_duration = apply_jitter(unjittered_duration, jitter); - assert!(jittered_duration <= unjittered_duration); + proptest! { + #[test] + fn test_jitter(millis: u64, jitter: u64) { + let max_num = 2u64.checked_pow(f64::MANTISSA_DIGITS).unwrap(); + let jitter = (jitter % max_num) as f64 / (max_num as f64); + let unjittered_duration = Duration::from_millis(millis); + let jittered_duration = apply_jitter(unjittered_duration, jitter); + prop_assert!(jittered_duration <= unjittered_duration); + } } // NOTE: The test is disabled because the clock does not advance.