Skip to content

Commit

Permalink
Amazing cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 19, 2023
1 parent 8169c7a commit a261f95
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 107 deletions.
11 changes: 4 additions & 7 deletions mullvad-api/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use futures::{
};
use hyper::StatusCode;
use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use tokio::select;

pub const AUTH_URL_PREFIX: &str = "auth/v1";
Expand All @@ -22,7 +22,7 @@ enum StoreAction {
/// Request an access token for `AccountToken`, or return a saved one if it's not expired.
GetAccessToken(
AccountToken,
oneshot::Sender<Result<AccessToken, Arc<rest::Error>>>,
oneshot::Sender<Result<AccessToken, rest::Error>>,
),
/// Forget cached access token for `AccountToken`, and drop any in-flight requests
InvalidateToken(AccountToken),
Expand All @@ -32,7 +32,7 @@ enum StoreAction {
struct AccountState {
current_access_token: Option<AccessTokenData>,
inflight_request: Option<tokio::task::JoinHandle<()>>,
response_channels: Vec<oneshot::Sender<Result<AccessToken, Arc<rest::Error>>>>,
response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>,
}

impl AccessTokenStore {
Expand Down Expand Up @@ -127,9 +127,6 @@ impl AccessTokenStore {

account_state.inflight_request = None;

// Sadly, rest::Error is not cloneable
let result = result.map_err(Arc::new);

// Send response to all channels
for tx in account_state.response_channels.drain(..) {
let _ = tx.send(result.clone().map(|data| data.access_token));
Expand All @@ -144,7 +141,7 @@ impl AccessTokenStore {
}

/// Obtain access token for an account, requesting a new one from the API if necessary.
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, Arc<rest::Error>> {
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
let (tx, rx) = oneshot::channel();
let _ = self
.tx
Expand Down
2 changes: 1 addition & 1 deletion mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() {

let relay_list = match relay_list_request {
Ok(relay_list) => relay_list,
Err(RestError::TimeoutError(_)) => {
Err(RestError::TimeoutError) => {
eprintln!("Request timed out");
process::exit(2);
}
Expand Down
15 changes: 7 additions & 8 deletions mullvad-api/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use mullvad_types::{
account::AccountToken,
device::{Device, DeviceId, DeviceName},
};
use std::{future::Future, sync::Arc};
use std::future::Future;
use talpid_types::net::wireguard;

use crate::rest;
Expand Down Expand Up @@ -36,9 +36,8 @@ impl DevicesProxy {
&self,
account: AccountToken,
pubkey: wireguard::PublicKey,
) -> impl Future<
Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), Arc<rest::Error>>,
> {
) -> impl Future<Output = Result<(Device, mullvad_types::wireguard::AssociatedAddresses), rest::Error>>
{
#[derive(serde::Serialize)]
struct DeviceSubmission {
pubkey: wireguard::PublicKey,
Expand Down Expand Up @@ -102,7 +101,7 @@ impl DevicesProxy {
&self,
account: AccountToken,
id: DeviceId,
) -> impl Future<Output = Result<Device, Arc<rest::Error>>> {
) -> impl Future<Output = Result<Device, rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
Expand All @@ -126,7 +125,7 @@ impl DevicesProxy {
pub fn list(
&self,
account: AccountToken,
) -> impl Future<Output = Result<Vec<Device>, Arc<rest::Error>>> {
) -> impl Future<Output = Result<Vec<Device>, rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
Expand All @@ -151,7 +150,7 @@ impl DevicesProxy {
&self,
account: AccountToken,
id: DeviceId,
) -> impl Future<Output = Result<(), Arc<rest::Error>>> {
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
Expand All @@ -178,7 +177,7 @@ impl DevicesProxy {
account: AccountToken,
id: DeviceId,
pubkey: wireguard::PublicKey,
) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, Arc<rest::Error>>>
) -> impl Future<Output = Result<mullvad_types::wireguard::AssociatedAddresses, rest::Error>>
{
#[derive(serde::Serialize)]
struct RotateDevicePubkey {
Expand Down
22 changes: 10 additions & 12 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use mullvad_types::{
version::AppVersion,
};
use proxy::ApiConnectionMode;
use std::sync::{Arc, OnceLock};
use std::sync::OnceLock;
use std::{
cell::Cell,
collections::BTreeMap,
Expand Down Expand Up @@ -384,7 +384,7 @@ impl AccountsProxy {
pub fn get_expiry(
&self,
account: AccountToken,
) -> impl Future<Output = Result<DateTime<Utc>, Arc<rest::Error>>> {
) -> impl Future<Output = Result<DateTime<Utc>, rest::Error>> {
#[derive(serde::Deserialize)]
struct AccountExpiryResponse {
expiry: DateTime<Utc>,
Expand All @@ -411,9 +411,7 @@ impl AccountsProxy {
}
}

pub fn create_account(
&mut self,
) -> impl Future<Output = Result<AccountToken, Arc<rest::Error>>> {
pub fn create_account(&mut self) -> impl Future<Output = Result<AccountToken, rest::Error>> {
#[derive(serde::Deserialize)]
struct AccountCreationResponse {
number: AccountToken,
Expand All @@ -439,7 +437,7 @@ impl AccountsProxy {
&mut self,
account_token: AccountToken,
voucher_code: String,
) -> impl Future<Output = Result<VoucherSubmission, Arc<rest::Error>>> {
) -> impl Future<Output = Result<VoucherSubmission, rest::Error>> {
#[derive(serde::Serialize)]
struct VoucherSubmission {
voucher_code: String,
Expand Down Expand Up @@ -475,7 +473,7 @@ impl AccountsProxy {
pub fn init_play_purchase(
&mut self,
account: AccountToken,
) -> impl Future<Output = Result<PlayPurchasePaymentToken, Arc<rest::Error>>> {
) -> impl Future<Output = Result<PlayPurchasePaymentToken, rest::Error>> {
#[derive(serde::Deserialize)]
struct PlayPurchaseInitResponse {
obfuscated_id: String,
Expand Down Expand Up @@ -513,7 +511,7 @@ impl AccountsProxy {
&mut self,
account: AccountToken,
play_purchase: PlayPurchase,
) -> impl Future<Output = Result<(), Arc<rest::Error>>> {
) -> impl Future<Output = Result<(), rest::Error>> {
let service = self.handle.service.clone();
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
Expand Down Expand Up @@ -541,7 +539,7 @@ impl AccountsProxy {
pub fn get_www_auth_token(
&self,
account: AccountToken,
) -> impl Future<Output = Result<String, Arc<rest::Error>>> {
) -> impl Future<Output = Result<String, rest::Error>> {
#[derive(serde::Deserialize)]
struct AuthTokenResponse {
auth_token: String,
Expand Down Expand Up @@ -584,7 +582,7 @@ impl ProblemReportProxy {
message: &str,
log: &str,
metadata: &BTreeMap<String, String>,
) -> impl Future<Output = Result<(), Arc<rest::Error>>> {
) -> impl Future<Output = Result<(), rest::Error>> {
#[derive(serde::Serialize)]
struct ProblemReport {
address: String,
Expand Down Expand Up @@ -642,7 +640,7 @@ impl AppVersionProxy {
app_version: AppVersion,
platform: &str,
platform_version: String,
) -> impl Future<Output = Result<AppVersionResponse, Arc<rest::Error>>> {
) -> impl Future<Output = Result<AppVersionResponse, rest::Error>> {
let service = self.handle.service.clone();

let path = format!("{APP_URL_PREFIX}/releases/{platform}/{app_version}");
Expand Down Expand Up @@ -670,7 +668,7 @@ impl ApiProxy {
Self { handle }
}

pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, Arc<rest::Error>> {
pub async fn get_api_addrs(&self) -> Result<Vec<SocketAddr>, rest::Error> {
let service = self.handle.service.clone();

let response = rest::send_request(
Expand Down
63 changes: 42 additions & 21 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,25 @@ pub type Result<T> = std::result::Result<T, Error>;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);

/// Describes all the ways a REST request can fail
#[derive(err_derive::Error, Debug)]
#[derive(err_derive::Error, Debug, Clone)]
pub enum Error {
#[error(display = "Request cancelled")]
Aborted,

#[error(display = "Hyper error")]
HyperError(#[error(source)] hyper::Error),
HyperError(#[error(source)] Arc<hyper::Error>),

#[error(display = "Invalid header value")]
InvalidHeaderError(#[error(source)] http::header::InvalidHeaderValue),
InvalidHeaderError,

#[error(display = "HTTP error")]
HttpError(#[error(source)] http::Error),
HttpError(#[error(source)] Arc<http::Error>),

#[error(display = "Request timed out")]
TimeoutError(#[error(source)] tokio::time::error::Elapsed),
TimeoutError,

#[error(display = "Failed to deserialize data")]
DeserializeError(#[error(source)] serde_json::Error),
DeserializeError(#[error(source)] Arc<serde_json::Error>),

#[error(display = "Failed to send request to rest client")]
SendError,
Expand All @@ -76,7 +76,7 @@ pub enum Error {

/// The string given was not a valid URI.
#[error(display = "Not a valid URI")]
UriError(#[error(source)] http::uri::InvalidUri),
InvalidUri,

/// A new API config was requested, but the request could not be completed.
#[error(display = "Failed to rotate API config")]
Expand All @@ -85,7 +85,7 @@ pub enum Error {

impl Error {
pub fn is_network_error(&self) -> bool {
matches!(self, Error::HyperError(_) | Error::TimeoutError(_))
matches!(self, Error::HyperError(_) | Error::TimeoutError)
}

pub fn is_aborted(&self) -> bool {
Expand Down Expand Up @@ -203,7 +203,7 @@ impl<
let future = async move {
let response = tokio::time::timeout(timeout, request_future)
.await
.map_err(Error::TimeoutError);
.map_err(|_| Error::TimeoutError);

let response = flatten_result(response).map_err(|error| error.map_aborted());

Expand Down Expand Up @@ -314,20 +314,20 @@ pub struct RestRequest {
impl RestRequest {
/// Constructs a GET request with the given URI. Returns an error if the URI is not valid.
pub fn get(uri: &str) -> Result<Self> {
let uri = hyper::Uri::from_str(uri).map_err(Error::UriError)?;
let uri = hyper::Uri::from_str(uri).map_err(|_| Error::InvalidUri)?;

let mut builder = http::request::Builder::new()
.method(Method::GET)
.header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"));
if let Some(host) = uri.host() {
builder = builder.header(header::HOST, HeaderValue::from_str(host)?);
builder = builder.header(
header::HOST,
HeaderValue::from_str(host).map_err(|_| Error::InvalidHeaderError)?,
);
};

let request = builder
.uri(uri)
.body(hyper::Body::empty())
.map_err(Error::HttpError)?;
let request = builder.uri(uri).body(hyper::Body::empty())?;

Ok(RestRequest {
timeout: DEFAULT_TIMEOUT,
Expand All @@ -341,7 +341,7 @@ impl RestRequest {
let header = match auth {
Some(auth) => Some(
HeaderValue::from_str(&format!("Bearer {auth}"))
.map_err(Error::InvalidHeaderError)?,
.map_err(|_| Error::InvalidHeaderError)?,
),
None => None,
};
Expand All @@ -361,7 +361,8 @@ impl RestRequest {
}

pub fn add_header<T: header::IntoHeaderName>(&mut self, key: T, value: &str) -> Result<()> {
let header_value = http::HeaderValue::from_str(value).map_err(Error::InvalidHeaderError)?;
let header_value =
http::HeaderValue::from_str(value).map_err(|_| Error::InvalidHeaderError)?;
self.request.headers_mut().insert(key, header_value);
Ok(())
}
Expand Down Expand Up @@ -458,7 +459,8 @@ impl RequestFactory {
let headers = request.headers_mut();
headers.insert(
header::CONTENT_LENGTH,
HeaderValue::from_str(&body_length.to_string()).map_err(Error::InvalidHeaderError)?,
HeaderValue::from_str(&body_length.to_string())
.map_err(|_| Error::InvalidHeaderError)?,
);
headers.insert(
header::CONTENT_TYPE,
Expand All @@ -483,13 +485,14 @@ impl RequestFactory {
.header(header::ACCEPT, HeaderValue::from_static("application/json"))
.header(header::HOST, self.hostname.clone());

request.body(hyper::Body::empty()).map_err(Error::HttpError)
let result = request.body(hyper::Body::empty())?;
Ok(result)
}

fn get_uri(&self, path: &str) -> Result<Uri> {
let prefix = self.path_prefix.as_ref().map(AsRef::as_ref).unwrap_or("");
let uri = format!("https://{}/{}{}", self.hostname, prefix, path);
hyper::Uri::from_str(&uri).map_err(Error::UriError)
hyper::Uri::from_str(&uri).map_err(|_| Error::InvalidUri)
}

fn set_request_timeout(&self, mut request: RestRequest) -> RestRequest {
Expand Down Expand Up @@ -548,7 +551,7 @@ async fn deserialize_body_inner<T: serde::de::DeserializeOwned>(
body.extend(&chunk?);
}

serde_json::from_slice(&body).map_err(Error::DeserializeError)
serde_json::from_slice(&body).map_err(Error::from)
}

fn get_body_length(response: &Response) -> usize {
Expand Down Expand Up @@ -710,3 +713,21 @@ fn flatten_result<T, E>(
Err(err) => Err(err),
}
}

impl From<hyper::Error> for Error {
fn from(value: hyper::Error) -> Self {
Error::HyperError(Arc::new(value))
}
}

impl From<serde_json::Error> for Error {
fn from(value: serde_json::Error) -> Self {
Error::DeserializeError(Arc::new(value))
}
}

impl From<http::Error> for Error {
fn from(value: http::Error) -> Self {
Error::HttpError(Arc::new(value))
}
}
2 changes: 1 addition & 1 deletion mullvad-daemon/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub enum Error {
#[error(display = "Failed parse device cache")]
ParseDeviceCache(#[error(source)] serde_json::Error),
#[error(display = "Unexpected HTTP request error")]
OtherRestError(#[error(source)] Arc<rest::Error>),
OtherRestError(#[error(source)] rest::Error),
#[error(display = "The device update task is not running")]
Cancelled,
/// Intended to be broadcast to requesters
Expand Down
Loading

0 comments on commit a261f95

Please sign in to comment.