diff --git a/mullvad-api/src/access.rs b/mullvad-api/src/access.rs index 67c83ac4da0f..276cc1f5613e 100644 --- a/mullvad-api/src/access.rs +++ b/mullvad-api/src/access.rs @@ -1,6 +1,7 @@ use crate::{ rest, rest::{RequestFactory, RequestServiceHandle}, + API, }; use futures::{ channel::{mpsc, oneshot}, @@ -13,7 +14,7 @@ use tokio::select; pub const AUTH_URL_PREFIX: &str = "auth/v1"; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct AccessTokenStore { tx: mpsc::UnboundedSender, } @@ -36,7 +37,8 @@ struct AccountState { } impl AccessTokenStore { - pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self { + pub(crate) fn new(service: RequestServiceHandle) -> Self { + let factory = rest::RequestFactory::new(&API.host, None); let (tx, rx) = mpsc::unbounded(); tokio::spawn(Self::service_requests(rx, service, factory)); Self { tx } @@ -174,8 +176,8 @@ async fn fetch_access_token( account_number: account_token, }; - let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?; - let response = service.request(rest_request).await?; - let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(response).await + let rest_request = factory + .post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)? + .expected_status(&[StatusCode::OK]); + service.request(rest_request).await?.deserialize().await } diff --git a/mullvad-api/src/device.rs b/mullvad-api/src/device.rs index 3d8913e366ec..37410e99b3cc 100644 --- a/mullvad-api/src/device.rs +++ b/mullvad-api/src/device.rs @@ -1,5 +1,5 @@ use chrono::{DateTime, Utc}; -use http::{Method, StatusCode}; +use http::StatusCode; use mullvad_types::{ account::AccountToken, device::{Device, DeviceId, DeviceName}, @@ -51,25 +51,13 @@ impl DevicesProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices"), - Method::POST, - &submission, - Some(access_token), - &[StatusCode::CREATED], - ) - .await; - - access_proxy.check_response(&account, &response); - - let response: DeviceResponse = rest::deserialize_body(response?).await?; + let request = factory + .post_json(&format!("{ACCOUNTS_URL_PREFIX}/devices"), &submission)? + .account(account)? + .expected_status(&[StatusCode::CREATED]); + let response = service.request(request).await?; let DeviceResponse { id, name, @@ -79,7 +67,7 @@ impl DevicesProxy { hijack_dns, created, .. - } = response; + } = response.deserialize().await?; Ok(( Device { @@ -104,21 +92,12 @@ impl DevicesProxy { ) -> impl Future> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), - Method::GET, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - let device = rest::deserialize_body(response?).await?; - Ok(device) + let request = factory + .get(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))? + .expected_status(&[StatusCode::OK]) + .account(account)?; + service.request(request).await?.deserialize().await } } @@ -128,21 +107,12 @@ impl DevicesProxy { ) -> impl Future, rest::Error>> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices"), - Method::GET, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - let devices = rest::deserialize_body(response?).await?; - Ok(devices) + let request = factory + .get(&format!("{ACCOUNTS_URL_PREFIX}/device"))? + .expected_status(&[StatusCode::OK]) + .account(account)?; + service.request(request).await?.deserialize().await } } @@ -153,21 +123,12 @@ impl DevicesProxy { ) -> impl Future> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"), - Method::DELETE, - Some(access_token), - &[StatusCode::NO_CONTENT], - ) - .await; - access_proxy.check_response(&account, &response); - - response?; + let request = factory + .delete(&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"))? + .expected_status(&[StatusCode::NO_CONTENT]) + .account(account)?; + service.request(request).await?; Ok(()) } } @@ -187,30 +148,21 @@ impl DevicesProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"), - Method::PUT, - &req_body, - Some(access_token), - &[StatusCode::OK], - ) - .await; - - access_proxy.check_response(&account, &response); - - let updated_device: DeviceResponse = rest::deserialize_body(response?).await?; + let request = factory + .put_json( + &format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"), + &req_body, + )? + .expected_status(&[StatusCode::OK]) + .account(account)?; + let response = service.request(request).await?; let DeviceResponse { ipv4_address, ipv6_address, .. - } = updated_device; + } = response.deserialize().await?; Ok(mullvad_types::wireguard::AssociatedAddresses { ipv4_address, ipv6_address, diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 6beb1c8e864f..91e2bc524acc 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -340,7 +340,8 @@ impl Runtime { self.socket_bypass_tx.clone(), ) .await; - let factory = rest::RequestFactory::new(API.host.clone(), None); + let token_store = access::AccessTokenStore::new(service.clone()); + let factory = rest::RequestFactory::new(&API.host, Some(token_store)); rest::MullvadRestHandle::new( service, @@ -392,21 +393,13 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/accounts/me"), - Method::GET, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - - let account: AccountExpiryResponse = rest::deserialize_body(response?).await?; + let request = factory + .get(&format!("{ACCOUNTS_URL_PREFIX}/accounts/me"))? + .expected_status(&[StatusCode::OK]) + .account(account)?; + let response = service.request(request).await?; + let account: AccountExpiryResponse = response.deserialize().await?; Ok(account.expiry) } } @@ -418,24 +411,21 @@ impl AccountsProxy { } let service = self.handle.service.clone(); - let response = rest::send_request( - &self.handle.factory, - service, - &format!("{ACCOUNTS_URL_PREFIX}/accounts"), - Method::POST, - None, - &[StatusCode::CREATED], - ); + let factory = self.handle.factory.clone(); async move { - let account: AccountCreationResponse = rest::deserialize_body(response.await?).await?; + let request = factory + .post(&format!("{ACCOUNTS_URL_PREFIX}/accounts"))? + .expected_status(&[StatusCode::CREATED]); + let response = service.request(request).await?; + let account: AccountCreationResponse = response.deserialize().await?; Ok(account.number) } } pub fn submit_voucher( &mut self, - account_token: AccountToken, + account: AccountToken, voucher_code: String, ) -> impl Future> { #[derive(serde::Serialize)] @@ -445,26 +435,14 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); let submission = VoucherSubmission { voucher_code }; async move { - let access_token = access_proxy.get_token(&account_token).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{APP_URL_PREFIX}/submit-voucher"), - Method::POST, - &submission, - Some(access_token), - &[StatusCode::OK], - ) - .await; - - access_proxy.check_response(&account_token, &response); - - rest::deserialize_body(response?).await + let request = factory + .post_json(&format!("{APP_URL_PREFIX}/submit-voucher"), &submission)? + .account(account)? + .expected_status(&[StatusCode::OK]); + service.request(request).await?.deserialize().await } } @@ -480,26 +458,15 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), - Method::POST, - &(), - Some(access_token), - &[StatusCode::OK], - ) - .await; - - access_proxy.check_response(&account, &response); + let request = factory + .post_json(&format!("{GOOGLE_PAYMENTS_URL_PREFIX}/init"), &())? + .account(account)? + .expected_status(&[StatusCode::OK]); + let response = service.request(request).await?; - let PlayPurchaseInitResponse { obfuscated_id } = - rest::deserialize_body(response?).await?; + let PlayPurchaseInitResponse { obfuscated_id } = response.deserialize().await?; Ok(obfuscated_id) } @@ -513,24 +480,16 @@ impl AccountsProxy { ) -> impl Future> { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - - let response = rest::send_json_request( - &factory, - service, - &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"), - Method::POST, - &play_purchase, - Some(access_token), - &[StatusCode::ACCEPTED], - ) - .await; - - access_proxy.check_response(&account, &response); - response?; + let request = factory + .post_json( + &format!("{GOOGLE_PAYMENTS_URL_PREFIX}/acknowledge"), + &play_purchase, + )? + .account(account)? + .expected_status(&[StatusCode::ACCEPTED]); + service.request(request).await?; Ok(()) } } @@ -546,21 +505,14 @@ impl AccountsProxy { let service = self.handle.service.clone(); let factory = self.handle.factory.clone(); - let access_proxy = self.handle.token_store.clone(); async move { - let access_token = access_proxy.get_token(&account).await?; - let response = rest::send_request( - &factory, - service, - &format!("{APP_URL_PREFIX}/www-auth-token"), - Method::POST, - Some(access_token), - &[StatusCode::OK], - ) - .await; - access_proxy.check_response(&account, &response); - let response: AuthTokenResponse = rest::deserialize_body(response?).await?; + let request = factory + .post(&format!("{APP_URL_PREFIX}/www-auth-token"))? + .account(account)? + .expected_status(&[StatusCode::OK]); + let response = service.request(request).await?; + let response: AuthTokenResponse = response.deserialize().await?; Ok(response.auth_token) } } @@ -598,19 +550,13 @@ impl ProblemReportProxy { }; let service = self.handle.service.clone(); - - let request = rest::send_json_request( - &self.handle.factory, - service, - &format!("{APP_URL_PREFIX}/problem-report"), - Method::POST, - &report, - None, - &[StatusCode::NO_CONTENT], - ); + let factory = self.handle.factory.clone(); async move { - request.await?; + let request = factory + .post_json(&format!("{APP_URL_PREFIX}/problem-report"), &report)? + .expected_status(&[StatusCode::NO_CONTENT]); + service.request(request).await?; Ok(()) } } @@ -646,12 +592,11 @@ impl AppVersionProxy { let request = self.handle.factory.request(&path, Method::GET); async move { - let mut request = request?; - request.add_header("M-Platform-Version", &platform_version)?; - + let request = request? + .expected_status(&[StatusCode::OK]) + .header("M-Platform-Version", &platform_version)?; let response = service.request(request).await?; - let parsed_response = rest::parse_rest_response(response, &[StatusCode::OK]).await?; - rest::deserialize_body(parsed_response).await + response.deserialize().await } } } @@ -667,18 +612,12 @@ impl ApiProxy { } pub async fn get_api_addrs(&self) -> Result, rest::Error> { - let service = self.handle.service.clone(); - - let response = rest::send_request( - &self.handle.factory, - service, - &format!("{APP_URL_PREFIX}/api-addrs"), - Method::GET, - None, - &[StatusCode::OK], - ) - .await?; - - rest::deserialize_body(response).await + let request = self + .handle + .factory + .get(&format!("{APP_URL_PREFIX}/api-addrs"))? + .expected_status(&[StatusCode::OK]); + let response = self.handle.service.request(request).await?; + response.deserialize().await } } diff --git a/mullvad-api/src/relay_list.rs b/mullvad-api/src/relay_list.rs index 0eb2b22afd5a..deaf29ef10d6 100644 --- a/mullvad-api/src/relay_list.rs +++ b/mullvad-api/src/relay_list.rs @@ -36,20 +36,18 @@ impl RelayListProxy { let request = self.handle.factory.request("app/v1/relays", Method::GET); async move { - let mut request = request?; - request.set_timeout(RELAY_LIST_TIMEOUT); + let mut request = request? + .timeout(RELAY_LIST_TIMEOUT) + .expected_status(&[StatusCode::NOT_MODIFIED, StatusCode::OK]); if let Some(ref tag) = etag { - request.add_header(header::IF_NONE_MATCH, tag)?; + request = request.header(header::IF_NONE_MATCH, tag)?; } let response = service.request(request).await?; if etag.is_some() && response.status() == StatusCode::NOT_MODIFIED { return Ok(None); } - if response.status() != StatusCode::OK { - return rest::handle_error_response(response).await; - } let etag = response .headers() @@ -62,11 +60,8 @@ impl RelayListProxy { } }); - Ok(Some( - rest::deserialize_body::(response) - .await? - .into_relay_list(etag), - )) + let relay_list: ServerRelayList = response.deserialize().await?; + Ok(Some(relay_list.into_relay_list(etag))) } } } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 39ed98d370c4..63c909507cea 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -10,17 +10,16 @@ use crate::{ use futures::{ channel::{mpsc, oneshot}, stream::StreamExt, - Stream, TryFutureExt, + Stream, }; use hyper::{ - client::Client, + client::{connect::Connect, Client}, header::{self, HeaderValue}, Method, Uri, }; use mullvad_types::account::AccountToken; use std::{ error::Error as StdError, - future::Future, str::FromStr, sync::{Arc, Weak}, time::Duration, @@ -32,9 +31,6 @@ use crate::API; pub use hyper::StatusCode; -pub type Request = hyper::Request; -pub type Response = hyper::Response; - const USER_AGENT: &str = "mullvad-app"; const API_IP_CHECK_INITIAL: Duration = Duration::from_secs(15 * 60); @@ -47,6 +43,9 @@ const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); /// Describes all the ways a REST request can fail #[derive(err_derive::Error, Debug, Clone)] pub enum Error { + #[error(display = "REST client service is down")] + RestServiceDown, + #[error(display = "Request cancelled")] Aborted, @@ -65,12 +64,6 @@ pub enum Error { #[error(display = "Failed to deserialize data")] DeserializeError(#[error(source)] Arc), - #[error(display = "Failed to send request to rest client")] - SendError, - - #[error(display = "Failed to receive response from rest client")] - ReceiveError, - /// Unexpected response code #[error(display = "Unexpected response status code {} - {}", _0, _1)] ApiError(StatusCode, String), @@ -79,9 +72,8 @@ pub enum Error { #[error(display = "Not a valid URI")] InvalidUri, - /// A new API config was requested, but the request could not be completed. - #[error(display = "Failed to rotate API config")] - NextApiConfigError, + #[error(display = "Set account token on factory with no access token store")] + NoAccessTokenStore, } impl Error { @@ -201,45 +193,7 @@ impl< async fn process_command(&mut self, command: RequestCommand) { match command { RequestCommand::NewRequest(request, completion_tx) => { - let tx = self.command_tx.upgrade(); - let timeout = request.timeout(); - - let hyper_request = request.into_request(); - - let api_availability = self.api_availability.clone(); - let suspend_fut = api_availability.wait_for_unsuspend(); - let request_fut = self.client.request(hyper_request).map_err(Error::from); - - let request_future = async move { - let _ = suspend_fut.await; - request_fut.await - }; - - let future = async move { - let response = tokio::time::timeout(timeout, request_future) - .await - .map_err(|_| Error::TimeoutError); - - let response = flatten_result(response).map_err(|error| error.map_aborted()); - - if let Err(err) = &response { - if err.is_network_error() && !api_availability.get_state().is_offline() { - log::error!("{}", err.display_chain_with_msg("HTTP request failed")); - if let Some(tx) = tx { - let (completion_tx, _completion_rx) = oneshot::channel(); - let _ = - tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx)); - } - } - } - - if completion_tx.send(response).is_err() { - log::trace!( - "Failed to send response to caller, caller channel is shut down" - ); - } - }; - tokio::spawn(future); + self.handle_new_request(request, completion_tx); } RequestCommand::Reset => { self.connector_handle.reset(); @@ -268,6 +222,34 @@ impl< } } + fn handle_new_request( + &mut self, + request: Request, + completion_tx: oneshot::Sender>, + ) { + let tx = self.command_tx.upgrade(); + + let api_availability = self.api_availability.clone(); + let request_future = request.into_future(self.client.clone(), api_availability.clone()); + + tokio::spawn(async move { + let response = request_future.await.map_err(|error| error.map_aborted()); + + // Switch API endpoint if the request failed due to a network error + if let Err(err) = &response { + if err.is_network_error() && !api_availability.get_state().is_offline() { + log::error!("{}", err.display_chain_with_msg("HTTP request failed")); + if let Some(tx) = tx { + let (completion_tx, _completion_rx) = oneshot::channel(); + let _ = tx.unbounded_send(RequestCommand::NextApiConfig(completion_tx)); + } + } + } + + let _ = completion_tx.send(response); + }); + } + async fn into_future(mut self) { while let Some(command) = self.command_rx.next().await { self.process_command(command).await; @@ -289,12 +271,12 @@ impl RequestServiceHandle { } /// Submits a `RestRequest` for execution to the request service. - pub async fn request(&self, request: RestRequest) -> Result { + pub async fn request(&self, request: Request) -> Result { let (completion_tx, completion_rx) = oneshot::channel(); self.tx .unbounded_send(RequestCommand::NewRequest(request, completion_tx)) - .map_err(|_| Error::SendError)?; - completion_rx.await.map_err(|_| Error::ReceiveError)? + .map_err(|_| Error::RestServiceDown)?; + completion_rx.await.map_err(|_| Error::RestServiceDown)? } /// Forcibly update the connection mode. @@ -302,16 +284,15 @@ impl RequestServiceHandle { let (completion_tx, completion_rx) = oneshot::channel(); self.tx .unbounded_send(RequestCommand::NextApiConfig(completion_tx)) - .map_err(|_| Error::SendError)?; - - completion_rx.await.map_err(|_| Error::NextApiConfigError)? + .map_err(|_| Error::RestServiceDown)?; + completion_rx.await.map_err(|_| Error::RestServiceDown)? } } #[derive(Debug)] pub(crate) enum RequestCommand { NewRequest( - RestRequest, + Request, oneshot::Sender>, ), Reset, @@ -320,13 +301,15 @@ pub(crate) enum RequestCommand { /// A REST request that is sent to the RequestService to be executed. #[derive(Debug)] -pub struct RestRequest { - request: Request, +pub struct Request { + request: hyper::Request, timeout: Duration, - auth: Option, + access_token_store: Option, + account: Option, + expected_status: &'static [hyper::StatusCode], } -impl RestRequest { +impl Request { /// Constructs a GET request with the given URI. Returns an error if the URI is not valid. pub fn get(uri: &str) -> Result { let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?; @@ -343,54 +326,112 @@ impl RestRequest { }; let request = builder.uri(uri).body(hyper::Body::empty())?; + Ok(Self::new(request, None)) + } - Ok(RestRequest { - timeout: DEFAULT_TIMEOUT, - auth: None, + fn new( + request: hyper::Request, + access_token_store: Option, + ) -> Self { + Self { request, - }) + timeout: DEFAULT_TIMEOUT, + access_token_store, + account: None, + expected_status: &[], + } } - /// Set the auth header with the following format: `Bearer $auth`. - pub fn set_auth(&mut self, auth: Option) -> Result<()> { - let header = match auth { - Some(auth) => Some( - HeaderValue::from_str(&format!("Bearer {auth}")) - .map_err(|_| Error::InvalidHeaderError)?, - ), - None => None, - }; - - self.auth = header; - Ok(()) + /// Set the account token to obtain authentication for. + /// This fails if no store is set. + pub fn account(mut self, account: AccountToken) -> Result { + if self.access_token_store.is_none() { + return Err(Error::NoAccessTokenStore); + } + self.account = Some(account); + Ok(self) } /// Sets timeout for the request. - pub fn set_timeout(&mut self, timeout: Duration) { + pub fn timeout(mut self, timeout: Duration) -> Self { self.timeout = timeout; + self } - /// Retrieves timeout - pub fn timeout(&self) -> Duration { - self.timeout + pub fn expected_status(mut self, expected_status: &'static [hyper::StatusCode]) -> Self { + self.expected_status = expected_status; + self } - pub fn add_header(&mut self, key: T, value: &str) -> Result<()> { + pub fn header(mut self, key: T, value: &str) -> Result { let header_value = http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?; self.request.headers_mut().insert(key, header_value); - Ok(()) + Ok(self) + } + + async fn into_future( + self, + hyper_client: hyper::Client, + api_availability: ApiAvailabilityHandle, + ) -> Result { + let timeout = self.timeout; + let inner_fut = self.into_future_without_timeout(hyper_client, api_availability); + tokio::time::timeout(timeout, inner_fut) + .await + .map_err(|_| Error::TimeoutError)? } - /// Converts into a `hyper::Request` - fn into_request(self) -> Request { - let Self { - mut request, auth, .. - } = self; - if let Some(auth) = auth { - request.headers_mut().insert(header::AUTHORIZATION, auth); + async fn into_future_without_timeout( + mut self, + hyper_client: hyper::Client, + api_availability: ApiAvailabilityHandle, + ) -> Result { + let _ = api_availability.wait_for_unsuspend().await; + + // Obtain access token first + if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) { + let access_token = store.get_token(account).await?; + let auth = HeaderValue::from_str(&format!("Bearer {access_token}")) + .map_err(|_| Error::InvalidHeaderError)?; + self.request + .headers_mut() + .insert(header::AUTHORIZATION, auth); + } + + // Make request to hyper client + let response = hyper_client + .request(self.request) + .await + .map_err(Error::from); + + // Notify access token store of expired tokens + if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) { + store.check_response(account, &response); } - request + + // Parse unexpected responses and errors + + let response = response?; + + if !self.expected_status.contains(&response.status()) { + if !self.expected_status.is_empty() { + log::error!( + "Unexpected HTTP status code {}, expected codes [{}]", + response.status(), + self.expected_status + .iter() + .map(ToString::to_string) + .collect::>() + .join(",") + ); + } + if !response.status().is_success() { + return handle_error_response(response).await; + } + } + + Ok(Response::new(response)) } /// Returns the URI of the request @@ -399,13 +440,28 @@ impl RestRequest { } } -impl From for RestRequest { - fn from(request: Request) -> Self { - Self { - request, - timeout: DEFAULT_TIMEOUT, - auth: None, - } +/// Successful result of a REST request +#[derive(Debug)] +pub struct Response { + response: hyper::Response, +} + +impl Response { + fn new(response: hyper::Response) -> Self { + Self { response } + } + + pub fn status(&self) -> StatusCode { + self.response.status() + } + + pub fn headers(&self) -> &hyper::HeaderMap { + self.response.headers() + } + + pub async fn deserialize(self) -> Result { + let body_length = get_body_length(&self.response); + deserialize_body_inner(self.response, body_length).await } } @@ -423,48 +479,62 @@ struct NewErrorResponse { #[derive(Clone)] pub struct RequestFactory { - hostname: String, - path_prefix: Option, - pub timeout: Duration, + hostname: &'static str, + token_store: Option, + default_timeout: Duration, } impl RequestFactory { - pub fn new(hostname: String, path_prefix: Option) -> Self { + pub fn new(hostname: &'static str, token_store: Option) -> Self { Self { hostname, - path_prefix, - timeout: DEFAULT_TIMEOUT, + token_store, + default_timeout: DEFAULT_TIMEOUT, } } - pub fn request(&self, path: &str, method: Method) -> Result { - self.hyper_request(path, method) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + pub fn request(&self, path: &str, method: Method) -> Result { + Ok( + Request::new(self.hyper_request(path, method)?, self.token_store.clone()) + .timeout(self.default_timeout), + ) + } + + pub fn get(&self, path: &str) -> Result { + self.request(path, Method::GET) + } + + pub fn post(&self, path: &str) -> Result { + self.request(path, Method::POST) } - pub fn get(&self, path: &str) -> Result { - self.hyper_request(path, Method::GET) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + pub fn put(&self, path: &str) -> Result { + self.request(path, Method::PUT) } - pub fn post(&self, path: &str) -> Result { - self.hyper_request(path, Method::POST) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + pub fn delete(&self, path: &str) -> Result { + self.request(path, Method::DELETE) } - pub fn post_json(&self, path: &str, body: &S) -> Result { + pub fn post_json(&self, path: &str, body: &S) -> Result { self.json_request(Method::POST, path, body) } + pub fn put_json(&self, path: &str, body: &S) -> Result { + self.json_request(Method::PUT, path, body) + } + + pub fn default_timeout(mut self, timeout: Duration) -> Self { + self.default_timeout = timeout; + self + } + fn json_request( &self, method: Method, path: &str, body: &S, - ) -> Result { + ) -> Result { let mut request = self.hyper_request(path, method)?; let json_body = serde_json::to_string(&body)?; @@ -482,94 +552,29 @@ impl RequestFactory { HeaderValue::from_static("application/json"), ); - Ok(self.set_request_timeout(RestRequest::from(request))) - } - - pub fn delete(&self, path: &str) -> Result { - self.hyper_request(path, Method::DELETE) - .map(RestRequest::from) - .map(|req| self.set_request_timeout(req)) + Ok(Request::new(request, self.token_store.clone()).timeout(self.default_timeout)) } - fn hyper_request(&self, path: &str, method: Method) -> Result { + fn hyper_request(&self, path: &str, method: Method) -> Result> { let uri = self.get_uri(path)?; let request = http::request::Builder::new() .method(method) .uri(uri) .header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT)) .header(header::ACCEPT, HeaderValue::from_static("application/json")) - .header(header::HOST, self.hostname.clone()); + .header(header::HOST, HeaderValue::from_static(self.hostname)); let result = request.body(hyper::Body::empty())?; Ok(result) } fn get_uri(&self, path: &str) -> Result { - let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or(""); - let uri = format!("https://{}/{}{}", self.hostname, prefix, path); + let uri = format!("https://{}/{}", self.hostname, path); hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri) } - - fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest { - request.timeout = self.timeout; - request - } -} - -pub fn send_request( - factory: &RequestFactory, - service: RequestServiceHandle, - uri: &str, - method: Method, - access_token: Option, - expected_statuses: &'static [hyper::StatusCode], -) -> impl Future> { - let request = factory.request(uri, method); - - async move { - let mut request = request?; - request.set_auth(access_token)?; - let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await - } -} - -pub fn send_json_request( - factory: &RequestFactory, - service: RequestServiceHandle, - uri: &str, - method: Method, - body: &B, - access_token: Option, - expected_statuses: &'static [hyper::StatusCode], -) -> impl Future> { - let request = factory.json_request(method, uri, body); - async move { - let mut request = request?; - request.set_auth(access_token)?; - let response = service.request(request).await?; - parse_rest_response(response, expected_statuses).await - } } -pub async fn deserialize_body(response: Response) -> Result { - let body_length = get_body_length(&response); - deserialize_body_inner(response, body_length).await -} - -async fn deserialize_body_inner( - mut response: Response, - body_length: usize, -) -> Result { - let mut body: Vec = Vec::with_capacity(body_length); - while let Some(chunk) = response.body_mut().next().await { - body.extend(&chunk?); - } - - serde_json::from_slice(&body).map_err(Error::from) -} - -fn get_body_length(response: &Response) -> usize { +fn get_body_length(response: &hyper::Response) -> usize { response .headers() .get(header::CONTENT_LENGTH) @@ -578,29 +583,7 @@ fn get_body_length(response: &Response) -> usize { .unwrap_or(0) } -pub async fn parse_rest_response( - response: Response, - expected_statuses: &'static [hyper::StatusCode], -) -> Result { - if !expected_statuses.contains(&response.status()) { - log::error!( - "Unexpected HTTP status code {}, expected codes [{}]", - response.status(), - expected_statuses - .iter() - .map(ToString::to_string) - .collect::>() - .join(",") - ); - if !response.status().is_success() { - return handle_error_response(response).await; - } - } - - Ok(response) -} - -pub async fn handle_error_response(response: Response) -> Result { +async fn handle_error_response(response: hyper::Response) -> Result { let status = response.status(); let error_message = match status { hyper::StatusCode::METHOD_NOT_ALLOWED => "Method not allowed", @@ -634,12 +617,23 @@ pub async fn handle_error_response(response: Response) -> Result { Err(Error::ApiError(status, error_message.to_owned())) } +async fn deserialize_body_inner( + mut response: hyper::Response, + body_length: usize, +) -> Result { + let mut body: Vec = Vec::with_capacity(body_length); + while let Some(chunk) = response.body_mut().next().await { + body.extend(&chunk?); + } + + serde_json::from_slice(&body).map_err(Error::from) +} + #[derive(Clone)] pub struct MullvadRestHandle { pub(crate) service: RequestServiceHandle, pub factory: RequestFactory, pub availability: ApiAvailabilityHandle, - pub token_store: AccessTokenStore, } impl MullvadRestHandle { @@ -649,13 +643,10 @@ impl MullvadRestHandle { address_cache: AddressCache, availability: ApiAvailabilityHandle, ) -> Self { - let token_store = AccessTokenStore::new(service.clone(), factory.clone()); - let handle = Self { service, factory, availability, - token_store, }; #[cfg(feature = "api-override")] if API.disable_address_cache { @@ -714,19 +705,6 @@ impl MullvadRestHandle { pub fn service(&self) -> RequestServiceHandle { self.service.clone() } - - pub fn factory(&self) -> &RequestFactory { - &self.factory - } -} - -fn flatten_result( - result: std::result::Result, E>, -) -> std::result::Result { - match result { - Ok(value) => value, - Err(err) => Err(err), - } } macro_rules! impl_into_arc_err { diff --git a/mullvad-daemon/src/geoip.rs b/mullvad-daemon/src/geoip.rs index 0939787e91dd..43ffaf1054e6 100644 --- a/mullvad-daemon/src/geoip.rs +++ b/mullvad-daemon/src/geoip.rs @@ -83,9 +83,8 @@ async fn send_location_request_internal( service: RequestServiceHandle, ) -> Result { let future_service = service.clone(); - let request = mullvad_api::rest::RestRequest::get(uri)?; - let response = future_service.request(request).await?; - mullvad_api::rest::deserialize_body(response).await + let request = mullvad_api::rest::Request::get(uri)?; + future_service.request(request).await?.deserialize().await } fn log_network_error(err: Error, version: &'static str) { diff --git a/mullvad-daemon/src/version_check.rs b/mullvad-daemon/src/version_check.rs index dfe0a26b5fc0..d7c7eaa80565 100644 --- a/mullvad-daemon/src/version_check.rs +++ b/mullvad-daemon/src/version_check.rs @@ -150,7 +150,7 @@ impl VersionUpdater { last_app_version_info: Option, show_beta_releases: bool, ) -> (Self, VersionUpdaterHandle) { - api_handle.factory.timeout = DOWNLOAD_TIMEOUT; + api_handle.factory = api_handle.factory.default_timeout(DOWNLOAD_TIMEOUT); let version_proxy = AppVersionProxy::new(api_handle); let cache_path = cache_dir.join(VERSION_INFO_FILENAME); let (tx, rx) = mpsc::channel(1);