Skip to content

Commit

Permalink
Synchronize mullvad-api and mullvad-daemon when the api-override
Browse files Browse the repository at this point in the history
feature is enabled

Move the logic for using overridden API endpoints for API calls from
`mullvad-api::rest` to `mullvad_daemon::api`. This is in line with how
the interaction between the two crates work for a normal release build,
i.e. when the `api-override` feature is disabled.

This commit also removes references to `force_direct_connection` in the
Android code. The flag does not exist in the `mullvad-*` rust crates
anymore, so it would be erroneous to try to serialize/deserialize the
value from the Android client.
  • Loading branch information
MarkusPettersson98 committed Jan 10, 2024
1 parent e7f11f1 commit 503888d
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ import kotlinx.parcelize.Parcelize
data class ApiEndpoint(
val address: InetSocketAddress,
val disableAddressCache: Boolean,
val disableTls: Boolean,
val forceDirectConnection: Boolean
val disableTls: Boolean
) : Parcelable
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ data class CustomApiEndpointConfiguration(
val hostname: String,
val port: Int,
val disableAddressCache: Boolean = true,
val disableTls: Boolean = false,
val forceDirectConnection: Boolean = true
val disableTls: Boolean = false
) : ApiEndpointConfiguration {
override fun apiEndpoint() =
ApiEndpoint(
address = InetSocketAddress(hostname, port),
disableAddressCache = disableAddressCache,
disableTls = disableTls,
forceDirectConnection = forceDirectConnection
disableTls = disableTls
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ abstract class MockApiTest {
InetAddress.getLocalHost().hostName,
port,
disableAddressCache = true,
disableTls = true,
forceDirectConnection = true
disableTls = true
)
}
}
2 changes: 1 addition & 1 deletion mullvad-api/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct AccountState {

impl AccessTokenStore {
pub(crate) fn new(service: RequestServiceHandle) -> Self {
let factory = rest::RequestFactory::new(&API.host, None);
let factory = rest::RequestFactory::new(API.host(), None);
let (tx, rx) = mpsc::unbounded();
tokio::spawn(Self::service_requests(rx, service, factory));
Self { tx }
Expand Down
4 changes: 2 additions & 2 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct AddressCache {
impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> {
Self::new_inner(API.addr, write_path)
Self::new_inner(API.address(), write_path)
}

/// Initialize cache using `read_path`, and write changes to `write_path`.
Expand All @@ -53,7 +53,7 @@ impl AddressCache {

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(&API.host) {
if hostname.eq_ignore_ascii_case(&API.host()) {
Some(self.get_address().await)
} else {
None
Expand Down
200 changes: 136 additions & 64 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,95 +103,167 @@ impl<T> Deref for LazyManual<T> {
/// A hostname and socketaddr to reach the Mullvad REST API over.
#[derive(Debug)]
pub struct ApiEndpoint {
pub host: String,
pub addr: SocketAddr,
/// An overriden API hostname. Initialized with the value of the environment
/// variable `MULLVAD_API_HOSt` if it has been set.
///
/// Use the associated function [`Self::host`] to read this value with a
/// default fallback if `MULLVAD_API_HOST` was not set.
pub host: Option<String>,
/// An overriden API address. Initialized with the value of the environment
/// variable `MULLVAD_API_ADDR` if it has been set.
///
/// Use the associated function [`Self::address()`] to read this value with
/// a default fallback if `MULLVAD_API_ADDR` was not set.
///
/// # Note
///
/// If [`Self::address`] is populated with [`Some(SocketAddr)`], it should
/// always be respected when establishing API connections.
pub address: Option<SocketAddr>,
#[cfg(feature = "api-override")]
pub disable_address_cache: bool,
#[cfg(feature = "api-override")]
pub disable_tls: bool,
#[cfg(feature = "api-override")]
pub force_direct_connection: bool,
}

impl ApiEndpoint {
const API_HOST_DEFAULT: &'static str = "api.mullvad.net";
const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(45, 83, 223, 196));
const API_PORT_DEFAULT: u16 = 443;

const API_HOST_VAR: &'static str = "MULLVAD_API_HOST";
const API_ADDR_VAR: &'static str = "MULLVAD_API_ADDR";
const DISABLE_TLS_VAR: &'static str = "MULLVAD_API_DISABLE_TLS";

/// Returns the endpoint to connect to the API over.
///
/// # Panics
///
/// Panics if `MULLVAD_API_ADDR` has invalid contents or if only one of
/// `MULLVAD_API_ADDR` or `MULLVAD_API_HOST` has been set but not the other.
#[cfg(feature = "api-override")]
pub fn from_env_vars() -> ApiEndpoint {
const API_HOST_DEFAULT: &str = "api.mullvad.net";
const API_IP_DEFAULT: IpAddr = IpAddr::V4(Ipv4Addr::new(45, 83, 223, 196));
const API_PORT_DEFAULT: u16 = 443;

fn read_var(key: &'static str) -> Option<String> {
use std::env;
match env::var(key) {
Ok(v) => Some(v),
Err(env::VarError::NotPresent) => None,
Err(env::VarError::NotUnicode(_)) => panic!("{key} does not contain valid UTF-8"),
}
}
use std::net::ToSocketAddrs;

let host_var = read_var("MULLVAD_API_HOST");
let address_var = read_var("MULLVAD_API_ADDR");
let disable_tls_var = read_var("MULLVAD_API_DISABLE_TLS");
let host_var = Self::read_var(ApiEndpoint::API_HOST_VAR);
let address_var = Self::read_var(ApiEndpoint::API_ADDR_VAR);
let disable_tls_var = Self::read_var(ApiEndpoint::DISABLE_TLS_VAR);

#[cfg_attr(not(feature = "api-override"), allow(unused_mut))]
let mut api = ApiEndpoint {
host: API_HOST_DEFAULT.to_owned(),
addr: SocketAddr::new(API_IP_DEFAULT, API_PORT_DEFAULT),
#[cfg(feature = "api-override")]
disable_address_cache: false,
#[cfg(feature = "api-override")]
disable_tls: false,
#[cfg(feature = "api-override")]
force_direct_connection: false,
host: host_var.clone(),
address: None,
disable_address_cache: true,
disable_tls: disable_tls_var
.as_ref()
.map(|disable_tls| disable_tls != "0")
.unwrap_or(false),
};

#[cfg(feature = "api-override")]
{
use std::net::ToSocketAddrs;

if host_var.is_none() && address_var.is_none() {
if disable_tls_var.is_some() {
log::warn!("MULLVAD_API_DISABLE_TLS is ignored since MULLVAD_API_HOST and MULLVAD_API_ADDR are not set");
}
return api;
}

let scheme = if let Some(disable_tls_var) = disable_tls_var {
api.disable_tls = disable_tls_var != "0";
"http://"
} else {
"https://"
};

if let Some(user_host) = host_var {
api.host = user_host;
api.address = match address_var {
Some(user_addr) => {
let addr = user_addr.parse().unwrap_or_else(|_| {
panic!(
"{api_addr} is not a valid socketaddr",
api_addr = ApiEndpoint::API_ADDR_VAR,
)
});
Some(addr)
}
if let Some(user_addr) = address_var {
api.addr = user_addr
.parse()
.expect("MULLVAD_API_ADDR is not a valid socketaddr");
} else {
log::warn!("Resolving API IP from MULLVAD_API_HOST");
api.addr = format!("{}:{}", api.host, API_PORT_DEFAULT)
None => {
log::warn!(
"Resolving API host from {api_host}",
api_host = ApiEndpoint::API_HOST_VAR
);
format!("{}:{}", api.host(), ApiEndpoint::API_PORT_DEFAULT)
.to_socket_addrs()
.expect("failed to resolve API host")
.next()
.expect("API host yielded 0 addresses");
}
api.disable_address_cache = true;
api.force_direct_connection = true;
log::debug!("Overriding API. Using {} at {scheme}{}", api.host, api.addr);
};

if api.host.is_none() && api.address.is_none() {
if disable_tls_var.is_some() {
log::warn!(
"{disable_tls} is ignored since {api_host} and {api_addr} are not set",
disable_tls = ApiEndpoint::DISABLE_TLS_VAR,
api_host = ApiEndpoint::API_HOST_VAR,
api_addr = ApiEndpoint::API_ADDR_VAR,
);
}
} else {
log::debug!(
"Overriding API. Using {host} at {scheme}{addr}",
host = api.host(),
addr = api.address(),
scheme = if api.disable_tls {
"http://"
} else {
"https://"
}
);
}
#[cfg(not(feature = "api-override"))]
api
}

/// Returns the endpoint to connect to the API over.
///
/// # Panics
///
/// Panics if `MULLVAD_API_ADDR`, `MULLVAD_API_HOST` or
/// `MULLVAD_API_DISABLE_TLS` has invalid contents.
#[cfg(not(feature = "api-override"))]
pub fn from_env_vars() -> ApiEndpoint {
let host_var = Self::read_var(ApiEndpoint::API_HOST_VAR);
let address_var = Self::read_var(ApiEndpoint::API_ADDR_VAR);
let disable_tls_var = Self::read_var(ApiEndpoint::DISABLE_TLS_VAR);

if host_var.is_some() || address_var.is_some() || disable_tls_var.is_some() {
log::warn!("These variables are ignored in production builds: MULLVAD_API_HOST, MULLVAD_API_ADDR, MULLVAD_API_DISABLE_TLS");
log::warn!(
"These variables are ignored in production builds: {api_host}, {api_addr}, {disable_tls}",
api_host = ApiEndpoint::API_HOST_VAR,
api_addr = ApiEndpoint::API_ADDR_VAR,
disable_tls = ApiEndpoint::DISABLE_TLS_VAR
);
}

ApiEndpoint {
host: None,
address: None,
}
}

/// Read the [`Self::host`] value, falling back to
/// [`Self::API_HOST_DEFAULT`] as default value if it does not exist.
pub fn host(&self) -> String {
self.host
.clone()
.unwrap_or(ApiEndpoint::API_HOST_DEFAULT.to_string())
}

/// Read the [`Self::address`] value, falling back to
/// [`Self::API_IP_DEFAULT`]:[`Self::API_PORT_DEFAULT`] as default if it
/// does not exist.
pub fn address(&self) -> SocketAddr {
self.address.unwrap_or(SocketAddr::new(
ApiEndpoint::API_IP_DEFAULT,
ApiEndpoint::API_PORT_DEFAULT,
))
}

/// Try to read the value of an environment variable. Returns `None` if the
/// environment variable has not been set.
///
/// # Panics
///
/// Panics if the environment variable was found, but it did not contain
/// valid unicode data.
fn read_var(key: &'static str) -> Option<String> {
use std::env;
match env::var(key) {
Ok(v) => Some(v),
Err(env::VarError::NotPresent) => None,
Err(env::VarError::NotUnicode(_)) => panic!("{key} does not contain valid UTF-8"),
}
api
}
}

Expand Down Expand Up @@ -314,14 +386,14 @@ impl Runtime {
) -> rest::MullvadRestHandle {
let service = self
.new_request_service(
Some(API.host.clone()),
Some(API.host()),
proxy_provider,
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
)
.await;
let token_store = access::AccessTokenStore::new(service.clone());
let factory = rest::RequestFactory::new(&API.host, Some(token_store));
let factory = rest::RequestFactory::new(API.host(), Some(token_store));

rest::MullvadRestHandle::new(
service,
Expand Down
37 changes: 12 additions & 25 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ use std::{
};
use talpid_types::ErrorExt;

#[cfg(feature = "api-override")]
use crate::API;

pub use hyper::StatusCode;

const USER_AGENT: &str = "mullvad-app";
Expand Down Expand Up @@ -147,14 +144,7 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
socket_bypass_tx.clone(),
);

#[cfg(feature = "api-override")]
let force_direct_connection = API.force_direct_connection;
#[cfg(not(feature = "api-override"))]
let force_direct_connection = false;

if force_direct_connection {
log::debug!("API proxies are disabled");
} else if let Some(config) = proxy_config_provider.next().await {
if let Some(config) = proxy_config_provider.next().await {
connector_handle.set_connection_mode(config);
}

Expand Down Expand Up @@ -185,17 +175,9 @@ impl<T: Stream<Item = ApiConnectionMode> + Unpin + Send + 'static> RequestServic
self.connector_handle.reset();
}
RequestCommand::NextApiConfig(completion_tx) => {
#[cfg(feature = "api-override")]
let force_direct_connection = API.force_direct_connection;
#[cfg(not(feature = "api-override"))]
let force_direct_connection = false;

if force_direct_connection {
log::debug!("Ignoring API connection mode");
} else if let Some(connection_mode) = self.proxy_config_provider.next().await {
if let Some(connection_mode) = self.proxy_config_provider.next().await {
self.connector_handle.set_connection_mode(connection_mode);
}

let _ = completion_tx.send(Ok(()));
}
}
Expand Down Expand Up @@ -458,13 +440,13 @@ struct NewErrorResponse {

#[derive(Clone)]
pub struct RequestFactory {
hostname: &'static str,
hostname: String,
token_store: Option<AccessTokenStore>,
default_timeout: Duration,
}

impl RequestFactory {
pub fn new(hostname: &'static str, token_store: Option<AccessTokenStore>) -> Self {
pub fn new(hostname: String, token_store: Option<AccessTokenStore>) -> Self {
Self {
hostname,
token_store,
Expand Down Expand Up @@ -545,7 +527,10 @@ impl RequestFactory {
.uri(uri)
.header(header::USER_AGENT, HeaderValue::from_static(USER_AGENT))
.header(header::ACCEPT, HeaderValue::from_static("application/json"))
.header(header::HOST, HeaderValue::from_static(self.hostname));
.header(
header::HOST,
HeaderValue::from_str(&self.hostname).map_err(|_| Error::InvalidHeaderError)?,
);

let result = request.body(hyper::Body::empty())?;
Ok(result)
Expand Down Expand Up @@ -632,8 +617,10 @@ impl MullvadRestHandle {
availability,
};
#[cfg(feature = "api-override")]
if API.disable_address_cache {
return handle;
{
if crate::API.disable_address_cache {
return handle;
}
}
handle.spawn_api_address_fetcher(address_cache);
handle
Expand Down
Loading

0 comments on commit 503888d

Please sign in to comment.