Skip to content

Commit

Permalink
Merge branch 'default-to-connecting-on-corrupt-state-cache-460'
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Dec 6, 2023
2 parents ac3e222 + 07dedbc commit 96d2e3a
Show file tree
Hide file tree
Showing 10 changed files with 431 additions and 150 deletions.
146 changes: 121 additions & 25 deletions mullvad-daemon/src/target_state.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use mullvad_types::states::TargetState;
use std::{
future::Future,
ops::Deref,
path::{Path, PathBuf},
};
Expand All @@ -21,48 +22,76 @@ impl PersistentTargetState {
/// Initialize using the current target state (if there is one)
pub async fn new(cache_dir: &Path) -> Self {
let cache_path = cache_dir.join(TARGET_START_STATE_FILE);
let mut update_cache = false;
let state = match fs::read_to_string(&cache_path).await {
let TargetStateInner {
state,
update_cache,
} = Self::read_target_state(&cache_path, fs::read_to_string).await;
let state = PersistentTargetState {
state,
cache_path,
locked: false,
};
if update_cache {
state.save().await;
}
state
}

/// Construct a [`TargetState`] from cache.
///
/// `read_cache` allows the caller to decide how to read from a cache of
/// [`TargetState`].
///
/// This function will always succeed, even in the presence of IO
/// operations. Errors are handled gracefully by defaulting to safe target
/// states if necessary.
async fn read_target_state<F, R>(cache: &Path, read_cache: F) -> TargetStateInner
where
F: FnOnce(PathBuf) -> R,
R: Future<Output = io::Result<String>>,
{
match read_cache(cache.to_path_buf()).await {
Ok(content) => serde_json::from_str(&content)
.map(|state| {
log::info!(
"Loaded cached target state \"{}\" from {}",
state,
cache_path.display()
cache.display()
);
state
TargetStateInner {
state,
update_cache: false,
}
})
.unwrap_or_else(|error| {
log::error!(
"{}",
error.display_chain_with_msg("Failed to parse cached target tunnel state")
);
update_cache = true;
TargetState::Secured
TargetStateInner {
state: TargetState::Secured,
update_cache: true,
}
}),

Err(error) if error.kind() == io::ErrorKind::NotFound => {
log::debug!("No cached target state to load");
TargetStateInner {
state: DEFAULT_TARGET_STATE,
update_cache: false,
}
}
Err(error) => {
if error.kind() == io::ErrorKind::NotFound {
log::debug!("No cached target state to load");
DEFAULT_TARGET_STATE
} else {
log::error!(
"{}",
error.display_chain_with_msg("Failed to read cached target tunnel state")
);
update_cache = true;
TargetState::Secured
log::error!(
"{}",
error.display_chain_with_msg("Failed to read cached target tunnel state")
);
TargetStateInner {
state: TargetState::Secured,
update_cache: true,
}
}
};
let state = PersistentTargetState {
state,
cache_path,
locked: false,
};
if update_cache {
state.save().await;
}
state
}

/// Override the current target state, if there is one
Expand Down Expand Up @@ -153,3 +182,70 @@ impl Deref for PersistentTargetState {
&self.state
}
}

/// The result of calling `read_target_state`.
struct TargetStateInner {
state: TargetState,
/// In some circumstances, the target state cache should be updated on disk
/// upon initialization a [`PersistentTargetState`]. This is signaled to the
/// constructor of [`PersistentTargetState`] by setting this value to
/// `true`.
update_cache: bool,
}

impl Deref for TargetStateInner {
type Target = TargetState;

fn deref(&self) -> &Self::Target {
&self.state
}
}

#[cfg(test)]
mod test {
use super::*;

static DUMMY_CACHE_DIR: &str = "target-state-test";

/// If no target state cache exist, the default target state is used. This
/// is the most basic check.
#[tokio::test]
async fn test_target_state_initialization_empty() {
let target_state =
PersistentTargetState::read_target_state(Path::new(DUMMY_CACHE_DIR), |_| async {
// A completely blank slate. No target state cache file has been created yet.
Err(io::ErrorKind::NotFound.into())
})
.await;
assert_eq!(*target_state, DEFAULT_TARGET_STATE);
}

/// If a target state cache exist with some target state, the state can be
/// read-back successfully.
#[tokio::test]
async fn test_target_state_initialization_existing() {
for cached_state in [TargetState::Secured, TargetState::Unsecured] {
let target_state =
PersistentTargetState::read_target_state(Path::new(DUMMY_CACHE_DIR), |_| async {
Ok(serde_json::to_string(&cached_state).unwrap())
})
.await;
assert_eq!(*target_state, cached_state);
}
}

/// The state can not be read-back successfully if the state file has become
/// corrupt. In such cases, initializing a [`PersistentTargetState`] should
/// yield a "better safe than sorry"-target state of `Secured`.
#[tokio::test]
async fn test_target_corrupt_state_cache() {
let target_state =
PersistentTargetState::read_target_state(Path::new(DUMMY_CACHE_DIR), |_| async {
// Intentionally corrupt the target state cache.
Ok("Not a valid target state".to_string())
})
.await;
// Reading back a corrupt target state cache should yield `TargetState::Secured`.
assert_eq!(*target_state, TargetState::Secured);
}
}
4 changes: 2 additions & 2 deletions test/test-manager/src/tests/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ pub async fn test_automatic_wireguard_rotation(
.pubkey;

// Stop daemon
rpc.set_mullvad_daemon_service_state(false)
rpc.stop_mullvad_daemon()
.await
.expect("Could not stop system service");

Expand All @@ -334,7 +334,7 @@ pub async fn test_automatic_wireguard_rotation(
.expect("Could not change device.json to have an old created timestamp");

// Start daemon
rpc.set_mullvad_daemon_service_state(true)
rpc.start_mullvad_daemon()
.await
.expect("Could not start system service");

Expand Down
51 changes: 51 additions & 0 deletions test/test-manager/src/tests/tunnel_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,54 @@ pub async fn test_connected_state(

Ok(())
}

/// Verify that the app defaults to the connecting state if it is started with a
/// corrupt state cache.
#[test_function]
pub async fn test_connecting_state_when_corrupted_state_cache(
_: TestContext,
rpc: ServiceClient,
mullvad_client: ManagementServiceClient,
) -> Result<(), Error> {
// The test should start in a disconnected state. Normally this would be
// preserved when restarting the app, i.e. the user would still be
// disconnected after a successfull restart. However, as we will
// intentionally corrupt the state target cache the user should end up in
// the connecting/connected state, *not in the disconnected state*, upon
// restart.

// Stopping the app should write to the state target cache.
log::info!("Stopping the app");
rpc.stop_mullvad_daemon().await?;

// Intentionally corrupt the state cache. Note that we can not simply remove
// the cache, as this will put the app in the default target state which is
// 'unsecured'.
log::info!("Figuring out where state cache resides on test runner ..");
let state_cache = rpc
.find_mullvad_app_cache_dir()
.await?
.join("target-start-state.json");
log::info!(
"Intentionally writing garbage to the state cache {file}",
file = state_cache.display()
);
rpc.write_file(state_cache, "cookie was here".into())
.await?;

// Start the app & make sure that we start in the 'connecting state'. The
// side-effect of this is that no network traffic is allowed to leak.
log::info!("Starting the app back up again");
rpc.start_mullvad_daemon().await?;
wait_for_tunnel_state(mullvad_client.clone(), |state| !state.is_disconnected())
.await
.map_err(|err| {
log::error!("App did not start in an expected state. \
App is not in either `Connecting` or `Connected` state after starting with corrupt state cache! \
There is a possibility of leaks during app startup ");
err
})?;
log::info!("App successfully recovered from a corrupt tunnel state cache.");

Ok(())
}
87 changes: 69 additions & 18 deletions test/test-rpc/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
collections::HashMap,
path::Path,
time::{Duration, SystemTime},
};

Expand Down Expand Up @@ -51,7 +52,7 @@ impl ServiceClient {
self.client.uninstall_app(ctx, env).await?
}

/// Execute a program.
/// Execute a program with additional environment-variables set.
pub async fn exec_env<
I: IntoIterator<Item = T>,
M: IntoIterator<Item = (K, T)>,
Expand Down Expand Up @@ -151,6 +152,13 @@ impl ServiceClient {
.await?
}

/// Returns path of Mullvad app cache directorie on the test runner.
pub async fn find_mullvad_app_cache_dir(&self) -> Result<PathBuf, Error> {
self.client
.get_mullvad_app_cache_dir(tarpc::context::current())
.await?
}

/// Send TCP packet
pub async fn send_tcp(
&self,
Expand Down Expand Up @@ -213,6 +221,51 @@ impl ServiceClient {
.await?
}

/// Restarts the app.
///
/// Shuts down a running app, making it disconnect from any current tunnel
/// connection before starting the app again.
///
/// # Note
/// This function will return *after* the app is running again, thus
/// blocking execution until then.
pub async fn restart_mullvad_daemon(&self) -> Result<(), Error> {
let _ = self
.client
.restart_mullvad_daemon(tarpc::context::current())
.await?;
Ok(())
}

/// Stop the app.
///
/// Shuts down a running app, making it disconnect from any current tunnel
/// connection and making it write to caches.
///
/// # Note
/// This function will return *after* the app has been stopped, thus
/// blocking execution until then.
pub async fn stop_mullvad_daemon(&self) -> Result<(), Error> {
let _ = self
.client
.stop_mullvad_daemon(tarpc::context::current())
.await?;
Ok(())
}

/// Start the app.
///
/// # Note
/// This function will return *after* the app has been started, thus
/// blocking execution until then.
pub async fn start_mullvad_daemon(&self) -> Result<(), Error> {
let _ = self
.client
.start_mullvad_daemon(tarpc::context::current())
.await?;
Ok(())
}

pub async fn set_daemon_log_level(
&self,
verbosity_level: mullvad_daemon::Verbosity,
Expand Down Expand Up @@ -247,6 +300,21 @@ impl ServiceClient {
.await?
}

pub async fn write_file(&self, dest: impl AsRef<Path>, bytes: Vec<u8>) -> Result<(), Error> {
log::debug!(
"Writing {bytes} bytes to \"{file}\"",
bytes = bytes.len(),
file = dest.as_ref().display()
);
self.client
.write_file(
tarpc::context::current(),
dest.as_ref().to_path_buf(),
bytes,
)
.await?
}

pub async fn reboot(&mut self) -> Result<(), Error> {
log::debug!("Rebooting server");

Expand All @@ -262,23 +330,6 @@ impl ServiceClient {
Ok(())
}

pub async fn set_mullvad_daemon_service_state(&self, on: bool) -> Result<(), Error> {
self.client
.set_mullvad_daemon_service_state(tarpc::context::current(), on)
.await??;

self.mullvad_daemon_wait_for_state(|state| {
if on {
state == ServiceStatus::Running
} else {
state == ServiceStatus::NotRunning
}
})
.await?;

Ok(())
}

pub async fn make_device_json_old(&self) -> Result<(), Error> {
self.client
.make_device_json_old(tarpc::context::current())
Expand Down
Loading

0 comments on commit 96d2e3a

Please sign in to comment.