Skip to content

Commit

Permalink
Merge branch 'fold-access-token-requests'
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 19, 2023
2 parents 58b09e8 + f68eaf6 commit 437dc5b
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 153 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Line wrap the file at 100 chars. Th

### Fixed
- Show correct endpoint in CLI for custom relays.
- Lower risk of being rate limited.

#### Windows
- Correctly detect whether OS is Windows Server (primarily for logging in daemon.log).
Expand Down
231 changes: 151 additions & 80 deletions mullvad-api/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,109 +2,180 @@ use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
};
use futures::{
channel::{mpsc, oneshot},
StreamExt,
};
use hyper::StatusCode;
use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use talpid_types::ErrorExt;
use std::collections::HashMap;
use tokio::select;

pub const AUTH_URL_PREFIX: &str = "auth/v1";

#[derive(Clone)]
pub struct AccessTokenProxy {
service: RequestServiceHandle,
factory: RequestFactory,
access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
pub struct AccessTokenStore {
tx: mpsc::UnboundedSender<StoreAction>,
}

enum StoreAction {
/// Request an access token for `AccountToken`, or return a saved one if it's not expired.
GetAccessToken(
AccountToken,
oneshot::Sender<Result<AccessToken, rest::Error>>,
),
/// Forget cached access token for `AccountToken`, and drop any in-flight requests
InvalidateToken(AccountToken),
}

#[derive(Default)]
struct AccountState {
current_access_token: Option<AccessTokenData>,
inflight_request: Option<tokio::task::JoinHandle<()>>,
response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>,
}

impl AccessTokenProxy {
impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
Self {
service,
factory,
access_from_account: Arc::new(Mutex::new(HashMap::new())),
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
}

async fn service_requests(
mut rx: mpsc::UnboundedReceiver<StoreAction>,
service: RequestServiceHandle,
factory: RequestFactory,
) {
let mut account_states: HashMap<AccountToken, AccountState> = HashMap::new();

let (completed_tx, mut completed_rx) = mpsc::unbounded();

loop {
select! {
action = rx.next() => {
let Some(action) = action else {
// We're done
break;
};

match action {
StoreAction::GetAccessToken(account, response_tx) => {
let account_state = account_states
.entry(account.clone())
.or_default();

// If there is an unexpired access token, just return it.
// Otherwise, generate a new token
if let Some(ref access_token) = account_state.current_access_token {
if !access_token.is_expired() {
log::trace!("Using stored access token");
let _ = response_tx.send(Ok(access_token.access_token.clone()));
continue;
}

log::debug!("Replacing expired access token");
account_state.current_access_token = None;
}

// Begin requesting an access token if it's not already underway.
// If there's already an inflight request, just save `response_tx`
account_state
.inflight_request
.get_or_insert_with(|| {
let completed_tx = completed_tx.clone();
let account = account.clone();
let service = service.clone();
let factory = factory.clone();

log::debug!("Fetching access token for an account");

tokio::spawn(async move {
let result = fetch_access_token(service, factory, account.clone()).await;
let _ = completed_tx.unbounded_send((account, result));
})
});

// Save the channel to respond to later
account_state.response_channels.push(response_tx);
}
StoreAction::InvalidateToken(account) => {
let account_state = account_states
.entry(account)
.or_default();

// Drop in-flight requests for the account
// & forget any existing access token

log::debug!("Invalidating access token for an account");

if let Some(task) = account_state.inflight_request.take() {
task.abort();
let _ = task.await;
}

account_state.response_channels.clear();
account_state.current_access_token = None;
}
}
}

Some((account, result)) = completed_rx.next() => {
let account_state = account_states
.entry(account)
.or_default();

account_state.inflight_request = None;

// Send response to all channels
for tx in account_state.response_channels.drain(..) {
let _ = tx.send(result.clone().map(|data| data.access_token));
}

if let Ok(access_token) = result {
account_state.current_access_token = Some(access_token);
}
}
}
}
}

/// Obtain access token for an account, requesting a new one from the API if necessary.
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
let existing_token = {
self.access_from_account
.lock()
.unwrap()
.get(account.as_str())
.cloned()
};
if let Some(access_token) = existing_token {
if access_token.is_expired() {
log::debug!("Replacing expired access token");
return self.request_new_token(account.clone()).await;
}
log::trace!("Using stored access token");
return Ok(access_token.access_token.clone());
}
self.request_new_token(account.clone()).await
let (tx, rx) = oneshot::channel();
let _ = self
.tx
.unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx));
rx.await.map_err(|_| rest::Error::Aborted)?
}

/// Remove an access token if the API response calls for it.
pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
pub fn check_response<T>(&self, account: &AccountToken, response: &Result<T, rest::Error>) {
if let Err(rest::Error::ApiError(_status, code)) = response {
if code == crate::INVALID_ACCESS_TOKEN {
log::debug!("Dropping invalid access token");
self.remove_token(account);
let _ = self
.tx
.unbounded_send(StoreAction::InvalidateToken(account.to_owned()));
}
}
}
}

/// Removes a stored access token.
fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
self.access_from_account
.lock()
.unwrap()
.remove(account)
.map(|v| v.access_token)
}

async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
log::debug!("Fetching access token for an account");
let access_token = self
.fetch_access_token(account.clone())
.await
.map_err(|error| {
log::error!(
"{}",
error.display_chain_with_msg("Failed to obtain access token")
);
error
})?;
self.access_from_account
.lock()
.unwrap()
.insert(account, access_token.clone());
Ok(access_token.access_token)
async fn fetch_access_token(
service: RequestServiceHandle,
factory: RequestFactory,
account_token: AccountToken,
) -> Result<AccessTokenData, rest::Error> {
#[derive(serde::Serialize)]
struct AccessTokenRequest {
account_number: String,
}
let request = AccessTokenRequest {
account_number: account_token,
};

async fn fetch_access_token(
&self,
account_token: AccountToken,
) -> Result<AccessTokenData, rest::Error> {
#[derive(serde::Serialize)]
struct AccessTokenRequest {
account_number: String,
}
let request = AccessTokenRequest {
account_number: account_token,
};

let service = self.service.clone();

let rest_request = self
.factory
.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
let response = service.request(rest_request).await?;
let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
rest::deserialize_body(response).await
}
let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
let response = service.request(rest_request).await?;
let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
rest::deserialize_body(response).await
}
2 changes: 1 addition & 1 deletion mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() {

let relay_list = match relay_list_request {
Ok(relay_list) => relay_list,
Err(RestError::TimeoutError(_)) => {
Err(RestError::TimeoutError) => {
eprintln!("Request timed out");
process::exit(2);
}
Expand Down
30 changes: 23 additions & 7 deletions mullvad-api/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,21 @@ impl DevicesProxy {
let access_proxy = self.handle.token_store.clone();

async move {
let access_token = access_proxy.get_token(&account).await?;

let response = rest::send_json_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices"),
Method::POST,
&submission,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::CREATED],
)
.await;

access_proxy.check_response(&account, &response);

let response: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
id,
Expand Down Expand Up @@ -102,16 +106,19 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
Method::GET,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::OK],
)
.await;
rest::deserialize_body(response?).await
access_proxy.check_response(&account, &response);
let device = rest::deserialize_body(response?).await?;
Ok(device)
}
}

Expand All @@ -123,16 +130,19 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices"),
Method::GET,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::OK],
)
.await;
rest::deserialize_body(response?).await
access_proxy.check_response(&account, &response);
let devices = rest::deserialize_body(response?).await?;
Ok(devices)
}
}

Expand All @@ -145,15 +155,17 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
Method::DELETE,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::NO_CONTENT],
)
.await;
access_proxy.check_response(&account, &response);

response?;
Ok(())
Expand All @@ -178,17 +190,21 @@ impl DevicesProxy {
let access_proxy = self.handle.token_store.clone();

async move {
let access_token = access_proxy.get_token(&account).await?;

let response = rest::send_json_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
Method::PUT,
&req_body,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::OK],
)
.await;

access_proxy.check_response(&account, &response);

let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
ipv4_address,
Expand Down
Loading

0 comments on commit 437dc5b

Please sign in to comment.