Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to select TCP/UDP for localhost SOCKS5 API access method #5320

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::{
ops::Deref,
path::Path,
};
use talpid_types::ErrorExt;
use talpid_types::{net::AllowedEndpoint, ErrorExt};

pub mod availability;
use availability::{ApiAvailability, ApiAvailabilityHandle};
Expand Down Expand Up @@ -221,13 +221,13 @@ pub enum Error {

/// Closure that receives the next API (real or proxy) endpoint to use for `api.mullvad.net`.
/// It should return a future that determines whether to reject the new endpoint or not.
pub trait ApiEndpointUpdateCallback: Fn(SocketAddr) -> Self::AcceptedNewEndpoint {
pub trait ApiEndpointUpdateCallback: Fn(AllowedEndpoint) -> Self::AcceptedNewEndpoint {
type AcceptedNewEndpoint: Future<Output = bool> + Send;
}

impl<U, T: Future<Output = bool> + Send> ApiEndpointUpdateCallback for U
where
U: Fn(SocketAddr) -> T,
U: Fn(AllowedEndpoint) -> T,
{
type AcceptedNewEndpoint = T;
}
Expand Down
69 changes: 51 additions & 18 deletions mullvad-api/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use mullvad_types::access_method;
use serde::{Deserialize, Serialize};
use std::{
fmt, io,
net::SocketAddr,
path::Path,
pin::Pin,
task::{self, Poll},
};
use talpid_types::ErrorExt;
use talpid_types::{
net::{AllowedClients, Endpoint, TransportProtocol},
ErrorExt,
};
use tokio::{
fs,
io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
Expand Down Expand Up @@ -41,32 +43,32 @@ pub enum ProxyConfig {
}

impl ProxyConfig {
/// Returns the remote address to reach the proxy.
fn get_endpoint(&self) -> SocketAddr {
/// Returns the remote endpoint describing how to reach the proxy.
fn get_endpoint(&self) -> Endpoint {
match self {
ProxyConfig::Shadowsocks(ss) => ss.peer,
ProxyConfig::Shadowsocks(shadowsocks) => {
Endpoint::from_socket_address(shadowsocks.peer, TransportProtocol::Tcp)
}
ProxyConfig::Socks(socks) => match socks {
access_method::Socks5::Local(s) => s.remote_peer,
access_method::Socks5::Remote(s) => s.peer,
access_method::Socks5::Local(local) => local.remote_endpoint,
access_method::Socks5::Remote(remote) => {
Endpoint::from_socket_address(remote.peer, TransportProtocol::Tcp)
}
},
}
}
}

impl fmt::Display for ProxyConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let endpoint = self.get_endpoint();
match self {
// TODO: Do not hardcode TCP
ProxyConfig::Shadowsocks(ss) => write!(f, "Shadowsocks {}/TCP", ss.peer),
ProxyConfig::Shadowsocks(_) => write!(f, "Shadowsocks {}", endpoint),
ProxyConfig::Socks(socks) => match socks {
access_method::Socks5::Local(s) => {
write!(
f,
"Socks5 {}/TCP via localhost:{}",
s.remote_peer, s.local_port
)
access_method::Socks5::Remote(_) => write!(f, "Socks5 {}", endpoint),
access_method::Socks5::Local(local) => {
write!(f, "Socks5 {} via localhost:{}", endpoint, local.local_port)
}
access_method::Socks5::Remote(s) => write!(f, "Socks5 {}/TCP", s.peer),
},
}
}
Expand Down Expand Up @@ -132,14 +134,45 @@ impl ApiConnectionMode {
}
}

/// Returns the remote address required to reach the API, or `None` for `ApiConnectionMode::Direct`.
pub fn get_endpoint(&self) -> Option<SocketAddr> {
/// Returns the remote endpoint required to reach the API, or `None` for
/// `ApiConnectionMode::Direct`.
pub fn get_endpoint(&self) -> Option<Endpoint> {
match self {
ApiConnectionMode::Direct => None,
ApiConnectionMode::Proxied(proxy_config) => Some(proxy_config.get_endpoint()),
}
}

#[cfg(unix)]
pub fn allowed_clients(&self) -> AllowedClients {
use access_method::Socks5;
match self {
ApiConnectionMode::Proxied(ProxyConfig::Socks(Socks5::Local(_))) => AllowedClients::All,
ApiConnectionMode::Direct | ApiConnectionMode::Proxied(_) => AllowedClients::Root,
}
}

#[cfg(windows)]
pub fn allowed_clients(&self) -> AllowedClients {
use access_method::Socks5;
match self {
ApiConnectionMode::Proxied(ProxyConfig::Socks(Socks5::Local(_))) => {
AllowedClients::all()
}
ApiConnectionMode::Direct | ApiConnectionMode::Proxied(_) => {
let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
vec![
daemon_exe
.parent()
.expect("missing executable parent directory")
.join("mullvad-problem-report.exe"),
daemon_exe,
]
.into()
}
}
}

pub fn is_proxy(&self) -> bool {
*self != ApiConnectionMode::Direct
}
Expand Down
14 changes: 11 additions & 3 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use std::{
sync::{Arc, Weak},
time::Duration,
};
use talpid_types::ErrorExt;
use talpid_types::{
net::{AllowedEndpoint, Endpoint, TransportProtocol},
ErrorExt,
};

#[cfg(feature = "api-override")]
use crate::API;
Expand Down Expand Up @@ -209,10 +212,15 @@ impl<
if let Some(new_config) = self.proxy_config_provider.next().await {
let endpoint = match new_config.get_endpoint() {
Some(endpoint) => endpoint,
None => self.address_cache.get_address().await,
None => Endpoint::from_socket_address(
self.address_cache.get_address().await,
TransportProtocol::Tcp,
),
};
let clients = new_config.allowed_clients();
let allowed_endpoint = AllowedEndpoint { endpoint, clients };
// Switch to new connection mode unless rejected by address change callback
if (self.new_address_callback)(endpoint).await {
if (self.new_address_callback)(allowed_endpoint).await {
self.connector_handle.set_connection_mode(new_config);
}
}
Expand Down
48 changes: 41 additions & 7 deletions mullvad-cli/src/cmds/api_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, CustomAcce
use std::net::IpAddr;

use clap::{Args, Subcommand};
use talpid_types::net::openvpn::SHADOWSOCKS_CIPHERS;
use talpid_types::net::{openvpn::SHADOWSOCKS_CIPHERS, TransportProtocol};

#[derive(Subcommand, Debug, Clone)]
pub enum ApiAccess {
Expand Down Expand Up @@ -118,10 +118,21 @@ impl ApiAccess {
}
CustomAccessMethod::Socks5(socks) => match socks {
Socks5::Local(local) => {
let remote_ip = cmd.params.ip.unwrap_or(local.remote_peer.ip());
let remote_port = cmd.params.port.unwrap_or(local.remote_peer.port());
let remote_ip = cmd.params.ip.unwrap_or(local.remote_endpoint.address.ip());
let remote_port = cmd
.params
.port
.unwrap_or(local.remote_endpoint.address.port());
let local_port = cmd.params.local_port.unwrap_or(local.local_port);
AccessMethod::from(Socks5Local::new((remote_ip, remote_port), local_port))
let remote_peer_transport_protocol = cmd
.params
.transport_protocol
.unwrap_or(local.remote_endpoint.protocol);
AccessMethod::from(Socks5Local::new_with_transport_protocol(
(remote_ip, remote_port),
local_port,
remote_peer_transport_protocol,
))
}
Socks5::Remote(remote) => {
let ip = cmd.params.ip.unwrap_or(remote.peer.ip());
Expand Down Expand Up @@ -306,6 +317,14 @@ pub enum AddSocks5Commands {
remote_ip: IpAddr,
/// The port of the remote peer
remote_port: u16,
/// The Mullvad App can not know which transport protocol that the
/// remote peer accepts, but it needs to know this in order to correctly
/// exempt the connection traffic in the firewall.
///
/// By default, the transport protocol is assumed to be `TCP`, but it
/// can optionally be set to `UDP` as well.
#[arg(long, default_value_t = TransportProtocol::Tcp)]
transport_protocol: TransportProtocol,
/// Disable the use of this custom access method. It has to be manually
/// enabled at a later stage to be used when accessing the Mullvad API.
#[arg(default_value_t = false, short, long)]
Expand Down Expand Up @@ -398,6 +417,9 @@ pub struct EditParams {
/// The port that the server on localhost is listening on [Socks5 (Local proxy)]
#[arg(long)]
local_port: Option<u16>,
/// The transport protocol used by the remote proxy [Socks5 (Local proxy)]
#[arg(long)]
transport_protocol: Option<TransportProtocol>,
}

/// Implement conversions from CLI types to Daemon types.
Expand All @@ -418,9 +440,15 @@ mod conversions {
remote_port,
name: _,
disabled: _,
transport_protocol,
} => {
println!("Adding SOCKS5-proxy: localhost:{local_port} => {remote_ip}:{remote_port}");
daemon_types::Socks5Local::new((remote_ip, remote_port), local_port).into()
println!("Adding SOCKS5-proxy: localhost:{local_port} => {remote_ip}:{remote_port}/{transport_protocol}");
daemon_types::Socks5Local::new_with_transport_protocol(
(remote_ip, remote_port),
local_port,
transport_protocol,
)
.into()
}
AddSocks5Commands::Remote {
remote_ip,
Expand Down Expand Up @@ -559,7 +587,13 @@ mod pp {
}
writeln!(f)?;
print_option!("Protocol", "Socks5 (local)");
print_option!("Peer", local.remote_peer);
print_option!(
"Peer",
format!(
"{}/{}",
local.remote_endpoint.address, local.remote_endpoint.protocol
)
);
print_option!("Local port", local.local_port);
Ok(())
}
Expand Down
44 changes: 22 additions & 22 deletions mullvad-daemon/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use mullvad_api::{
use mullvad_relay_selector::RelaySelector;
use mullvad_types::access_method::{AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
use std::{
net::SocketAddr,
path::PathBuf,
pin::Pin,
sync::{Arc, Mutex, Weak},
Expand All @@ -22,7 +21,7 @@ use std::{
use talpid_core::mpsc::Sender;
use talpid_core::tunnel_state_machine::TunnelCommand;
use talpid_types::{
net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint, TransportProtocol},
net::{openvpn::ProxySettings, AllowedEndpoint, Endpoint},
ErrorExt,
};

Expand Down Expand Up @@ -240,7 +239,7 @@ impl ApiEndpointUpdaterHandle {

pub fn callback(&self) -> impl ApiEndpointUpdateCallback {
let tunnel_tx = self.tunnel_cmd_tx.clone();
move |address: SocketAddr| {
move |allowed_endpoint: AllowedEndpoint| {
let inner_tx = tunnel_tx.clone();
async move {
let tunnel_tx = if let Some(tunnel_tx) = { inner_tx.lock().unwrap().as_ref() }
Expand All @@ -253,37 +252,38 @@ impl ApiEndpointUpdaterHandle {
};
let (result_tx, result_rx) = oneshot::channel();
let _ = tunnel_tx.unbounded_send(TunnelCommand::AllowEndpoint(
get_allowed_endpoint(address),
allowed_endpoint.clone(),
result_tx,
));
// Wait for the firewall policy to be updated.
let _ = result_rx.await;
log::debug!("API endpoint: {}", address);
log::debug!(
"API endpoint: {endpoint}",
endpoint = allowed_endpoint.endpoint
);
true
}
}
}
}

pub(super) fn get_allowed_endpoint(api_address: SocketAddr) -> AllowedEndpoint {
let endpoint = Endpoint::from_socket_address(api_address, TransportProtocol::Tcp);

pub(super) fn get_allowed_endpoint(endpoint: Endpoint) -> AllowedEndpoint {
#[cfg(unix)]
let clients = talpid_types::net::AllowedClients::Root;
#[cfg(windows)]
let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
#[cfg(windows)]
let clients = vec![
daemon_exe
.parent()
.expect("missing executable parent directory")
.join("mullvad-problem-report.exe"),
daemon_exe,
];
let clients = {
let daemon_exe = std::env::current_exe().expect("failed to obtain executable path");
vec![
daemon_exe
.parent()
.expect("missing executable parent directory")
.join("mullvad-problem-report.exe"),
daemon_exe,
]
.into()
};

AllowedEndpoint {
#[cfg(windows)]
clients,
endpoint,
}
AllowedEndpoint { endpoint, clients }
}

pub(crate) fn forward_offline_state(
Expand Down
5 changes: 4 additions & 1 deletion mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,10 @@ where
};

let initial_api_endpoint =
api::get_allowed_endpoint(api_runtime.address_cache.get_address().await);
api::get_allowed_endpoint(talpid_types::net::Endpoint::from_socket_address(
api_runtime.address_cache.get_address().await,
talpid_types::net::TransportProtocol::Tcp,
));
let parameters_generator = tunnel::ParametersGenerator::new(
account_manager.clone(),
relay_selector.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ message AccessMethod {
message Socks5Local {
string remote_ip = 1;
uint32 remote_port = 2;
uint32 local_port = 3;
TransportProtocol remote_transport_protocol = 3;
uint32 local_port = 4;
}
message SocksAuth {
string username = 1;
Expand Down
Loading
Loading