Skip to content

Commit

Permalink
Use access method-specific transport protocol for firewall exemption.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Oct 18, 2023
1 parent ff90ea7 commit 8b0bf1e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 28 deletions.
6 changes: 3 additions & 3 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{
ops::Deref,
path::Path,
};
use talpid_types::ErrorExt;
use talpid_types::{net::Endpoint, ErrorExt};

pub mod availability;
use availability::{ApiAvailability, ApiAvailabilityHandle};
Expand Down Expand Up @@ -221,13 +221,13 @@ pub enum Error {

/// Closure that receives the next API (real or proxy) endpoint to use for `api.mullvad.net`.
/// It should return a future that determines whether to reject the new endpoint or not.
pub trait ApiEndpointUpdateCallback: Fn(SocketAddr) -> Self::AcceptedNewEndpoint {
pub trait ApiEndpointUpdateCallback: Fn(Endpoint) -> Self::AcceptedNewEndpoint {
type AcceptedNewEndpoint: Future<Output = bool> + Send;
}

impl<U, T: Future<Output = bool> + Send> ApiEndpointUpdateCallback for U
where
U: Fn(SocketAddr) -> T,
U: Fn(Endpoint) -> T,
{
type AcceptedNewEndpoint = T;
}
Expand Down
45 changes: 31 additions & 14 deletions mullvad-api/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use mullvad_types::access_method;
use serde::{Deserialize, Serialize};
use std::{
fmt, io,
net::SocketAddr,
path::Path,
pin::Pin,
task::{self, Poll},
};
use talpid_types::ErrorExt;
use talpid_types::{
net::{Endpoint, TransportProtocol},
ErrorExt,
};
use tokio::{
fs,
io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
Expand Down Expand Up @@ -41,28 +43,42 @@ pub enum ProxyConfig {
}

impl ProxyConfig {
/// Returns the remote address to reach the proxy.
fn get_endpoint(&self) -> SocketAddr {
/// Returns the remote endpoint describing how to reach the proxy.
fn get_endpoint(&self) -> Endpoint {
match self {
ProxyConfig::Shadowsocks(ss) => ss.peer,
ProxyConfig::Shadowsocks(shadowsocks) => {
Endpoint::from_socket_address(shadowsocks.peer, TransportProtocol::Tcp)
}
ProxyConfig::Socks(socks) => match socks {
access_method::Socks5::Local(s) => s.peer,
access_method::Socks5::Remote(s) => s.peer,
access_method::Socks5::Local(local) => {
Endpoint::from_socket_address(local.peer, local.peer_transport_protol)
}
access_method::Socks5::Remote(remote) => {
Endpoint::from_socket_address(remote.peer, TransportProtocol::Tcp)
}
},
}
}
}

impl fmt::Display for ProxyConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let endpoint = self.get_endpoint();
match self {
// TODO: Do not hardcode TCP
ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer),
ProxyConfig::Shadowsocks(_) => {
write!(f, "Shadowsocks {}/{}", endpoint.address, endpoint.protocol)
}
ProxyConfig::Socks(socks) => match socks {
access_method::Socks5::Local(s) => {
write!(f, "Socks5 {}/TCP via localhost:{}", s.peer, s.port)
access_method::Socks5::Local(local) => {
write!(
f,
"Socks5 {}/{} via localhost:{}",
endpoint.address, endpoint.protocol, local.port
)
}
access_method::Socks5::Remote(_) => {
write!(f, "Socks5 {}/{}", endpoint.address, endpoint.protocol)
}
access_method::Socks5::Remote(s) => write!(f, "Socks5 {}/TCP", s.peer),
},
}
}
Expand Down Expand Up @@ -128,8 +144,9 @@ impl ApiConnectionMode {
}
}

/// Returns the remote address required to reach the API, or `None` for `ApiConnectionMode::Direct`.
pub fn get_endpoint(&self) -> Option<SocketAddr> {
/// Returns the remote endpoint required to reach the API, or `None` for
/// `ApiConnectionMode::Direct`.
pub fn get_endpoint(&self) -> Option<Endpoint> {
match self {
ApiConnectionMode::Direct => None,
ApiConnectionMode::Proxied(proxy_config) => Some(proxy_config.get_endpoint()),
Expand Down
10 changes: 8 additions & 2 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use std::{
sync::{Arc, Weak},
time::Duration,
};
use talpid_types::ErrorExt;
use talpid_types::{
net::{Endpoint, TransportProtocol},
ErrorExt,
};

#[cfg(feature = "api-override")]
use crate::API;
Expand Down Expand Up @@ -240,7 +243,10 @@ impl<
if let Some(new_config) = self.proxy_config_provider.next().await {
let endpoint = match new_config.get_endpoint() {
Some(endpoint) => endpoint,
None => self.address_cache.get_address().await,
None => Endpoint::from_socket_address(
self.address_cache.get_address().await,
TransportProtocol::Tcp,
),
};
// Switch to new connection mode unless rejected by address change callback
if (self.new_address_callback)(endpoint).await {
Expand Down
13 changes: 5 additions & 8 deletions mullvad-daemon/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use mullvad_api::{
use mullvad_relay_selector::RelaySelector;
use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
use std::{
net::SocketAddr,
path::PathBuf,
pin::Pin,
sync::{Arc, Mutex, Weak},
Expand All @@ -22,7 +21,7 @@ use std::{
use talpid_core::mpsc::Sender;
use talpid_core::tunnel_state_machine::TunnelCommand;
use talpid_types::{
net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint, TransportProtocol},
net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint},
ErrorExt,
};

Expand Down Expand Up @@ -240,7 +239,7 @@ impl ApiEndpointUpdaterHandle {

pub fn callback(&self) -> impl ApiEndpointUpdateCallback {
let tunnel_tx = self.tunnel_cmd_tx.clone();
move |address: SocketAddr| {
move |endpoint: Endpoint| {
let inner_tx = tunnel_tx.clone();
async move {
let tunnel_tx = if let Some(tunnel_tx) = { inner_tx.lock().unwrap().as_ref() }
Expand All @@ -253,21 +252,19 @@ impl ApiEndpointUpdaterHandle {
};
let (result_tx, result_rx) = oneshot::channel();
let _ = tunnel_tx.unbounded_send(TunnelCommand::AllowEndpoint(
get_allowed_endpoint(address),
get_allowed_endpoint(endpoint),
result_tx,
));
// Wait for the firewall policy to be updated.
let _ = result_rx.await;
log::debug!("API endpoint: {}", address);
log::debug!("API endpoint: {}", endpoint);
true
}
}
}
}

pub(super) fn get_allowed_endpoint(api_address: SocketAddr) -> AllowedEndpoint {
let endpoint = Endpoint::from_socket_address(api_address, TransportProtocol::Tcp);

pub(super) fn get_allowed_endpoint(endpoint: Endpoint) -> AllowedEndpoint {
#[cfg(windows)]
let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
#[cfg(windows)]
Expand Down
6 changes: 5 additions & 1 deletion mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,12 @@ where
vec![]
};

// TODO(markus): Solve this more succintly
let initial_api_endpoint =
api::get_allowed_endpoint(api_runtime.address_cache.get_address().await);
api::get_allowed_endpoint(talpid_types::net::Endpoint::from_socket_address(
api_runtime.address_cache.get_address().await,
talpid_types::net::TransportProtocol::Tcp,
));
let parameters_generator = tunnel::ParametersGenerator::new(
account_manager.clone(),
relay_selector.clone(),
Expand Down

0 comments on commit 8b0bf1e

Please sign in to comment.