From ae6a4a1d9fba88aa6b3a23f6fe1252bf474e99eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Sun, 17 Sep 2023 15:54:18 +0200 Subject: [PATCH 1/3] Define retry strategy constants --- mullvad-daemon/src/device/service.rs | 41 ++++++++------------------- mullvad-relay-selector/src/updater.rs | 12 ++++---- talpid-core/src/future_retry.rs | 27 ++++++++++++------ 3 files changed, 35 insertions(+), 45 deletions(-) diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs index c1aba7b74823..74dc15b94f56 100644 --- a/mullvad-daemon/src/device/service.rs +++ b/mullvad-daemon/src/device/service.rs @@ -21,9 +21,10 @@ use talpid_core::future_retry::{ const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; const RETRY_ACTION_MAX_RETRIES: usize = 2; -const RETRY_BACKOFF_INTERVAL_INITIAL: Duration = Duration::from_secs(4); -const RETRY_BACKOFF_INTERVAL_FACTOR: u32 = 5; -const RETRY_BACKOFF_INTERVAL_MAX: Duration = Duration::from_secs(24 * 60 * 60); +const RETRY_BACKOFF_STRATEGY: Jittered = Jittered::jitter( + ExponentialBackoff::new(Duration::from_secs(4), 5) + .max_delay(Some(Duration::from_secs(24 * 60 * 60))), +); #[derive(Clone)] pub struct DeviceService { @@ -87,7 +88,7 @@ impl DeviceService { let (device, addresses) = retry_future( move || api_handle.when_online(proxy.create(token_copy.clone(), pubkey.clone())), should_retry_backoff, - retry_strategy(), + RETRY_BACKOFF_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -141,18 +142,12 @@ impl DeviceService { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - let retry_strategy = Jittered::jitter( - ExponentialBackoff::new( - RETRY_BACKOFF_INTERVAL_INITIAL, - RETRY_BACKOFF_INTERVAL_FACTOR, - ), // Not setting a maximum interval - ); - retry_future( // NOTE: Not honoring "paused" state, because the account may have no time on it. move || api_handle.when_online(proxy.remove(token.clone(), device.clone())), should_retry_backoff, - retry_strategy, + // Not setting a maximum interval + RETRY_BACKOFF_STRATEGY.clone().max_delay(None), ) .await .map_err(map_rest_error)?; @@ -197,6 +192,8 @@ impl DeviceService { let api_handle = self.api_availability.clone(); let pubkey = private_key.public_key(); + let rotate_retry_strategy = std::iter::repeat(Duration::from_secs(24 * 60 * 60)); + let addresses = retry_future( move || { api_handle.when_bg_resumes(proxy.replace_wg_key( @@ -206,7 +203,7 @@ impl DeviceService { )) }, should_retry_backoff, - rotate_retry_strategy(), + rotate_retry_strategy, ) .await .map_err(map_rest_error)?; @@ -241,7 +238,7 @@ impl DeviceService { retry_future( move || api_handle.when_online(proxy.list(token.clone())), should_retry_backoff, - retry_strategy(), + RETRY_BACKOFF_STRATEGY, ) .await .map_err(map_rest_error) @@ -361,7 +358,7 @@ pub fn spawn_account_service( async move { handle_expiry_result_inner(&expiry_fut.await, &api_availability_copy) } }; let should_retry = move |state_was_updated: &bool| -> bool { !*state_was_updated }; - retry_future(future_generator, should_retry, retry_strategy()).await; + retry_future(future_generator, should_retry, RETRY_BACKOFF_STRATEGY).await; }); tokio::spawn(future); @@ -433,17 +430,3 @@ fn map_rest_error(error: rest::Error) -> Error { error => Error::OtherRestError(error), } } - -fn retry_strategy() -> Jittered { - Jittered::jitter( - ExponentialBackoff::new( - RETRY_BACKOFF_INTERVAL_INITIAL, - RETRY_BACKOFF_INTERVAL_FACTOR, - ) - .max_delay(RETRY_BACKOFF_INTERVAL_MAX), - ) -} - -fn rotate_retry_strategy() -> impl Iterator { - std::iter::repeat(RETRY_BACKOFF_INTERVAL_MAX) -} diff --git a/mullvad-relay-selector/src/updater.rs b/mullvad-relay-selector/src/updater.rs index 31299eea14aa..6600efb02f42 100644 --- a/mullvad-relay-selector/src/updater.rs +++ b/mullvad-relay-selector/src/updater.rs @@ -23,8 +23,10 @@ const UPDATE_CHECK_INTERVAL: Duration = Duration::from_secs(60 * 15); /// How old the cached relays need to be to trigger an update const UPDATE_INTERVAL: Duration = Duration::from_secs(60 * 60); -const EXPONENTIAL_BACKOFF_INITIAL: Duration = Duration::from_secs(16); -const EXPONENTIAL_BACKOFF_FACTOR: u32 = 8; +const DOWNLOAD_RETRY_STRATEGY: Jittered = Jittered::jitter( + ExponentialBackoff::new(Duration::from_secs(16), 8) + .max_delay(Some(Duration::from_secs(2 * 60 * 60))), +); #[derive(Clone)] pub struct RelayListUpdaterHandle { @@ -161,14 +163,10 @@ impl RelayListUpdater { } }; - let exponential_backoff = - ExponentialBackoff::new(EXPONENTIAL_BACKOFF_INITIAL, EXPONENTIAL_BACKOFF_FACTOR) - .max_delay(UPDATE_INTERVAL * 2); - retry_future( download_futures, |result| result.is_err(), - Jittered::jitter(exponential_backoff), + DOWNLOAD_RETRY_STRATEGY, ) } diff --git a/talpid-core/src/future_retry.rs b/talpid-core/src/future_retry.rs index 604e4513e139..f5dc7a8f72ec 100644 --- a/talpid-core/src/future_retry.rs +++ b/talpid-core/src/future_retry.rs @@ -1,5 +1,5 @@ use rand::{distributions::OpenClosed01, Rng}; -use std::{future::Future, time::Duration}; +use std::{future::Future, ops::Deref, time::Duration}; use talpid_time::sleep; /// Convenience function that works like [`retry_future`] but limits the number @@ -50,6 +50,7 @@ pub fn constant_interval(interval: Duration) -> impl Iterator { } /// Provides an exponential back-off timer to delay the next retry of a failed operation. +#[derive(Clone)] pub struct ExponentialBackoff { next: Duration, factor: u32, @@ -61,7 +62,7 @@ impl ExponentialBackoff { /// /// All else staying the same, the first delay will be `initial` long, the second /// one will be `initial * factor`, third `initial * factor^2` and so on. - pub fn new(initial: Duration, factor: u32) -> ExponentialBackoff { + pub const fn new(initial: Duration, factor: u32) -> ExponentialBackoff { ExponentialBackoff { next: initial, factor, @@ -71,8 +72,8 @@ impl ExponentialBackoff { /// Set the maximum delay. By default, there is no maximum value set. The limit is /// `Duration::MAX`. - pub fn max_delay(mut self, duration: Duration) -> ExponentialBackoff { - self.max_delay = Some(duration); + pub const fn max_delay(mut self, duration: Option) -> ExponentialBackoff { + self.max_delay = duration; self } @@ -100,13 +101,13 @@ impl Iterator for ExponentialBackoff { } /// Adds jitter to a duration iterator -pub struct Jittered> { +pub struct Jittered { inner: I, } -impl> Jittered { +impl Jittered { /// Create an iterator of jittered durations - pub fn jitter(inner: I) -> Self { + pub const fn jitter(inner: I) -> Self { Self { inner } } } @@ -119,6 +120,14 @@ impl> Iterator for Jittered { } } +impl Deref for Jittered { + type Target = I; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + /// Apply a jitter to a duration. fn jitter(dur: Duration) -> Duration { apply_jitter(dur, rand::thread_rng().sample(OpenClosed01)) @@ -158,7 +167,7 @@ mod test { #[test] fn test_maximum_bound() { let mut backoff = ExponentialBackoff::new(Duration::from_millis(2), 3) - .max_delay(Duration::from_millis(7)); + .max_delay(Some(Duration::from_millis(7))); assert_eq!(backoff.next(), Some(Duration::from_millis(2))); assert_eq!(backoff.next(), Some(Duration::from_millis(6))); @@ -207,7 +216,7 @@ mod test { || async { 0 }, |_| true, ExponentialBackoff::new(retry_interval_initial, retry_interval_factor) - .max_delay(retry_interval_max), + .max_delay(Some(retry_interval_max)), 5, ) .await; From c8b9da33986f9d64b4c4e84324d2bb92fffc52f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Sun, 17 Sep 2023 16:42:18 +0200 Subject: [PATCH 2/3] Simplify immediate retry strategy --- mullvad-daemon/src/device/service.rs | 55 +++++++++++---------------- mullvad-daemon/src/version_check.rs | 12 +++--- mullvad-setup/src/main.rs | 10 ++--- talpid-core/src/future_retry.rs | 56 ++++++++++++++++------------ 4 files changed, 64 insertions(+), 69 deletions(-) diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs index 74dc15b94f56..9c967e441379 100644 --- a/mullvad-daemon/src/device/service.rs +++ b/mullvad-daemon/src/device/service.rs @@ -15,12 +15,10 @@ use mullvad_api::{ rest::{self, Error as RestError, MullvadRestHandle}, AccountsProxy, DevicesProxy, }; -use talpid_core::future_retry::{ - constant_interval, retry_future, retry_future_n, ExponentialBackoff, Jittered, -}; -const RETRY_ACTION_INTERVAL: Duration = Duration::ZERO; -const RETRY_ACTION_MAX_RETRIES: usize = 2; - +use talpid_core::future_retry::{retry_future, ConstantInterval, ExponentialBackoff, Jittered}; +/// Retry strategy used for user-initiated actions that require immediate feedback +const RETRY_ACTION_STRATEGY: ConstantInterval = ConstantInterval::new(Duration::ZERO, Some(3)); +/// Retry strategy used for background tasks const RETRY_BACKOFF_STRATEGY: Jittered = Jittered::jitter( ExponentialBackoff::new(Duration::from_secs(4), 5) .max_delay(Some(Duration::from_secs(24 * 60 * 60))), @@ -52,11 +50,10 @@ impl DeviceService { let api_handle = self.api_availability.clone(); let token_copy = account_token.clone(); async move { - let (device, addresses) = retry_future_n( + let (device, addresses) = retry_future( move || proxy.create(token_copy.clone(), pubkey.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -123,11 +120,10 @@ impl DeviceService { ) -> Result<(), Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.remove(token.clone(), device.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -165,11 +161,10 @@ impl DeviceService { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); let pubkey = private_key.public_key(); - let addresses = retry_future_n( + let addresses = retry_future( move || proxy.replace_wg_key(token.clone(), device.clone(), pubkey.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error)?; @@ -218,11 +213,10 @@ impl DeviceService { pub async fn list_devices(&self, token: AccountToken) -> Result, Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.list(token.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error) @@ -247,11 +241,10 @@ impl DeviceService { pub async fn get(&self, token: AccountToken, device: DeviceId) -> Result { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.get(token.clone(), device.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await .map_err(map_rest_error) @@ -269,11 +262,10 @@ impl AccountService { pub fn create_account(&self) -> impl Future> { let mut proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.create_account(), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) } @@ -283,22 +275,20 @@ impl AccountService { ) -> impl Future> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - retry_future_n( + retry_future( move || proxy.get_www_auth_token(account.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) } pub async fn check_expiry(&self, token: AccountToken) -> Result, rest::Error> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - let result = retry_future_n( + let result = retry_future( move || proxy.get_expiry(token.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await; if handle_expiry_result_inner(&result, &self.api_availability) { @@ -318,11 +308,10 @@ impl AccountService { ) -> Result { let mut proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); - let result = retry_future_n( + let result = retry_future( move || proxy.submit_voucher(account_token.clone(), voucher.clone()), move |result| should_retry(result, &api_handle), - constant_interval(RETRY_ACTION_INTERVAL), - RETRY_ACTION_MAX_RETRIES, + RETRY_ACTION_STRATEGY, ) .await; if result.is_ok() { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index b488b0b3245a..dfe0a26b5fc0 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -15,7 +15,7 @@ use std::{ str::FromStr, time::Duration, }; -use talpid_core::mpsc::Sender; +use talpid_core::{future_retry::ConstantInterval, mpsc::Sender}; use talpid_types::ErrorExt; use tokio::fs::{self, File}; @@ -31,9 +31,8 @@ const DOWNLOAD_TIMEOUT: Duration = Duration::from_secs(15); const UPDATE_INTERVAL: Duration = Duration::from_secs(60 * 60 * 24); /// Wait this long until next try if an update failed const UPDATE_INTERVAL_ERROR: Duration = Duration::from_secs(60 * 60 * 6); -/// Retry interval for `RunVersionCheck`. -const IMMEDIATE_UPDATE_INTERVAL_ERROR: Duration = Duration::ZERO; -const IMMEDIATE_UPDATE_MAX_RETRIES: usize = 2; +/// Retry strategy for `RunVersionCheck`. +const IMMEDIATE_RETRY_STRATEGY: ConstantInterval = ConstantInterval::new(Duration::ZERO, Some(3)); #[cfg(target_os = "linux")] const PLATFORM: &str = "linux"; @@ -194,11 +193,10 @@ impl VersionUpdater { .map_err(Error::Download) }; - Box::pin(talpid_core::future_retry::retry_future_n( + Box::pin(talpid_core::future_retry::retry_future( download_future_factory, move |result| Self::should_retry_immediate(result, &api_handle), - std::iter::repeat(IMMEDIATE_UPDATE_INTERVAL_ERROR), - IMMEDIATE_UPDATE_MAX_RETRIES, + IMMEDIATE_RETRY_STRATEGY, )) } diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index a349554ca656..bcae45944234 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -6,15 +6,14 @@ use once_cell::sync::Lazy; use std::{path::PathBuf, process, str::FromStr, time::Duration}; use talpid_core::{ firewall::{self, Firewall}, - future_retry::{constant_interval, retry_future_n}, + future_retry::{retry_future, ConstantInterval}, }; use talpid_types::ErrorExt; static APP_VERSION: Lazy = Lazy::new(|| ParsedAppVersion::from_str(mullvad_version::VERSION).unwrap()); -const KEY_RETRY_INTERVAL: Duration = Duration::ZERO; -const KEY_RETRY_MAX_RETRIES: usize = 4; +const DEVICE_REMOVAL_STRATEGY: ConstantInterval = ConstantInterval::new(Duration::ZERO, Some(5)); #[repr(i32)] enum ExitStatus { @@ -171,14 +170,13 @@ async fn remove_device() -> Result<(), Error> { .await, ); - let device_removal = retry_future_n( + let device_removal = retry_future( move || proxy.remove(device.account_token.clone(), device.device.id.clone()), move |result| match result { Err(error) => error.is_network_error(), _ => false, }, - constant_interval(KEY_RETRY_INTERVAL), - KEY_RETRY_MAX_RETRIES, + DEVICE_REMOVAL_STRATEGY, ) .await; diff --git a/talpid-core/src/future_retry.rs b/talpid-core/src/future_retry.rs index f5dc7a8f72ec..f7b68a3f2d1f 100644 --- a/talpid-core/src/future_retry.rs +++ b/talpid-core/src/future_retry.rs @@ -2,23 +2,6 @@ use rand::{distributions::OpenClosed01, Rng}; use std::{future::Future, ops::Deref, time::Duration}; use talpid_time::sleep; -/// Convenience function that works like [`retry_future`] but limits the number -/// of retries to `max_retries`. -pub async fn retry_future_n< - F: FnMut() -> O + 'static, - R: FnMut(&T) -> bool + 'static, - D: Iterator + 'static, - O: Future, - T, ->( - factory: F, - should_retry: R, - delays: D, - max_retries: usize, -) -> T { - retry_future(factory, should_retry, delays.take(max_retries)).await -} - /// Retries a future until it should stop as determined by the retry function, or when /// the iterator returns `None`. pub async fn retry_future< @@ -44,9 +27,36 @@ pub async fn retry_future< } } -/// Returns an iterator that repeats the same interval. -pub fn constant_interval(interval: Duration) -> impl Iterator { - std::iter::repeat(interval) +/// Iterator that repeats the same interval, with an optional maximum no. of attempts. +pub struct ConstantInterval { + interval: Duration, + attempt: usize, + max_attempts: Option, +} + +impl ConstantInterval { + /// Creates a `ConstantInterval` that repeats `interval`, at most `max_attempts` times. + pub const fn new(interval: Duration, max_attempts: Option) -> ConstantInterval { + ConstantInterval { + interval, + attempt: 0, + max_attempts, + } + } +} + +impl Iterator for ConstantInterval { + type Item = Duration; + + fn next(&mut self) -> Option { + if let Some(max_attempts) = self.max_attempts { + if self.attempt >= max_attempts { + return None; + } + } + self.attempt = self.attempt.saturating_add(1); + Some(self.interval) + } } /// Provides an exponential back-off timer to delay the next retry of a failed operation. @@ -212,12 +222,12 @@ mod test { let retry_interval_max = Duration::from_secs(24 * 60 * 60); tokio::time::pause(); - let _ = retry_future_n( + let _ = retry_future( || async { 0 }, |_| true, ExponentialBackoff::new(retry_interval_initial, retry_interval_factor) - .max_delay(Some(retry_interval_max)), - 5, + .max_delay(Some(retry_interval_max)) + .take(5), ) .await; } From 2ce9379ac5d7b7344fc9eedaddb19a337653c9c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Tue, 19 Sep 2023 14:00:53 +0200 Subject: [PATCH 3/3] Add unit test for ConstantInterval --- talpid-core/src/future_retry.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/talpid-core/src/future_retry.rs b/talpid-core/src/future_retry.rs index f7b68a3f2d1f..197042e353d7 100644 --- a/talpid-core/src/future_retry.rs +++ b/talpid-core/src/future_retry.rs @@ -154,6 +154,22 @@ fn apply_jitter(duration: Duration, jitter: f64) -> Duration { mod test { use super::*; + #[test] + fn test_constant_interval() { + let mut ivl = ConstantInterval::new(Duration::from_secs(2), Some(3)); + + assert_eq!(ivl.next(), Some(Duration::from_secs(2))); + assert_eq!(ivl.next(), Some(Duration::from_secs(2))); + assert_eq!(ivl.next(), Some(Duration::from_secs(2))); + assert_eq!(ivl.next(), None); + } + + #[test] + fn test_constant_interval_no_max() { + let mut ivl = ConstantInterval::new(Duration::from_secs(2), None); + assert_eq!(ivl.next(), Some(Duration::from_secs(2))); + } + #[test] fn test_exponential_backoff() { let mut backoff = ExponentialBackoff::new(Duration::from_secs(2), 3);