From 48a9566815abcbce77a2c0ceca0e4519e7bcffd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Fri, 17 Nov 2023 15:37:36 +0100 Subject: [PATCH] Complete certain management interface commands when the tunnel state machine has actually handled the request --- mullvad-daemon/src/lib.rs | 48 +++++++++++-- .../tunnel_state_machine/connected_state.rs | 70 +++++++++++-------- .../tunnel_state_machine/connecting_state.rs | 27 ++++--- .../disconnected_state.rs | 10 +-- .../disconnecting_state.rs | 36 +++++++--- .../src/tunnel_state_machine/error_state.rs | 35 ++++++---- talpid-core/src/tunnel_state_machine/mod.rs | 6 +- 7 files changed, 156 insertions(+), 76 deletions(-) diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d4964a8b807e..33d8e82b046a 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -1881,9 +1881,15 @@ where .await { Ok(settings_changed) => { - Self::oneshot_send(tx, Ok(()), "set_allow_lan response"); if settings_changed { - self.send_tunnel_command(TunnelCommand::AllowLan(allow_lan)); + self.send_tunnel_command(TunnelCommand::AllowLan( + allow_lan, + oneshot_map(tx, |tx, ()| { + Self::oneshot_send(tx, Ok(()), "set_allow_lan response"); + }), + )); + } else { + Self::oneshot_send(tx, Ok(()), "set_allow_lan response"); } } Err(e) => { @@ -1928,11 +1934,15 @@ where .await { Ok(settings_changed) => { - Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response"); if settings_changed { self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected( block_when_disconnected, + oneshot_map(tx, |tx, ()| { + Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response"); + }), )); + } else { + Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response"); } } Err(e) => { @@ -2130,12 +2140,18 @@ where .await { Ok(settings_changed) => { - Self::oneshot_send(tx, Ok(()), "set_dns_options response"); if settings_changed { let settings = self.settings.to_settings(); let resolvers = dns::addresses_from_options(&settings.tunnel_options.dns_options); - self.send_tunnel_command(TunnelCommand::Dns(resolvers)); + self.send_tunnel_command(TunnelCommand::Dns( + resolvers, + oneshot_map(tx, |tx, ()| { + Self::oneshot_send(tx, Ok(()), "set_dns_options response"); + }), + )); + } else { + Self::oneshot_send(tx, Ok(()), "set_dns_options response"); } } Err(e) => { @@ -2378,7 +2394,8 @@ where && (*self.target_state == TargetState::Secured || self.settings.auto_connect) { log::debug!("Blocking firewall during shutdown since system is going down"); - self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true)); + let (tx, _rx) = oneshot::channel(); + self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx)); } self.state.shutdown(&self.tunnel_state); @@ -2390,7 +2407,8 @@ where // without causing the service to be restarted. if *self.target_state == TargetState::Secured { - self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true)); + let (tx, _rx) = oneshot::channel(); + self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx)); } self.target_state.lock(); } @@ -2539,3 +2557,19 @@ fn new_selector_config(settings: &Settings) -> SelectorConfig { relay_overrides: settings.relay_overrides.clone(), } } + +/// Consume a oneshot sender of `T1` and return a sender that takes a different type `T2`. `forwarder` should map `T1` back to `T2` and +/// send the result back to the original receiver. +fn oneshot_map( + tx: oneshot::Sender, + forwarder: impl Fn(oneshot::Sender, T2) + Send + 'static, +) -> oneshot::Sender { + let (new_tx, new_rx) = oneshot::channel(); + tokio::spawn(async move { + match new_rx.await { + Ok(result) => forwarder(tx, result), + Err(oneshot::Canceled) => (), + } + }); + new_tx +} diff --git a/talpid-core/src/tunnel_state_machine/connected_state.rs b/talpid-core/src/tunnel_state_machine/connected_state.rs index 687e941aa729..21375f056ba2 100644 --- a/talpid-core/src/tunnel_state_machine/connected_state.rs +++ b/talpid-core/src/tunnel_state_machine/connected_state.rs @@ -213,8 +213,8 @@ impl ConnectedState { use self::EventConsequence::*; match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { - if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { + let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } else { match self.set_firewall_policy(shared_values) { @@ -230,43 +230,55 @@ impl ConnectedState { AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), ), } - } + }; + let _ = complete_tx.send(()); + consequence } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { shared_values.allowed_endpoint = endpoint; let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) { - Ok(true) => { - if let Err(error) = self.set_firewall_policy(shared_values) { - return self.disconnect( - shared_values, - AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)), - ); - } - - match self.set_dns(shared_values) { - #[cfg(target_os = "android")] - Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - #[cfg(not(target_os = "android"))] - Ok(()) => SameState(self), - Err(error) => { - log::error!("{}", error.display_chain_with_msg("Failed to set DNS")); - self.disconnect( + Some(TunnelCommand::Dns(servers, complete_tx)) => { + let consequence = match shared_values.set_dns_servers(servers) { + Ok(true) => { + if let Err(error) = self.set_firewall_policy(shared_values) { + return self.disconnect( shared_values, - AfterDisconnect::Block(ErrorStateCause::SetDnsError), - ) + AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError( + error, + )), + ); + } + + match self.set_dns(shared_values) { + #[cfg(target_os = "android")] + Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), + #[cfg(not(target_os = "android"))] + Ok(()) => SameState(self), + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to set DNS") + ); + self.disconnect( + shared_values, + AfterDisconnect::Block(ErrorStateCause::SetDnsError), + ) + } } } - } - Ok(false) => SameState(self), - Err(error_cause) => { - self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) - } - }, - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Ok(false) => SameState(self), + Err(error_cause) => { + self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) + } + }; + let _ = complete_tx.send(()); + consequence + } + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 2a728513ff88..d7d93da9d43e 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -392,12 +392,14 @@ impl ConnectingState { use self::EventConsequence::*; match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { - if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { + let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) { self.disconnect(shared_values, AfterDisconnect::Block(error_cause)) } else { self.reset_firewall(shared_values) - } + }; + let _ = complete_tx.send(()); + consequence } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { if shared_values.allowed_endpoint != endpoint { @@ -418,14 +420,19 @@ impl ConnectingState { let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) { - #[cfg(target_os = "android")] - Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), - Ok(_) => SameState(self), - Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)), - }, - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { + let consequence = match shared_values.set_dns_servers(servers) { + #[cfg(target_os = "android")] + Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)), + Ok(_) => SameState(self), + Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)), + }; + let _ = complete_tx.send(()); + consequence + } + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 5a2cf6fc4d2f..d46f06e782f7 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -128,7 +128,7 @@ impl TunnelState for DisconnectedState { use self::EventConsequence::*; match runtime.block_on(commands.next()) { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { if shared_values.allow_lan != allow_lan { // The only platform that can fail is Android, but Android doesn't support the // "block when disconnected" option, so the following call never fails. @@ -138,6 +138,7 @@ impl TunnelState for DisconnectedState { Self::set_firewall_policy(shared_values, false); } + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -148,15 +149,15 @@ impl TunnelState for DisconnectedState { let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { // Same situation as allow LAN above. shared_values .set_dns_servers(servers) .expect("Failed to reconnect after changing custom DNS servers"); - + let _ = complete_tx.send(()); SameState(self) } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { if shared_values.block_when_disconnected != block_when_disconnected { shared_values.block_when_disconnected = block_when_disconnected; Self::set_firewall_policy(shared_values, true); @@ -178,6 +179,7 @@ impl TunnelState for DisconnectedState { Self::reset_dns(shared_values); } } + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs index 08248fbac2d8..185d2f7d0afc 100644 --- a/talpid-core/src/tunnel_state_machine/disconnecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnecting_state.rs @@ -40,8 +40,9 @@ impl DisconnectingState { self.after_disconnect = match after_disconnect { AfterDisconnect::Nothing => match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { let _ = shared_values.set_allow_lan(allow_lan); + let _ = complete_tx.send(()); AfterDisconnect::Nothing } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -49,12 +50,17 @@ impl DisconnectingState { let _ = tx.send(()); AfterDisconnect::Nothing } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { let _ = shared_values.set_dns_servers(servers); + let _ = complete_tx.send(()); AfterDisconnect::Nothing } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected( + block_when_disconnected, + complete_tx, + )) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); AfterDisconnect::Nothing } Some(TunnelCommand::IsOffline(is_offline)) => { @@ -76,8 +82,9 @@ impl DisconnectingState { } }, AfterDisconnect::Block(reason) => match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { let _ = shared_values.set_allow_lan(allow_lan); + let _ = complete_tx.send(()); AfterDisconnect::Block(reason) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -85,12 +92,17 @@ impl DisconnectingState { let _ = tx.send(()); AfterDisconnect::Block(reason) } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { let _ = shared_values.set_dns_servers(servers); + let _ = complete_tx.send(()); AfterDisconnect::Block(reason) } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected( + block_when_disconnected, + complete_tx, + )) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); AfterDisconnect::Block(reason) } Some(TunnelCommand::IsOffline(is_offline)) => { @@ -117,8 +129,9 @@ impl DisconnectingState { None => AfterDisconnect::Block(reason), }, AfterDisconnect::Reconnect(retry_attempt) => match command { - Some(TunnelCommand::AllowLan(allow_lan)) => { + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { let _ = shared_values.set_allow_lan(allow_lan); + let _ = complete_tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { @@ -126,12 +139,17 @@ impl DisconnectingState { let _ = tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } - Some(TunnelCommand::Dns(servers)) => { + Some(TunnelCommand::Dns(servers, complete_tx)) => { let _ = shared_values.set_dns_servers(servers); + let _ = complete_tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected( + block_when_disconnected, + complete_tx, + )) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); AfterDisconnect::Reconnect(retry_attempt) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 11a805f7dc0f..2f82cb4cf567 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -138,13 +138,16 @@ impl TunnelState for ErrorState { use self::EventConsequence::*; match runtime.block_on(commands.next()) { - Some(TunnelCommand::AllowLan(allow_lan)) => { - if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { - NewState(Self::enter(shared_values, error_state_cause)) - } else { - let _ = Self::set_firewall_policy(shared_values); - SameState(self) - } + Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => { + let consequence = + if let Err(error_state_cause) = shared_values.set_allow_lan(allow_lan) { + NewState(Self::enter(shared_values, error_state_cause)) + } else { + let _ = Self::set_firewall_policy(shared_values); + SameState(self) + }; + let _ = complete_tx.send(()); + consequence } Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => { if shared_values.allowed_endpoint != endpoint { @@ -163,15 +166,19 @@ impl TunnelState for ErrorState { let _ = tx.send(()); SameState(self) } - Some(TunnelCommand::Dns(servers)) => { - if let Err(error_state_cause) = shared_values.set_dns_servers(servers) { - NewState(Self::enter(shared_values, error_state_cause)) - } else { - SameState(self) - } + Some(TunnelCommand::Dns(servers, complete_tx)) => { + let consequence = + if let Err(error_state_cause) = shared_values.set_dns_servers(servers) { + NewState(Self::enter(shared_values, error_state_cause)) + } else { + SameState(self) + }; + let _ = complete_tx.send(()); + consequence } - Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => { + Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => { shared_values.block_when_disconnected = block_when_disconnected; + let _ = complete_tx.send(()); SameState(self) } Some(TunnelCommand::IsOffline(is_offline)) => { diff --git a/talpid-core/src/tunnel_state_machine/mod.rs b/talpid-core/src/tunnel_state_machine/mod.rs index 12bc4cfc86fc..5957b2f73152 100644 --- a/talpid-core/src/tunnel_state_machine/mod.rs +++ b/talpid-core/src/tunnel_state_machine/mod.rs @@ -189,15 +189,15 @@ pub async fn spawn( /// Representation of external commands for the tunnel state machine. pub enum TunnelCommand { /// Enable or disable LAN access in the firewall. - AllowLan(bool), + AllowLan(bool, oneshot::Sender<()>), /// Endpoint that should never be blocked. `()` is sent to the /// channel after attempting to set the firewall policy, regardless /// of whether it succeeded. AllowEndpoint(AllowedEndpoint, oneshot::Sender<()>), /// Set DNS servers to use. - Dns(Option>), + Dns(Option>, oneshot::Sender<()>), /// Enable or disable the block_when_disconnected feature. - BlockWhenDisconnected(bool), + BlockWhenDisconnected(bool, oneshot::Sender<()>), /// Notify the state machine of the connectivity of the device. IsOffline(bool), /// Open tunnel connection.