diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c81c2e56ee1..adc4ceee3c04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ Line wrap the file at 100 chars. Th `SOCKS5`. - Add social media content blocker. - Add ability to override server IPs to the CLI. +- Add CLI support for applying patches to the settings with `mullvad import-settings`. ### Changed - Update Electron from 25.2.0 to 26.3.0. diff --git a/mullvad-cli/src/cmds/import_settings.rs b/mullvad-cli/src/cmds/import_settings.rs new file mode 100644 index 000000000000..a354eb4b95d0 --- /dev/null +++ b/mullvad-cli/src/cmds/import_settings.rs @@ -0,0 +1,101 @@ +use anyhow::{anyhow, Context, Result}; +use mullvad_management_interface::MullvadProxyClient; +use std::{ + fs::File, + io::{stdin, BufRead, BufReader}, + path::Path, +}; + +/// Maximum size of a settings patch. Bigger files/streams cause the read to fail. +const MAX_PATCH_BYTES: usize = 10 * 1024; + +/// If source is specified, read from the provided file and send it as a settings patch to the daemon. +/// Otherwise, read the patch from standard input. +pub async fn handle(source: String) -> Result<()> { + let json_blob = tokio::task::spawn_blocking(|| get_blob(source)) + .await + .unwrap()?; + + let mut rpc = MullvadProxyClient::new().await?; + rpc.apply_json_settings(json_blob) + .await + .context("Error applying patch")?; + + println!("Settings applied"); + + Ok(()) +} + +fn get_blob(source: String) -> Result { + match source.as_str() { + "-" => read_settings_from_stdin().context("Failed to read from stdin"), + _ => read_settings_from_file(source).context("Failed to read from path: {source}"), + } +} + +/// Read settings from standard input +fn read_settings_from_stdin() -> Result { + read_settings_from_reader(BufReader::new(stdin())) +} + +/// Read settings from a path +fn read_settings_from_file(path: impl AsRef) -> Result { + read_settings_from_reader(BufReader::new(File::open(path)?)) +} + +/// Read until EOF or until newline when the last pair of braces has been closed +fn read_settings_from_reader(mut reader: impl BufRead) -> Result { + let mut buf = [0u8; MAX_PATCH_BYTES]; + + let mut was_open = false; + let mut close_after_newline = false; + let mut brace_count: usize = 0; + let mut cursor_pos = 0; + + loop { + let Some(cursor) = buf.get_mut(cursor_pos..) else { + return Err(anyhow!( + "Patch too long: maximum length is {MAX_PATCH_BYTES} bytes" + )); + }; + + let prev_cursor_pos = cursor_pos; + let read_n = reader.read(cursor)?; + if read_n == 0 { + // EOF + break; + } + cursor_pos += read_n; + + let additional_bytes = &buf[prev_cursor_pos..cursor_pos]; + + if !close_after_newline { + for next in additional_bytes { + match next { + b'{' => brace_count += 1, + b'}' => { + brace_count = brace_count.checked_sub(1).with_context(|| { + // exit: too many closing braces + "syntax error: unexpected '}'" + })? + } + _ => (), + } + was_open |= brace_count > 0; + } + if brace_count == 0 && was_open { + // complete settings + close_after_newline = true; + } + } + + if close_after_newline && additional_bytes.contains(&b'\n') { + // done + break; + } + } + + Ok(std::str::from_utf8(&buf[0..cursor_pos]) + .context("settings must be utf8 encoded")? + .to_owned()) +} diff --git a/mullvad-cli/src/cmds/mod.rs b/mullvad-cli/src/cmds/mod.rs index 43d224233ee7..7944e8bdc07a 100644 --- a/mullvad-cli/src/cmds/mod.rs +++ b/mullvad-cli/src/cmds/mod.rs @@ -9,6 +9,7 @@ pub mod beta_program; pub mod bridge; pub mod custom_list; pub mod dns; +pub mod import_settings; pub mod lan; pub mod lockdown; pub mod obfuscation; diff --git a/mullvad-cli/src/main.rs b/mullvad-cli/src/main.rs index 7a09a4eebdf7..d1c518119cf6 100644 --- a/mullvad-cli/src/main.rs +++ b/mullvad-cli/src/main.rs @@ -133,6 +133,12 @@ enum Cli { /// Manage custom lists #[clap(subcommand)] CustomList(custom_list::CustomList), + + /// Apply a JSON patch + ImportSettings { + /// File to read from. If this is "-", read from standard input + file: String, + }, } #[tokio::main] @@ -160,6 +166,7 @@ async fn main() -> Result<()> { Cli::SplitTunnel(cmd) => cmd.handle().await, Cli::Status { cmd, args } => status::handle(cmd, args).await, Cli::CustomList(cmd) => cmd.handle().await, + Cli::ImportSettings { file } => import_settings::handle(file).await, #[cfg(all(unix, not(target_os = "android")))] Cli::ShellCompletions { shell, dir } => { diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d4964a8b807e..890cdfb13e04 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -349,6 +349,8 @@ pub enum DaemonCommand { /// Verify that a google play payment was successful through the API. #[cfg(target_os = "android")] VerifyPlayPurchase(ResponseTx<(), Error>, PlayPurchase), + /// Patch the settings using a blob of JSON settings + ApplyJsonSettings(ResponseTx<(), settings::patch::Error>, String), } /// All events that can happen in the daemon. Sent from various threads and exposed interfaces. @@ -1171,6 +1173,7 @@ where VerifyPlayPurchase(tx, play_purchase) => { self.on_verify_play_purchase(tx, play_purchase) } + ApplyJsonSettings(tx, blob) => self.on_apply_json_settings(tx, blob).await, } } @@ -2439,6 +2442,18 @@ where }); } + async fn on_apply_json_settings( + &mut self, + tx: ResponseTx<(), settings::patch::Error>, + blob: String, + ) { + let result = settings::patch::merge_validate_patch(&mut self.settings, &blob).await; + if result.is_ok() { + self.reconnect_tunnel(); + } + Self::oneshot_send(tx, result, "apply_json_settings response"); + } + /// Set the target state of the client. If it changed trigger the operations needed to /// progress towards that state. /// Returns a bool representing whether or not a state change was initiated. diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index e67a02117ce3..f042a923e51e 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -1,4 +1,4 @@ -use crate::{account_history, device, settings, DaemonCommand, DaemonCommandSender, EventListener}; +use crate::{account_history, device, DaemonCommand, DaemonCommandSender, EventListener}; use futures::{ channel::{mpsc, oneshot}, StreamExt, @@ -177,10 +177,8 @@ impl ManagementService for ManagementServiceImpl { let message = DaemonCommand::SetRelaySettings(tx, constraints_update); self.send_command_to_daemon(message)?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn get_relay_locations(&self, _: Request<()>) -> ServiceResult { @@ -215,10 +213,8 @@ impl ManagementService for ManagementServiceImpl { let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetBridgeSettings(tx, settings))?; - let settings_result = self.wait_for_result(rx).await?; - settings_result - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_obfuscation_settings( @@ -230,10 +226,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_obfuscation_settings({:?})", settings); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetObfuscationSettings(tx, settings))?; - let settings_result = self.wait_for_result(rx).await?; - settings_result - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_bridge_state(&self, request: Request) -> ServiceResult<()> { @@ -243,10 +237,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_bridge_state({:?})", bridge_state); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetBridgeState(tx, bridge_state))?; - let settings_result = self.wait_for_result(rx).await?; - settings_result - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } // Settings @@ -266,10 +258,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_allow_lan({})", allow_lan); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetAllowLan(tx, allow_lan))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_show_beta_releases(&self, request: Request) -> ServiceResult<()> { @@ -277,10 +267,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_show_beta_releases({})", enabled); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetShowBetaReleases(tx, enabled))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_block_when_disconnected(&self, request: Request) -> ServiceResult<()> { @@ -291,10 +279,8 @@ impl ManagementService for ManagementServiceImpl { tx, block_when_disconnected, ))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_auto_connect(&self, request: Request) -> ServiceResult<()> { @@ -302,10 +288,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_auto_connect({})", auto_connect); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetAutoConnect(tx, auto_connect))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_openvpn_mssfix(&self, request: Request) -> ServiceResult<()> { @@ -318,10 +302,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_openvpn_mssfix({:?})", mssfix); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetOpenVpnMssfix(tx, mssfix))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_wireguard_mtu(&self, request: Request) -> ServiceResult<()> { @@ -330,10 +312,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_wireguard_mtu({:?})", mtu); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetWireguardMtu(tx, mtu))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_enable_ipv6(&self, request: Request) -> ServiceResult<()> { @@ -341,10 +321,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_enable_ipv6({})", enable_ipv6); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetEnableIpv6(tx, enable_ipv6))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn set_quantum_resistant_tunnel( @@ -357,10 +335,8 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_quantum_resistant_tunnel({state:?})"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetQuantumResistantTunnel(tx, state))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } #[cfg(not(target_os = "android"))] @@ -370,10 +346,8 @@ impl ManagementService for ManagementServiceImpl { let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetDnsOptions(tx, options))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } #[cfg(target_os = "android")] @@ -390,20 +364,16 @@ impl ManagementService for ManagementServiceImpl { log::debug!("set_relay_override"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetRelayOverride(tx, relay_override))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn clear_all_relay_overrides(&self, _: Request<()>) -> ServiceResult<()> { log::debug!("clear_all_relay_overrides"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::ClearAllRelayOverrides(tx))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } // Account management @@ -571,20 +541,16 @@ impl ManagementService for ManagementServiceImpl { tx, Some(interval), ))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn reset_wireguard_rotation_interval(&self, _: Request<()>) -> ServiceResult<()> { log::debug!("reset_wireguard_rotation_interval"); let (tx, rx) = oneshot::channel(); self.send_command_to_daemon(DaemonCommand::SetWireguardRotationInterval(tx, None))?; - self.wait_for_result(rx) - .await? - .map(Response::new) - .map_err(map_settings_error) + self.wait_for_result(rx).await??; + Ok(Response::new(())) } async fn rotate_wireguard_key(&self, _: Request<()>) -> ServiceResult<()> { @@ -929,6 +895,14 @@ impl ManagementService for ManagementServiceImpl { async fn check_volumes(&self, _: Request<()>) -> ServiceResult<()> { Ok(Response::new(())) } + + async fn apply_json_settings(&self, blob: Request) -> ServiceResult<()> { + log::debug!("apply_json_settings"); + let (tx, rx) = oneshot::channel(); + self.send_command_to_daemon(DaemonCommand::ApplyJsonSettings(tx, blob.into_inner()))?; + self.wait_for_result(rx).await??; + Ok(Response::new(())) + } } impl ManagementServiceImpl { @@ -1061,7 +1035,7 @@ fn map_daemon_error(error: crate::Error) -> Status { match error { DaemonError::RestError(error) => map_rest_error(&error), - DaemonError::SettingsError(error) => map_settings_error(error), + DaemonError::SettingsError(error) => Status::from(error), DaemonError::AlreadyLoggedIn => Status::already_exists(error.to_string()), DaemonError::LoginError(error) => map_device_error(&error), DaemonError::LogoutError(error) => map_device_error(&error), @@ -1121,20 +1095,6 @@ fn map_rest_error(error: &RestError) -> Status { } } -/// Converts an instance of [`mullvad_daemon::settings::Error`] into a tonic status. -fn map_settings_error(error: settings::Error) -> Status { - match error { - settings::Error::DeleteError(..) - | settings::Error::WriteError(..) - | settings::Error::ReadError(..) => { - Status::new(Code::FailedPrecondition, error.to_string()) - } - settings::Error::SerializeError(..) | settings::Error::ParseError(..) => { - Status::new(Code::Internal, error.to_string()) - } - } -} - /// Converts an instance of [`mullvad_daemon::device::Error`] into a tonic status. fn map_device_error(error: &device::Error) -> Status { match error { diff --git a/mullvad-daemon/src/settings.rs b/mullvad-daemon/src/settings/mod.rs similarity index 95% rename from mullvad-daemon/src/settings.rs rename to mullvad-daemon/src/settings/mod.rs index f5c5e31e94ce..99f637c023a4 100644 --- a/mullvad-daemon/src/settings.rs +++ b/mullvad-daemon/src/settings/mod.rs @@ -16,6 +16,8 @@ use tokio::{ io::{self, AsyncWriteExt}, }; +pub mod patch; + const SETTINGS_FILE: &str = "settings.json"; #[derive(err_derive::Error, Debug)] @@ -38,6 +40,23 @@ pub enum Error { WriteError(String, #[error(source)] io::Error), } +/// Converts an [Error] to a management interface status +#[cfg(not(target_os = "android"))] +impl From for mullvad_management_interface::Status { + fn from(error: Error) -> mullvad_management_interface::Status { + use mullvad_management_interface::{Code, Status}; + + match error { + Error::DeleteError(..) | Error::WriteError(..) | Error::ReadError(..) => { + Status::new(Code::FailedPrecondition, error.to_string()) + } + Error::SerializeError(..) | Error::ParseError(..) => { + Status::new(Code::Internal, error.to_string()) + } + } + } +} + pub struct SettingsPersister { settings: Settings, path: PathBuf, diff --git a/mullvad-daemon/src/settings/patch.rs b/mullvad-daemon/src/settings/patch.rs new file mode 100644 index 000000000000..f006b21056cf --- /dev/null +++ b/mullvad-daemon/src/settings/patch.rs @@ -0,0 +1,481 @@ +//! This module provides functionality for updating settings using a JSON string, i.e. applying a +//! patch. It is intended to be relatively safe, preventing editing of "dangerous" settings such as +//! custom DNS. +//! +//! Patching the settings is a three-step procedure: +//! 1. Validating the input. Only a subset of settings is allowed to be edited using this method. +//! Attempting to edit prohibited or invalid settings results in an error. +//! 2. Merging the changes. When the patch has been accepted, it can be applied to the existing +//! settings. How they're merged depends on the actual setting. See [MergeStrategy]. +//! 3. Deserialize the resulting JSON back to a [Settings] instance, and, if valid, replace the +//! existing settings. +//! +//! Permitted settings and merge strategies are defined in the [PERMITTED_SUBKEYS] constant. + +use super::SettingsPersister; +use mullvad_types::settings::Settings; + +#[derive(err_derive::Error, Debug)] +#[error(no_from)] +pub enum Error { + /// Missing expected JSON object + #[error(display = "Incorrect or missing value: {}", _0)] + InvalidOrMissingValue(&'static str), + /// Unknown or prohibited key + #[error(display = "Invalid or prohibited key: {}", _0)] + UnknownOrProhibitedKey(String), + /// Failed to parse patch json + #[error(display = "Failed to parse settings patch")] + ParsePatch(#[error(source)] serde_json::Error), + /// Failed to deserialize patched settings + #[error(display = "Failed to deserialize patched settings")] + DeserializePatched(#[error(source)] serde_json::Error), + /// Failed to serialize settings + #[error(display = "Failed to serialize current settings")] + SerializeSettings(#[error(source)] serde_json::Error), + /// Recursion limit reached + #[error(display = "Maximum JSON object depth reached")] + RecursionLimit, + /// Settings error + #[error(display = "Settings error")] + Settings(#[error(source)] super::Error), +} + +/// Converts an [Error] to a management interface status +#[cfg(not(target_os = "android"))] +impl From for mullvad_management_interface::Status { + fn from(error: Error) -> mullvad_management_interface::Status { + use mullvad_management_interface::Status; + + match error { + Error::InvalidOrMissingValue(_) + | Error::UnknownOrProhibitedKey(_) + | Error::ParsePatch(_) + | Error::DeserializePatched(_) + | Error::RecursionLimit => Status::invalid_argument(error.to_string()), + Error::Settings(error) => Status::from(error), + Error::SerializeSettings(error) => Status::internal(error.to_string()), + } + } +} + +enum MergeStrategy { + /// Replace or append keys to objects, and replace everything else + Replace, + /// Call a function to combine an existing setting (which may be null) with the patch. + /// The returned value replaces the existing node. + Custom(fn(&serde_json::Value, &serde_json::Value) -> Result), +} + +// TODO: Use Default trait when `const_trait_impl`` is available. +const DEFAULT_MERGE_STRATEGY: MergeStrategy = MergeStrategy::Replace; + +struct PermittedKey { + key_type: PermittedKeyValue, + merge_strategy: MergeStrategy, +} + +impl PermittedKey { + const fn object(keys: &'static [(&'static str, PermittedKey)]) -> Self { + Self { + key_type: PermittedKeyValue::Object(keys), + merge_strategy: DEFAULT_MERGE_STRATEGY, + } + } + + const fn array(key: &'static PermittedKey) -> Self { + Self { + key_type: PermittedKeyValue::Array(key), + merge_strategy: DEFAULT_MERGE_STRATEGY, + } + } + + const fn any() -> Self { + Self { + key_type: PermittedKeyValue::Any, + merge_strategy: DEFAULT_MERGE_STRATEGY, + } + } + + const fn merge_strategy(mut self, merge_strategy: MergeStrategy) -> Self { + self.merge_strategy = merge_strategy; + self + } +} + +enum PermittedKeyValue { + /// Select subkeys that can be modified at this level + Object(&'static [(&'static str, PermittedKey)]), + /// Array that can be modified at this level + Array(&'static PermittedKey), + /// Accept any object at this level + Any, +} + +const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "relay_overrides", + PermittedKey::array(&PermittedKey::object(&[ + ("hostname", PermittedKey::any()), + ("ipv4_addr_in", PermittedKey::any()), + ("ipv6_addr_in", PermittedKey::any()), + ])) + .merge_strategy(MergeStrategy::Custom(merge_relay_overrides)), +)]); +/// Prohibit stack overflow via excessive recursion. It might be possible to forgo this when +/// tail-call optimization can be enforced? +const RECURSE_LIMIT: usize = 15; + +/// Update the settings with the supplied patch. Only settings specified in `PERMITTED_SUBKEYS` can +/// be updated. All other changes are rejected +pub async fn merge_validate_patch( + settings: &mut SettingsPersister, + json_patch: &str, +) -> Result<(), Error> { + let mut settings_value: serde_json::Value = + serde_json::to_value(settings.to_settings()).map_err(Error::SerializeSettings)?; + let patch_value: serde_json::Value = + serde_json::from_str(json_patch).map_err(Error::ParsePatch)?; + + validate_patch_value(PERMITTED_SUBKEYS, &patch_value, 0)?; + merge_patch_to_value(PERMITTED_SUBKEYS, &mut settings_value, &patch_value, 0)?; + + let new_settings: Settings = + serde_json::from_value(settings_value).map_err(Error::DeserializePatched)?; + + settings + .update(move |settings| *settings = new_settings) + .await + .map_err(Error::Settings)?; + + Ok(()) +} + +/// Replace overrides for existing values in the array if there's a matching hostname. For hostnames +/// that do not exist, just append the overrides. +fn merge_relay_overrides( + current_settings: &serde_json::Value, + patch: &serde_json::Value, +) -> Result { + if current_settings.is_null() { + return Ok(patch.to_owned()); + } + + let patch_array = patch.as_array().ok_or(Error::InvalidOrMissingValue( + "relay overrides must be array", + ))?; + let current_array = current_settings + .as_array() + .ok_or(Error::InvalidOrMissingValue( + "existing overrides should be an array", + ))?; + let mut new_array = current_array.clone(); + + for patch_override in patch_array.iter().cloned() { + let patch_obj = patch_override + .as_object() + .ok_or(Error::InvalidOrMissingValue("override entry"))?; + let patch_hostname = patch_obj + .get("hostname") + .and_then(|hostname| hostname.as_str()) + .ok_or(Error::InvalidOrMissingValue("hostname"))?; + + let existing_obj = new_array.iter_mut().find(|value| { + value + .as_object() + .and_then(|obj| obj.get("hostname")) + .map(|hostname| hostname.as_str() == Some(patch_hostname)) + .unwrap_or(false) + }); + + match existing_obj { + Some(existing_val) => { + // Replace or append to existing values + match (existing_val, patch_override) { + ( + serde_json::Value::Object(ref mut current), + serde_json::Value::Object(ref patch), + ) => { + for (k, v) in patch { + current.insert(k.to_owned(), v.to_owned()); + } + } + _ => { + return Err(Error::InvalidOrMissingValue( + "all override entries must be objects", + )); + } + } + } + None => new_array.push(patch_override), + } + } + + Ok(serde_json::Value::Array(new_array)) +} + +fn merge_patch_to_value( + permitted_key: &'static PermittedKey, + current_value: &mut serde_json::Value, + patch_value: &serde_json::Value, + recurse_level: usize, +) -> Result<(), Error> { + if recurse_level >= RECURSE_LIMIT { + return Err(Error::RecursionLimit); + } + + match permitted_key.merge_strategy { + MergeStrategy::Replace => { + match (&permitted_key.key_type, current_value, patch_value) { + // Append or replace keys to objects + ( + PermittedKeyValue::Object(sub_permitteds), + serde_json::Value::Object(ref mut current), + serde_json::Value::Object(ref patch), + ) => { + for (k, sub_patch) in patch { + let Some((_, sub_permitted)) = sub_permitteds + .iter() + .find(|(permitted_key, _)| k == permitted_key) + else { + return Err(Error::UnknownOrProhibitedKey(k.to_owned())); + }; + let sub_current = current.entry(k).or_insert(serde_json::Value::Null); + merge_patch_to_value( + sub_permitted, + sub_current, + sub_patch, + recurse_level + 1, + )?; + } + } + // Totally replace anything else + (_, current, patch) => { + *current = patch.clone(); + } + } + } + MergeStrategy::Custom(merge_function) => { + *current_value = merge_function(current_value, patch_value)?; + } + } + + Ok(()) +} + +fn validate_patch_value( + permitted_key: &'static PermittedKey, + json_value: &serde_json::Value, + recurse_level: usize, +) -> Result<(), Error> { + if recurse_level >= RECURSE_LIMIT { + return Err(Error::RecursionLimit); + } + + match permitted_key.key_type { + PermittedKeyValue::Object(subkeys) => { + let map = json_value.as_object().ok_or(Error::InvalidOrMissingValue( + "expected JSON object in patch", + ))?; + for (k, v) in map.into_iter() { + // NOTE: We're relying on the parser to shed duplicate keys here. + // As of this writing, `Map` is implemented using BTreeMap. + let Some((_, subkey)) = + subkeys.iter().find(|(permitted_key, _)| k == permitted_key) + else { + return Err(Error::UnknownOrProhibitedKey(k.to_owned())); + }; + validate_patch_value(subkey, v, recurse_level + 1)?; + } + Ok(()) + } + PermittedKeyValue::Array(subkey) => { + let values = json_value + .as_array() + .ok_or(Error::InvalidOrMissingValue("expected JSON array in patch"))?; + for v in values { + validate_patch_value(subkey, v, recurse_level + 1)?; + } + Ok(()) + } + PermittedKeyValue::Any => Ok(()), + } +} + +#[test] +fn test_permitted_value() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "key", + PermittedKey::array(&PermittedKey::object(&[("a", PermittedKey::any())])), + )]); + + let patch = r#"{"key": [ {"a": "test" } ] }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); +} + +#[test] +fn test_prohibited_value() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "key", + PermittedKey::array(&PermittedKey::object(&[("a", PermittedKey::any())])), + )]); + + let patch = r#"{"keyx": [] }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err(); + + let patch = r#"{"key": { "b": 1 } }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err(); +} + +#[test] +fn test_merge_append_to_object() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[ + ("test0", PermittedKey::any()), + ("test1", PermittedKey::any()), + ]); + + let current = r#"{ "test0": 1 }"#; + let patch = r#"{ "test1": [] }"#; + let expected = r#"{ "test0": 1, "test1": [] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); +} + +#[test] +fn test_merge_replace_in_object() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[ + ("test0", PermittedKey::any()), + ( + "test1", + PermittedKey::object(&[("a", PermittedKey::any()), ("test0", PermittedKey::any())]), + ), + ]); + + let current = r#"{ "test0": 1, "test1": { "a": 1, "test0": [] } }"#; + let patch = r#"{ "test1": { "test0": [1, 2, 3] } }"#; + let expected = r#"{ "test0": 1, "test1": { "a": 1, "test0": [1, 2, 3] } }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); +} + +#[test] +fn test_overflow() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::array(&PermittedKey::array( + &PermittedKey::array(&PermittedKey::any()), + ))), + ))), + ))), + ))), + ))), + )); + + let patch = r#"[[[[[[[[[[[[[[[[[[[[[[]]]]]]]]]]]]]]]]]]]]]]"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + + assert!(matches!( + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0), + Err(Error::RecursionLimit) + )); +} + +#[test] +fn test_patch_relay_override() { + const PERMITTED_SUBKEYS: &PermittedKey = &PermittedKey::object(&[( + "relay_overrides", + PermittedKey::array(&PermittedKey::object(&[ + ("hostname", PermittedKey::any()), + ("ipv4_addr_in", PermittedKey::any()), + ("ipv6_addr_in", PermittedKey::any()), + ])) + .merge_strategy(MergeStrategy::Custom(merge_relay_overrides)), + )]); + + // If override has no hostname, fail + // + let patch = r#"{ "relay_overrides": [ { "invalid": 0 } ] }"#; + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap_err(); + + // If there are no overrides, append new override + // + let current = r#"{ "other": 1 }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + let expected = r#"{ "other": 1, "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); + + // If there are overrides, append new override to existing list + // + let current = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "new", "ipv4_addr_in": "1.2.3.4" } ] }"#; + let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "new", "ipv4_addr_in": "1.2.3.4" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); + + // If there are overrides, replace existing overrides but keep rest + // + let current = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "test2", "ipv4_addr_in": "1.2.3.4" } ] }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "test2", "ipv4_addr_in": "0.0.0.0" }, { "hostname": "test3", "ipv4_addr_in": "192.168.1.1" } ] }"#; + let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" }, { "hostname": "test2", "ipv4_addr_in": "0.0.0.0" }, { "hostname": "test3", "ipv4_addr_in": "192.168.1.1" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); + + // For same hostname, only update specified overrides + // + let current = + r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7" } ] }"#; + let patch = r#"{ "relay_overrides": [ { "hostname": "test", "ipv6_addr_in": "::1" } ] }"#; + let expected = r#"{ "relay_overrides": [ { "hostname": "test", "ipv4_addr_in": "1.3.3.7", "ipv6_addr_in": "::1" } ] }"#; + + let mut current: serde_json::Value = serde_json::from_str(current).unwrap(); + let patch: serde_json::Value = serde_json::from_str(patch).unwrap(); + let expected: serde_json::Value = serde_json::from_str(expected).unwrap(); + + validate_patch_value(&PERMITTED_SUBKEYS, &patch, 0).unwrap(); + merge_patch_to_value(&PERMITTED_SUBKEYS, &mut current, &patch, 0).unwrap(); + + assert_eq!(current, expected); +} diff --git a/mullvad-management-interface/proto/management_interface.proto b/mullvad-management-interface/proto/management_interface.proto index 64d87d37effd..a27698f3176b 100644 --- a/mullvad-management-interface/proto/management_interface.proto +++ b/mullvad-management-interface/proto/management_interface.proto @@ -99,6 +99,9 @@ service ManagementService { // Notify the split tunnel monitor that a volume was mounted or dismounted // (Windows). rpc CheckVolumes(google.protobuf.Empty) returns (google.protobuf.Empty) {} + + // Apply a JSON blob to the settings + rpc ApplyJsonSettings(google.protobuf.StringValue) returns (google.protobuf.Empty) {} } message UUID { string value = 1; } diff --git a/mullvad-management-interface/src/client.rs b/mullvad-management-interface/src/client.rs index 140eddc08a75..c1b27fea6561 100644 --- a/mullvad-management-interface/src/client.rs +++ b/mullvad-management-interface/src/client.rs @@ -677,6 +677,11 @@ impl MullvadProxyClient { } // check_volumes + + pub async fn apply_json_settings(&mut self, blob: String) -> Result<()> { + self.0.apply_json_settings(blob).await.map_err(Error::Rpc)?; + Ok(()) + } } fn map_device_error(status: Status) -> Error {