From 81a31434893fea6245ff3fb9b49462ee1eb3867a Mon Sep 17 00:00:00 2001 From: Markus Pettersson Date: Wed, 20 Mar 2024 14:21:15 +0100 Subject: [PATCH] Remove wrong blanket implementation of `Intersection` --- .../src/relay_selector/query.rs | 71 +++++++++++++++++-- mullvad-types/src/relay_constraints.rs | 2 +- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/mullvad-relay-selector/src/relay_selector/query.rs b/mullvad-relay-selector/src/relay_selector/query.rs index 5a6206b4dc34..ecb4bd3d872f 100644 --- a/mullvad-relay-selector/src/relay_selector/query.rs +++ b/mullvad-relay-selector/src/relay_selector/query.rs @@ -429,11 +429,10 @@ impl From for OpenVpnConstraints { pub trait Intersection { fn intersection(self, other: Self) -> Option where - Self: PartialEq, Self: Sized; } -impl Intersection for Constraint { +impl Intersection for Constraint { /// Define the intersection between two arbitrary [`Constraint`]s. /// /// This operation may be compared to the set operation with the same name. @@ -446,13 +445,75 @@ impl Intersection for Constraint { match (self, other) { (Any, Any) => Some(Any), (Only(t), Any) | (Any, Only(t)) => Some(Only(t)), - // Pick any of `left` or `right` if they are the same. - (Only(left), Only(right)) if left == right => Some(Only(left)), - _ => None, + // Recurse on `left` and `right` to see if there exist an intersection + (Only(left), Only(right)) => Some(Only(left.intersection(right)?)), } } } +// Implement `Intersection` for different types + +impl Intersection for Providers { + fn intersection(self, other: Self) -> Option + where + Self: Sized, + { + Providers::new(self.providers.intersection(&other.providers)).ok() + } +} + +impl Intersection for Udp2TcpObfuscationSettings { + fn intersection(self, other: Self) -> Option + where + Self: Sized, + { + Some(Udp2TcpObfuscationSettings { + port: self.port.intersection(other.port)?, + }) + } +} + +impl Intersection for TransportPort { + fn intersection(self, other: Self) -> Option + where + Self: Sized, + { + let protocol = if self.protocol == other.protocol { + Some(self.protocol) + } else { + None + }?; + let port = self.port.intersection(other.port)?; + Some(TransportPort { protocol, port }) + } +} + +/// Auto-implement `Intersection` for trivial cases where the logic should just check if +/// `self` is equal to `other`. +macro_rules! impl_intersection_partialeq { + ($ty:ty) => { + impl Intersection for $ty { + fn intersection(self, other: Self) -> Option { + if self == other { + Some(self) + } else { + None + } + } + } + }; +} +impl_intersection_partialeq!(u16); +impl_intersection_partialeq!(bool); +// FIXME: [`LocationConstraint`] deserves a hand-rolled implementation of [`Intersection`], but +// it would probably be best to implement it for [`ResolvedLocationConstraint`] instead to properly +// handle custom lists. +impl_intersection_partialeq!(LocationConstraint); +impl_intersection_partialeq!(Ownership); +impl_intersection_partialeq!(talpid_types::net::TransportProtocol); +impl_intersection_partialeq!(talpid_types::net::TunnelType); +impl_intersection_partialeq!(talpid_types::net::IpVersion); + #[allow(unused)] pub mod builder { //! Strongly typed Builder pattern for of relay constraints though the use of the Typestate pattern. diff --git a/mullvad-types/src/relay_constraints.rs b/mullvad-types/src/relay_constraints.rs index 7b027050e82f..ffb13a2792eb 100644 --- a/mullvad-types/src/relay_constraints.rs +++ b/mullvad-types/src/relay_constraints.rs @@ -393,7 +393,7 @@ pub type Provider = String; #[cfg_attr(target_os = "android", jnix(package = "net.mullvad.mullvadvpn.model"))] #[derive(Debug, Clone, Eq, PartialEq, Deserialize, Serialize)] pub struct Providers { - providers: HashSet, + pub providers: HashSet, } /// Returned if the iterator contained no providers.