Skip to content

Commit

Permalink
Complete certain management interface commands when the tunnel state …
Browse files Browse the repository at this point in the history
…machine has actually handled the request
  • Loading branch information
dlon committed Nov 17, 2023
1 parent c54646e commit 48a9566
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 76 deletions.
48 changes: 41 additions & 7 deletions mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1881,9 +1881,15 @@ where
.await
{
Ok(settings_changed) => {
Self::oneshot_send(tx, Ok(()), "set_allow_lan response");
if settings_changed {
self.send_tunnel_command(TunnelCommand::AllowLan(allow_lan));
self.send_tunnel_command(TunnelCommand::AllowLan(
allow_lan,
oneshot_map(tx, |tx, ()| {
Self::oneshot_send(tx, Ok(()), "set_allow_lan response");
}),
));
} else {
Self::oneshot_send(tx, Ok(()), "set_allow_lan response");
}
}
Err(e) => {
Expand Down Expand Up @@ -1928,11 +1934,15 @@ where
.await
{
Ok(settings_changed) => {
Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
if settings_changed {
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(
block_when_disconnected,
oneshot_map(tx, |tx, ()| {
Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
}),
));
} else {
Self::oneshot_send(tx, Ok(()), "set_block_when_disconnected response");
}
}
Err(e) => {
Expand Down Expand Up @@ -2130,12 +2140,18 @@ where
.await
{
Ok(settings_changed) => {
Self::oneshot_send(tx, Ok(()), "set_dns_options response");
if settings_changed {
let settings = self.settings.to_settings();
let resolvers =
dns::addresses_from_options(&settings.tunnel_options.dns_options);
self.send_tunnel_command(TunnelCommand::Dns(resolvers));
self.send_tunnel_command(TunnelCommand::Dns(
resolvers,
oneshot_map(tx, |tx, ()| {
Self::oneshot_send(tx, Ok(()), "set_dns_options response");
}),
));
} else {
Self::oneshot_send(tx, Ok(()), "set_dns_options response");
}
}
Err(e) => {
Expand Down Expand Up @@ -2378,7 +2394,8 @@ where
&& (*self.target_state == TargetState::Secured || self.settings.auto_connect)
{
log::debug!("Blocking firewall during shutdown since system is going down");
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true));
let (tx, _rx) = oneshot::channel();
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx));
}

self.state.shutdown(&self.tunnel_state);
Expand All @@ -2390,7 +2407,8 @@ where
// without causing the service to be restarted.

if *self.target_state == TargetState::Secured {
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true));
let (tx, _rx) = oneshot::channel();
self.send_tunnel_command(TunnelCommand::BlockWhenDisconnected(true, tx));
}
self.target_state.lock();
}
Expand Down Expand Up @@ -2539,3 +2557,19 @@ fn new_selector_config(settings: &Settings) -> SelectorConfig {
relay_overrides: settings.relay_overrides.clone(),
}
}

/// Consume a oneshot sender of `T1` and return a sender that takes a different type `T2`. `forwarder` should map `T1` back to `T2` and
/// send the result back to the original receiver.
fn oneshot_map<T1: Send + 'static, T2: Send + 'static>(
tx: oneshot::Sender<T1>,
forwarder: impl Fn(oneshot::Sender<T1>, T2) + Send + 'static,
) -> oneshot::Sender<T2> {
let (new_tx, new_rx) = oneshot::channel();
tokio::spawn(async move {
match new_rx.await {
Ok(result) => forwarder(tx, result),
Err(oneshot::Canceled) => (),
}
});
new_tx
}
70 changes: 41 additions & 29 deletions talpid-core/src/tunnel_state_machine/connected_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ impl ConnectedState {
use self::EventConsequence::*;

match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
match self.set_firewall_policy(shared_values) {
Expand All @@ -230,43 +230,55 @@ impl ConnectedState {
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
),
}
}
};
let _ = complete_tx.send(());
consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
SameState(self)
}
Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
Ok(true) => {
if let Err(error) = self.set_firewall_policy(shared_values) {
return self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(error)),
);
}

match self.set_dns(shared_values) {
#[cfg(target_os = "android")]
Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
#[cfg(not(target_os = "android"))]
Ok(()) => SameState(self),
Err(error) => {
log::error!("{}", error.display_chain_with_msg("Failed to set DNS"));
self.disconnect(
Some(TunnelCommand::Dns(servers, complete_tx)) => {
let consequence = match shared_values.set_dns_servers(servers) {
Ok(true) => {
if let Err(error) = self.set_firewall_policy(shared_values) {
return self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SetDnsError),
)
AfterDisconnect::Block(ErrorStateCause::SetFirewallPolicyError(
error,
)),
);
}

match self.set_dns(shared_values) {
#[cfg(target_os = "android")]
Ok(()) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
#[cfg(not(target_os = "android"))]
Ok(()) => SameState(self),
Err(error) => {
log::error!(
"{}",
error.display_chain_with_msg("Failed to set DNS")
);
self.disconnect(
shared_values,
AfterDisconnect::Block(ErrorStateCause::SetDnsError),
)
}
}
}
}
Ok(false) => SameState(self),
Err(error_cause) => {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
}
},
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
Ok(false) => SameState(self),
Err(error_cause) => {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
}
};
let _ = complete_tx.send(());
consequence
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
Expand Down
27 changes: 17 additions & 10 deletions talpid-core/src/tunnel_state_machine/connecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,14 @@ impl ConnectingState {
use self::EventConsequence::*;

match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let consequence = if let Err(error_cause) = shared_values.set_allow_lan(allow_lan) {
self.disconnect(shared_values, AfterDisconnect::Block(error_cause))
} else {
self.reset_firewall(shared_values)
}
};
let _ = complete_tx.send(());
consequence
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
if shared_values.allowed_endpoint != endpoint {
Expand All @@ -418,14 +420,19 @@ impl ConnectingState {
let _ = tx.send(());
SameState(self)
}
Some(TunnelCommand::Dns(servers)) => match shared_values.set_dns_servers(servers) {
#[cfg(target_os = "android")]
Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
Ok(_) => SameState(self),
Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
},
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
Some(TunnelCommand::Dns(servers, complete_tx)) => {
let consequence = match shared_values.set_dns_servers(servers) {
#[cfg(target_os = "android")]
Ok(true) => self.disconnect(shared_values, AfterDisconnect::Reconnect(0)),
Ok(_) => SameState(self),
Err(cause) => self.disconnect(shared_values, AfterDisconnect::Block(cause)),
};
let _ = complete_tx.send(());
consequence
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
shared_values.block_when_disconnected = block_when_disconnected;
let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
Expand Down
10 changes: 6 additions & 4 deletions talpid-core/src/tunnel_state_machine/disconnected_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl TunnelState for DisconnectedState {
use self::EventConsequence::*;

match runtime.block_on(commands.next()) {
Some(TunnelCommand::AllowLan(allow_lan)) => {
Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
if shared_values.allow_lan != allow_lan {
// The only platform that can fail is Android, but Android doesn't support the
// "block when disconnected" option, so the following call never fails.
Expand All @@ -138,6 +138,7 @@ impl TunnelState for DisconnectedState {

Self::set_firewall_policy(shared_values, false);
}
let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
Expand All @@ -148,15 +149,15 @@ impl TunnelState for DisconnectedState {
let _ = tx.send(());
SameState(self)
}
Some(TunnelCommand::Dns(servers)) => {
Some(TunnelCommand::Dns(servers, complete_tx)) => {
// Same situation as allow LAN above.
shared_values
.set_dns_servers(servers)
.expect("Failed to reconnect after changing custom DNS servers");

let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected, complete_tx)) => {
if shared_values.block_when_disconnected != block_when_disconnected {
shared_values.block_when_disconnected = block_when_disconnected;
Self::set_firewall_policy(shared_values, true);
Expand All @@ -178,6 +179,7 @@ impl TunnelState for DisconnectedState {
Self::reset_dns(shared_values);
}
}
let _ = complete_tx.send(());
SameState(self)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
Expand Down
36 changes: 27 additions & 9 deletions talpid-core/src/tunnel_state_machine/disconnecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,27 @@ impl DisconnectingState {

self.after_disconnect = match after_disconnect {
AfterDisconnect::Nothing => match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::Dns(servers)) => {
Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
Some(TunnelCommand::BlockWhenDisconnected(
block_when_disconnected,
complete_tx,
)) => {
shared_values.block_when_disconnected = block_when_disconnected;
let _ = complete_tx.send(());
AfterDisconnect::Nothing
}
Some(TunnelCommand::IsOffline(is_offline)) => {
Expand All @@ -76,21 +82,27 @@ impl DisconnectingState {
}
},
AfterDisconnect::Block(reason) => match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::Dns(servers)) => {
Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
Some(TunnelCommand::BlockWhenDisconnected(
block_when_disconnected,
complete_tx,
)) => {
shared_values.block_when_disconnected = block_when_disconnected;
let _ = complete_tx.send(());
AfterDisconnect::Block(reason)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
Expand All @@ -117,21 +129,27 @@ impl DisconnectingState {
None => AfterDisconnect::Block(reason),
},
AfterDisconnect::Reconnect(retry_attempt) => match command {
Some(TunnelCommand::AllowLan(allow_lan)) => {
Some(TunnelCommand::AllowLan(allow_lan, complete_tx)) => {
let _ = shared_values.set_allow_lan(allow_lan);
let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::AllowEndpoint(endpoint, tx)) => {
shared_values.allowed_endpoint = endpoint;
let _ = tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::Dns(servers)) => {
Some(TunnelCommand::Dns(servers, complete_tx)) => {
let _ = shared_values.set_dns_servers(servers);
let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::BlockWhenDisconnected(block_when_disconnected)) => {
Some(TunnelCommand::BlockWhenDisconnected(
block_when_disconnected,
complete_tx,
)) => {
shared_values.block_when_disconnected = block_when_disconnected;
let _ = complete_tx.send(());
AfterDisconnect::Reconnect(retry_attempt)
}
Some(TunnelCommand::IsOffline(is_offline)) => {
Expand Down
Loading

0 comments on commit 48a9566

Please sign in to comment.