diff --git a/mullvad-api/src/address_cache.rs b/mullvad-api/src/address_cache.rs index bfdc16d5e0ec..0898f8da1f96 100644 --- a/mullvad-api/src/address_cache.rs +++ b/mullvad-api/src/address_cache.rs @@ -1,7 +1,7 @@ //! This module keeps track of the last known good API IP address and reads and stores it on disk. -use crate::DnsResolver; use super::API; +use crate::DnsResolver; use async_trait::async_trait; use std::{io, net::SocketAddr, path::Path, sync::Arc}; use tokio::{ @@ -25,28 +25,14 @@ pub enum Error { Write(#[source] io::Error), } -/// A DNS resolver which resolves using `AddressCache`, or else a fallback resolver. -pub struct AddressCacheResolver { - address_cache: AddressCache, - fallback_resolver: Arc, -} - -impl AddressCacheResolver { - pub fn new(address_cache: AddressCache, fallback_resolver: impl DnsResolver) -> Self { - Self { - address_cache, - fallback_resolver: Arc::new(fallback_resolver), - } - } -} - +/// A DNS resolver which resolves using `AddressCache`. #[async_trait] -impl DnsResolver for AddressCacheResolver { +impl DnsResolver for AddressCache { async fn resolve(&self, host: String) -> Result, io::Error> { - match self.address_cache.resolve_hostname(&host).await { - Some(addr) => Ok(vec![addr]), - None => self.fallback_resolver.resolve(host).await, - } + self.resolve_hostname(&host) + .await + .map(|addr| vec![addr]) + .ok_or(io::Error::other("host does not match API host")) } } diff --git a/mullvad-api/src/bin/relay_list.rs b/mullvad-api/src/bin/relay_list.rs index 22190abd63db..def32303eaef 100644 --- a/mullvad-api/src/bin/relay_list.rs +++ b/mullvad-api/src/bin/relay_list.rs @@ -2,15 +2,13 @@ //! Used by the installer artifact packer to bundle the latest available //! relay list at the time of creating the installer. -use mullvad_api::{ - proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy, -}; +use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy}; use std::process; use talpid_types::ErrorExt; #[tokio::main] async fn main() { - let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver) + let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) .expect("Failed to load runtime"); let relay_list_request = diff --git a/mullvad-api/src/ffi/mod.rs b/mullvad-api/src/ffi/mod.rs index a68ea40ed6cb..29d4db652288 100644 --- a/mullvad-api/src/ffi/mod.rs +++ b/mullvad-api/src/ffi/mod.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ rest::{self, MullvadRestHandle}, - AccountsProxy, DevicesProxy, + AccountsProxy, DevicesProxy, NullDnsResolver, }; mod device; @@ -209,11 +209,11 @@ impl FfiClient { } fn device_proxy(&self) -> DevicesProxy { - crate::DevicesProxy::new(self.rest_handle()) + crate::DevicesProxy::new(self.rest_handle(NullDnsResolver)) } fn accounts_proxy(&self) -> AccountsProxy { - crate::AccountsProxy::new(self.rest_handle()) + crate::AccountsProxy::new(self.rest_handle(NullDnsResolver)) } fn tokio_handle(&self) -> tokio::runtime::Handle { diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index c3cc17c9bc2e..4af0f06d4094 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -1,5 +1,4 @@ #![allow(rustdoc::private_intra_doc_links)] -use address_cache::AddressCacheResolver; use async_trait::async_trait; #[cfg(target_os = "android")] use futures::channel::mpsc; @@ -343,7 +342,6 @@ pub struct Runtime { handle: tokio::runtime::Handle, address_cache: AddressCache, api_availability: availability::ApiAvailability, - dns_resolver: Arc, #[cfg(target_os = "android")] socket_bypass_tx: Option>, } @@ -365,13 +363,9 @@ pub enum Error { impl Runtime { /// Create a new `Runtime`. - pub fn new( - handle: tokio::runtime::Handle, - dns_resolver: impl DnsResolver, - ) -> Result { + pub fn new(handle: tokio::runtime::Handle) -> Result { Self::new_inner( handle, - dns_resolver, #[cfg(target_os = "android")] None, ) @@ -389,15 +383,12 @@ impl Runtime { fn new_inner( handle: tokio::runtime::Handle, - dns_resolver: impl DnsResolver, #[cfg(target_os = "android")] socket_bypass_tx: Option>, ) -> Result { - let address_cache = AddressCache::new(None); Ok(Runtime { handle, - address_cache: address_cache.clone(), + address_cache: AddressCache::new(None), api_availability: ApiAvailability::default(), - dns_resolver: Arc::new(AddressCacheResolver::new(address_cache, dns_resolver)), #[cfg(target_os = "android")] socket_bypass_tx, }) @@ -406,7 +397,6 @@ impl Runtime { /// Create a new `Runtime` using the specified directories. /// Try to use the cache directory first, and fall back on the bundled address otherwise. pub async fn with_cache( - dns_resolver: impl DnsResolver, cache_dir: &Path, write_changes: bool, #[cfg(target_os = "android")] socket_bypass_tx: Option>, @@ -417,7 +407,6 @@ impl Runtime { if API.disable_address_cache { return Self::new_inner( handle, - dns_resolver, #[cfg(target_os = "android")] socket_bypass_tx, ); @@ -449,31 +438,13 @@ impl Runtime { Ok(Runtime { handle, - address_cache: address_cache.clone(), + address_cache, api_availability, - dns_resolver: Arc::new(AddressCacheResolver::new(address_cache, dns_resolver)), #[cfg(target_os = "android")] socket_bypass_tx, }) } - /// Creates a new request service and returns a handle to it. - fn new_request_service( - &self, - sni_hostname: Option, - connection_mode_provider: T, - #[cfg(target_os = "android")] socket_bypass_tx: Option>, - ) -> rest::RequestServiceHandle { - rest::RequestService::spawn( - sni_hostname, - self.api_availability.clone(), - connection_mode_provider, - self.dns_resolver.clone(), - #[cfg(target_os = "android")] - socket_bypass_tx, - ) - } - /// Returns a request factory initialized to create requests for the master API pub fn mullvad_rest_handle( &self, @@ -482,6 +453,7 @@ impl Runtime { let service = self.new_request_service( Some(API.host().to_string()), connection_mode_provider, + Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -496,6 +468,7 @@ impl Runtime { let service = self.new_request_service( Some(hostname.clone()), ApiConnectionMode::Direct.into_provider(), + Arc::new(self.address_cache.clone()), #[cfg(target_os = "android")] self.socket_bypass_tx.clone(), ); @@ -506,15 +479,34 @@ impl Runtime { } /// Returns a new request service handle - pub fn rest_handle(&self) -> rest::RequestServiceHandle { + pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle { self.new_request_service( None, ApiConnectionMode::Direct.into_provider(), + Arc::new(dns_resolver), #[cfg(target_os = "android")] None, ) } + /// Creates a new request service and returns a handle to it. + fn new_request_service( + &self, + sni_hostname: Option, + connection_mode_provider: T, + dns_resolver: Arc, + #[cfg(target_os = "android")] socket_bypass_tx: Option>, + ) -> rest::RequestServiceHandle { + rest::RequestService::spawn( + sni_hostname, + self.api_availability.clone(), + connection_mode_provider, + dns_resolver, + #[cfg(target_os = "android")] + socket_bypass_tx, + ) + } + pub fn handle(&mut self) -> &mut tokio::runtime::Handle { &mut self.handle } diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 79f08e8a0f09..2160390c041a 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -39,8 +39,6 @@ use futures::{ }; use geoip::GeoIpHandler; use management_interface::ManagementInterfaceServer; -#[cfg(not(target_os = "android"))] -use mullvad_api::DefaultDnsResolver; use mullvad_relay_selector::{RelaySelector, SelectorConfig}; #[cfg(target_os = "android")] use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken}; @@ -622,10 +620,6 @@ impl Daemon { mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await; let api_runtime = mullvad_api::Runtime::with_cache( - #[cfg(target_os = "android")] - android_dns::AndroidDnsResolver::new(connectivity_listener.clone()), - #[cfg(not(target_os = "android"))] - DefaultDnsResolver, &cache_dir, true, #[cfg(target_os = "android")] @@ -835,7 +829,12 @@ impl Daemon { relay_list_updater.update().await; let location_handler = GeoIpHandler::new( - api_runtime.rest_handle(), + api_runtime.rest_handle( + #[cfg(not(target_os = "android"))] + mullvad_api::DefaultDnsResolver, + #[cfg(target_os = "android")] + android_dns::AndroidDnsResolver::new(connectivity_listener.clone()), + ), internal_event_tx.clone().to_specialized_sender(), ); diff --git a/mullvad-problem-report/src/lib.rs b/mullvad-problem-report/src/lib.rs index 91b790e5f6fe..270de55f9589 100644 --- a/mullvad-problem-report/src/lib.rs +++ b/mullvad-problem-report/src/lib.rs @@ -1,4 +1,4 @@ -use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver}; +use mullvad_api::proxy::ApiConnectionMode; use regex::Regex; use std::{ borrow::Cow, @@ -292,7 +292,6 @@ async fn send_problem_report_inner( ) -> Result<(), Error> { let metadata = ProblemReport::parse_metadata(report_content).unwrap_or_else(metadata::collect); let api_runtime = mullvad_api::Runtime::with_cache( - NullDnsResolver, cache_dir, false, #[cfg(target_os = "android")] diff --git a/mullvad-setup/src/main.rs b/mullvad-setup/src/main.rs index 4a444aa63cdc..d3dfd6de8ac4 100644 --- a/mullvad-setup/src/main.rs +++ b/mullvad-setup/src/main.rs @@ -1,7 +1,7 @@ use clap::Parser; use std::{path::PathBuf, process, str::FromStr, sync::LazyLock, time::Duration}; -use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver, DEVICE_NOT_FOUND}; +use mullvad_api::{proxy::ApiConnectionMode, DEVICE_NOT_FOUND}; use mullvad_management_interface::MullvadProxyClient; use mullvad_types::version::ParsedAppVersion; use talpid_core::firewall::{self, Firewall}; @@ -152,7 +152,7 @@ async fn remove_device() -> Result<(), Error> { .await .map_err(Error::ReadDeviceCacheError)?; if let Some(device) = state.into_device() { - let api_runtime = mullvad_api::Runtime::with_cache(NullDnsResolver, &cache_path, false) + let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false) .await .map_err(Error::RpcInitializationError)?; diff --git a/test/test-manager/src/tests/account.rs b/test/test-manager/src/tests/account.rs index 45151070a9d9..7fe14ae58ee1 100644 --- a/test/test-manager/src/tests/account.rs +++ b/test/test-manager/src/tests/account.rs @@ -295,11 +295,8 @@ pub async fn new_device_client() -> anyhow::Result { ..api_endpoint }); - let api = mullvad_api::Runtime::new( - tokio::runtime::Handle::current(), - mullvad_api::DefaultDnsResolver, - ) - .expect("failed to create api runtime"); + let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current()) + .expect("failed to create api runtime"); let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()); Ok(DevicesProxy::new(rest_handle)) }