Skip to content

Commit

Permalink
Use range types for all port ranges
Browse files Browse the repository at this point in the history
Breaks backwards compatibility with relays.json (which is acceptable)
  • Loading branch information
dlon committed Aug 15, 2024
1 parent aa96f06 commit 6ea0db9
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 94 deletions.
18 changes: 16 additions & 2 deletions mullvad-api/src/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::{
collections::BTreeMap,
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
ops::RangeInclusive,
time::Duration,
};

Expand Down Expand Up @@ -253,15 +254,28 @@ struct Wireguard {
impl From<&Wireguard> for relay_list::WireguardEndpointData {
fn from(wg: &Wireguard) -> Self {
Self {
port_ranges: wg.port_ranges.clone(),
port_ranges: inclusive_range_from_pair_set(wg.port_ranges.clone()).collect(),
ipv4_gateway: wg.ipv4_gateway,
ipv6_gateway: wg.ipv6_gateway,
shadowsocks_port_ranges: wg.shadowsocks_port_ranges.clone(),
shadowsocks_port_ranges: inclusive_range_from_pair_set(
wg.shadowsocks_port_ranges.clone(),
)
.collect(),
udp2tcp_ports: vec![],
}
}
}

fn inclusive_range_from_pair_set<T>(
set: impl IntoIterator<Item = (T, T)>,
) -> impl Iterator<Item = RangeInclusive<T>> {
set.into_iter().map(inclusive_range_from_pair)
}

fn inclusive_range_from_pair<T>(pair: (T, T)) -> RangeInclusive<T> {
RangeInclusive::new(pair.0, pair.1)
}

impl Wireguard {
/// Consumes `self` and appends all its relays to `countries`.
fn extract_relays(
Expand Down
2 changes: 1 addition & 1 deletion mullvad-cli/src/cmds/relay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ impl Relay {
let is_valid_port = wireguard
.port_ranges
.into_iter()
.any(|(first, last)| first <= specific_port && specific_port <= last);
.any(|range| range.contains(&specific_port));
if !is_valid_port {
return Err(anyhow!("The specified port is invalid"));
}
Expand Down
43 changes: 21 additions & 22 deletions mullvad-management-interface/src/types/conversions/relay_list.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
net::{Ipv4Addr, Ipv6Addr},
ops::RangeInclusive,
str::FromStr,
};

Expand Down Expand Up @@ -79,11 +80,11 @@ impl From<mullvad_types::relay_list::WireguardEndpointData> for proto::Wireguard
}
}

impl From<(u16, u16)> for proto::PortRange {
fn from(range: (u16, u16)) -> Self {
impl From<RangeInclusive<u16>> for proto::PortRange {
fn from(range: RangeInclusive<u16>) -> Self {
proto::PortRange {
first: u32::from(range.0),
last: u32::from(range.1),
first: u32::from(*range.start()),
last: u32::from(*range.end()),
}
}
}
Expand Down Expand Up @@ -373,14 +374,8 @@ impl TryFrom<proto::WireguardEndpointData> for mullvad_types::relay_list::Wiregu
let port_ranges = wireguard
.port_ranges
.into_iter()
.map(|range| {
let first = u16::try_from(range.first)
.map_err(|_| FromProtobufTypeError::InvalidArgument("invalid wg port"))?;
let last = u16::try_from(range.last)
.map_err(|_| FromProtobufTypeError::InvalidArgument("invalid wg port"))?;
Ok((first, last))
})
.collect::<Result<Vec<(u16, u16)>, FromProtobufTypeError>>()?;
.map(RangeInclusive::try_from)
.collect::<Result<Vec<_>, FromProtobufTypeError>>()?;

let ipv4_gateway = Ipv4Addr::from_str(&wireguard.ipv4_gateway)
.map_err(|_| FromProtobufTypeError::InvalidArgument("Invalid IPv4 gateway"))?;
Expand All @@ -390,16 +385,8 @@ impl TryFrom<proto::WireguardEndpointData> for mullvad_types::relay_list::Wiregu
let shadowsocks_port_ranges = wireguard
.shadowsocks_port_ranges
.into_iter()
.map(|range| {
let first = u16::try_from(range.first).map_err(|_| {
FromProtobufTypeError::InvalidArgument("invalid shadowsocks port")
})?;
let last = u16::try_from(range.last).map_err(|_| {
FromProtobufTypeError::InvalidArgument("invalid shadowsocks port")
})?;
Ok((first, last))
})
.collect::<Result<Vec<(u16, u16)>, FromProtobufTypeError>>()?;
.map(RangeInclusive::try_from)
.collect::<Result<Vec<_>, FromProtobufTypeError>>()?;

let udp2tcp_ports = wireguard
.udp2tcp_ports
Expand All @@ -419,3 +406,15 @@ impl TryFrom<proto::WireguardEndpointData> for mullvad_types::relay_list::Wiregu
})
}
}

impl TryFrom<proto::PortRange> for RangeInclusive<u16> {
type Error = FromProtobufTypeError;

fn try_from(range: proto::PortRange) -> Result<Self, Self::Error> {
let first = u16::try_from(range.first)
.map_err(|_| FromProtobufTypeError::InvalidArgument("invalid port"))?;
let last = u16::try_from(range.last)
.map_err(|_| FromProtobufTypeError::InvalidArgument("invalid port"))?;
Ok(first..=last)
}
}
86 changes: 39 additions & 47 deletions mullvad-relay-selector/src/relay_selector/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! This module contains various helper functions for the relay selector implementation.
use std::net::{IpAddr, SocketAddr};
use std::{
net::{IpAddr, SocketAddr},
ops::{RangeBounds, RangeInclusive},
};

use mullvad_types::{
constraints::Constraint,
Expand All @@ -18,12 +21,10 @@ use crate::SelectedObfuscator;

/// Port ranges available for WireGuard relays that have extra IPs for Shadowsocks.
/// For relays that have no additional IPs, only ports provided by the relay list are available.
const SHADOWSOCKS_EXTRA_PORT_RANGES: &[(u16, u16)] = &[(1, u16::MAX)];
const SHADOWSOCKS_EXTRA_PORT_RANGES: &[RangeInclusive<u16>] = &[1..=u16::MAX];

#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Port selection algorithm is broken")]
PortSelectionAlgorithm,
#[error("Found no valid port matching the selected settings")]
NoMatchingPort,
}
Expand Down Expand Up @@ -107,7 +108,7 @@ fn get_udp2tcp_obfuscator_port(

pub fn get_shadowsocks_obfuscator(
settings: &ShadowsocksSettings,
non_extra_port_ranges: &[(u16, u16)],
non_extra_port_ranges: &[RangeInclusive<u16>],
relay: Relay,
endpoint: &MullvadWireguardEndpoint,
) -> Result<SelectedObfuscator, Error> {
Expand Down Expand Up @@ -135,9 +136,9 @@ pub fn get_shadowsocks_obfuscator(
/// Return an obfuscation config for the wireguard server at `wg_in_addr` or one of `extra_in_addrs`
/// (unless empty). `wg_in_addr_port_ranges` contains all valid ports for `wg_in_addr`, and
/// `SHADOWSOCKS_EXTRA_PORT_RANGES` contains valid ports for `extra_in_addrs`.
fn get_shadowsocks_obfuscator_inner(
fn get_shadowsocks_obfuscator_inner<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
wg_in_addr: IpAddr,
wg_in_addr_port_ranges: &[(u16, u16)],
wg_in_addr_port_ranges: &[R],
extra_in_addrs: &[IpAddr],
desired_port: Constraint<u16>,
) -> Result<SocketAddr, Error> {
Expand All @@ -154,23 +155,27 @@ fn get_shadowsocks_obfuscator_inner(
.copied()
.unwrap_or(wg_in_addr);

let port_ranges = if extra_in_addrs.is_empty() {
wg_in_addr_port_ranges
let selected_port = if extra_in_addrs.is_empty() {
desired_port_from_range(wg_in_addr_port_ranges, desired_port)
} else {
SHADOWSOCKS_EXTRA_PORT_RANGES
};
desired_port_from_range(SHADOWSOCKS_EXTRA_PORT_RANGES, desired_port)
}?;

Ok(SocketAddr::from((in_ip, selected_port)))
}

let selected_port = match desired_port {
fn desired_port_from_range<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
port_ranges: &[R],
desired_port: Constraint<u16>,
) -> Result<u16, Error> {
match desired_port {
// Selected a specific, in-range port
Constraint::Only(port) if super::helpers::port_in_range(port, port_ranges) => Some(port),
Constraint::Only(port) if port_in_range(port, port_ranges) => Ok(port),
// Selected a specific, out-of-range port
Constraint::Only(_port) => None,
Constraint::Only(_port) => Err(Error::NoMatchingPort),
// Selected no specific port
Constraint::Any => super::helpers::select_random_port(port_ranges).ok(),
Constraint::Any => select_random_port(port_ranges),
}
.ok_or(Error::NoMatchingPort)?;

Ok(SocketAddr::from((in_ip, selected_port)))
}

/// Selects a random port number from a list of provided port ranges.
Expand All @@ -185,45 +190,32 @@ fn get_shadowsocks_obfuscator_inner(
///
/// # Returns
/// - A randomly selected port number within the given ranges.
///
/// # Panic
/// - If port ranges contains no ports, this function panics.
pub fn select_random_port(port_ranges: &[(u16, u16)]) -> Result<u16, Error> {
let get_port_amount = |range: &(u16, u16)| -> u64 { 1 + range.1 as u64 - range.0 as u64 };
let port_amount: u64 = port_ranges.iter().map(get_port_amount).sum();

if port_amount < 1 {
return Err(Error::PortSelectionAlgorithm);
}

let mut port_index = rand::thread_rng().gen_range(0..port_amount);

for range in port_ranges.iter() {
let ports_in_range = get_port_amount(range);
if port_index < ports_in_range {
return Ok(port_index as u16 + range.0);
}
port_index -= ports_in_range;
}
Err(Error::PortSelectionAlgorithm)
}

pub fn port_in_range(port: u16, port_ranges: &[(u16, u16)]) -> bool {
/// - An error if `port_ranges` is empty.
pub fn select_random_port<R: RangeBounds<u16> + Iterator<Item = u16> + Clone>(
port_ranges: &[R],
) -> Result<u16, Error> {
port_ranges
.iter()
.any(|range| (range.0 <= port && port <= range.1))
.cloned()
.flatten()
.choose(&mut rand::thread_rng())
.ok_or(Error::NoMatchingPort)
}

pub fn port_in_range<R: RangeBounds<u16>>(port: u16, port_ranges: &[R]) -> bool {
port_ranges.iter().any(|range| range.contains(&port))
}

#[cfg(test)]
mod tests {
use super::{get_shadowsocks_obfuscator_inner, port_in_range, SHADOWSOCKS_EXTRA_PORT_RANGES};
use mullvad_types::constraints::Constraint;
use std::net::IpAddr;
use std::{net::IpAddr, ops::RangeInclusive};

/// Test whether select ports are available when relay has no extra IPs
#[test]
fn test_shadowsocks_no_extra_addrs() {
const PORT_RANGES: &[(u16, u16)] = &[(100, 200), (1000, 2000)];
const PORT_RANGES: &[RangeInclusive<u16>] = &[100..=200, 1000..=2000];
const WITHIN_RANGE_PORT: u16 = 100;
const OUT_OF_RANGE_PORT: u16 = 1;
let wg_in_ip: IpAddr = "1.2.3.4".parse().unwrap();
Expand Down Expand Up @@ -267,7 +259,7 @@ mod tests {
/// All ports should be available when relay has extra IPs, and only extra IPs should be used
#[test]
fn test_shadowsocks_extra_addrs() {
const PORT_RANGES: &[(u16, u16)] = &[(100, 200), (1000, 2000)];
const PORT_RANGES: &[RangeInclusive<u16>] = &[100..=200, 1000..=2000];
const OUT_OF_RANGE_PORT: u16 = 1;
let wg_in_ip: IpAddr = "1.2.3.4".parse().unwrap();

Expand Down Expand Up @@ -312,7 +304,7 @@ mod tests {
/// Extra addresses that belong to the wrong IP family should be ignored
#[test]
fn test_shadowsocks_irrelevant_extra_addrs() {
const PORT_RANGES: &[(u16, u16)] = &[(100, 200), (1000, 2000)];
const PORT_RANGES: &[RangeInclusive<u16>] = &[100..=200, 1000..=2000];
const IN_RANGE_PORT: u16 = 100;
const OUT_OF_RANGE_PORT: u16 = 1;
let wg_in_ip: IpAddr = "1.2.3.4".parse().unwrap();
Expand Down
8 changes: 3 additions & 5 deletions mullvad-relay-selector/src/relay_selector/matcher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! This module is responsible for filtering the whole relay list based on queries.
use std::collections::HashSet;
use std::{collections::HashSet, ops::RangeInclusive};

use mullvad_types::{
constraints::{Constraint, Match},
Expand Down Expand Up @@ -155,7 +155,7 @@ fn filter_on_obfuscation(

/// Returns whether `relay` satisfies the Shadowsocks filter posed by `port`.
fn filter_on_shadowsocks(
port_ranges: &[(u16, u16)],
port_ranges: &[RangeInclusive<u16>],
ip_version: &Constraint<IpVersion>,
settings: &ShadowsocksSettings,
relay: &Relay,
Expand All @@ -177,9 +177,7 @@ fn filter_on_shadowsocks(
.find(|&&addr| IpVersion::from(addr) == ip_version);

filtered_extra_addrs.is_some()
|| port_ranges
.iter()
.any(|(begin, end)| (*begin..=*end).contains(desired_port))
|| port_ranges.iter().any(|range| range.contains(desired_port))
}

// Otherwise, any relay works.
Expand Down
28 changes: 14 additions & 14 deletions mullvad-relay-selector/tests/relay_selector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,16 @@ static RELAYS: Lazy<RelayList> = Lazy::new(|| RelayList {
},
wireguard: WireguardEndpointData {
port_ranges: vec![
(53, 53),
(443, 443),
(4000, 33433),
(33565, 51820),
(52000, 60000),
53..=53,
443..=443,
4000..=33433,
33565..=51820,
52000..=60000,
],
ipv4_gateway: "10.64.0.1".parse().unwrap(),
ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(),
udp2tcp_ports: vec![],
shadowsocks_port_ranges: vec![(100, 200), (1000, 2000)],
shadowsocks_port_ranges: vec![100..=200, 1000..=2000],
},
});

Expand Down Expand Up @@ -506,16 +506,16 @@ fn test_wireguard_entry() {
},
wireguard: WireguardEndpointData {
port_ranges: vec![
(53, 53),
(443, 443),
(4000, 33433),
(33565, 51820),
(52000, 60000),
53..=53,
443..=443,
4000..=33433,
33565..=51820,
52000..=60000,
],
ipv4_gateway: "10.64.0.1".parse().unwrap(),
ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(),
udp2tcp_ports: vec![],
shadowsocks_port_ranges: vec![(100, 200), (1000, 2000)],
shadowsocks_port_ranges: vec![100..=200, 1000..=2000],
},
};

Expand Down Expand Up @@ -1175,7 +1175,7 @@ fn test_include_in_country() {
shadowsocks: vec![],
},
wireguard: WireguardEndpointData {
port_ranges: vec![(53, 53), (4000, 33433), (33565, 51820), (52000, 60000)],
port_ranges: vec![53..=53, 4000..=33433, 33565..=51820, 52000..=60000],
ipv4_gateway: "10.64.0.1".parse().unwrap(),
ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(),
udp2tcp_ports: vec![],
Expand Down Expand Up @@ -1381,7 +1381,7 @@ fn test_daita() {
shadowsocks: vec![],
},
wireguard: WireguardEndpointData {
port_ranges: vec![(53, 53), (4000, 33433), (33565, 51820), (52000, 60000)],
port_ranges: vec![53..=53, 4000..=33433, 33565..=51820, 52000..=60000],
ipv4_gateway: "10.64.0.1".parse().unwrap(),
ipv6_gateway: "fc00:bbbb:bbbb:bb01::1".parse().unwrap(),
shadowsocks_port_ranges: vec![],
Expand Down
Loading

0 comments on commit 6ea0db9

Please sign in to comment.