Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coalesce concurrent access token requests #5323

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Line wrap the file at 100 chars. Th

### Fixed
- Show correct endpoint in CLI for custom relays.
- Lower risk of being rate limited.

#### Windows
- Correctly detect whether OS is Windows Server (primarily for logging in daemon.log).
Expand Down
231 changes: 151 additions & 80 deletions mullvad-api/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,109 +2,180 @@ use crate::{
rest,
rest::{RequestFactory, RequestServiceHandle},
};
use futures::{
channel::{mpsc, oneshot},
StreamExt,
};
use hyper::StatusCode;
use mullvad_types::account::{AccessToken, AccessTokenData, AccountToken};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
use talpid_types::ErrorExt;
use std::collections::HashMap;
use tokio::select;

pub const AUTH_URL_PREFIX: &str = "auth/v1";

#[derive(Clone)]
pub struct AccessTokenProxy {
service: RequestServiceHandle,
factory: RequestFactory,
access_from_account: Arc<Mutex<HashMap<AccountToken, AccessTokenData>>>,
pub struct AccessTokenStore {
tx: mpsc::UnboundedSender<StoreAction>,
}

enum StoreAction {
/// Request an access token for `AccountToken`, or return a saved one if it's not expired.
GetAccessToken(
AccountToken,
oneshot::Sender<Result<AccessToken, rest::Error>>,
),
/// Forget cached access token for `AccountToken`, and drop any in-flight requests
InvalidateToken(AccountToken),
}

#[derive(Default)]
struct AccountState {
current_access_token: Option<AccessTokenData>,
inflight_request: Option<tokio::task::JoinHandle<()>>,
response_channels: Vec<oneshot::Sender<Result<AccessToken, rest::Error>>>,
}

impl AccessTokenProxy {
impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle, factory: RequestFactory) -> Self {
Self {
service,
factory,
access_from_account: Arc::new(Mutex::new(HashMap::new())),
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
}

async fn service_requests(
mut rx: mpsc::UnboundedReceiver<StoreAction>,
service: RequestServiceHandle,
factory: RequestFactory,
) {
let mut account_states: HashMap<AccountToken, AccountState> = HashMap::new();

let (completed_tx, mut completed_rx) = mpsc::unbounded();

loop {
select! {
action = rx.next() => {
let Some(action) = action else {
// We're done
break;
};

match action {
StoreAction::GetAccessToken(account, response_tx) => {
let account_state = account_states
.entry(account.clone())
.or_default();

// If there is an unexpired access token, just return it.
// Otherwise, generate a new token
if let Some(ref access_token) = account_state.current_access_token {
if !access_token.is_expired() {
log::trace!("Using stored access token");
let _ = response_tx.send(Ok(access_token.access_token.clone()));
continue;
}

log::debug!("Replacing expired access token");
account_state.current_access_token = None;
}

// Begin requesting an access token if it's not already underway.
// If there's already an inflight request, just save `response_tx`
account_state
.inflight_request
.get_or_insert_with(|| {
let completed_tx = completed_tx.clone();
let account = account.clone();
let service = service.clone();
let factory = factory.clone();

log::debug!("Fetching access token for an account");

tokio::spawn(async move {
let result = fetch_access_token(service, factory, account.clone()).await;
let _ = completed_tx.unbounded_send((account, result));
})
});

// Save the channel to respond to later
account_state.response_channels.push(response_tx);
}
StoreAction::InvalidateToken(account) => {
let account_state = account_states
.entry(account)
.or_default();

// Drop in-flight requests for the account
// & forget any existing access token

log::debug!("Invalidating access token for an account");

if let Some(task) = account_state.inflight_request.take() {
task.abort();
let _ = task.await;
}

account_state.response_channels.clear();
account_state.current_access_token = None;
}
}
}

Some((account, result)) = completed_rx.next() => {
let account_state = account_states
.entry(account)
.or_default();

account_state.inflight_request = None;

// Send response to all channels
for tx in account_state.response_channels.drain(..) {
let _ = tx.send(result.clone().map(|data| data.access_token));
}

if let Ok(access_token) = result {
account_state.current_access_token = Some(access_token);
}
}
}
}
}

/// Obtain access token for an account, requesting a new one from the API if necessary.
pub async fn get_token(&self, account: &AccountToken) -> Result<AccessToken, rest::Error> {
let existing_token = {
self.access_from_account
.lock()
.unwrap()
.get(account.as_str())
.cloned()
};
if let Some(access_token) = existing_token {
if access_token.is_expired() {
log::debug!("Replacing expired access token");
return self.request_new_token(account.clone()).await;
}
log::trace!("Using stored access token");
return Ok(access_token.access_token.clone());
}
self.request_new_token(account.clone()).await
let (tx, rx) = oneshot::channel();
let _ = self
.tx
.unbounded_send(StoreAction::GetAccessToken(account.to_owned(), tx));
rx.await.map_err(|_| rest::Error::Aborted)?
}

/// Remove an access token if the API response calls for it.
pub fn check_response<T>(&self, account: &AccessToken, response: &Result<T, rest::Error>) {
pub fn check_response<T>(&self, account: &AccountToken, response: &Result<T, rest::Error>) {
if let Err(rest::Error::ApiError(_status, code)) = response {
if code == crate::INVALID_ACCESS_TOKEN {
log::debug!("Dropping invalid access token");
self.remove_token(account);
let _ = self
.tx
.unbounded_send(StoreAction::InvalidateToken(account.to_owned()));
}
}
}
}

/// Removes a stored access token.
fn remove_token(&self, account: &AccountToken) -> Option<AccessToken> {
self.access_from_account
.lock()
.unwrap()
.remove(account)
.map(|v| v.access_token)
}

async fn request_new_token(&self, account: AccountToken) -> Result<AccessToken, rest::Error> {
log::debug!("Fetching access token for an account");
let access_token = self
.fetch_access_token(account.clone())
.await
.map_err(|error| {
log::error!(
"{}",
error.display_chain_with_msg("Failed to obtain access token")
);
error
})?;
self.access_from_account
.lock()
.unwrap()
.insert(account, access_token.clone());
Ok(access_token.access_token)
async fn fetch_access_token(
service: RequestServiceHandle,
factory: RequestFactory,
account_token: AccountToken,
) -> Result<AccessTokenData, rest::Error> {
#[derive(serde::Serialize)]
struct AccessTokenRequest {
account_number: String,
}
let request = AccessTokenRequest {
account_number: account_token,
};

async fn fetch_access_token(
&self,
account_token: AccountToken,
) -> Result<AccessTokenData, rest::Error> {
#[derive(serde::Serialize)]
struct AccessTokenRequest {
account_number: String,
}
let request = AccessTokenRequest {
account_number: account_token,
};

let service = self.service.clone();

let rest_request = self
.factory
.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
let response = service.request(rest_request).await?;
let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
rest::deserialize_body(response).await
}
let rest_request = factory.post_json(&format!("{AUTH_URL_PREFIX}/token"), &request)?;
let response = service.request(rest_request).await?;
let response = rest::parse_rest_response(response, &[StatusCode::OK]).await?;
rest::deserialize_body(response).await
}
2 changes: 1 addition & 1 deletion mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn main() {

let relay_list = match relay_list_request {
Ok(relay_list) => relay_list,
Err(RestError::TimeoutError(_)) => {
Err(RestError::TimeoutError) => {
eprintln!("Request timed out");
process::exit(2);
}
Expand Down
30 changes: 23 additions & 7 deletions mullvad-api/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,21 @@ impl DevicesProxy {
let access_proxy = self.handle.token_store.clone();

async move {
let access_token = access_proxy.get_token(&account).await?;

let response = rest::send_json_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices"),
Method::POST,
&submission,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::CREATED],
)
.await;

access_proxy.check_response(&account, &response);

let response: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
id,
Expand Down Expand Up @@ -102,16 +106,19 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
Method::GET,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::OK],
)
.await;
rest::deserialize_body(response?).await
access_proxy.check_response(&account, &response);
let device = rest::deserialize_body(response?).await?;
Ok(device)
}
}

Expand All @@ -123,16 +130,19 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices"),
Method::GET,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::OK],
)
.await;
rest::deserialize_body(response?).await
access_proxy.check_response(&account, &response);
let devices = rest::deserialize_body(response?).await?;
Ok(devices)
}
}

Expand All @@ -145,15 +155,17 @@ impl DevicesProxy {
let factory = self.handle.factory.clone();
let access_proxy = self.handle.token_store.clone();
async move {
let access_token = access_proxy.get_token(&account).await?;
let response = rest::send_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}"),
Method::DELETE,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::NO_CONTENT],
)
.await;
access_proxy.check_response(&account, &response);

response?;
Ok(())
Expand All @@ -178,17 +190,21 @@ impl DevicesProxy {
let access_proxy = self.handle.token_store.clone();

async move {
let access_token = access_proxy.get_token(&account).await?;

let response = rest::send_json_request(
&factory,
service,
&format!("{ACCOUNTS_URL_PREFIX}/devices/{id}/pubkey"),
Method::PUT,
&req_body,
Some((access_proxy, account)),
Some(access_token),
&[StatusCode::OK],
)
.await;

access_proxy.check_response(&account, &response);

let updated_device: DeviceResponse = rest::deserialize_body(response?).await?;
let DeviceResponse {
ipv4_address,
Expand Down
Loading
Loading