diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index dfa586daf477..0898f8da1f96 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -1,6 +1,8 @@ //! This module keeps track of the last known good API IP address and reads and stores it on disk. use super::API; +use crate::DnsResolver; +use async_trait::async_trait; use std::{io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ fs, @@ -23,6 +25,17 @@ pub enum Error { Write(#[source] io::Error), } +/// A DNS resolver which resolves using `AddressCache`. +#[async_trait] +impl DnsResolver for AddressCache { + async fn resolve(&self, host: String) -> Result, io::Error> { + self.resolve_hostname(&host) + .await + .map(|addr| vec![addr]) + .ok_or(io::Error::other("host does not match API host")) + } +} + #[derive(Clone)] pub struct AddressCache { inner: Arc>, @@ -31,34 +44,35 @@ pub struct AddressCache { impl AddressCache { /// Initialize cache using the hardcoded address, and write changes to `write_path`. - pub fn new(write_path: Option>) -> Result { + pub fn new(write_path: Option>) -> Self { Self::new_inner(API.address(), write_path) } pub fn with_static_addr(address: SocketAddr) -> Self { Self::new_inner(address, None) - .expect("Failed to construct an address cache from a static address") } /// Initialize cache using `read_path`, and write changes to `write_path`. pub async fn from_file(read_path: &Path, write_path: Option>) -> Result { log::debug!("Loading API addresses from {}", read_path.display()); - Self::new_inner(read_address_file(read_path).await?, write_path) + Ok(Self::new_inner( + read_address_file(read_path).await?, + write_path, + )) } - fn new_inner(address: SocketAddr, write_path: Option>) -> Result { + fn new_inner(address: SocketAddr, write_path: Option>) -> Self { let cache = AddressCacheInner::from_address(address); log::debug!("Using API address: {}", cache.address); - let address_cache = Self { + Self { inner: Arc::new(Mutex::new(cache)), write_path: write_path.map(Arc::from), - }; - Ok(address_cache) + } } /// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`. - pub async fn resolve_hostname(&self, hostname: &str) -> Option { + async fn resolve_hostname(&self, hostname: &str) -> Option { if hostname.eq_ignore_ascii_case(API.host()) { Some(self.get_address().await) } else { diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index 22190abd63db..def32303eaef 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,15 +2,13 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{ - proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy, -}; +use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver) + let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("Failed to load runtime"); let relay_list_request = diff --git a/mullvad-api/src/https_client_with_sni.rs b/mullvad-api/src/https_client_with_sni.rs index 09ce493431ce..3dfd168a9281 100644 --- a/mullvad-api/src/https_client_with_sni.rs +++ b/mullvad-api/src/https_client_with_sni.rs @@ -2,7 +2,7 @@ use crate::{ abortable_stream::{AbortableStream, AbortableStreamHandle}, proxy::{ApiConnection, ApiConnectionMode, ProxyConfig}, tls_stream::TlsStream, - AddressCache, DnsResolver, + DnsResolver, }; use futures::{channel::mpsc, future, pin_mut, StreamExt}; #[cfg(target_os = "android")] @@ -286,8 +286,6 @@ impl TryFrom for InnerConnectionMode { #[derive(Clone)] pub struct HttpsConnectorWithSni { inner: Arc>, - sni_hostname: Option, - address_cache: AddressCache, abort_notify: Arc, dns_resolver: Arc, #[cfg(target_os = "android")] @@ -304,8 +302,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>); impl HttpsConnectorWithSni { pub fn new( - sni_hostname: Option, - address_cache: AddressCache, dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> (Self, HttpsConnectorWithSniHandle) { @@ -352,8 +348,6 @@ impl HttpsConnectorWithSni { ( HttpsConnectorWithSni { inner, - sni_hostname, - address_cache, abort_notify, dns_resolver, #[cfg(target_os = "android")] @@ -390,13 +384,9 @@ impl HttpsConnectorWithSni { } /// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used. - /// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback. + /// Otherwise `dns_resolver` will be used as a fallback. /// If the URI contains a port, then that port will be used. - async fn resolve_address( - address_cache: AddressCache, - dns_resolver: &dyn DnsResolver, - uri: Uri, - ) -> io::Result { + async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result { const DEFAULT_PORT: u16 = 443; let hostname = uri.host().ok_or_else(|| { @@ -407,22 +397,16 @@ impl HttpsConnectorWithSni { return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT))); } - // Preferentially, use cached address. - // - if let Some(addr) = address_cache.resolve_hostname(hostname).await { - return Ok(SocketAddr::new( - addr.ip(), - port.unwrap_or_else(|| addr.port()), - )); - } - - // Use DNS resolution as fallback - // let addrs = dns_resolver.resolve(hostname.to_owned()).await?; let addr = addrs .first() .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?; - Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT))) + let port = match (addr.port(), port) { + (_, Some(port)) => port, + (0, None) => DEFAULT_PORT, + (addr_port, None) => addr_port, + }; + Ok(SocketAddr::new(addr.ip(), port)) } } @@ -445,18 +429,10 @@ impl Service for HttpsConnectorWithSni { } fn call(&mut self, uri: Uri) -> Self::Future { - let sni_hostname = self - .sni_hostname - .clone() - .or_else(|| uri.host().map(str::to_owned)) - .ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host") - }); let inner = self.inner.clone(); let abort_notify = self.abort_notify.clone(); #[cfg(target_os = "android")] let socket_bypass_tx = self.socket_bypass_tx.clone(); - let address_cache = self.address_cache.clone(); let dns_resolver = self.dns_resolver.clone(); let fut = async move { @@ -466,9 +442,13 @@ impl Service for HttpsConnectorWithSni { "invalid url, not https", )); } - - let hostname = sni_hostname?; - let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?; + let Some(hostname) = uri.host().map(str::to_owned) else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid url, missing host", + )); + }; + let addr = Self::resolve_address(&*dns_resolver, uri).await?; // Loop until we have established a connection. This starts over if a new endpoint // is selected while connecting. diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 3f6770924275..3b02e4fe98ed 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -308,7 +308,7 @@ impl ApiEndpoint { #[async_trait] pub trait DnsResolver: 'static + Send + Sync { - async fn resolve(&self, host: String) -> io::Result>; + async fn resolve(&self, host: String) -> io::Result>; } /// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`). @@ -316,14 +316,14 @@ pub struct DefaultDnsResolver; #[async_trait] impl DnsResolver for DefaultDnsResolver { - async fn resolve(&self, host: String) -> io::Result> { + async fn resolve(&self, host: String) -> io::Result> { use std::net::ToSocketAddrs; // Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which // blocks and either has no timeout or a very long one. let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs()) .await .expect("DNS task panicked")?; - Ok(addrs.map(|addr| addr.ip()).collect()) + Ok(addrs.collect()) } } @@ -332,7 +332,7 @@ pub struct NullDnsResolver; #[async_trait] impl DnsResolver for NullDnsResolver { - async fn resolve(&self, _host: String) -> io::Result> { + async fn resolve(&self, _host: String) -> io::Result> { Ok(vec![]) } } @@ -342,7 +342,6 @@ pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, - dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, } @@ -364,13 +363,9 @@ pub enum Error { impl Runtime { /// Create a new `Runtime`. - pub fn new( - handle: tokio::runtime::Handle, - dns_resolver: impl DnsResolver, - ) -> Result { + pub fn new(handle: tokio::runtime::Handle) -> Result { Self::new_inner( handle, - dns_resolver, #[cfg(target_os = "android")] None, ) @@ -381,21 +376,18 @@ impl Runtime { Runtime { handle, address_cache: AddressCache::with_static_addr(address), - dns_resolver: Arc::new(NullDnsResolver), api_availability: ApiAvailability::default(), } } fn new_inner( handle: tokio::runtime::Handle, - dns_resolver: impl DnsResolver, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> Result { Ok(Runtime { handle, - address_cache: AddressCache::new(None)?, + address_cache: AddressCache::new(None), api_availability: ApiAvailability::default(), - dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -404,7 +396,6 @@ impl Runtime { /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. pub async fn with_cache( - dns_resolver: impl DnsResolver, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option>, @@ -415,7 +406,6 @@ impl Runtime { if API.disable_address_cache { return Self::new_inner( handle, - dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, ); @@ -439,7 +429,7 @@ impl Runtime { ) ); } - AddressCache::new(write_file)? + AddressCache::new(write_file) } }; @@ -449,38 +439,19 @@ impl Runtime { handle, address_cache, api_availability, - dns_resolver: Arc::new(dns_resolver), #[cfg(target_os = "android")] socket_bypass_tx, }) } - /// Creates a new request service and returns a handle to it. - fn new_request_service( - &self, - sni_hostname: Option, - connection_mode_provider: T, - #[cfg(target_os = "android")] socket_bypass_tx: Option>, - ) -> rest::RequestServiceHandle { - rest::RequestService::spawn( - sni_hostname, - self.api_availability.clone(), - self.address_cache.clone(), - connection_mode_provider, - self.dns_resolver.clone(), - #[cfg(target_os = "android")] - socket_bypass_tx, - ) - } - /// Returns a request factory initialized to create requests for the master API pub fn mullvad_rest_handle( &self, connection_mode_provider: T, ) -> rest::MullvadRestHandle { let service = self.new_request_service( - Some(API.host().to_string()), connection_mode_provider, + Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -493,8 +464,8 @@ impl Runtime { /// This is only to be used in test code pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle { let service = self.new_request_service( - Some(hostname.clone()), ApiConnectionMode::Direct.into_provider(), + Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -505,15 +476,31 @@ impl Runtime { } /// Returns a new request service handle - pub fn rest_handle(&self) -> rest::RequestServiceHandle { + pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle { self.new_request_service( - None, ApiConnectionMode::Direct.into_provider(), + Arc::new(dns_resolver), #[cfg(target_os = "android")] None, ) } + /// Creates a new request service and returns a handle to it. + fn new_request_service( + &self, + connection_mode_provider: T, + dns_resolver: Arc, + #[cfg(target_os = "android")] socket_bypass_tx: Option>, + ) -> rest::RequestServiceHandle { + rest::RequestService::spawn( + self.api_availability.clone(), + connection_mode_provider, + dns_resolver, + #[cfg(target_os = "android")] + socket_bypass_tx, + ) + } + pub fn handle(&mut self) -> &mut tokio::runtime::Handle { &mut self.handle } diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index 54a32f63f937..5b93eea31142 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -2,7 +2,6 @@ pub use crate::https_client_with_sni::SocketBypassRequest; use crate::{ access::AccessTokenStore, - address_cache::AddressCache, availability::ApiAvailability, https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle}, proxy::ConnectionModeProvider, @@ -151,16 +150,12 @@ pub(crate) struct RequestService { impl RequestService { /// Constructs a new request service. pub fn spawn( - sni_hostname: Option, api_availability: ApiAvailability, - address_cache: AddressCache, connection_mode_provider: T, dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> RequestServiceHandle { let (connector, connector_handle) = HttpsConnectorWithSni::new( - sni_hostname, - address_cache.clone(), dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx.clone(), diff --git a/mullvad-daemon/src/android_dns.rs b/mullvad-daemon/src/android_dns.rs index ed44f5dc8c6d..5cbc9c271a9f 100644 --- a/mullvad-daemon/src/android_dns.rs +++ b/mullvad-daemon/src/android_dns.rs @@ -7,7 +7,7 @@ use hickory_resolver::{ TokioAsyncResolver, }; use mullvad_api::DnsResolver; -use std::{io, net::IpAddr}; +use std::{io, net::SocketAddr}; use talpid_core::connectivity_listener::ConnectivityListener; /// A non-blocking DNS resolver. The default resolver uses `getaddrinfo`, which often prevents the @@ -27,7 +27,7 @@ impl AndroidDnsResolver { #[async_trait] impl DnsResolver for AndroidDnsResolver { - async fn resolve(&self, host: String) -> io::Result> { + async fn resolve(&self, host: String) -> io::Result> { let ips = self .connectivity_listener .current_dns_servers() @@ -44,6 +44,6 @@ impl DnsResolver for AndroidDnsResolver { .await .map_err(|err| io::Error::other(format!("lookup_ip failed: {err}")))?; - Ok(lookup.into_iter().collect()) + Ok(lookup.into_iter().map(|ip| (ip, 0).into()).collect()) } } diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 79f08e8a0f09..4f98c73d0189 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -39,8 +39,6 @@ use futures::{ }; use geoip::GeoIpHandler; use management_interface::ManagementInterfaceServer; -#[cfg(not(target_os = "android"))] -use mullvad_api::DefaultDnsResolver; use mullvad_relay_selector::{RelaySelector, SelectorConfig}; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; @@ -622,10 +620,6 @@ impl Daemon { mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; let api_runtime = mullvad_api::Runtime::with_cache( - #[cfg(target_os = "android")] - android_dns::AndroidDnsResolver::new(connectivity_listener.clone()), - #[cfg(not(target_os = "android"))] - DefaultDnsResolver, &cache_dir, true, #[cfg(target_os = "android")] @@ -798,7 +792,7 @@ impl Daemon { #[cfg(target_os = "android")] android_context, #[cfg(target_os = "android")] - connectivity_listener, + connectivity_listener.clone(), #[cfg(target_os = "linux")] tunnel_state_machine::LinuxNetworkingIdentifiers { fwmark: mullvad_types::TUNNEL_FWMARK, @@ -835,7 +829,12 @@ impl Daemon { relay_list_updater.update().await; let location_handler = GeoIpHandler::new( - api_runtime.rest_handle(), + api_runtime.rest_handle( + #[cfg(not(target_os = "android"))] + mullvad_api::DefaultDnsResolver, + #[cfg(target_os = "android")] + android_dns::AndroidDnsResolver::new(connectivity_listener), + ), internal_event_tx.clone().to_specialized_sender(), ); diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 91b790e5f6fe..270de55f9589 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -1,4 +1,4 @@ -use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver}; +use mullvad_api::proxy::ApiConnectionMode; use regex::Regex; use std::{ borrow::Cow, @@ -292,7 +292,6 @@ async fn send_problem_report_inner( ) -> Result<(), Error> { let metadata = ProblemReport::parse_metadata(report_content).unwrap_or_else(metadata::collect); let api_runtime = mullvad_api::Runtime::with_cache( - NullDnsResolver, cache_dir, false, #[cfg(target_os = "android")] diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 4a444aa63cdc..d3dfd6de8ac4 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -1,7 +1,7 @@ use clap::Parser; use std::{path::PathBuf, process, str::FromStr, sync::LazyLock, time::Duration}; -use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver, DEVICE_NOT_FOUND}; +use mullvad_api::{proxy::ApiConnectionMode, DEVICE_NOT_FOUND}; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::version::ParsedAppVersion; use talpid_core::firewall::{self, Firewall}; @@ -152,7 +152,7 @@ async fn remove_device() -> Result<(), Error> { .await .map_err(Error::ReadDeviceCacheError)?; if let Some(device) = state.into_device() { - let api_runtime = mullvad_api::Runtime::with_cache(NullDnsResolver, &cache_path, false) + let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false) .await .map_err(Error::RpcInitializationError)?; diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs index 45151070a9d9..7fe14ae58ee1 100644 --- a/test/test-manager/src/tests/account.rs +++ b/test/test-manager/src/tests/account.rs @@ -295,11 +295,8 @@ pub async fn new_device_client() -> anyhow::Result { ..api_endpoint }); - let api = mullvad_api::Runtime::new( - tokio::runtime::Handle::current(), - mullvad_api::DefaultDnsResolver, - ) - .expect("failed to create api runtime"); + let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) + .expect("failed to create api runtime"); let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()); Ok(DevicesProxy::new(rest_handle)) }