Skip to content

Commit

Permalink
Add non-blocking DNS resolver for Android API requests
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Nov 22, 2024
1 parent 2b2c597 commit 5f7c7af
Show file tree
Hide file tree
Showing 21 changed files with 480 additions and 265 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ single_use_lifetimes = "warn"
unused_async = "deny"

[workspace.dependencies]
hickory-proto = "0.24.1"
hickory-resolver = "0.24.1"
hickory-server = { version = "0.24.1", features = ["resolver"] }
tokio = { version = "1.8" }
parity-tokio-ipc = "0.9"
futures = "0.3.15"
Expand Down
1 change: 1 addition & 0 deletions mullvad-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ workspace = true
api-override = []

[dependencies]
async-trait = "0.1"
libc = "0.2"
chrono = { workspace = true }
thiserror = { workspace = true }
Expand Down
6 changes: 4 additions & 2 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
.expect("Failed to load runtime");

let relay_list_request =
Expand Down
36 changes: 19 additions & 17 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
AddressCache,
AddressCache, DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
use hyper::Uri;
use hyper_util::{
client::legacy::connect::dns::{GaiResolver, Name},
rt::TokioIo,
};
use hyper_util::rt::TokioIo;
use mullvad_encrypted_dns_proxy::{
config::ProxyConfig as EncryptedDNSConfig, Forwarder as EncryptedDNSForwarder,
};
Expand Down Expand Up @@ -291,6 +288,7 @@ pub struct HttpsConnectorWithSni {
sni_hostname: Option<String>,
address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -307,6 +305,7 @@ impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
let (tx, mut rx) = mpsc::unbounded();
Expand Down Expand Up @@ -355,6 +354,7 @@ impl HttpsConnectorWithSni {
sni_hostname,
address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
},
Expand Down Expand Up @@ -388,7 +388,14 @@ impl HttpsConnectorWithSni {
.map_err(|err| io::Error::new(io::ErrorKind::TimedOut, err))?
}

async fn resolve_address(address_cache: AddressCache, uri: Uri) -> io::Result<SocketAddr> {
/// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
/// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
/// If the URI contains a port, then that port will be used.
async fn resolve_address(
address_cache: AddressCache,
dns_resolver: &dyn DnsResolver,
uri: Uri,
) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;

let hostname = uri.host().ok_or_else(|| {
Expand All @@ -408,19 +415,13 @@ impl HttpsConnectorWithSni {
));
}

// Use getaddrinfo as a fallback
// Use DNS resolution as fallback
//
let mut addrs = GaiResolver::new()
.call(
Name::from_str(hostname)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?,
)
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.next()
.first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
Ok(SocketAddr::new(addr.ip(), port.unwrap_or(DEFAULT_PORT)))
Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
}
}

Expand Down Expand Up @@ -455,6 +456,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();

let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
Expand All @@ -465,7 +467,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}

let hostname = sni_hostname?;
let addr = Self::resolve_address(address_cache, uri).await?;
let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
Expand Down
52 changes: 50 additions & 2 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(rustdoc::private_intra_doc_links)]
use async_trait::async_trait;
#[cfg(target_os = "android")]
use futures::channel::mpsc;
#[cfg(target_os = "android")]
Expand All @@ -12,10 +13,11 @@ use std::{
cell::Cell,
collections::BTreeMap,
future::Future,
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
ops::Deref,
path::Path,
sync::OnceLock,
sync::{Arc, OnceLock},
};
use talpid_types::ErrorExt;

Expand Down Expand Up @@ -304,11 +306,43 @@ impl ApiEndpoint {
}
}

#[async_trait]
pub trait DnsResolver: 'static + Send + Sync {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
}

/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
pub struct DefaultDnsResolver;

#[async_trait]
impl DnsResolver for DefaultDnsResolver {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
use std::net::ToSocketAddrs;
// Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
// blocks and either has no timeout or a very long one.
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
Ok(addrs.map(|addr| addr.ip()).collect())
}
}

/// DNS resolver that always returns no results
pub struct NullDnsResolver;

#[async_trait]
impl DnsResolver for NullDnsResolver {
async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
Ok(vec![])
}
}

/// A type that helps with the creation of API connections.
pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -323,13 +357,20 @@ pub enum Error {

#[error("API availability check failed")]
ApiCheckError(#[from] availability::Error),

#[error("DNS resolution error")]
ResolutionFailed(#[from] std::io::Error),
}

impl Runtime {
/// Create a new `Runtime`.
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
pub fn new(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
) -> Result<Self, Error> {
Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
None,
)
Expand All @@ -346,12 +387,14 @@ impl Runtime {

fn new_inner(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
address_cache: AddressCache::new(None)?,
api_availability: ApiAvailability::default(),
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -360,15 +403,18 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
let handle = tokio::runtime::Handle::current();

#[cfg(feature = "api-override")]
if API.disable_address_cache {
return Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
Expand Down Expand Up @@ -402,6 +448,7 @@ impl Runtime {
handle,
address_cache,
api_availability,
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -419,6 +466,7 @@ impl Runtime {
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
Expand Down
3 changes: 3 additions & 0 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
DnsResolver,
};
use futures::{
channel::{mpsc, oneshot},
Expand Down Expand Up @@ -154,11 +155,13 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
api_availability: ApiAvailability,
address_cache: AddressCache,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
sni_hostname,
address_cache.clone(),
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
);
Expand Down
2 changes: 2 additions & 0 deletions mullvad-daemon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ tokio = { workspace = true, features = ["test-util"] }

[target.'cfg(target_os="android")'.dependencies]
android_logger = "0.8"
async-trait = "0.1"
hickory-resolver = { workspace = true }

[target.'cfg(unix)'.dependencies]
nix = "0.23"
Expand Down
49 changes: 49 additions & 0 deletions mullvad-daemon/src/android_dns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#![cfg(target_os = "android")]
//! See [AndroidDnsResolver].
use async_trait::async_trait;
use hickory_resolver::{
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
TokioAsyncResolver,
};
use mullvad_api::DnsResolver;
use std::{io, net::IpAddr};
use talpid_core::connectivity_listener::ConnectivityListener;

/// A non-blocking DNS resolver. The default resolver uses `getaddrinfo`, which often prevents the
/// tokio runtime from being dropped, since it waits indefinitely on blocking threads. This is
/// particularly bad on Android, so we use a non-blocking resolver instead.
pub struct AndroidDnsResolver {
connectivity_listener: ConnectivityListener,
}

impl AndroidDnsResolver {
pub fn new(connectivity_listener: ConnectivityListener) -> Self {
Self {
connectivity_listener,
}
}
}

#[async_trait]
impl DnsResolver for AndroidDnsResolver {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
let ips = self
.connectivity_listener
.current_dns_servers()
.map_err(|err| {
io::Error::other(format!("Failed to retrieve current servers: {err}"))
})?;
let group = NameServerConfigGroup::from_ips_clear(&ips, 53, false);

let config = ResolverConfig::from_parts(None, vec![], group);
let resolver = TokioAsyncResolver::tokio(config, ResolverOpts::default());

let lookup = resolver
.lookup_ip(host)
.await
.map_err(|err| io::Error::other(format!("lookup_ip failed: {err}")))?;

Ok(lookup.into_iter().collect())
}
}
Loading

0 comments on commit 5f7c7af

Please sign in to comment.