From 885c8df5b5b4a08ef9c8926bc23495ef73b86695 Mon Sep 17 00:00:00 2001 From: Markus Pettersson Date: Wed, 18 Oct 2023 13:09:50 +0200 Subject: [PATCH] Use access method-specific transport protocol for firewall exemption. --- mullvad-api/src/lib.rs | 6 +++--- mullvad-api/src/proxy.rs | 45 +++++++++++++++++++++++++++------------ mullvad-api/src/rest.rs | 10 +++++++-- mullvad-daemon/src/api.rs | 13 +++++------ mullvad-daemon/src/lib.rs | 6 +++++- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/mullvad-api/src/lib.rs b/mullvad-api/src/lib.rs index 63f5c2ad5b46..189fe4d32c27 100644 --- a/mullvad-api/src/lib.rs +++ b/mullvad-api/src/lib.rs @@ -21,7 +21,7 @@ use std::{ ops::Deref, path::Path, }; -use talpid_types::ErrorExt; +use talpid_types::{net::Endpoint, ErrorExt}; pub mod availability; use availability::{ApiAvailability, ApiAvailabilityHandle}; @@ -221,13 +221,13 @@ pub enum Error { /// Closure that receives the next API (real or proxy) endpoint to use for `api.mullvad.net`. /// It should return a future that determines whether to reject the new endpoint or not. -pub trait ApiEndpointUpdateCallback: Fn(SocketAddr) -> Self::AcceptedNewEndpoint { +pub trait ApiEndpointUpdateCallback: Fn(Endpoint) -> Self::AcceptedNewEndpoint { type AcceptedNewEndpoint: Future + Send; } impl + Send> ApiEndpointUpdateCallback for U where - U: Fn(SocketAddr) -> T, + U: Fn(Endpoint) -> T, { type AcceptedNewEndpoint = T; } diff --git a/mullvad-api/src/proxy.rs b/mullvad-api/src/proxy.rs index 44a2309587e5..5b2a3e18436b 100644 --- a/mullvad-api/src/proxy.rs +++ b/mullvad-api/src/proxy.rs @@ -4,12 +4,14 @@ use mullvad_types::access_method; use serde::{Deserialize, Serialize}; use std::{ fmt, io, - net::SocketAddr, path::Path, pin::Pin, task::{self, Poll}, }; -use talpid_types::ErrorExt; +use talpid_types::{ + net::{Endpoint, TransportProtocol}, + ErrorExt, +}; use tokio::{ fs, io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}, @@ -41,13 +43,19 @@ pub enum ProxyConfig { } impl ProxyConfig { - /// Returns the remote address to reach the proxy. - fn get_endpoint(&self) -> SocketAddr { + /// Returns the remote endpoint describing how to reach the proxy. + fn get_endpoint(&self) -> Endpoint { match self { - ProxyConfig::Shadowsocks(ss) => ss.peer, + ProxyConfig::Shadowsocks(shadowsocks) => { + Endpoint::from_socket_address(shadowsocks.peer, TransportProtocol::Tcp) + } ProxyConfig::Socks(socks) => match socks { - access_method::Socks5::Local(s) => s.peer, - access_method::Socks5::Remote(s) => s.peer, + access_method::Socks5::Local(local) => { + Endpoint::from_socket_address(local.peer, local.peer_transport_protol) + } + access_method::Socks5::Remote(remote) => { + Endpoint::from_socket_address(remote.peer, TransportProtocol::Tcp) + } }, } } @@ -55,14 +63,22 @@ impl ProxyConfig { impl fmt::Display for ProxyConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + let endpoint = self.get_endpoint(); match self { - // TODO: Do not hardcode TCP - ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer), + ProxyConfig::Shadowsocks(_) => { + write!(f, "Shadowsocks {}/{}", endpoint.address, endpoint.protocol) + } ProxyConfig::Socks(socks) => match socks { - access_method::Socks5::Local(s) => { - write!(f, "Socks5 {}/TCP via localhost:{}", s.peer, s.port) + access_method::Socks5::Local(local) => { + write!( + f, + "Socks5 {}/{} via localhost:{}", + endpoint.address, endpoint.protocol, local.port + ) + } + access_method::Socks5::Remote(_) => { + write!(f, "Socks5 {}/{}", endpoint.address, endpoint.protocol) } - access_method::Socks5::Remote(s) => write!(f, "Socks5 {}/TCP", s.peer), }, } } @@ -128,8 +144,9 @@ impl ApiConnectionMode { } } - /// Returns the remote address required to reach the API, or `None` for `ApiConnectionMode::Direct`. - pub fn get_endpoint(&self) -> Option { + /// Returns the remote endpoint required to reach the API, or `None` for + /// `ApiConnectionMode::Direct`. + pub fn get_endpoint(&self) -> Option { match self { ApiConnectionMode::Direct => None, ApiConnectionMode::Proxied(proxy_config) => Some(proxy_config.get_endpoint()), diff --git a/mullvad-api/src/rest.rs b/mullvad-api/src/rest.rs index c3687a1eee9d..53fc118dc962 100644 --- a/mullvad-api/src/rest.rs +++ b/mullvad-api/src/rest.rs @@ -24,7 +24,10 @@ use std::{ sync::{Arc, Weak}, time::Duration, }; -use talpid_types::ErrorExt; +use talpid_types::{ + net::{Endpoint, TransportProtocol}, + ErrorExt, +}; #[cfg(feature = "api-override")] use crate::API; @@ -240,7 +243,10 @@ impl< if let Some(new_config) = self.proxy_config_provider.next().await { let endpoint = match new_config.get_endpoint() { Some(endpoint) => endpoint, - None => self.address_cache.get_address().await, + None => Endpoint::from_socket_address( + self.address_cache.get_address().await, + TransportProtocol::Tcp, + ), }; // Switch to new connection mode unless rejected by address change callback if (self.new_address_callback)(endpoint).await { diff --git a/mullvad-daemon/src/api.rs b/mullvad-daemon/src/api.rs index c548f0a293c4..38bf18ca8cfb 100644 --- a/mullvad-daemon/src/api.rs +++ b/mullvad-daemon/src/api.rs @@ -12,7 +12,6 @@ use mullvad_api::{ use mullvad_relay_selector::RelaySelector; use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod}; use std::{ - net::SocketAddr, path::PathBuf, pin::Pin, sync::{Arc, Mutex, Weak}, @@ -22,7 +21,7 @@ use std::{ use talpid_core::mpsc::Sender; use talpid_core::tunnel_state_machine::TunnelCommand; use talpid_types::{ - net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint, TransportProtocol}, + net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint}, ErrorExt, }; @@ -240,7 +239,7 @@ impl ApiEndpointUpdaterHandle { pub fn callback(&self) -> impl ApiEndpointUpdateCallback { let tunnel_tx = self.tunnel_cmd_tx.clone(); - move |address: SocketAddr| { + move |endpoint: Endpoint| { let inner_tx = tunnel_tx.clone(); async move { let tunnel_tx = if let Some(tunnel_tx) = { inner_tx.lock().unwrap().as_ref() } @@ -253,21 +252,19 @@ impl ApiEndpointUpdaterHandle { }; let (result_tx, result_rx) = oneshot::channel(); let _ = tunnel_tx.unbounded_send(TunnelCommand::AllowEndpoint( - get_allowed_endpoint(address), + get_allowed_endpoint(endpoint), result_tx, )); // Wait for the firewall policy to be updated. let _ = result_rx.await; - log::debug!("API endpoint: {}", address); + log::debug!("API endpoint: {}", endpoint); true } } } } -pub(super) fn get_allowed_endpoint(api_address: SocketAddr) -> AllowedEndpoint { - let endpoint = Endpoint::from_socket_address(api_address, TransportProtocol::Tcp); - +pub(super) fn get_allowed_endpoint(endpoint: Endpoint) -> AllowedEndpoint { #[cfg(windows)] let daemon_exe = std::env::current_exe().expect("failed to obtain executable path"); #[cfg(windows)] diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index 1077185ca38d..2dd76c9b20bb 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -725,8 +725,12 @@ where vec![] }; + // TODO(markus): Solve this more succintly let initial_api_endpoint = - api::get_allowed_endpoint(api_runtime.address_cache.get_address().await); + api::get_allowed_endpoint(talpid_types::net::Endpoint::from_socket_address( + api_runtime.address_cache.get_address().await, + talpid_types::net::TransportProtocol::Tcp, + )); let parameters_generator = tunnel::ParametersGenerator::new( account_manager.clone(), relay_selector.clone(),