From b3d29325c6f96c83c144870499f2095bb09dfb04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Sat, 10 Aug 2024 14:29:45 +0200 Subject: [PATCH] Move ownership of tun config to tun provider --- .../tunnel_state_machine/connected_state.rs | 85 ++++++----- .../tunnel_state_machine/connecting_state.rs | 67 ++++++--- .../disconnected_state.rs | 15 +- .../disconnecting_state.rs | 9 +- .../src/tunnel_state_machine/error_state.rs | 75 ++++----- talpid-core/src/tunnel_state_machine/mod.rs | 113 ++++++-------- talpid-tunnel/src/tun_provider/android/mod.rs | 142 ++++-------------- talpid-tunnel/src/tun_provider/mod.rs | 34 ++++- talpid-tunnel/src/tun_provider/unix.rs | 22 +-- talpid-wireguard/src/wireguard_go/mod.rs | 7 +- 10 files changed, 273 insertions(+), 296 deletions(-) diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 8208223224f5..5de0120ffcde 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -228,23 +228,34 @@ impl ConnectedState { match command { Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { - let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { - self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) - } else { - match self.set_firewall_policy(shared_values) { - Ok(()) => { - if cfg!(target_os = "android") { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } else { - SameState(self) - } + let consequence = if shared_values.set_allow_lan(allow_lan) { + #[cfg(target_os = "android")] + { + if let Err(_err) = shared_values.restart_tunnel(false) { + self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::StartTunnelError), + ) + } else { + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) + } + } + #[cfg(not(target_os = "android"))] + { + match self.set_firewall_policy(shared_values) { + Ok(()) => SameState(self), + Err(error) => self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError( + error, + )), + ), } - Err(error) => self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), - ), } + } else { + SameState(self) }; + let _ = complete_tx.send(()); consequence } @@ -254,8 +265,20 @@ impl ConnectedState { SameState(self) } Some(TunnelCommand::Dns(servers, complete_tx)) => { - let consequence = match shared_values.set_dns_servers(servers) { - Ok(true) => { + let consequence = if shared_values.set_dns_servers(servers) { + #[cfg(target_os = "android")] + { + if let Err(_err) = shared_values.restart_tunnel(false) { + self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::StartTunnelError), + ) + } else { + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) + } + } + #[cfg(not(target_os = "android"))] + { if let Err(error) = self.set_firewall_policy(shared_values) { return self.disconnect( shared_values, @@ -266,9 +289,6 @@ impl ConnectedState { } match self.set_dns(shared_values) { - #[cfg(target_os = "android")] - Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - #[cfg(not(target_os = "android"))] Ok(()) => SameState(self), Err(error) => { log::error!( @@ -282,10 +302,8 @@ impl ConnectedState { } } } - Ok(false) => SameState(self), - Err(error_cause) => { - self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) - } + } else { + SameState(self) }; let _ = complete_tx.send(()); consequence @@ -327,22 +345,21 @@ impl ConnectedState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - match shared_values.exclude_paths(paths) { - Ok(changed) => { - let _ = result_tx.send(Ok(())); - if changed { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } else { - SameState(self) - } - } - Err(err) => { - let _ = result_tx.send(Err(err)); + if shared_values.set_excluded_paths(paths) { + if let Err(err) = shared_values.restart_tunnel(false) { + let _ = + result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err))); self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::SplitTunnelError), ) + } else { + let _ = result_tx.send(Ok(())); + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } + } else { + let _ = result_tx.send(Ok(())); + SameState(self) } } #[cfg(target_os = "macos")] diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 6f0e0414d0fd..1334b716b2d1 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -96,6 +96,9 @@ impl ConnectingState { ErrorStateCause::SetFirewallPolicyError(error), ) } else { + #[cfg(target_os = "android")] + shared_values.prepare_tun_config(false); + let connecting_state = Self::start_tunnel( shared_values.runtime.clone(), tunnel_parameters, @@ -354,6 +357,7 @@ impl ConnectingState { )) } + #[cfg(not(target_os = "android"))] fn reset_firewall( self: Box, shared_values: &mut SharedTunnelStateValues, @@ -364,13 +368,7 @@ impl ConnectingState { &self.tunnel_metadata, self.allowed_tunnel_traffic.clone(), ) { - Ok(()) => { - if cfg!(target_os = "android") { - self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) - } else { - EventConsequence::SameState(self) - } - } + Ok(()) => EventConsequence::SameState(self), Err(error) => self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), @@ -387,10 +385,22 @@ impl ConnectingState { match command { Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { - let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { - self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) - } else { + let consequence = if shared_values.set_allow_lan(allow_lan) { + #[cfg(target_os = "android")] + { + if let Err(_err) = shared_values.restart_tunnel(false) { + self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::StartTunnelError), + ) + } else { + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) + } + } + #[cfg(not(target_os = "android"))] self.reset_firewall(shared_values) + } else { + SameState(self) }; let _ = complete_tx.send(()); consequence @@ -415,12 +425,24 @@ impl ConnectingState { SameState(self) } Some(TunnelCommand::Dns(servers, complete_tx)) => { - let consequence = match shared_values.set_dns_servers(servers) { + let consequence = if shared_values.set_dns_servers(servers) { #[cfg(target_os = "android")] - Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - Ok(_) => SameState(self), - Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)), + { + if let Err(_err) = shared_values.restart_tunnel(false) { + self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::StartTunnelError), + ) + } else { + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) + } + } + #[cfg(not(target_os = "android"))] + SameState(self) + } else { + SameState(self) }; + let _ = complete_tx.send(()); consequence } @@ -461,18 +483,21 @@ impl ConnectingState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - match shared_values.exclude_paths(paths) { - Ok(_changed) => { - let _ = result_tx.send(Ok(())); - SameState(self) - } - Err(error) => { - let _ = result_tx.send(Err(error)); + if shared_values.set_excluded_paths(paths) { + if let Err(err) = shared_values.restart_tunnel(false) { + let _ = + result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err))); self.disconnect( shared_values, AfterDisconnect::Block(ErrorStateCause::SplitTunnelError), ) + } else { + let _ = result_tx.send(Ok(())); + self.disconnect(shared_values, AfterDisconnect::Reconnect(0)) } + } else { + let _ = result_tx.send(Ok(())); + SameState(self) } } #[cfg(target_os = "macos")] diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index e570d3af6441..4ee19a5e16aa 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -138,13 +138,7 @@ impl TunnelState for DisconnectedState { match runtime.block_on(commands.next()) { Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { - if shared_values.allow_lan != allow_lan { - // The only platform that can fail is Android, but Android doesn't support the - // "block when disconnected" option, so the following call never fails. - shared_values - .set_allow_lan(allow_lan) - .expect("Failed to set allow LAN parameter"); - + if shared_values.set_allow_lan(allow_lan) { Self::set_firewall_policy(shared_values, false); } let _ = complete_tx.send(()); @@ -160,9 +154,7 @@ impl TunnelState for DisconnectedState { } Some(TunnelCommand::Dns(servers, complete_tx)) => { // Same situation as allow LAN above. - shared_values - .set_dns_servers(servers) - .expect("Failed to reconnect after changing custom DNS servers"); + shared_values.set_dns_servers(servers); let _ = complete_tx.send(()); SameState(self) } @@ -218,7 +210,8 @@ impl TunnelState for DisconnectedState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - let _ = result_tx.send(shared_values.exclude_paths(paths).map(|_| ())); + shared_values.set_excluded_paths(paths); + let _ = result_tx.send(Ok(())); SameState(self) } #[cfg(target_os = "macos")] diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 099208201e99..ddcb3cebd25f 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -83,7 +83,8 @@ impl DisconnectingState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - let _ = result_tx.send(shared_values.exclude_paths(paths).map(|_| ())); + shared_values.set_excluded_paths(paths); + let _ = result_tx.send(Ok(())); AfterDisconnect::Nothing } #[cfg(target_os = "macos")] @@ -139,7 +140,8 @@ impl DisconnectingState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - let _ = result_tx.send(shared_values.exclude_paths(paths).map(|_| ())); + shared_values.set_excluded_paths(paths); + let _ = result_tx.send(Ok(())); AfterDisconnect::Block(reason) } #[cfg(target_os = "macos")] @@ -196,7 +198,8 @@ impl DisconnectingState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - let _ = result_tx.send(shared_values.exclude_paths(paths).map(|_| ())); + shared_values.set_excluded_paths(paths); + let _ = result_tx.send(Ok(())); AfterDisconnect::Reconnect(retry_attempt) } #[cfg(target_os = "macos")] diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index b4f642d8e2c4..99f8dc17c403 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -51,11 +51,12 @@ impl ErrorState { let block_failure = Self::set_firewall_policy(shared_values).err(); #[cfg(target_os = "android")] - let block_failure = if !Self::create_blocking_tun(shared_values) { + let block_failure = if shared_values.restart_tunnel(true).is_err() { Some(FirewallPolicyError::Generic) } else { None }; + ( Box::new(ErrorState { block_reason: block_reason.clone(), @@ -98,28 +99,6 @@ impl ErrorState { }) } - /// Returns true if a new tunnel device was successfully created. - #[cfg(target_os = "android")] - fn create_blocking_tun(shared_values: &mut SharedTunnelStateValues) -> bool { - match shared_values - .tun_provider - .lock() - .unwrap() - .create_blocking_tun() - { - Ok(()) => true, - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to open tunnel adapter to drop packets for blocked state" - ) - ); - false - } - } - } - fn reset_dns(shared_values: &mut SharedTunnelStateValues) { if let Err(error) = shared_values.dns_monitor.reset() { log::error!("{}", error.display_chain_with_msg("Unable to reset DNS")); @@ -139,13 +118,25 @@ impl TunnelState for ErrorState { match runtime.block_on(commands.next()) { Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { - let consequence = - if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { - NewState(Self::enter(shared_values, error_state_cause)) + let consequence = if shared_values.set_allow_lan(allow_lan) { + #[cfg(target_os = "android")] + if let Err(_err) = shared_values.restart_tunnel(true) { + NewState(Self::enter( + shared_values, + ErrorStateCause::StartTunnelError, + )) } else { + SameState(self) + } + #[cfg(not(target_os = "android"))] + { let _ = Self::set_firewall_policy(shared_values); SameState(self) - }; + } + } else { + SameState(self) + }; + let _ = complete_tx.send(()); consequence } @@ -155,7 +146,7 @@ impl TunnelState for ErrorState { let _ = Self::set_firewall_policy(shared_values); #[cfg(target_os = "android")] - if !Self::create_blocking_tun(shared_values) { + if let Err(_err) = shared_values.restart_tunnel(true) { let _ = tx.send(()); return NewState(Self::enter( shared_values, @@ -167,12 +158,21 @@ impl TunnelState for ErrorState { SameState(self) } Some(TunnelCommand::Dns(servers, complete_tx)) => { - let consequence = - if let Err(error_state_cause) = shared_values.set_dns_servers(servers) { - NewState(Self::enter(shared_values, error_state_cause)) - } else { + let consequence = if shared_values.set_dns_servers(servers) { + #[cfg(target_os = "android")] + { + // DNS is blocked in the error state, so only update tun config + shared_values.prepare_tun_config(true); + SameState(self) + } + #[cfg(not(target_os = "android"))] + { + let _ = Self::set_firewall_policy(shared_values); SameState(self) - }; + } + } else { + SameState(self) + }; let _ = complete_tx.send(()); consequence } @@ -213,7 +213,14 @@ impl TunnelState for ErrorState { } #[cfg(target_os = "android")] Some(TunnelCommand::SetExcludedApps(result_tx, paths)) => { - let _ = result_tx.send(shared_values.exclude_paths(paths).map(|_| ())); + if shared_values.set_excluded_paths(paths) { + if let Err(err) = shared_values.restart_tunnel(true) { + let _ = + result_tx.send(Err(crate::split_tunnel::Error::SetExcludedApps(err))); + } + } else { + let _ = result_tx.send(Ok(())); + } SameState(self) } #[cfg(windows)] diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 5850c5d8d9f0..e0007fcaa7bc 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -136,12 +136,6 @@ pub async fn spawn( let tun_provider = TunProvider::new( #[cfg(target_os = "android")] android_context.clone(), - #[cfg(target_os = "android")] - initial_settings.allow_lan, - #[cfg(target_os = "android")] - initial_settings.dns_servers.clone(), - #[cfg(target_os = "android")] - initial_settings.exclude_paths.clone(), ); let (shutdown_tx, shutdown_rx) = oneshot::channel(); @@ -369,6 +363,8 @@ impl TunnelStateMachine { let mut shared_values = SharedTunnelStateValues { #[cfg(any(target_os = "windows", target_os = "macos"))] split_tunnel, + #[cfg(target_os = "android")] + excluded_packages: args.settings.exclude_paths, runtime, firewall, dns_monitor, @@ -457,6 +453,8 @@ struct SharedTunnelStateValues { split_tunnel: split_tunnel::SplitTunnel, #[cfg(target_os = "macos")] split_tunnel: split_tunnel::Handle, + #[cfg(target_os = "android")] + excluded_packages: Vec, runtime: tokio::runtime::Handle, firewall: Firewall, dns_monitor: DnsMonitor, @@ -548,56 +546,21 @@ impl SharedTunnelStateValues { .map_err(|error| ErrorStateCause::from(&error)) } - pub fn set_allow_lan(&mut self, allow_lan: bool) -> Result<(), ErrorStateCause> { + pub fn set_allow_lan(&mut self, allow_lan: bool) -> bool { if self.allow_lan != allow_lan { self.allow_lan = allow_lan; - - #[cfg(target_os = "android")] - { - if let Err(error) = self.tun_provider.lock().unwrap().set_allow_lan(allow_lan) { - log::error!( - "{}", - error.display_chain_with_msg(&format!( - "Failed to restart tunnel after {} LAN connections", - if allow_lan { "allowing" } else { "blocking" } - )) - ); - return Err(ErrorStateCause::StartTunnelError); - } - } + true + } else { + false } - - Ok(()) } - pub fn set_dns_servers( - &mut self, - dns_servers: Option>, - ) -> Result { + pub fn set_dns_servers(&mut self, dns_servers: Option>) -> bool { if self.dns_servers != dns_servers { self.dns_servers = dns_servers; - - #[cfg(target_os = "android")] - { - if let Err(error) = self - .tun_provider - .lock() - .unwrap() - .set_dns_servers(self.dns_servers.clone()) - { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to restart tunnel after changing DNS servers", - ) - ); - return Err(ErrorStateCause::StartTunnelError); - } - } - - Ok(true) + true } else { - Ok(false) + false } } @@ -645,27 +608,47 @@ impl SharedTunnelStateValues { } /// Update the set of excluded paths (split tunnel apps) for the tunnel provider. - /// Returns `Ok(true)` if the tunnel state machine should issue a tunnel reconnect. #[cfg(target_os = "android")] - pub fn exclude_paths(&mut self, apps: Vec) -> Result { - self.tun_provider - .lock() - .unwrap() - .set_exclude_apps(apps) - .map_err(split_tunnel::Error::SetExcludedApps) - .inspect_err(|error| { + pub fn set_excluded_paths(&mut self, apps: Vec) -> bool { + if apps != self.excluded_packages { + self.excluded_packages = apps; + true + } else { + false + } + } + + /// Update the tunnel provider config. This does not actually create any tunnel. + #[cfg(target_os = "android")] + pub fn prepare_tun_config(&self, blocking: bool) { + let mut tun_provider = self.tun_provider.lock().unwrap(); + + let config = tun_provider.config_mut(); + if blocking { + config.dns_servers = Some(vec![]); + } else { + config.dns_servers = self.dns_servers.clone(); + } + config.allow_lan = self.allow_lan; + config.excluded_packages = self.excluded_packages.clone(); + } + + /// Recreate the tunnel device. Note that this causes the current tunnel fd used by + /// the tunnel monitor to become stale, so a reconnect is needed. + #[cfg(target_os = "android")] + pub fn restart_tunnel(&self, blocking: bool) -> Result<(), talpid_tunnel::tun_provider::Error> { + self.prepare_tun_config(blocking); + + match self.tun_provider.lock().unwrap().get_tun() { + Ok(_tun) => Ok(()), + Err(error) => { log::error!( "{}", - error.display_chain_with_msg( - "Failed to restart tunnel after updating excluded apps", - ) + error.display_chain_with_msg("Failed to restart tunnel") ); - })?; - // NOTE: For now, we tell the TSM to always reconnect when this function has been - // successfully called. We still return a boolean value in case we would like to introduce - // some condition in the future, thus forcing the TSM to be ready to handle both cases - // already. - Ok(true) + Err(error) + } + } } } diff --git a/talpid-tunnel/src/tun_provider/android/mod.rs b/talpid-tunnel/src/tun_provider/android/mod.rs index 35da0870c817..8e7e406a6cc1 100644 --- a/talpid-tunnel/src/tun_provider/android/mod.rs +++ b/talpid-tunnel/src/tun_provider/android/mod.rs @@ -12,7 +12,7 @@ use jnix::{ FromJava, IntoJava, JnixEnv, }; use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr}, + net::IpAddr, os::unix::io::{AsRawFd, RawFd}, sync::Arc, }; @@ -55,20 +55,12 @@ pub struct AndroidTunProvider { jvm: Arc, class: GlobalRef, object: GlobalRef, - last_tun_config: Option<(TunConfig, bool)>, - allow_lan: bool, - custom_dns_servers: Option>, - excluded_packages: Vec, + config: TunConfig, } impl AndroidTunProvider { /// Create a new AndroidTunProvider interfacing with Android's VpnService. - pub fn new( - context: AndroidContext, - allow_lan: bool, - custom_dns_servers: Option>, - excluded_packages: Vec, - ) -> Self { + pub fn new(context: AndroidContext) -> Self { let env = JnixEnv::from( context .jvm @@ -81,65 +73,24 @@ impl AndroidTunProvider { jvm: context.jvm, class: talpid_vpn_service_class, object: context.vpn_service, - last_tun_config: None, - allow_lan, - custom_dns_servers, - excluded_packages, + config: TunConfig::default(), } } - pub fn set_allow_lan(&mut self, allow_lan: bool) -> Result<(), Error> { - if self.allow_lan != allow_lan { - self.allow_lan = allow_lan; - self.recreate_tun_if_open()?; - } - - Ok(()) - } - - pub fn set_dns_servers(&mut self, servers: Option>) -> Result<(), Error> { - if self.custom_dns_servers != servers { - self.custom_dns_servers = servers; - self.recreate_tun_if_open()?; - } - - Ok(()) + /// Get the current tunnel config. Note that the tunnel must be recreated for any changes to + /// take effect. + pub fn config_mut(&mut self) -> &mut TunConfig { + &mut self.config } - /// Update the set of excluded paths (split tunnel apps) for the tunnel provider. - /// This will cause any pre-existing tunnel to be recreated if necessary. See - /// [`AndroidTunProvider::recreate_tun_if_open()`] for details. - pub fn set_exclude_apps(&mut self, excluded_packages: Vec) -> Result<(), Error> { - if self.excluded_packages != excluded_packages { - self.excluded_packages = excluded_packages; - self.recreate_tun_if_open()?; - } - Ok(()) - } - - /// Retrieve a tunnel device with the provided configuration. Custom DNS and LAN routes are - /// appended to the provided config. - pub fn get_tun(&mut self, config: TunConfig) -> Result { - self.get_tun_inner(config, false) + /// Retrieve a tunnel device with the provided configuration. + pub fn get_tun(&mut self) -> Result { + self.get_tun_inner() } /// Retrieve a tunnel device with the provided configuration. - fn get_tun_inner(&mut self, config: TunConfig, blocking: bool) -> Result { - let service_config = VpnServiceConfig::new( - config.clone(), - self.allow_lan, - if !blocking { - self.custom_dns_servers.clone() - } else { - // Disable DNS - Some(vec![]) - }, - self.excluded_packages.clone(), - ); - - let tun_fd = self.get_tun_fd(service_config)?; - - self.last_tun_config = Some((config, blocking)); + fn get_tun_inner(&mut self) -> Result { + let tun_fd = self.get_tun_fd()?; let jvm = unsafe { JavaVM::from_raw(self.jvm.get_java_vm_pointer()) } .map_err(Error::CloneJavaVm)?; @@ -152,7 +103,9 @@ impl AndroidTunProvider { }) } - fn get_tun_fd(&self, config: VpnServiceConfig) -> Result { + fn get_tun_fd(&self) -> Result { + let config = VpnServiceConfig::new(self.config.clone()); + let env = self.env()?; let java_config = config.into_java(&env); @@ -169,16 +122,6 @@ impl AndroidTunProvider { } } - /// Open a tunnel device that routes everything but (potentially) LAN routes via the tunnel - /// device. Excluded apps will also be kept. - /// - /// Will open a new tunnel if there is already an active tunnel. The previous tunnel will be - /// closed. - pub fn create_blocking_tun(&mut self) -> Result<(), Error> { - let _ = self.get_tun_inner(TunConfig::default(), true)?; - Ok(()) - } - /// Close currently active tunnel device. pub fn close_tun(&mut self) { let result = self.call_method("closeTun", "()V", JavaType::Primitive(Primitive::Void), &[]); @@ -192,8 +135,6 @@ impl AndroidTunProvider { Err(error) => Some(error), }; - self.last_tun_config = None; - if let Some(error) = error { log::error!( "{}", @@ -202,13 +143,6 @@ impl AndroidTunProvider { } } - fn recreate_tun_if_open(&mut self) -> Result<(), Error> { - if let Some((config, blocking)) = self.last_tun_config.clone() { - let _ = self.get_tun_inner(config, blocking)?; - } - Ok(()) - } - /// Allow a socket to bypass the tunnel. pub fn bypass(&mut self, socket: RawFd) -> Result<(), Error> { let env = JnixEnv::from( @@ -288,34 +222,32 @@ struct VpnServiceConfig { } impl VpnServiceConfig { - pub fn new( - tun_config: TunConfig, - allow_lan: bool, - dns_servers: Option>, - excluded_packages: Vec, - ) -> VpnServiceConfig { - let dns_servers = Self::resolve_dns_servers(&tun_config, dns_servers); - let routes = Self::resolve_routes(&tun_config, allow_lan); + pub fn new(tun_config: TunConfig) -> VpnServiceConfig { + let dns_servers = Self::resolve_dns_servers(&tun_config); + let routes = Self::resolve_routes(&tun_config); VpnServiceConfig { addresses: tun_config.addresses, dns_servers, routes, - excluded_packages, + excluded_packages: tun_config.excluded_packages, mtu: tun_config.mtu, } } /// Return a list of custom DNS servers. If not specified, gateway addresses are used for DNS. /// Note that `Some(vec![])` is different from `None`. `Some(vec![])` disables DNS. - fn resolve_dns_servers(config: &TunConfig, custom_dns: Option>) -> Vec { - custom_dns.unwrap_or_else(|| config.gateway_ips()) + fn resolve_dns_servers(config: &TunConfig) -> Vec { + config + .dns_servers + .clone() + .unwrap_or_else(|| config.gateway_ips()) } /// Potentially subtract LAN nets from the VPN service routes, excepting gateways. /// This prevents LAN traffic from going in the tunnel. - fn resolve_routes(config: &TunConfig, allow_lan: bool) -> Vec { - if !allow_lan { + fn resolve_routes(config: &TunConfig) -> Vec { + if !config.allow_lan { return config .routes .iter() @@ -429,26 +361,6 @@ impl AsRawFd for VpnServiceTun { } } -impl Default for TunConfig { - fn default() -> Self { - // Default configuration simply intercepts all packets. The only field that matters is - // `routes`, because it determines what must enter the tunnel. All other fields contain - // stub values. - TunConfig { - addresses: vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))], - ipv4_gateway: Ipv4Addr::new(10, 64, 0, 1), - ipv6_gateway: None, - routes: vec![ - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) - .expect("Invalid IP network prefix for IPv4 address"), - IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) - .expect("Invalid IP network prefix for IPv6 address"), - ], - mtu: 1380, - } - } -} - #[derive(FromJava)] #[jnix(package = "net.mullvad.talpid.model")] enum CreateTunResult { diff --git a/talpid-tunnel/src/tun_provider/mod.rs b/talpid-tunnel/src/tun_provider/mod.rs index 66ab43c3dfae..af6285bc083a 100644 --- a/talpid-tunnel/src/tun_provider/mod.rs +++ b/talpid-tunnel/src/tun_provider/mod.rs @@ -34,6 +34,9 @@ pub struct TunConfig { /// IP addresses for the tunnel interface. pub addresses: Vec, + /// MTU of the tunnel interface. + pub mtu: u16, + /// IPv4 address of the VPN server, and the default IPv4 DNS resolver. pub ipv4_gateway: Ipv4Addr, @@ -43,8 +46,15 @@ pub struct TunConfig { /// Routes to configure for the tunnel. pub routes: Vec, - /// MTU of the tunnel interface. - pub mtu: u16, + /// Exclude private IPs from the tunnel + pub allow_lan: bool, + + /// DNS servers to use for the tunnel config. + /// Unless specified, the gateways will be used for DNS + pub dns_servers: Option>, + + /// Applications to exclude from the tunnel. + pub excluded_packages: Vec, } impl TunConfig { @@ -57,3 +67,23 @@ impl TunConfig { servers } } + +impl Default for TunConfig { + fn default() -> Self { + TunConfig { + addresses: vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))], + mtu: 1380, + ipv4_gateway: Ipv4Addr::new(10, 64, 0, 1), + ipv6_gateway: None, + routes: vec![ + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0) + .expect("Invalid IP network prefix for IPv4 address"), + IpNetwork::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0) + .expect("Invalid IP network prefix for IPv6 address"), + ], + allow_lan: false, + dns_servers: None, + excluded_packages: vec![], + } + } +} diff --git a/talpid-tunnel/src/tun_provider/unix.rs b/talpid-tunnel/src/tun_provider/unix.rs index 457f124f64b2..b713bbb2c351 100644 --- a/talpid-tunnel/src/tun_provider/unix.rs +++ b/talpid-tunnel/src/tun_provider/unix.rs @@ -25,23 +25,27 @@ pub enum Error { } /// Factory of tunnel devices on Unix systems. -pub struct UnixTunProvider; - -impl Default for UnixTunProvider { - fn default() -> Self { - Self::new() - } +pub struct UnixTunProvider { + config: TunConfig, } impl UnixTunProvider { pub fn new() -> Self { - UnixTunProvider + UnixTunProvider { + config: TunConfig::default(), + } + } + + /// Get the current tunnel config. Note that the tunnel must be recreated for any changes to + /// take effect. + pub fn config_mut(&mut self) -> &mut TunConfig { + &mut self.config } - pub fn get_tun(&mut self, config: TunConfig) -> Result { + pub fn get_tun(&mut self) -> Result { let mut tunnel_device = TunnelDevice::new().map_err(Error::CreateTunnelDevice)?; - for ip in config.addresses.iter() { + for ip in self.config.addresses.iter() { tunnel_device .set_ip(*ip) .map_err(|cause| Error::SetIpAddr(*ip, cause))?; diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 3f3ed97d6457..859abcb909d2 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -129,17 +129,20 @@ impl WgGoTunnel { let mut last_error = None; let mut tun_provider = tun_provider.lock().unwrap(); - let tunnel_config = TunConfig { + let tun_config = tun_provider.config_mut(); + + *tun_config = TunConfig { addresses: config.tunnel.addresses.clone(), ipv4_gateway: config.ipv4_gateway, ipv6_gateway: config.ipv6_gateway, routes: routes.collect(), mtu: config.mtu, + ..tun_config.clone() }; for _ in 1..=MAX_PREPARE_TUN_ATTEMPTS { let tunnel_device = tun_provider - .get_tun(tunnel_config.clone()) + .get_tun() .map_err(TunnelError::SetupTunnelDevice)?; match nix::unistd::dup(tunnel_device.as_raw_fd()) {