Skip to content

Commit

Permalink
Rotate access method if the currently active one is updated or removed
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Sep 29, 2023
1 parent b5f6c95 commit b3fa20e
Showing 1 changed file with 105 additions and 30 deletions.
135 changes: 105 additions & 30 deletions mullvad-daemon/src/access_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ use crate::{
settings::{self, MadeChanges},
Daemon, EventListener,
};
use mullvad_types::access_method::{self, AccessMethod, AccessMethodSetting};
use mullvad_types::{
access_method::{self, AccessMethod, AccessMethodSetting},
settings::Settings,
};

#[derive(err_derive::Error, Debug)]
pub enum Error {
Expand All @@ -19,15 +22,35 @@ pub enum Error {
/// the user should do a factory reset.
#[error(display = "No access methods are configured")]
NoMethodsExist,
/// Access method could not be rotate
#[error(display = "Access method could not be rotated")]
RotationError,
/// Access methods settings error
#[error(display = "Settings error")]
Settings(#[error(source)] settings::Error),
}

/// A tiny datastructure used for signaling whether the daemon should force a
/// rotation of the currently used [`AccessMethodSetting`] or not, and if so:
/// how it should do it.
pub enum Command {
/// There is no need to force a rotation of [`AccessMethodSetting`]
Nothing,
/// Select the next available [`AccessMethodSetting`], whichever that is
Rotate,
/// Select the [`AccessMethodSetting`] with a certain [`access_method::Id`]
Set(access_method::Id),
}

impl<L> Daemon<L>
where
L: EventListener + Clone + Send + 'static,
{
/// Add a [`AccessMethod`] to the daemon's settings.
///
/// If the daemon settings are successfully updated, the
/// [`access_method::Id`] of the newly created [`AccessMethodSetting`]
/// (which has been derived from the [`AccessMethod`]) is returned.
pub async fn add_access_method(
&mut self,
name: String,
Expand All @@ -44,77 +67,119 @@ where
.map_err(Error::Settings)
}

/// Remove a [`AccessMethodSetting`] from the daemon's saved settings.
///
/// If the [`AccessMethodSetting`] which is currently in use happens to be
/// removed, the daemon should force a rotation of the active API endpoint.
pub async fn remove_access_method(
&mut self,
access_method: access_method::Id,
) -> Result<(), Error> {
// Make sure that we are not trying to remove a built-in API access
// method
match self.settings.api_access_methods.find(&access_method) {
None => return Ok(()),
let command = match self.settings.api_access_methods.find(&access_method) {
Some(api_access_method) => {
if api_access_method.is_builtin() {
return Err(Error::RemoveBuiltIn);
Err(Error::RemoveBuiltIn)
} else if api_access_method.get_id() == self.get_current_access_method()?.get_id() {
Ok(Command::Rotate)
} else {
Ok(Command::Nothing)
}
}
};
None => Ok(Command::Nothing),
}?;

self.settings
.update(|settings| settings.api_access_methods.remove(&access_method))
.await
.map(|did_change| self.notify_on_change(did_change))
.map_err(Error::Settings)
.map_err(Error::Settings)?
.process_command(command)
.await
}

/// Set a [`AccessMethodSetting`] as the current API access method.
///
/// If successful, the daemon will force a rotation of the active API access
/// method, which means that subsequent API calls will use the new
/// [`AccessMethodSetting`] to figure out the API endpoint.
pub async fn set_api_access_method(
&mut self,
access_method: access_method::Id,
) -> Result<(), Error> {
if let Some(access_method) = self.settings.api_access_methods.find(&access_method) {
{
let mut connection_modes = self.connection_modes.lock().unwrap();
connection_modes.set_access_method(access_method.clone());
}
// Force a rotation of Access Methods.

if let Err(error) = self.api_handle.service().next_api_endpoint().await {
log::error!("Failed to rotate API endpoint: {}", error);
}
Ok(())
} else {
Err(Error::NoSuchMethod(access_method))
let access_method = self
.settings
.api_access_methods
.find(&access_method)
.ok_or(Error::NoSuchMethod(access_method))?;
{
let mut connection_modes = self.connection_modes.lock().unwrap();
connection_modes.set_access_method(access_method.clone());
}
// Force a rotation of Access Methods.
//
// This is not a call to `process_command` due to the restrictions on
// recursively calling async functions.
self.force_api_endpoint_rotation().await
}

/// "Updates" an [`AccessMethodSetting`] by replacing the existing entry
/// with the argument `access_method_update` if an existing entry with
/// matching UUID is found.
/// matching [`access_method::Id`] is found.
///
/// If the currently active [`AccessMethodSetting`] is updated, the daemon
/// will automatically use this updated [`AccessMethodSetting`] when
/// performing subsequent API calls.
pub async fn update_access_method(
&mut self,
access_method_update: AccessMethodSetting,
) -> Result<(), Error> {
self.settings
.update(|settings| {
let access_methods = &mut settings.api_access_methods;
if let Some(access_method) = access_methods.find_mut(&access_method_update.get_id())
{
*access_method = access_method_update
let current = self.get_current_access_method()?;
let mut command = Command::Nothing;
let settings_update = |settings: &mut Settings| {
if let Some(access_method) = settings
.api_access_methods
.find_mut(&access_method_update.get_id())
{
*access_method = access_method_update;
if access_method.get_id() == current.get_id() {
command = Command::Set(access_method.get_id())
}
})
}
};

self.settings
.update(settings_update)
.await
.map(|did_change| self.notify_on_change(did_change))
.map_err(Error::Settings)
.map_err(Error::Settings)?
.process_command(command)
.await
}

/// Return the [`AccessMethodSetting`] which is currently used to access the
/// Mullvad API.
pub fn get_current_access_method(&mut self) -> Result<AccessMethodSetting, Error> {
pub fn get_current_access_method(&self) -> Result<AccessMethodSetting, Error> {
let connections_modes = self.connection_modes.lock().unwrap();
Ok(connections_modes.peek())
}

/// Change which [`AccessMethodSetting`] which will be used to figure out
/// the Mullvad API endpoint.
async fn force_api_endpoint_rotation(&self) -> Result<(), Error> {
self.api_handle
.service()
.next_api_endpoint()
.await
.map_err(|error| {
log::error!("Failed to rotate API endpoint: {}", error);
Error::RotationError
})
}

/// If settings were changed due to an update, notify all listeners.
fn notify_on_change(&mut self, settings_changed: MadeChanges) {
fn notify_on_change(&mut self, settings_changed: MadeChanges) -> &mut Self {
if settings_changed {
self.event_listener
.notify_settings(self.settings.to_settings());
Expand All @@ -130,5 +195,15 @@ where
.collect(),
)
};
self
}

/// The semantics of the [`Command`] datastructure.
async fn process_command(&mut self, command: Command) -> Result<(), Error> {
match command {
Command::Nothing => Ok(()),
Command::Rotate => self.force_api_endpoint_rotation().await,
Command::Set(id) => self.set_api_access_method(id).await,
}
}
}

0 comments on commit b3fa20e

Please sign in to comment.