Skip to content

Commit

Permalink
Merge branch 'simplify-api-addr-cache'
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Dec 2, 2024
2 parents 8bc1412 + 354665e commit 8bda210
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 113 deletions.
30 changes: 22 additions & 8 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
use super::API;
use crate::DnsResolver;
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
fs,
Expand All @@ -23,6 +25,17 @@ pub enum Error {
Write(#[source] io::Error),
}

/// A DNS resolver which resolves using `AddressCache`.
#[async_trait]
impl DnsResolver for AddressCache {
async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> {
self.resolve_hostname(&host)
.await
.map(|addr| vec![addr])
.ok_or(io::Error::other("host does not match API host"))
}
}

#[derive(Clone)]
pub struct AddressCache {
inner: Arc<Mutex<AddressCacheInner>>,
Expand All @@ -31,34 +44,35 @@ pub struct AddressCache {

impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> {
pub fn new(write_path: Option<Box<Path>>) -> Self {
Self::new_inner(API.address(), write_path)
}

pub fn with_static_addr(address: SocketAddr) -> Self {
Self::new_inner(address, None)
.expect("Failed to construct an address cache from a static address")
}

/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
Self::new_inner(read_address_file(read_path).await?, write_path)
Ok(Self::new_inner(
read_address_file(read_path).await?,
write_path,
))
}

fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> {
fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
let cache = AddressCacheInner::from_address(address);
log::debug!("Using API address: {}", cache.address);

let address_cache = Self {
Self {
inner: Arc::new(Mutex::new(cache)),
write_path: write_path.map(Arc::from),
};
Ok(address_cache)
}
}

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
Some(self.get_address().await)
} else {
Expand Down
6 changes: 2 additions & 4 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
};
use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");

let relay_list_request =
Expand Down
52 changes: 16 additions & 36 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
AddressCache, DnsResolver,
DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -286,8 +286,6 @@ impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
#[derive(Clone)]
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
sni_hostname: Option<String>,
address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
Expand All @@ -304,8 +302,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);

impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
Expand Down Expand Up @@ -352,8 +348,6 @@ impl HttpsConnectorWithSni {
(
HttpsConnectorWithSni {
inner,
sni_hostname,
address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -390,13 +384,9 @@ impl HttpsConnectorWithSni {
}

/// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
/// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
/// Otherwise `dns_resolver` will be used as a fallback.
/// If the URI contains a port, then that port will be used.
async fn resolve_address(
address_cache: AddressCache,
dns_resolver: &dyn DnsResolver,
uri: Uri,
) -> io::Result<SocketAddr> {
async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;

let hostname = uri.host().ok_or_else(|| {
Expand All @@ -407,22 +397,16 @@ impl HttpsConnectorWithSni {
return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT)));
}

// Preferentially, use cached address.
//
if let Some(addr) = address_cache.resolve_hostname(hostname).await {
return Ok(SocketAddr::new(
addr.ip(),
port.unwrap_or_else(|| addr.port()),
));
}

// Use DNS resolution as fallback
//
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
let port = match (addr.port(), port) {
(_, Some(port)) => port,
(0, None) => DEFAULT_PORT,
(addr_port, None) => addr_port,
};
Ok(SocketAddr::new(addr.ip(), port))
}
}

Expand All @@ -445,18 +429,10 @@ impl Service<Uri> for HttpsConnectorWithSni {
}

fn call(&mut self, uri: Uri) -> Self::Future {
let sni_hostname = self
.sni_hostname
.clone()
.or_else(|| uri.host().map(str::to_owned))
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
});
let inner = self.inner.clone();
let abort_notify = self.abort_notify.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();

let fut = async move {
Expand All @@ -466,9 +442,13 @@ impl Service<Uri> for HttpsConnectorWithSni {
"invalid url, not https",
));
}

let hostname = sni_hostname?;
let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;
let Some(hostname) = uri.host().map(str::to_owned) else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, missing host",
));
};
let addr = Self::resolve_address(&*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
Expand Down
67 changes: 27 additions & 40 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,22 @@ impl ApiEndpoint {

#[async_trait]
pub trait DnsResolver: 'static + Send + Sync {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>>;
}

/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
pub struct DefaultDnsResolver;

#[async_trait]
impl DnsResolver for DefaultDnsResolver {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> {
use std::net::ToSocketAddrs;
// Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
// blocks and either has no timeout or a very long one.
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
Ok(addrs.map(|addr| addr.ip()).collect())
Ok(addrs.collect())
}
}

Expand All @@ -332,7 +332,7 @@ pub struct NullDnsResolver;

#[async_trait]
impl DnsResolver for NullDnsResolver {
async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
async fn resolve(&self, _host: String) -> io::Result<Vec<SocketAddr>> {
Ok(vec![])
}
}
Expand All @@ -342,7 +342,6 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -364,13 +363,9 @@ pub enum Error {

impl Runtime {
/// Create a new `Runtime`.
pub fn new(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
) -> Result<Self, Error> {
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
None,
)
Expand All @@ -381,21 +376,18 @@ impl Runtime {
Runtime {
handle,
address_cache: AddressCache::with_static_addr(address),
dns_resolver: Arc::new(NullDnsResolver),
api_availability: ApiAvailability::default(),
}
}

fn new_inner(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
address_cache: AddressCache::new(None)?,
address_cache: AddressCache::new(None),
api_availability: ApiAvailability::default(),
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -404,7 +396,6 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
Expand All @@ -415,7 +406,6 @@ impl Runtime {
if API.disable_address_cache {
return Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
Expand All @@ -439,7 +429,7 @@ impl Runtime {
)
);
}
AddressCache::new(write_file)?
AddressCache::new(write_file)
}
};

Expand All @@ -449,38 +439,19 @@ impl Runtime {
handle,
address_cache,
api_availability,
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
connection_mode_provider: T,
) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(API.host().to_string()),
connection_mode_provider,
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -493,8 +464,8 @@ impl Runtime {
/// This is only to be used in test code
pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct.into_provider(),
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -505,15 +476,31 @@ impl Runtime {
}

/// Returns a new request service handle
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_provider(),
Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
)
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
self.api_availability.clone(),
connection_mode_provider,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
Expand Down
Loading

0 comments on commit 8bda210

Please sign in to comment.