diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs index a3bec3f72599..2ed69241a152 100644 --- a/mullvad-api/src/access.rs +++ b/mullvad-api/src/access.rs @@ -2,109 +2,172 @@ use crate::{ rest, rest::{RequestFactory, RequestServiceHandle}, }; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; use hyper::StatusCode; use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken}; -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; -use talpid_types::ErrorExt; +use std::{collections::HashMap, sync::Arc}; +use tokio::select; pub const AUTH_URL_PREFIX: &str = "auth/v1"; #[derive(Clone)] -pub struct AccessTokenProxy { - service: RequestServiceHandle, - factory: RequestFactory, - access_from_account: Arc>>, +pub struct AccessTokenStore { + tx: mpsc::UnboundedSender, } -impl AccessTokenProxy { +enum StoreAction { + /// Request an access token for `AccountToken`, or return a saved one if it's not expired. + GetAccessToken( + AccountToken, + oneshot::Sender>>, + ), + /// Forget cached access token for `AccountToken`, and drop any in-flight requests + InvalidateToken(AccountToken), +} + +/// Keep track of access token for the given account. +impl AccessTokenStore { pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { - Self { - service, - factory, - access_from_account: Arc::new(Mutex::new(HashMap::new())), - } + let (tx, rx) = mpsc::unbounded(); + tokio::spawn(Self::service_requests(rx, service, factory)); + Self { tx } } - /// Obtain access token for an account, requesting a new one from the API if necessary. - pub async fn get_token(&self, account: &AccountToken) -> Result { - let existing_token = { - self.access_from_account - .lock() - .unwrap() - .get(account.as_str()) - .cloned() - }; - if let Some(access_token) = existing_token { - if access_token.is_expired() { - log::debug!("Replacing expired access token"); - return self.request_new_token(account.clone()).await; + async fn service_requests( + mut rx: mpsc::UnboundedReceiver, + service: RequestServiceHandle, + factory: RequestFactory, + ) { + let mut access_from_account: HashMap = HashMap::new(); + let mut inflight_requests = HashMap::new(); + let mut response_channels = HashMap::new(); + + let (completed_tx, mut completed_rx) = mpsc::unbounded(); + + loop { + select! { + action = rx.next() => { + let Some(action) = action else { + // We're done + break; + }; + + match action { + StoreAction::GetAccessToken(account, response_tx) => { + // If there is an unexpired access token, just return it. + // Otherwise, generate a new token + if let Some(access_token) = access_from_account.get_mut(&account) { + if !access_token.is_expired() { + log::trace!("Using stored access token"); + let _ = response_tx.send(Ok(access_token.access_token.clone())); + continue; + } + + log::debug!("Replacing expired access token"); + access_from_account.remove(&account); + } + + // Begin requesting an access token if it's not already underway. + // If there's already an inflight request, just save `response_tx` + inflight_requests + .entry(account.clone()) + .or_insert_with(|| { + let completed_tx = completed_tx.clone(); + let account = account.clone(); + let service = service.clone(); + let factory = factory.clone(); + + log::debug!("Fetching access token for an account"); + + tokio::spawn(async move { + let result = fetch_access_token(service, factory, account.clone()).await; + let _ = completed_tx.unbounded_send((account, result)); + }) + }); + + // Save the channel to respond to later + response_channels + .entry(account) + .or_insert_with(|| vec![]) + .push(response_tx); + } + StoreAction::InvalidateToken(account) => { + // Drop in-flight requests for the account + // & forget any existing access token + + log::debug!("Invalidating access token for an account"); + + if let Some(task) = inflight_requests.remove(&account) { + task.abort(); + let _ = task.await; + } + + response_channels.remove(&account); + access_from_account.remove(&account); + } + } + } + + Some((account, result)) = completed_rx.next() => { + inflight_requests.remove(&account); + + // Sadly, rest::Error is not cloneable + let result = result.map_err(|error| Arc::new(error)); + + // Send response to all channels + if let Some(channels) = response_channels.remove(&account) { + for tx in channels { + let _ = tx.send(result.clone().map(|data| data.access_token)); + } + } + + if let Ok(access_token) = result { + access_from_account.insert(account, access_token); + } + } } - log::trace!("Using stored access token"); - return Ok(access_token.access_token.clone()); } - self.request_new_token(account.clone()).await + } + + /// Obtain access token for an account, requesting a new one from the API if necessary. + pub async fn get_token(&self, account: &AccountToken) -> Result> { + let (tx, rx) = oneshot::channel(); + let _ = self + .tx + .unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx)); + rx.await.map_err(|_| rest::Error::Aborted)? } /// Remove an access token if the API response calls for it. - pub fn check_response(&self, account: &AccessToken, response: &Result) { + pub fn check_response(&self, account: &AccountToken, response: &Result) { if let Err(rest::Error::ApiError(_status, code)) = response { if code == crate::INVALID_ACCESS_TOKEN { - log::debug!("Dropping invalid access token"); - self.remove_token(account); + let _ = self + .tx + .unbounded_send(StoreAction::InvalidateToken(account.to_owned())); } } } +} - /// Removes a stored access token. - fn remove_token(&self, account: &AccountToken) -> Option { - self.access_from_account - .lock() - .unwrap() - .remove(account) - .map(|v| v.access_token) - } - - async fn request_new_token(&self, account: AccountToken) -> Result { - log::debug!("Fetching access token for an account"); - let access_token = self - .fetch_access_token(account.clone()) - .await - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to obtain access token") - ); - error - })?; - self.access_from_account - .lock() - .unwrap() - .insert(account, access_token.clone()); - Ok(access_token.access_token) +async fn fetch_access_token( + service: RequestServiceHandle, + factory: RequestFactory, + account_token: AccountToken, +) -> Result { + #[derive(serde::Serialize)] + struct AccessTokenRequest { + account_number: String, } + let request = AccessTokenRequest { + account_number: account_token, + }; - async fn fetch_access_token( - &self, - account_token: AccountToken, - ) -> Result { - #[derive(serde::Serialize)] - struct AccessTokenRequest { - account_number: String, - } - let request = AccessTokenRequest { - account_number: account_token, - }; - - let service = self.service.clone(); - - let rest_request = self - .factory - .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; - let response = service.request(rest_request).await?; - let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(response).await - } + let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; + let response = service.request(rest_request).await?; + let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(response).await } diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs index 585e863ccf7d..5f99cf7d0818 100644 --- a/mullvad-api/src/device.rs +++ b/mullvad-api/src/device.rs @@ -4,7 +4,7 @@ use mullvad_types::{ account::AccountToken, device::{Device, DeviceId, DeviceName}, }; -use std::future::Future; +use std::{future::Future, sync::Arc}; use talpid_types::net::wireguard; use crate::rest; @@ -36,8 +36,9 @@ impl DevicesProxy { &self, account: AccountToken, pubkey: wireguard::PublicKey, - ) -> impl Future> - { + ) -> impl Future< + Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), Arc>, + > { #[derive(serde::Serialize)] struct DeviceSubmission { pubkey: wireguard::PublicKey, @@ -54,17 +55,21 @@ impl DevicesProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + let response = rest::send_json_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices"), Method::POST, &submission, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::CREATED], ) .await; + access_proxy.check_response(&account, &response); + let response: DeviceResponse = rest::deserialize_body(response?).await?; let DeviceResponse { id, @@ -97,42 +102,48 @@ impl DevicesProxy { &self, account: AccountToken, id: DeviceId, - ) -> impl Future> { + ) -> impl Future>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), Method::GET, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; - rest::deserialize_body(response?).await + access_proxy.check_response(&account, &response); + let device = rest::deserialize_body(response?).await?; + Ok(device) } } pub fn list( &self, account: AccountToken, - ) -> impl Future, rest::Error>> { + ) -> impl Future, Arc>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices"), Method::GET, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; - rest::deserialize_body(response?).await + access_proxy.check_response(&account, &response); + let devices = rest::deserialize_body(response?).await?; + Ok(devices) } } @@ -140,20 +151,22 @@ impl DevicesProxy { &self, account: AccountToken, id: DeviceId, - ) -> impl Future> { + ) -> impl Future>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), Method::DELETE, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::NO_CONTENT], ) .await; + access_proxy.check_response(&account, &response); response?; Ok(()) @@ -165,7 +178,7 @@ impl DevicesProxy { account: AccountToken, id: DeviceId, pubkey: wireguard::PublicKey, - ) -> impl Future> + ) -> impl Future>> { #[derive(serde::Serialize)] struct RotateDevicePubkey { @@ -178,17 +191,21 @@ impl DevicesProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + let response = rest::send_json_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"), Method::PUT, &req_body, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); + let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; let DeviceResponse { ipv4_address, diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 63f5c2ad5b46..2eb741897035 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -12,7 +12,7 @@ use mullvad_types::{ version::AppVersion, }; use proxy::ApiConnectionMode; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use std::{ cell::Cell, collections::BTreeMap, @@ -384,7 +384,7 @@ impl AccountsProxy { pub fn get_expiry( &self, account: AccountToken, - ) -> impl Future, rest::Error>> { + ) -> impl Future, Arc>> { #[derive(serde::Deserialize)] struct AccountExpiryResponse { expiry: DateTime, @@ -394,22 +394,26 @@ impl AccountsProxy { let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/accounts/me"), Method::GET, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); let account: AccountExpiryResponse = rest::deserialize_body(response?).await?; Ok(account.expiry) } } - pub fn create_account(&mut self) -> impl Future> { + pub fn create_account( + &mut self, + ) -> impl Future>> { #[derive(serde::Deserialize)] struct AccountCreationResponse { number: AccountToken, @@ -435,7 +439,7 @@ impl AccountsProxy { &mut self, account_token: AccountToken, voucher_code: String, - ) -> impl Future> { + ) -> impl Future>> { #[derive(serde::Serialize)] struct VoucherSubmission { voucher_code: String, @@ -447,17 +451,23 @@ impl AccountsProxy { let submission = VoucherSubmission { voucher_code }; async move { + let access_token = access_proxy.get_token(&account_token).await?; + let response = rest::send_json_request( &factory, service, &format!("{APP_URL_PREFIX}/submit-voucher"), Method::POST, &submission, - Some((access_proxy, account_token)), + Some(access_token), &[StatusCode::OK], ) .await; - rest::deserialize_body(response?).await + + access_proxy.check_response(&account_token, &response); + + let submission = rest::deserialize_body(response?).await?; + Ok(submission) } } @@ -465,7 +475,7 @@ impl AccountsProxy { pub fn init_play_purchase( &mut self, account_token: AccountToken, - ) -> impl Future> { + ) -> impl Future>> { #[derive(serde::Deserialize)] struct PlayPurchaseInitResponse { obfuscated_id: String, @@ -476,17 +486,21 @@ impl AccountsProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + let response = rest::send_json_request( &factory, service, &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), Method::POST, &(), - Some((access_proxy, account_token)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); + let PlayPurchaseInitResponse { obfuscated_id } = rest::deserialize_body(response?).await?; @@ -499,22 +513,26 @@ impl AccountsProxy { &mut self, account_token: AccountToken, play_purchase: PlayPurchase, - ) -> impl Future> { + ) -> impl Future>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + rest::send_json_request( &factory, service, &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"), Method::POST, &play_purchase, - Some((access_proxy, account_token)), + Some(access_token), &[StatusCode::ACCEPTED], ) .await?; + + access_proxy.check_response(&account, &response); Ok(()) } } @@ -522,7 +540,7 @@ impl AccountsProxy { pub fn get_www_auth_token( &self, account: AccountToken, - ) -> impl Future> { + ) -> impl Future>> { #[derive(serde::Deserialize)] struct AuthTokenResponse { auth_token: String, @@ -533,15 +551,17 @@ impl AccountsProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{APP_URL_PREFIX}/www-auth-token"), Method::POST, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); let response: AuthTokenResponse = rest::deserialize_body(response?).await?; Ok(response.auth_token) } @@ -563,7 +583,7 @@ impl ProblemReportProxy { message: &str, log: &str, metadata: &BTreeMap, - ) -> impl Future> { + ) -> impl Future>> { #[derive(serde::Serialize)] struct ProblemReport { address: String, @@ -621,7 +641,7 @@ impl AppVersionProxy { app_version: AppVersion, platform: &str, platform_version: String, - ) -> impl Future> { + ) -> impl Future>> { let service = self.handle.service.clone(); let path = format!("{APP_URL_PREFIX}/releases/{platform}/{app_version}"); @@ -633,7 +653,8 @@ impl AppVersionProxy { let response = service.request(request).await?; let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(parsed_response).await + let response = rest::deserialize_body(parsed_response).await?; + Ok(response) } } } @@ -648,7 +669,7 @@ impl ApiProxy { Self { handle } } - pub async fn get_api_addrs(&self) -> Result, rest::Error> { + pub async fn get_api_addrs(&self) -> Result, Arc> { let service = self.handle.service.clone(); let response = rest::send_request( @@ -661,6 +682,7 @@ impl ApiProxy { ) .await?; - rest::deserialize_body(response).await + let addrs = rest::deserialize_body(response).await?; + Ok(addrs) } } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index c3687a1eee9d..a280e3431125 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -1,7 +1,7 @@ #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ - access::AccessTokenProxy, + access::AccessTokenStore, address_cache::AddressCache, availability::ApiAvailabilityHandle, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, @@ -503,25 +503,16 @@ pub fn send_request( service: RequestServiceHandle, uri: &str, method: Method, - auth: Option<(AccessTokenProxy, AccountToken)>, + access_token: Option, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future> { let request = factory.request(uri, method); async move { let mut request = request?; - if let Some((store, account)) = &auth { - let access_token = store.get_token(account).await?; - request.set_auth(Some(access_token))?; - } + request.set_auth(access_token)?; let response = service.request(request).await?; - let result = parse_rest_response(response, expected_statuses).await; - - if let Some((store, account)) = &auth { - store.check_response(account, &result); - } - - result + parse_rest_response(response, expected_statuses).await } } @@ -531,24 +522,15 @@ pub fn send_json_request( uri: &str, method: Method, body: &B, - auth: Option<(AccessTokenProxy, AccountToken)>, + access_token: Option, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future> { let request = factory.json_request(method, uri, body); async move { let mut request = request?; - if let Some((store, account)) = &auth { - let access_token = store.get_token(account).await?; - request.set_auth(Some(access_token))?; - } + request.set_auth(access_token)?; let response = service.request(request).await?; - let result = parse_rest_response(response, expected_statuses).await; - - if let Some((store, account)) = &auth { - store.check_response(account, &result); - } - - result + parse_rest_response(response, expected_statuses).await } } @@ -639,7 +621,7 @@ pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, pub availability: ApiAvailabilityHandle, - pub token_store: AccessTokenProxy, + pub token_store: AccessTokenStore, } impl MullvadRestHandle { @@ -649,7 +631,7 @@ impl MullvadRestHandle { address_cache: AddressCache, availability: ApiAvailabilityHandle, ) -> Self { - let token_store = AccessTokenProxy::new(service.clone(), factory.clone()); + let token_store = AccessTokenStore::new(service.clone(), factory.clone()); let handle = Self { service, diff --git a/mullvad-daemon/src/device/mod.rs b/mullvad-daemon/src/device/mod.rs index df5606a14375..6285d24010f2 100644 --- a/mullvad-daemon/src/device/mod.rs +++ b/mullvad-daemon/src/device/mod.rs @@ -67,7 +67,7 @@ pub enum Error { #[error(display = "Failed parse device cache")] ParseDeviceCache(#[error(source)] serde_json::Error), #[error(display = "Unexpected HTTP request error")] - OtherRestError(#[error(source)] rest::Error), + OtherRestError(#[error(source)] Arc), #[error(display = "The device update task is not running")] Cancelled, /// Intended to be broadcast to requesters diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs index fdda61297fd0..e68223bd4678 100644 --- a/mullvad-daemon/src/device/service.rs +++ b/mullvad-daemon/src/device/service.rs @@ -1,4 +1,4 @@ -use std::{future::Future, time::Duration}; +use std::{future::Future, sync::Arc, time::Duration}; use chrono::{DateTime, Utc}; use futures::future::{abortable, AbortHandle}; @@ -261,7 +261,7 @@ pub struct AccountService { } impl AccountService { - pub fn create_account(&self) -> impl Future> { + pub fn create_account(&self) -> impl Future>> { let mut proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); retry_future( @@ -274,7 +274,7 @@ impl AccountService { pub fn get_www_auth_token( &self, account: AccountToken, - ) -> impl Future> { + ) -> impl Future>> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); retry_future( @@ -284,7 +284,10 @@ impl AccountService { ) } - pub async fn check_expiry(&self, token: AccountToken) -> Result, rest::Error> { + pub async fn check_expiry( + &self, + token: AccountToken, + ) -> Result, Arc> { let proxy = self.proxy.clone(); let api_handle = self.api_availability.clone(); let result = retry_future( @@ -402,7 +405,7 @@ pub fn spawn_account_service( } fn handle_expiry_result_inner( - result: &Result, mullvad_api::rest::Error>, + result: &Result, Arc>, api_availability: &ApiAvailabilityHandle, ) -> bool { match result { @@ -414,29 +417,31 @@ fn handle_expiry_result_inner( api_availability.pause_background(); true } - Err(mullvad_api::rest::Error::ApiError(_status, code)) => { - if code == mullvad_api::INVALID_ACCOUNT { - api_availability.pause_background(); - return true; + Err(error) => match error.as_ref() { + mullvad_api::rest::Error::ApiError(_status, code) => { + if code == mullvad_api::INVALID_ACCOUNT { + api_availability.pause_background(); + return true; + } + false } - false - } - Err(_) => false, + _ => false, + }, } } -fn should_retry(result: &Result, api_handle: &ApiAvailabilityHandle) -> bool { +fn should_retry(result: &Result>, api_handle: &ApiAvailabilityHandle) -> bool { match result { Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), _ => false, } } -fn should_retry_backoff(result: &Result) -> bool { +fn should_retry_backoff(result: &Result>) -> bool { match result { Ok(_) => false, Err(error) => { - if let RestError::ApiError(status, code) = error { + if let RestError::ApiError(status, code) = error.as_ref() { *status != rest::StatusCode::NOT_FOUND && code != mullvad_api::DEVICE_NOT_FOUND && code != mullvad_api::INVALID_ACCOUNT @@ -449,8 +454,8 @@ fn should_retry_backoff(result: &Result) -> bool { } } -fn map_rest_error(error: rest::Error) -> Error { - match error { +fn map_rest_error(error: Arc) -> Error { + match error.as_ref() { RestError::ApiError(_status, ref code) => match code.as_str() { // TODO: Implement invalid payment mullvad_api::DEVICE_NOT_FOUND => Error::InvalidDevice, @@ -460,6 +465,6 @@ fn map_rest_error(error: rest::Error) -> Error { mullvad_api::VOUCHER_USED => Error::UsedVoucher, _ => Error::OtherRestError(error), }, - error => Error::OtherRestError(error), + _error => Error::OtherRestError(error), } } diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 1077185ca38d..1a4493fad781 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -104,7 +104,7 @@ pub enum Error { InitRpcFactory(#[error(source)] mullvad_api::Error), #[error(display = "REST request failed")] - RestError(#[error(source)] mullvad_api::rest::Error), + RestError(#[error(source)] Arc), #[error(display = "API availability check failed")] ApiCheckError(#[error(source)] mullvad_api::availability::Error), @@ -203,7 +203,7 @@ pub enum DaemonCommand { CreateNewAccount(ResponseTx), /// Request the metadata for an account. GetAccountData( - ResponseTx, + ResponseTx>, AccountToken, ), /// Request www auth token for an account @@ -1376,7 +1376,7 @@ where fn on_get_account_data( &mut self, - tx: ResponseTx, + tx: ResponseTx>, account_token: AccountToken, ) { let account = self.account_manager.account_service.clone(); diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 993f0f9ece0b..bfd8191db669 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -424,7 +424,7 @@ impl ManagementService for ManagementServiceImpl { let result = self.wait_for_result(rx).await?; result .map(|account_data| Response::new(types::AccountData::from(account_data))) - .map_err(|error: RestError| { + .map_err(|error: Arc| { log::error!( "Unable to get account data from API: {}", error.display_chain() @@ -1081,8 +1081,8 @@ fn map_split_tunnel_error(error: talpid_core::split_tunnel::Error) -> Status { } /// Converts a REST API error into a tonic status. -fn map_rest_error(error: &RestError) -> Status { - match error { +fn map_rest_error(error: &Arc) -> Status { + match error.as_ref() { RestError::ApiError(status, message) if *status == StatusCode::UNAUTHORIZED || *status == StatusCode::FORBIDDEN => { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index dfe0a26b5fc0..8c00d3bb7b2a 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -13,6 +13,7 @@ use std::{ io, path::{Path, PathBuf}, str::FromStr, + sync::Arc, time::Duration, }; use talpid_core::{future_retry::ConstantInterval, mpsc::Sender}; @@ -75,7 +76,7 @@ pub enum Error { Deserialize(#[error(source)] serde_json::Error), #[error(display = "Failed to check the latest app version")] - Download(#[error(source)] mullvad_api::rest::Error), + Download(#[error(source)] Arc), #[error(display = "API availability check failed")] ApiCheck(#[error(source)] mullvad_api::availability::Error), diff --git a/mullvad-jni/src/daemon_interface.rs b/mullvad-jni/src/daemon_interface.rs index c64c94041e62..9f90418a4e4a 100644 --- a/mullvad-jni/src/daemon_interface.rs +++ b/mullvad-jni/src/daemon_interface.rs @@ -26,7 +26,7 @@ pub enum Error { NoSender, #[error(display = "Error performing RPC with the remote API")] - Api(#[error(source)] mullvad_api::rest::Error), + Api(#[error(source)] Arc), #[error(display = "Failed to update settings")] UpdateSettings, diff --git a/mullvad-jni/src/lib.rs b/mullvad-jni/src/lib.rs index a40bf73fc13f..751db7f113ed 100644 --- a/mullvad-jni/src/lib.rs +++ b/mullvad-jni/src/lib.rs @@ -81,12 +81,15 @@ impl From> for GetAccountDataResult match result { Ok(account_data) => GetAccountDataResult::Ok(account_data), Err(error) => match error { - daemon_interface::Error::Api(RestError::ApiError(status, _code)) - if status == StatusCode::UNAUTHORIZED || status == StatusCode::FORBIDDEN => - { - GetAccountDataResult::InvalidAccount - } - daemon_interface::Error::Api(_) => GetAccountDataResult::RpcError, + daemon_interface::Error::Api(error) => match error.as_ref() { + RestError::ApiError(status, _code) + if status == StatusCode::UNAUTHORIZED + || status == StatusCode::FORBIDDEN => + { + GetAccountDataResult::InvalidAccount + } + _ => GetAccountDataResult::RpcError, + }, _ => GetAccountDataResult::OtherError, }, } @@ -1375,7 +1378,7 @@ pub extern "system" fn Java_net_mullvad_mullvadvpn_service_MullvadDaemon_setQuan fn log_request_error(request: &str, error: &daemon_interface::Error) { match error { - daemon_interface::Error::Api(RestError::Aborted) => { + daemon_interface::Error::Api(error) if error.is_aborted() => { log::debug!("Request to {} cancelled", request); } error => { diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 10c2ec1b9618..46c500c11aa0 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -11,6 +11,7 @@ use std::{ fs::{self, File}, io::{self, BufWriter, Read, Seek, SeekFrom, Write}, path::{Path, PathBuf}, + sync::Arc, }; use talpid_types::ErrorExt; @@ -67,7 +68,7 @@ pub enum Error { CreateRpcClientError(#[error(source)] mullvad_api::Error), #[error(display = "Failed to send problem report")] - SendProblemReportError(#[error(source)] mullvad_api::rest::Error), + SendProblemReportError(#[error(source)] Arc), #[error(display = "Failed to send problem report {} times", MAX_SEND_ATTEMPTS)] SendFailedTooManyTimes, diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index bcae45944234..eebed732ed84 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -3,7 +3,7 @@ use mullvad_api::{self, proxy::ApiConnectionMode, DEVICE_NOT_FOUND}; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::version::ParsedAppVersion; use once_cell::sync::Lazy; -use std::{path::PathBuf, process, str::FromStr, time::Duration}; +use std::{path::PathBuf, process, str::FromStr, sync::Arc, time::Duration}; use talpid_core::{ firewall::{self, Firewall}, future_retry::{retry_future, ConstantInterval}, @@ -51,7 +51,7 @@ pub enum Error { RpcInitializationError(#[error(source)] mullvad_api::Error), #[error(display = "Failed to remove device from account")] - RemoveDeviceError(#[error(source)] mullvad_api::rest::Error), + RemoveDeviceError(#[error(source)] Arc), #[error(display = "Failed to obtain settings directory path")] SettingsPathError(#[error(source)] mullvad_paths::Error), @@ -183,10 +183,12 @@ async fn remove_device() -> Result<(), Error> { // `DEVICE_NOT_FOUND` is not considered to be an error in this context. match device_removal { Ok(_) => Ok(()), - Err(mullvad_api::rest::Error::ApiError(_status, code)) if code == DEVICE_NOT_FOUND => { - Ok(()) - } - Err(e) => Err(Error::RemoveDeviceError(e)), + Err(error) => match error.as_ref() { + mullvad_api::rest::Error::ApiError(_status, code) if code == DEVICE_NOT_FOUND => { + Ok(()) + } + _ => Err(Error::RemoveDeviceError(error)), + }, }?; cacher