From 0b3444dfb26cb9b7da72bc55625f4c79703d5561 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Sat, 25 Nov 2023 16:34:34 +0100 Subject: [PATCH] Subtract multihop overhead from default route MTU on Linux --- talpid-routing/Cargo.toml | 2 +- talpid-routing/src/lib.rs | 20 ++++++++++++++++++++ talpid-routing/src/unix/linux.rs | 20 +++++++++++++++++--- talpid-wireguard/src/lib.rs | 27 +++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/talpid-routing/Cargo.toml b/talpid-routing/Cargo.toml index e5cfb9d9f904..31545fe7053d 100644 --- a/talpid-routing/Cargo.toml +++ b/talpid-routing/Cargo.toml @@ -23,7 +23,7 @@ talpid-types = { path = "../talpid-types" } libc = "0.2" once_cell = { workspace = true } rtnetlink = "0.11" -netlink-packet-route = "0.13" +netlink-packet-route = { version = "0.13", features = ["rich_nlas"] } netlink-sys = "0.8.3" [target.'cfg(target_os = "macos")'.dependencies] diff --git a/talpid-routing/src/lib.rs b/talpid-routing/src/lib.rs index dd5fd3a7616a..f1bf28d1a2c8 100644 --- a/talpid-routing/src/lib.rs +++ b/talpid-routing/src/lib.rs @@ -38,6 +38,8 @@ pub struct Route { metric: Option, #[cfg(target_os = "linux")] table_id: u32, + #[cfg(target_os = "linux")] + mtu: Option, } impl Route { @@ -49,6 +51,8 @@ impl Route { metric: None, #[cfg(target_os = "linux")] table_id: u32::from(RT_TABLE_MAIN), + #[cfg(target_os = "linux")] + mtu: None, } } @@ -72,6 +76,10 @@ impl fmt::Display for Route { } #[cfg(target_os = "linux")] write!(f, " table {}", self.table_id)?; + #[cfg(target_os = "linux")] + if let Some(mtu) = self.mtu { + write!(f, " mtu {mtu}")?; + } Ok(()) } } @@ -87,6 +95,9 @@ pub struct RequiredRoute { /// Specifies whether the route should be added to the main routing table or not. #[cfg(target_os = "linux")] main_table: bool, + /// Specifies route MTU + #[cfg(target_os = "linux")] + mtu: Option, } impl RequiredRoute { @@ -97,6 +108,8 @@ impl RequiredRoute { prefix, #[cfg(target_os = "linux")] main_table: true, + #[cfg(target_os = "linux")] + mtu: None, } } @@ -106,6 +119,13 @@ impl RequiredRoute { self.main_table = main_table; self } + + /// Set route MTU to the given value. + #[cfg(target_os = "linux")] + pub fn mtu(mut self, mtu: u16) -> Self { + self.mtu = Some(mtu); + self + } } /// A NetNode represents a network node - either a real one or a symbolic default one. diff --git a/talpid-routing/src/unix/linux.rs b/talpid-routing/src/unix/linux.rs index a2cf19d3b48c..600ddd6794bd 100644 --- a/talpid-routing/src/unix/linux.rs +++ b/talpid-routing/src/unix/linux.rs @@ -20,7 +20,7 @@ use libc::{AF_INET, AF_INET6}; use netlink_packet_route::{ constants::{ARPHRD_LOOPBACK, FIB_RULE_INVERT, FR_ACT_TO_TBL, NLM_F_REQUEST}, link::{nlas::Nla as LinkNla, LinkMessage}, - route::{nlas::Nla as RouteNla, RouteHeader, RouteMessage}, + route::{nlas::Nla as RouteNla, Metrics, RouteHeader, RouteMessage}, rtnl::{ constants::{ RTN_UNSPEC, RTPROT_UNSPEC, RT_SCOPE_LINK, RT_SCOPE_UNIVERSE, RT_TABLE_COMPAT, @@ -293,7 +293,9 @@ impl RouteManagerImpl { } else { self.table_id }; - required_normal_routes.insert(Route::new(node, route.prefix).table(table)); + let mut new_route = Route::new(node, route.prefix).table(table); + new_route.mtu = route.mtu.map(u32::from); + required_normal_routes.insert(new_route); } } } @@ -450,12 +452,13 @@ impl RouteManagerImpl { destination_length, ) .map_err(Error::InvalidNetworkPrefix)?; + let mut node_addr = None; let mut device = None; let mut metric = None; let mut gateway: Option = None; - let mut table_id = u32::from(msg.header.table); + let mut route_mtu = None; for nla in msg.nlas.iter() { match nla { @@ -501,6 +504,10 @@ impl RouteManagerImpl { RouteNla::Table(id) => { table_id = *id; } + + RouteNla::Metrics(Metrics::Mtu(mtu)) => { + route_mtu = Some(*mtu); + } _ => continue, } } @@ -519,6 +526,7 @@ impl RouteManagerImpl { prefix, metric, table_id, + mtu: route_mtu, })) } @@ -700,6 +708,11 @@ impl RouteManagerImpl { add_message.nlas.push(RouteNla::Priority(metric)); } + // Set route MTU + if let Some(mtu) = route.mtu { + add_message.nlas.push(RouteNla::Metrics(Metrics::Mtu(mtu))); + } + // Need to modify the request in place to set the correct flags to be able to replace any // existing routes - self.handle.route().add_v4().execute() sets the NLM_F_EXCL flag which // will make the request fail if a route with the same destination already exists. @@ -743,6 +756,7 @@ impl RouteManagerImpl { async fn get_mtu_for_route(&self, ip: IpAddr) -> Result { // RECURSION_LIMIT controls how many times we recurse to find the device name by looking up // an IP with `get_destination_route`. + // TODO: Check route MTU first const RECURSION_LIMIT: usize = 10; const STANDARD_MTU: u16 = 1500; let mut attempted_ip = ip; diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index c95a5d371b67..9ca230b652d5 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -880,6 +880,10 @@ impl WireguardMonitor { let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); + #[cfg(target_os = "linux")] + let gateway_routes = + gateway_routes.map(|route| Self::apply_route_mtu_for_multihop(route, config)); + let routes = gateway_routes.chain( Self::get_tunnel_destinations(config) .filter(|allowed_ip| allowed_ip.prefix() != 0) @@ -916,6 +920,29 @@ impl WireguardMonitor { #[cfg(target_os = "linux")] iter.map(|route| route.use_main_table(false)) + .map(|route| Self::apply_route_mtu_for_multihop(route, config)) + } + + #[cfg(target_os = "linux")] + fn apply_route_mtu_for_multihop(route: RequiredRoute, config: &Config) -> RequiredRoute { + if config.peers.len() == 1 { + route + } else { + // Set route MTU by subtracting the WireGuard overhead from the tunnel MTU. + // NOTE: Somewhat incorrect since it doesn't account for packet padding/alignment? + // TODO: Move consts to shared location + const IPV4_HEADER_SIZE: u16 = 20; + const IPV6_HEADER_SIZE: u16 = 40; + const WIREGUARD_HEADER_SIZE: u16 = 40; + + let ip_overhead = match route.prefix.is_ipv4() { + true => IPV4_HEADER_SIZE, + false => IPV6_HEADER_SIZE, + }; + let mtu = config.mtu - ip_overhead - WIREGUARD_HEADER_SIZE; + + route.mtu(mtu) + } } /// Return routes for all allowed IPs.