Skip to content

Commit

Permalink
Remove DNS fallback except for conncheck
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Dec 2, 2024
1 parent 3e5f660 commit dff45ab
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 77 deletions.
28 changes: 7 additions & 21 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
use crate::DnsResolver;
use super::API;
use crate::DnsResolver;
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
Expand All @@ -25,28 +25,14 @@ pub enum Error {
Write(#[source] io::Error),
}

/// A DNS resolver which resolves using `AddressCache`, or else a fallback resolver.
pub struct AddressCacheResolver {
address_cache: AddressCache,
fallback_resolver: Arc<dyn DnsResolver>,
}

impl AddressCacheResolver {
pub fn new(address_cache: AddressCache, fallback_resolver: impl DnsResolver) -> Self {
Self {
address_cache,
fallback_resolver: Arc::new(fallback_resolver),
}
}
}

/// A DNS resolver which resolves using `AddressCache`.
#[async_trait]
impl DnsResolver for AddressCacheResolver {
impl DnsResolver for AddressCache {
async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> {
match self.address_cache.resolve_hostname(&host).await {
Some(addr) => Ok(vec![addr]),
None => self.fallback_resolver.resolve(host).await,
}
self.resolve_hostname(&host)
.await
.map(|addr| vec![addr])
.ok_or(io::Error::other("host does not match API host"))
}
}

Expand Down
6 changes: 2 additions & 4 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
};
use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");

let relay_list_request =
Expand Down
6 changes: 3 additions & 3 deletions mullvad-api/src/ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use crate::{
rest::{self, MullvadRestHandle},
AccountsProxy, DevicesProxy,
AccountsProxy, DevicesProxy, NullDnsResolver,
};

mod device;
Expand Down Expand Up @@ -209,11 +209,11 @@ impl FfiClient {
}

fn device_proxy(&self) -> DevicesProxy {
crate::DevicesProxy::new(self.rest_handle())
crate::DevicesProxy::new(self.rest_handle(NullDnsResolver))
}

fn accounts_proxy(&self) -> AccountsProxy {
crate::AccountsProxy::new(self.rest_handle())
crate::AccountsProxy::new(self.rest_handle(NullDnsResolver))
}

fn tokio_handle(&self) -> tokio::runtime::Handle {
Expand Down
58 changes: 25 additions & 33 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![allow(rustdoc::private_intra_doc_links)]
use address_cache::AddressCacheResolver;
use async_trait::async_trait;
#[cfg(target_os = "android")]
use futures::channel::mpsc;
Expand Down Expand Up @@ -343,7 +342,6 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
dns_resolver: Arc<AddressCacheResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -365,13 +363,9 @@ pub enum Error {

impl Runtime {
/// Create a new `Runtime`.
pub fn new(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
) -> Result<Self, Error> {
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
None,
)
Expand All @@ -389,15 +383,12 @@ impl Runtime {

fn new_inner(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
let address_cache = AddressCache::new(None);
Ok(Runtime {
handle,
address_cache: address_cache.clone(),
address_cache: AddressCache::new(None),
api_availability: ApiAvailability::default(),
dns_resolver: Arc::new(AddressCacheResolver::new(address_cache, dns_resolver)),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -406,7 +397,6 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
Expand All @@ -417,7 +407,6 @@ impl Runtime {
if API.disable_address_cache {
return Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
Expand Down Expand Up @@ -449,31 +438,13 @@ impl Runtime {

Ok(Runtime {
handle,
address_cache: address_cache.clone(),
address_cache,
api_availability,
dns_resolver: Arc::new(AddressCacheResolver::new(address_cache, dns_resolver)),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.clone(),
connection_mode_provider,
self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
Expand All @@ -482,6 +453,7 @@ impl Runtime {
let service = self.new_request_service(
Some(API.host().to_string()),
connection_mode_provider,
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -496,6 +468,7 @@ impl Runtime {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct.into_provider(),
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -506,15 +479,34 @@ impl Runtime {
}

/// Returns a new request service handle
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_provider(),
Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
)
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.clone(),
connection_mode_provider,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
Expand Down
13 changes: 6 additions & 7 deletions mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ use futures::{
};
use geoip::GeoIpHandler;
use management_interface::ManagementInterfaceServer;
#[cfg(not(target_os = "android"))]
use mullvad_api::DefaultDnsResolver;
use mullvad_relay_selector::{RelaySelector, SelectorConfig};
#[cfg(target_os = "android")]
use mullvad_types::account::{PlayPurchase, PlayPurchasePaymentToken};
Expand Down Expand Up @@ -622,10 +620,6 @@ impl Daemon {

mullvad_api::proxy::ApiConnectionMode::try_delete_cache(&cache_dir).await;
let api_runtime = mullvad_api::Runtime::with_cache(
#[cfg(target_os = "android")]
android_dns::AndroidDnsResolver::new(connectivity_listener.clone()),
#[cfg(not(target_os = "android"))]
DefaultDnsResolver,
&cache_dir,
true,
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -835,7 +829,12 @@ impl Daemon {
relay_list_updater.update().await;

let location_handler = GeoIpHandler::new(
api_runtime.rest_handle(),
api_runtime.rest_handle(
#[cfg(not(target_os = "android"))]
mullvad_api::DefaultDnsResolver,
#[cfg(target_os = "android")]
android_dns::AndroidDnsResolver::new(connectivity_listener.clone()),
),
internal_event_tx.clone().to_specialized_sender(),
);

Expand Down
3 changes: 1 addition & 2 deletions mullvad-problem-report/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver};
use mullvad_api::proxy::ApiConnectionMode;
use regex::Regex;
use std::{
borrow::Cow,
Expand Down Expand Up @@ -292,7 +292,6 @@ async fn send_problem_report_inner(
) -> Result<(), Error> {
let metadata = ProblemReport::parse_metadata(report_content).unwrap_or_else(metadata::collect);
let api_runtime = mullvad_api::Runtime::with_cache(
NullDnsResolver,
cache_dir,
false,
#[cfg(target_os = "android")]
Expand Down
4 changes: 2 additions & 2 deletions mullvad-setup/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use clap::Parser;
use std::{path::PathBuf, process, str::FromStr, sync::LazyLock, time::Duration};

use mullvad_api::{proxy::ApiConnectionMode, NullDnsResolver, DEVICE_NOT_FOUND};
use mullvad_api::{proxy::ApiConnectionMode, DEVICE_NOT_FOUND};
use mullvad_management_interface::MullvadProxyClient;
use mullvad_types::version::ParsedAppVersion;
use talpid_core::firewall::{self, Firewall};
Expand Down Expand Up @@ -152,7 +152,7 @@ async fn remove_device() -> Result<(), Error> {
.await
.map_err(Error::ReadDeviceCacheError)?;
if let Some(device) = state.into_device() {
let api_runtime = mullvad_api::Runtime::with_cache(NullDnsResolver, &cache_path, false)
let api_runtime = mullvad_api::Runtime::with_cache(&cache_path, false)
.await
.map_err(Error::RpcInitializationError)?;

Expand Down
7 changes: 2 additions & 5 deletions test/test-manager/src/tests/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,8 @@ pub async fn new_device_client() -> anyhow::Result<DevicesProxy> {
..api_endpoint
});

let api = mullvad_api::Runtime::new(
tokio::runtime::Handle::current(),
mullvad_api::DefaultDnsResolver,
)
.expect("failed to create api runtime");
let api = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("failed to create api runtime");
let rest_handle = api.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider());
Ok(DevicesProxy::new(rest_handle))
}
Expand Down

0 comments on commit dff45ab

Please sign in to comment.