From 4e5d08ebbc2cb49304ff0c2f4c77ecc11be30c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20L=C3=B6nnhager?= Date: Wed, 18 Oct 2023 19:23:03 +0200 Subject: [PATCH] Fold all access token requests into a single request --- mullvad-api/src/access.rs | 231 ++++++++++++++------- mullvad-api/src/bin/relay_list.rs | 2 +- mullvad-api/src/device.rs | 30 ++- mullvad-api/src/lib.rs | 36 +++- mullvad-api/src/rest.rs | 99 ++++----- mullvad-daemon/src/device/service.rs | 12 +- mullvad-daemon/src/geoip.rs | 3 +- mullvad-daemon/src/management_interface.rs | 2 +- 8 files changed, 262 insertions(+), 153 deletions(-) diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs index a3bec3f72599..67c83ac4da0f 100644 --- a/mullvad-api/src/access.rs +++ b/mullvad-api/src/access.rs @@ -2,109 +2,180 @@ use crate::{ rest, rest::{RequestFactory, RequestServiceHandle}, }; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; use hyper::StatusCode; use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken}; -use std::{ - collections::HashMap, - sync::{Arc, Mutex}, -}; -use talpid_types::ErrorExt; +use std::collections::HashMap; +use tokio::select; pub const AUTH_URL_PREFIX: &str = "auth/v1"; #[derive(Clone)] -pub struct AccessTokenProxy { - service: RequestServiceHandle, - factory: RequestFactory, - access_from_account: Arc>>, +pub struct AccessTokenStore { + tx: mpsc::UnboundedSender, +} + +enum StoreAction { + /// Request an access token for `AccountToken`, or return a saved one if it's not expired. + GetAccessToken( + AccountToken, + oneshot::Sender>, + ), + /// Forget cached access token for `AccountToken`, and drop any in-flight requests + InvalidateToken(AccountToken), +} + +#[derive(Default)] +struct AccountState { + current_access_token: Option, + inflight_request: Option>, + response_channels: Vec>>, } -impl AccessTokenProxy { +impl AccessTokenStore { pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { - Self { - service, - factory, - access_from_account: Arc::new(Mutex::new(HashMap::new())), + let (tx, rx) = mpsc::unbounded(); + tokio::spawn(Self::service_requests(rx, service, factory)); + Self { tx } + } + + async fn service_requests( + mut rx: mpsc::UnboundedReceiver, + service: RequestServiceHandle, + factory: RequestFactory, + ) { + let mut account_states: HashMap = HashMap::new(); + + let (completed_tx, mut completed_rx) = mpsc::unbounded(); + + loop { + select! { + action = rx.next() => { + let Some(action) = action else { + // We're done + break; + }; + + match action { + StoreAction::GetAccessToken(account, response_tx) => { + let account_state = account_states + .entry(account.clone()) + .or_default(); + + // If there is an unexpired access token, just return it. + // Otherwise, generate a new token + if let Some(ref access_token) = account_state.current_access_token { + if !access_token.is_expired() { + log::trace!("Using stored access token"); + let _ = response_tx.send(Ok(access_token.access_token.clone())); + continue; + } + + log::debug!("Replacing expired access token"); + account_state.current_access_token = None; + } + + // Begin requesting an access token if it's not already underway. + // If there's already an inflight request, just save `response_tx` + account_state + .inflight_request + .get_or_insert_with(|| { + let completed_tx = completed_tx.clone(); + let account = account.clone(); + let service = service.clone(); + let factory = factory.clone(); + + log::debug!("Fetching access token for an account"); + + tokio::spawn(async move { + let result = fetch_access_token(service, factory, account.clone()).await; + let _ = completed_tx.unbounded_send((account, result)); + }) + }); + + // Save the channel to respond to later + account_state.response_channels.push(response_tx); + } + StoreAction::InvalidateToken(account) => { + let account_state = account_states + .entry(account) + .or_default(); + + // Drop in-flight requests for the account + // & forget any existing access token + + log::debug!("Invalidating access token for an account"); + + if let Some(task) = account_state.inflight_request.take() { + task.abort(); + let _ = task.await; + } + + account_state.response_channels.clear(); + account_state.current_access_token = None; + } + } + } + + Some((account, result)) = completed_rx.next() => { + let account_state = account_states + .entry(account) + .or_default(); + + account_state.inflight_request = None; + + // Send response to all channels + for tx in account_state.response_channels.drain(..) { + let _ = tx.send(result.clone().map(|data| data.access_token)); + } + + if let Ok(access_token) = result { + account_state.current_access_token = Some(access_token); + } + } + } } } /// Obtain access token for an account, requesting a new one from the API if necessary. pub async fn get_token(&self, account: &AccountToken) -> Result { - let existing_token = { - self.access_from_account - .lock() - .unwrap() - .get(account.as_str()) - .cloned() - }; - if let Some(access_token) = existing_token { - if access_token.is_expired() { - log::debug!("Replacing expired access token"); - return self.request_new_token(account.clone()).await; - } - log::trace!("Using stored access token"); - return Ok(access_token.access_token.clone()); - } - self.request_new_token(account.clone()).await + let (tx, rx) = oneshot::channel(); + let _ = self + .tx + .unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx)); + rx.await.map_err(|_| rest::Error::Aborted)? } /// Remove an access token if the API response calls for it. - pub fn check_response(&self, account: &AccessToken, response: &Result) { + pub fn check_response(&self, account: &AccountToken, response: &Result) { if let Err(rest::Error::ApiError(_status, code)) = response { if code == crate::INVALID_ACCESS_TOKEN { - log::debug!("Dropping invalid access token"); - self.remove_token(account); + let _ = self + .tx + .unbounded_send(StoreAction::InvalidateToken(account.to_owned())); } } } +} - /// Removes a stored access token. - fn remove_token(&self, account: &AccountToken) -> Option { - self.access_from_account - .lock() - .unwrap() - .remove(account) - .map(|v| v.access_token) - } - - async fn request_new_token(&self, account: AccountToken) -> Result { - log::debug!("Fetching access token for an account"); - let access_token = self - .fetch_access_token(account.clone()) - .await - .map_err(|error| { - log::error!( - "{}", - error.display_chain_with_msg("Failed to obtain access token") - ); - error - })?; - self.access_from_account - .lock() - .unwrap() - .insert(account, access_token.clone()); - Ok(access_token.access_token) +async fn fetch_access_token( + service: RequestServiceHandle, + factory: RequestFactory, + account_token: AccountToken, +) -> Result { + #[derive(serde::Serialize)] + struct AccessTokenRequest { + account_number: String, } + let request = AccessTokenRequest { + account_number: account_token, + }; - async fn fetch_access_token( - &self, - account_token: AccountToken, - ) -> Result { - #[derive(serde::Serialize)] - struct AccessTokenRequest { - account_number: String, - } - let request = AccessTokenRequest { - account_number: account_token, - }; - - let service = self.service.clone(); - - let rest_request = self - .factory - .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; - let response = service.request(rest_request).await?; - let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(response).await - } + let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; + let response = service.request(rest_request).await?; + let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; + rest::deserialize_body(response).await } diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index 2139e51f54c4..ffb65c28b282 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -21,7 +21,7 @@ async fn main() { let relay_list = match relay_list_request { Ok(relay_list) => relay_list, - Err(RestError::TimeoutError(_)) => { + Err(RestError::TimeoutError) => { eprintln!("Request timed out"); process::exit(2); } diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs index 585e863ccf7d..3d8913e366ec 100644 --- a/mullvad-api/src/device.rs +++ b/mullvad-api/src/device.rs @@ -54,17 +54,21 @@ impl DevicesProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + let response = rest::send_json_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices"), Method::POST, &submission, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::CREATED], ) .await; + access_proxy.check_response(&account, &response); + let response: DeviceResponse = rest::deserialize_body(response?).await?; let DeviceResponse { id, @@ -102,16 +106,19 @@ impl DevicesProxy { let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), Method::GET, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; - rest::deserialize_body(response?).await + access_proxy.check_response(&account, &response); + let device = rest::deserialize_body(response?).await?; + Ok(device) } } @@ -123,16 +130,19 @@ impl DevicesProxy { let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices"), Method::GET, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; - rest::deserialize_body(response?).await + access_proxy.check_response(&account, &response); + let devices = rest::deserialize_body(response?).await?; + Ok(devices) } } @@ -145,15 +155,17 @@ impl DevicesProxy { let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), Method::DELETE, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::NO_CONTENT], ) .await; + access_proxy.check_response(&account, &response); response?; Ok(()) @@ -178,17 +190,21 @@ impl DevicesProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + let response = rest::send_json_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"), Method::PUT, &req_body, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); + let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; let DeviceResponse { ipv4_address, diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 63f5c2ad5b46..6beb1c8e864f 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -394,15 +394,17 @@ impl AccountsProxy { let factory = self.handle.factory.clone(); let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{ACCOUNTS_URL_PREFIX}/accounts/me"), Method::GET, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); let account: AccountExpiryResponse = rest::deserialize_body(response?).await?; Ok(account.expiry) @@ -447,16 +449,21 @@ impl AccountsProxy { let submission = VoucherSubmission { voucher_code }; async move { + let access_token = access_proxy.get_token(&account_token).await?; + let response = rest::send_json_request( &factory, service, &format!("{APP_URL_PREFIX}/submit-voucher"), Method::POST, &submission, - Some((access_proxy, account_token)), + Some(access_token), &[StatusCode::OK], ) .await; + + access_proxy.check_response(&account_token, &response); + rest::deserialize_body(response?).await } } @@ -464,7 +471,7 @@ impl AccountsProxy { #[cfg(target_os = "android")] pub fn init_play_purchase( &mut self, - account_token: AccountToken, + account: AccountToken, ) -> impl Future> { #[derive(serde::Deserialize)] struct PlayPurchaseInitResponse { @@ -476,17 +483,21 @@ impl AccountsProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; + let response = rest::send_json_request( &factory, service, &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), Method::POST, &(), - Some((access_proxy, account_token)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); + let PlayPurchaseInitResponse { obfuscated_id } = rest::deserialize_body(response?).await?; @@ -497,7 +508,7 @@ impl AccountsProxy { #[cfg(target_os = "android")] pub fn verify_play_purchase( &mut self, - account_token: AccountToken, + account: AccountToken, play_purchase: PlayPurchase, ) -> impl Future> { let service = self.handle.service.clone(); @@ -505,16 +516,21 @@ impl AccountsProxy { let access_proxy = self.handle.token_store.clone(); async move { - rest::send_json_request( + let access_token = access_proxy.get_token(&account).await?; + + let response = rest::send_json_request( &factory, service, &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"), Method::POST, &play_purchase, - Some((access_proxy, account_token)), + Some(access_token), &[StatusCode::ACCEPTED], ) - .await?; + .await; + + access_proxy.check_response(&account, &response); + response?; Ok(()) } } @@ -533,15 +549,17 @@ impl AccountsProxy { let access_proxy = self.handle.token_store.clone(); async move { + let access_token = access_proxy.get_token(&account).await?; let response = rest::send_request( &factory, service, &format!("{APP_URL_PREFIX}/www-auth-token"), Method::POST, - Some((access_proxy, account)), + Some(access_token), &[StatusCode::OK], ) .await; + access_proxy.check_response(&account, &response); let response: AuthTokenResponse = rest::deserialize_body(response?).await?; Ok(response.auth_token) } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index c3687a1eee9d..3690f1450c87 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -1,7 +1,7 @@ #[cfg(target_os = "android")] pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ - access::AccessTokenProxy, + access::AccessTokenStore, address_cache::AddressCache, availability::ApiAvailabilityHandle, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, @@ -44,25 +44,25 @@ pub type Result = std::result::Result; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); /// Describes all the ways a REST request can fail -#[derive(err_derive::Error, Debug)] +#[derive(err_derive::Error, Debug, Clone)] pub enum Error { #[error(display = "Request cancelled")] Aborted, #[error(display = "Hyper error")] - HyperError(#[error(source)] hyper::Error), + HyperError(#[error(source)] Arc), #[error(display = "Invalid header value")] - InvalidHeaderError(#[error(source)] http::header::InvalidHeaderValue), + InvalidHeaderError, #[error(display = "HTTP error")] - HttpError(#[error(source)] http::Error), + HttpError(#[error(source)] Arc), #[error(display = "Request timed out")] - TimeoutError(#[error(source)] tokio::time::error::Elapsed), + TimeoutError, #[error(display = "Failed to deserialize data")] - DeserializeError(#[error(source)] serde_json::Error), + DeserializeError(#[error(source)] Arc), #[error(display = "Failed to send request to rest client")] SendError, @@ -76,7 +76,7 @@ pub enum Error { /// The string given was not a valid URI. #[error(display = "Not a valid URI")] - UriError(#[error(source)] http::uri::InvalidUri), + InvalidUri, /// A new API config was requested, but the request could not be completed. #[error(display = "Failed to rotate API config")] @@ -85,7 +85,7 @@ pub enum Error { impl Error { pub fn is_network_error(&self) -> bool { - matches!(self, Error::HyperError(_) | Error::TimeoutError(_)) + matches!(self, Error::HyperError(_) | Error::TimeoutError) } pub fn is_aborted(&self) -> bool { @@ -203,7 +203,7 @@ impl< let future = async move { let response = tokio::time::timeout(timeout, request_future) .await - .map_err(Error::TimeoutError); + .map_err(|_| Error::TimeoutError); let response = flatten_result(response).map_err(|error| error.map_aborted()); @@ -314,20 +314,20 @@ pub struct RestRequest { impl RestRequest { /// Constructs a GET request with the given URI. Returns an error if the URI is not valid. pub fn get(uri: &str) -> Result { - let uri = hyper::Uri::from_str(uri).map_err(Error::UriError)?; + let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?; let mut builder = http::request::Builder::new() .method(Method::GET) .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT)) .header(header::ACCEPT, HeaderValue::from_static("application/json")); if let Some(host) = uri.host() { - builder = builder.header(header::HOST, HeaderValue::from_str(host)?); + builder = builder.header( + header::HOST, + HeaderValue::from_str(host).map_err(|_| Error::InvalidHeaderError)?, + ); }; - let request = builder - .uri(uri) - .body(hyper::Body::empty()) - .map_err(Error::HttpError)?; + let request = builder.uri(uri).body(hyper::Body::empty())?; Ok(RestRequest { timeout: DEFAULT_TIMEOUT, @@ -341,7 +341,7 @@ impl RestRequest { let header = match auth { Some(auth) => Some( HeaderValue::from_str(&format!("Bearer {auth}")) - .map_err(Error::InvalidHeaderError)?, + .map_err(|_| Error::InvalidHeaderError)?, ), None => None, }; @@ -361,7 +361,8 @@ impl RestRequest { } pub fn add_header(&mut self, key: T, value: &str) -> Result<()> { - let header_value = http::HeaderValue::from_str(value).map_err(Error::InvalidHeaderError)?; + let header_value = + http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?; self.request.headers_mut().insert(key, header_value); Ok(()) } @@ -458,7 +459,8 @@ impl RequestFactory { let headers = request.headers_mut(); headers.insert( header::CONTENT_LENGTH, - HeaderValue::from_str(&body_length.to_string()).map_err(Error::InvalidHeaderError)?, + HeaderValue::from_str(&body_length.to_string()) + .map_err(|_| Error::InvalidHeaderError)?, ); headers.insert( header::CONTENT_TYPE, @@ -483,13 +485,14 @@ impl RequestFactory { .header(header::ACCEPT, HeaderValue::from_static("application/json")) .header(header::HOST, self.hostname.clone()); - request.body(hyper::Body::empty()).map_err(Error::HttpError) + let result = request.body(hyper::Body::empty())?; + Ok(result) } fn get_uri(&self, path: &str) -> Result { let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or(""); let uri = format!("https://{}/{}{}", self.hostname, prefix, path); - hyper::Uri::from_str(&uri).map_err(Error::UriError) + hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri) } fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest { @@ -503,25 +506,16 @@ pub fn send_request( service: RequestServiceHandle, uri: &str, method: Method, - auth: Option<(AccessTokenProxy, AccountToken)>, + access_token: Option, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future> { let request = factory.request(uri, method); async move { let mut request = request?; - if let Some((store, account)) = &auth { - let access_token = store.get_token(account).await?; - request.set_auth(Some(access_token))?; - } + request.set_auth(access_token)?; let response = service.request(request).await?; - let result = parse_rest_response(response, expected_statuses).await; - - if let Some((store, account)) = &auth { - store.check_response(account, &result); - } - - result + parse_rest_response(response, expected_statuses).await } } @@ -531,24 +525,15 @@ pub fn send_json_request( uri: &str, method: Method, body: &B, - auth: Option<(AccessTokenProxy, AccountToken)>, + access_token: Option, expected_statuses: &'static [hyper::StatusCode], ) -> impl Future> { let request = factory.json_request(method, uri, body); async move { let mut request = request?; - if let Some((store, account)) = &auth { - let access_token = store.get_token(account).await?; - request.set_auth(Some(access_token))?; - } + request.set_auth(access_token)?; let response = service.request(request).await?; - let result = parse_rest_response(response, expected_statuses).await; - - if let Some((store, account)) = &auth { - store.check_response(account, &result); - } - - result + parse_rest_response(response, expected_statuses).await } } @@ -566,7 +551,7 @@ async fn deserialize_body_inner( body.extend(&chunk?); } - serde_json::from_slice(&body).map_err(Error::DeserializeError) + serde_json::from_slice(&body).map_err(Error::from) } fn get_body_length(response: &Response) -> usize { @@ -639,7 +624,7 @@ pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, pub availability: ApiAvailabilityHandle, - pub token_store: AccessTokenProxy, + pub token_store: AccessTokenStore, } impl MullvadRestHandle { @@ -649,7 +634,7 @@ impl MullvadRestHandle { address_cache: AddressCache, availability: ApiAvailabilityHandle, ) -> Self { - let token_store = AccessTokenProxy::new(service.clone(), factory.clone()); + let token_store = AccessTokenStore::new(service.clone(), factory.clone()); let handle = Self { service, @@ -728,3 +713,21 @@ fn flatten_result( Err(err) => Err(err), } } + +impl From for Error { + fn from(value: hyper::Error) -> Self { + Error::HyperError(Arc::new(value)) + } +} + +impl From for Error { + fn from(value: serde_json::Error) -> Self { + Error::DeserializeError(Arc::new(value)) + } +} + +impl From for Error { + fn from(value: http::Error) -> Self { + Error::HttpError(Arc::new(value)) + } +} diff --git a/mullvad-daemon/src/device/service.rs b/mullvad-daemon/src/device/service.rs index fdda61297fd0..a56cf10b4829 100644 --- a/mullvad-daemon/src/device/service.rs +++ b/mullvad-daemon/src/device/service.rs @@ -14,7 +14,7 @@ use talpid_types::net::wireguard::PrivateKey; use super::{Error, PrivateAccountAndDevice, PrivateDevice}; use mullvad_api::{ availability::ApiAvailabilityHandle, - rest::{self, Error as RestError, MullvadRestHandle}, + rest::{self, MullvadRestHandle}, AccountsProxy, DevicesProxy, }; use talpid_core::future_retry::{retry_future, ConstantInterval, ExponentialBackoff, Jittered}; @@ -402,7 +402,7 @@ pub fn spawn_account_service( } fn handle_expiry_result_inner( - result: &Result, mullvad_api::rest::Error>, + result: &Result, rest::Error>, api_availability: &ApiAvailabilityHandle, ) -> bool { match result { @@ -425,18 +425,18 @@ fn handle_expiry_result_inner( } } -fn should_retry(result: &Result, api_handle: &ApiAvailabilityHandle) -> bool { +fn should_retry(result: &Result, api_handle: &ApiAvailabilityHandle) -> bool { match result { Err(error) if error.is_network_error() => !api_handle.get_state().is_offline(), _ => false, } } -fn should_retry_backoff(result: &Result) -> bool { +fn should_retry_backoff(result: &Result) -> bool { match result { Ok(_) => false, Err(error) => { - if let RestError::ApiError(status, code) = error { + if let rest::Error::ApiError(status, code) = error { *status != rest::StatusCode::NOT_FOUND && code != mullvad_api::DEVICE_NOT_FOUND && code != mullvad_api::INVALID_ACCOUNT @@ -451,7 +451,7 @@ fn should_retry_backoff(result: &Result) -> bool { fn map_rest_error(error: rest::Error) -> Error { match error { - RestError::ApiError(_status, ref code) => match code.as_str() { + rest::Error::ApiError(_status, ref code) => match code.as_str() { // TODO: Implement invalid payment mullvad_api::DEVICE_NOT_FOUND => Error::InvalidDevice, mullvad_api::INVALID_ACCOUNT => Error::InvalidAccount, diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs index 527e06cf6105..261993921505 100644 --- a/mullvad-daemon/src/geoip.rs +++ b/mullvad-daemon/src/geoip.rs @@ -89,10 +89,11 @@ async fn send_location_request_internal( } fn log_network_error(err: Error, version: &'static str) { + use std::sync::Arc; let err_message = &format!("Unable to fetch {version} GeoIP location"); match err { Error::HyperError(hyper_err) if hyper_err.is_connect() => { - if let Some(cause) = hyper_err.into_cause() { + if let Some(cause) = Arc::into_inner(hyper_err).and_then(|x| x.into_cause()) { if let Some(err) = cause.downcast_ref::() { // Don't log ENETUNREACH errors, they are not informative. if err.raw_os_error() == Some(libc::ENETUNREACH) { diff --git a/mullvad-daemon/src/management_interface.rs b/mullvad-daemon/src/management_interface.rs index 993f0f9ece0b..61e4b025bad4 100644 --- a/mullvad-daemon/src/management_interface.rs +++ b/mullvad-daemon/src/management_interface.rs @@ -1088,7 +1088,7 @@ fn map_rest_error(error: &RestError) -> Status { { Status::new(Code::Unauthenticated, message) } - RestError::TimeoutError(_elapsed) => Status::deadline_exceeded("API request timed out"), + RestError::TimeoutError => Status::deadline_exceeded("API request timed out"), RestError::HyperError(_) => Status::unavailable("Cannot reach the API"), error => Status::unknown(format!("REST error: {error}")), }