diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index 36b3c22bb8b6..013cf9d7ce51 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -2,7 +2,10 @@ use crate::{ settings::{self, MadeChanges}, Daemon, EventListener, }; -use mullvad_types::access_method::{self, AccessMethod, AccessMethodSetting}; +use mullvad_types::{ + access_method::{self, AccessMethod, AccessMethodSetting}, + settings::Settings, +}; #[derive(err_derive::Error, Debug)] pub enum Error { @@ -19,15 +22,35 @@ pub enum Error { /// the user should do a factory reset. #[error(display = "No access methods are configured")] NoMethodsExist, + /// Access method could not be rotate + #[error(display = "Access method could not be rotated")] + RotationError, /// Access methods settings error #[error(display = "Settings error")] Settings(#[error(source)] settings::Error), } +/// A tiny datastructure used for signaling whether the daemon should force a +/// rotation of the currently used [`AccessMethodSetting`] or not, and if so: +/// how it should do it. +pub enum Command { + /// There is no need to force a rotation of [`AccessMethodSetting`] + Nothing, + /// Select the next available [`AccessMethodSetting`], whichever that is + Rotate, + /// Select the [`AccessMethodSetting`] with a certain [`access_method::Id`] + Set(access_method::Id), +} + impl Daemon where L: EventListener + Clone + Send + 'static, { + /// Add a [`AccessMethod`] to the daemon's settings. + /// + /// If the daemon settings are successfully updated, the + /// [`access_method::Id`] of the newly created [`AccessMethodSetting`] + /// (which has been derived from the [`AccessMethod`]) is returned. pub async fn add_access_method( &mut self, name: String, @@ -44,77 +67,119 @@ where .map_err(Error::Settings) } + /// Remove a [`AccessMethodSetting`] from the daemon's saved settings. + /// + /// If the [`AccessMethodSetting`] which is currently in use happens to be + /// removed, the daemon should force a rotation of the active API endpoint. pub async fn remove_access_method( &mut self, access_method: access_method::Id, ) -> Result<(), Error> { // Make sure that we are not trying to remove a built-in API access // method - match self.settings.api_access_methods.find(&access_method) { - None => return Ok(()), + let command = match self.settings.api_access_methods.find(&access_method) { Some(api_access_method) => { if api_access_method.is_builtin() { - return Err(Error::RemoveBuiltIn); + Err(Error::RemoveBuiltIn) + } else if api_access_method.get_id() == self.get_current_access_method()?.get_id() { + Ok(Command::Rotate) + } else { + Ok(Command::Nothing) } } - }; + None => Ok(Command::Nothing), + }?; self.settings .update(|settings| settings.api_access_methods.remove(&access_method)) .await .map(|did_change| self.notify_on_change(did_change)) - .map_err(Error::Settings) + .map_err(Error::Settings)? + .process_command(command) + .await } + /// Set a [`AccessMethodSetting`] as the current API access method. + /// + /// If successful, the daemon will force a rotation of the active API access + /// method, which means that subsequent API calls will use the new + /// [`AccessMethodSetting`] to figure out the API endpoint. pub async fn set_api_access_method( &mut self, access_method: access_method::Id, ) -> Result<(), Error> { - if let Some(access_method) = self.settings.api_access_methods.find(&access_method) { - { - let mut connection_modes = self.connection_modes.lock().unwrap(); - connection_modes.set_access_method(access_method.clone()); - } - // Force a rotation of Access Methods. - - if let Err(error) = self.api_handle.service().next_api_endpoint().await { - log::error!("Failed to rotate API endpoint: {}", error); - } - Ok(()) - } else { - Err(Error::NoSuchMethod(access_method)) + let access_method = self + .settings + .api_access_methods + .find(&access_method) + .ok_or(Error::NoSuchMethod(access_method))?; + { + let mut connection_modes = self.connection_modes.lock().unwrap(); + connection_modes.set_access_method(access_method.clone()); } + // Force a rotation of Access Methods. + // + // This is not a call to `process_command` due to the restrictions on + // recursively calling async functions. + self.force_api_endpoint_rotation().await } /// "Updates" an [`AccessMethodSetting`] by replacing the existing entry /// with the argument `access_method_update` if an existing entry with - /// matching UUID is found. + /// matching [`access_method::Id`] is found. + /// + /// If the currently active [`AccessMethodSetting`] is updated, the daemon + /// will automatically use this updated [`AccessMethodSetting`] when + /// performing subsequent API calls. pub async fn update_access_method( &mut self, access_method_update: AccessMethodSetting, ) -> Result<(), Error> { - self.settings - .update(|settings| { - let access_methods = &mut settings.api_access_methods; - if let Some(access_method) = access_methods.find_mut(&access_method_update.get_id()) - { - *access_method = access_method_update + let current = self.get_current_access_method()?; + let mut command = Command::Nothing; + let settings_update = |settings: &mut Settings| { + if let Some(access_method) = settings + .api_access_methods + .find_mut(&access_method_update.get_id()) + { + *access_method = access_method_update; + if access_method.get_id() == current.get_id() { + command = Command::Set(access_method.get_id()) } - }) + } + }; + + self.settings + .update(settings_update) .await .map(|did_change| self.notify_on_change(did_change)) - .map_err(Error::Settings) + .map_err(Error::Settings)? + .process_command(command) + .await } /// Return the [`AccessMethodSetting`] which is currently used to access the /// Mullvad API. - pub fn get_current_access_method(&mut self) -> Result { + pub fn get_current_access_method(&self) -> Result { let connections_modes = self.connection_modes.lock().unwrap(); Ok(connections_modes.peek()) } + /// Change which [`AccessMethodSetting`] which will be used to figure out + /// the Mullvad API endpoint. + async fn force_api_endpoint_rotation(&self) -> Result<(), Error> { + self.api_handle + .service() + .next_api_endpoint() + .await + .map_err(|error| { + log::error!("Failed to rotate API endpoint: {}", error); + Error::RotationError + }) + } + /// If settings were changed due to an update, notify all listeners. - fn notify_on_change(&mut self, settings_changed: MadeChanges) { + fn notify_on_change(&mut self, settings_changed: MadeChanges) -> &mut Self { if settings_changed { self.event_listener .notify_settings(self.settings.to_settings()); @@ -130,5 +195,15 @@ where .collect(), ) }; + self + } + + /// The semantics of the [`Command`] datastructure. + async fn process_command(&mut self, command: Command) -> Result<(), Error> { + match command { + Command::Nothing => Ok(()), + Command::Rotate => self.force_api_endpoint_rotation().await, + Command::Set(id) => self.set_api_access_method(id).await, + } } }