From fcbc99c64f55b36f4fd5ebdc160f4474dbf34d58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Thu, 30 Nov 2023 13:06:30 +0100 Subject: [PATCH] Remove hidden assumptions from WireGuard config --- talpid-wireguard/src/config.rs | 107 +++++++++--------- talpid-wireguard/src/lib.rs | 39 +++---- .../src/wireguard_kernel/nm_tunnel.rs | 2 +- .../src/wireguard_kernel/wg_message.rs | 2 +- talpid-wireguard/src/wireguard_nt.rs | 2 +- 5 files changed, 70 insertions(+), 82 deletions(-) diff --git a/talpid-wireguard/src/config.rs b/talpid-wireguard/src/config.rs index d8769e8af14c..0e462102b244 100644 --- a/talpid-wireguard/src/config.rs +++ b/talpid-wireguard/src/config.rs @@ -10,8 +10,10 @@ use talpid_types::net::{obfuscation::ObfuscatorConfig, wireguard, GenericTunnelO pub struct Config { /// Contains tunnel endpoint specific config pub tunnel: wireguard::TunnelConfig, - /// List of peer configurations - pub peers: Vec, + /// Entry peer + pub entry_peer: wireguard::PeerConfig, + /// Multihop exit peer + pub exit_peer: Option, /// IPv4 gateway pub ipv4_gateway: Ipv4Addr, /// IPv6 gateway @@ -46,59 +48,28 @@ pub enum Error { /// Peer has no valid IPs #[error(display = "Supplied peer has no valid IPs")] InvalidPeerIpError, - - /// Parameters don't contain any peers - #[error(display = "No peers supplied")] - NoPeersSuppliedError, } impl Config { /// Constructs a Config from parameters pub fn from_parameters(params: &wireguard::TunnelParameters) -> Result { - let tunnel = params.connection.tunnel.clone(); - let mut peers = vec![params.connection.peer.clone()]; - if let Some(exit_peer) = ¶ms.connection.exit_peer { - peers.push(exit_peer.clone()); - } Self::new( - tunnel, - peers, - params.connection.ipv4_gateway, - params.connection.ipv6_gateway, + ¶ms.connection, ¶ms.options, ¶ms.generic_options, - params.obfuscation.clone(), - #[cfg(target_os = "linux")] - params.connection.fwmark, + ¶ms.obfuscation, ) } /// Constructs a new Config struct fn new( - mut tunnel: wireguard::TunnelConfig, - mut peers: Vec, - ipv4_gateway: Ipv4Addr, - ipv6_gateway: Option, + connection: &wireguard::ConnectionConfig, wg_options: &wireguard::TunnelOptions, generic_options: &GenericTunnelOptions, - obfuscator_config: Option, - #[cfg(target_os = "linux")] fwmark: Option, + obfuscator_config: &Option, ) -> Result { - if peers.is_empty() { - return Err(Error::NoPeersSuppliedError); - } + let mut tunnel = connection.tunnel.clone(); let mtu = wg_options.mtu.unwrap_or(DEFAULT_MTU); - for peer in &mut peers { - peer.allowed_ips = peer - .allowed_ips - .iter() - .cloned() - .filter(|ip| ip.is_ipv4() || generic_options.enable_ipv6) - .collect(); - if peer.allowed_ips.is_empty() { - return Err(Error::InvalidPeerIpError); - } - } if tunnel.addresses.is_empty() { return Err(Error::InvalidTunnelIpError); @@ -107,20 +78,33 @@ impl Config { .addresses .retain(|ip| ip.is_ipv4() || generic_options.enable_ipv6); - let ipv6_gateway = ipv6_gateway.filter(|_opt| generic_options.enable_ipv6); + let ipv6_gateway = connection + .ipv6_gateway + .filter(|_opt| generic_options.enable_ipv6); - Ok(Config { + let mut config = Config { tunnel, - peers, - ipv4_gateway, + entry_peer: connection.peer.clone(), + exit_peer: connection.exit_peer.clone(), + ipv4_gateway: connection.ipv4_gateway, ipv6_gateway, mtu, #[cfg(target_os = "linux")] - fwmark, + fwmark: connection.fwmark, #[cfg(target_os = "linux")] enable_ipv6: generic_options.enable_ipv6, - obfuscator_config, - }) + obfuscator_config: obfuscator_config.to_owned(), + }; + + for peer in config.peers_mut() { + peer.allowed_ips + .retain(|ip| ip.is_ipv4() || generic_options.enable_ipv6); + if peer.allowed_ips.is_empty() { + return Err(Error::InvalidPeerIpError); + } + } + + Ok(config) } /// Returns a CString with the appropriate config for WireGuard-go @@ -139,7 +123,7 @@ impl Config { wg_conf.add("replace_peers", "true"); - for peer in &self.peers { + for peer in self.peers() { wg_conf .add("public_key", peer.public_key.as_bytes().as_ref()) .add("endpoint", peer.endpoint.to_string().as_str()) @@ -157,17 +141,32 @@ impl Config { } /// Return whether the config connects to an exit peer from another remote peer. - /// - /// This relies on the assumption that multiple peers imply that multihop is used. This is - /// misguided in principle but happens to work given that normally only one peer will be - /// present. pub fn is_multihop(&self) -> bool { - self.peers.len() == 2 + self.exit_peer.is_some() + } + + /// Return the exit peer. `exit_peer` if it is set, otherwise `entry_peer`. + pub fn exit_peer_mut(&mut self) -> &mut wireguard::PeerConfig { + if let Some(ref mut peer) = self.exit_peer { + return peer; + } + &mut self.entry_peer + } + + /// Return an iterator over all peers. + pub fn peers(&self) -> impl Iterator { + self.exit_peer + .as_ref() + .into_iter() + .chain(std::iter::once(&self.entry_peer)) } - /// Return the entry peer. This happens to be the first peer. - pub fn entry_peer_mut(&mut self) -> Option<&mut wireguard::PeerConfig> { - self.peers.get_mut(0) + /// Return a mutable iterator over all peers. + pub fn peers_mut(&mut self) -> impl Iterator { + self.exit_peer + .as_mut() + .into_iter() + .chain(std::iter::once(&mut self.entry_peer)) } } diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 45a76eb65574..036388e7f230 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -167,10 +167,6 @@ async fn maybe_create_obfuscator( config: &mut Config, close_msg_sender: sync_mpsc::Sender, ) -> Result> { - // There are one or two peers. - // The first one is always the entry relay. - let first_peer = config.peers.get_mut(0).expect("missing peer"); - if let Some(ref obfuscator_config) = config.obfuscator_config { match obfuscator_config { ObfuscatorConfig::Udp2Tcp { endpoint } => { @@ -186,7 +182,7 @@ async fn maybe_create_obfuscator( let endpoint = obfuscator.endpoint(); log::trace!("Patching first WireGuard peer to become {:?}", endpoint); - first_peer.endpoint = endpoint; + config.entry_peer.endpoint = endpoint; #[cfg(target_os = "android")] let remote_socket_fd = obfuscator.remote_socket_fd(); @@ -234,8 +230,7 @@ impl WireguardMonitor { ) -> Result { let on_event = args.on_event.clone(); - let endpoint_addrs: Vec = - config.peers.iter().map(|peer| peer.endpoint.ip()).collect(); + let endpoint_addrs: Vec = config.peers().map(|peer| peer.endpoint.ip()).collect(); let (close_obfs_sender, close_obfs_listener) = sync_mpsc::channel(); let obfuscator = args.runtime.block_on(maybe_create_obfuscator( @@ -459,7 +454,10 @@ impl WireguardMonitor { // exit peer with these rules and not the broader internet. AllowedTunnelTraffic::Two( allowed_traffic, - Endpoint::from_socket_address(config.peers[1].endpoint, TransportProtocol::Udp), + Endpoint::from_socket_address( + config.exit_peer.as_mut().unwrap().endpoint, + TransportProtocol::Udp, + ), ) } else { AllowedTunnelTraffic::One(allowed_traffic) @@ -474,14 +472,11 @@ impl WireguardMonitor { log::debug!("Successfully exchanged PSK with exit peer"); - let mut entry_psk = None; - if config.is_multihop() { // Set up tunnel to lead to entry let mut entry_tun_config = config.clone(); entry_tun_config - .entry_peer_mut() - .expect("entry peer not found") + .entry_peer .allowed_ips .push(IpNetwork::new(IpAddr::V4(config.ipv4_gateway), 32).unwrap()); @@ -495,7 +490,7 @@ impl WireguardMonitor { &tun_provider, ) .await?; - entry_psk = Some( + let entry_psk = Some( Self::perform_psk_negotiation( retry_attempt, &entry_config, @@ -505,18 +500,13 @@ impl WireguardMonitor { .await?, ); log::debug!("Successfully exchanged PSK with entry peer"); + + config.entry_peer.psk = entry_psk; } - // Set new priv key and psks + config.exit_peer_mut().psk = Some(exit_psk); + config.tunnel.private_key = wg_psk_privkey; - if let Some(entry_psk) = entry_psk { - // The first peer is the entry peer and there is guaranteed to be a second peer - // which is the exit - config.peers.get_mut(0).expect("entry peer not found").psk = Some(entry_psk); - config.peers.get_mut(1).expect("exit peer not found").psk = Some(exit_psk); - } else { - config.peers.get_mut(0).expect("peer not found").psk = Some(exit_psk); - } *config = Self::reconfigure_tunnel( tunnel, @@ -588,7 +578,7 @@ impl WireguardMonitor { let gateway_net_v6 = config .ipv6_gateway .map(|net| ipnetwork::IpNetwork::from(IpAddr::from(net))); - for peer in &mut patched_config.peers { + for peer in patched_config.peers_mut() { peer.allowed_ips = peer .allowed_ips .iter() @@ -939,8 +929,7 @@ impl WireguardMonitor { /// Return routes for all allowed IPs. fn get_tunnel_destinations(config: &Config) -> impl Iterator + '_ { config - .peers - .iter() + .peers() .flat_map(|peer| peer.allowed_ips.iter()) .cloned() } diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs index 3f2661a4dc2b..7b5966b9e41b 100644 --- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs @@ -130,7 +130,7 @@ fn convert_config_to_dbus(config: &Config) -> DeviceConfig { ); wireguard_config.insert("private-key-flags".into(), Variant(Box::new(0x0u32))); - for peer in config.peers.iter() { + for peer in config.peers() { let mut peer_config: VariantMap = HashMap::new(); let allowed_ips = peer .allowed_ips diff --git a/talpid-wireguard/src/wireguard_kernel/wg_message.rs b/talpid-wireguard/src/wireguard_kernel/wg_message.rs index 7ed972c3eaf6..4dc84ed503b4 100644 --- a/talpid-wireguard/src/wireguard_kernel/wg_message.rs +++ b/talpid-wireguard/src/wireguard_kernel/wg_message.rs @@ -78,7 +78,7 @@ impl DeviceMessage { pub fn reset_config(message_type: u16, interface_index: u32, config: &Config) -> DeviceMessage { let mut peers = vec![]; - for peer in config.peers.iter() { + for peer in config.peers() { let peer_endpoint = InetAddr::from_std(&peer.endpoint); let allowed_ips = peer.allowed_ips.iter().map(From::from).collect(); let mut peer_nlas = vec![ diff --git a/talpid-wireguard/src/wireguard_nt.rs b/talpid-wireguard/src/wireguard_nt.rs index 588d5a7f82a9..78b6f77d3e45 100644 --- a/talpid-wireguard/src/wireguard_nt.rs +++ b/talpid-wireguard/src/wireguard_nt.rs @@ -816,7 +816,7 @@ fn serialize_config(config: &Config) -> Result>> { buffer.extend(as_uninit_byte_slice(&header)); - for peer in &config.peers { + for peer in config.peers() { let flags = if peer.psk.is_some() { WgPeerFlag::HAS_PRESHARED_KEY | WgPeerFlag::HAS_PUBLIC_KEY | WgPeerFlag::HAS_ENDPOINT } else {