From ca17aee44561ebd23321c82f1f0d05476d96d49d Mon Sep 17 00:00:00 2001 From: Sebastian Holmin Date: Thu, 28 Nov 2024 11:06:16 +0100 Subject: [PATCH] Replace generic with new type --- talpid-core/src/tunnel/mod.rs | 75 +++++++++---------- .../tunnel_state_machine/connecting_state.rs | 24 ++---- talpid-openvpn/src/lib.rs | 64 ++++++++-------- talpid-tunnel/src/lib.rs | 35 +++++++-- talpid-wireguard/src/lib.rs | 74 +++++++++--------- 5 files changed, 135 insertions(+), 137 deletions(-) diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index 5a12b67eb39e..63bd01c57ab1 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -14,6 +14,9 @@ use talpid_types::{ tunnel::ErrorStateCause, }; +#[cfg(not(target_os = "android"))] +use talpid_tunnel::EventHook; + const OPENVPN_LOG_FILENAME: &str = "openvpn.log"; const WIREGUARD_LOG_FILENAME: &str = "wireguard.log"; @@ -115,23 +118,19 @@ impl Error { } /// Abstraction for monitoring a generic VPN tunnel. -pub struct TunnelMonitor { - monitor: InternalTunnelMonitor, +pub struct TunnelMonitor { + monitor: InternalTunnelMonitor, } // TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor -impl TunnelMonitor -where - L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: std::future::Future + Send + 'static, -{ +impl TunnelMonitor { /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` /// on tunnel state changes. #[cfg_attr(any(target_os = "android", windows), allow(unused_variables))] pub fn start( tunnel_parameters: &TunnelParameters, log_dir: &Option, - args: TunnelArgs<'_, L, F>, + args: TunnelArgs<'_>, ) -> Result { Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?; let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?; @@ -142,7 +141,7 @@ where config, log_file, args.resource_dir, - args.on_event, + args.event_hook, args.tunnel_close_rx, args.route_manager, )), @@ -155,13 +154,33 @@ where } } + /// Returns a path to an executable that communicates with relay servers. + /// Returns `None` if the executable is unknown. + #[cfg(windows)] + pub fn get_relay_client( + resource_dir: &path::Path, + params: &TunnelParameters, + ) -> Option { + use talpid_types::net::proxy::CustomProxy; + + let resource_dir = resource_dir.to_path_buf(); + match params { + TunnelParameters::OpenVpn(params) => match ¶ms.proxy { + Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()), + Some(CustomProxy::Socks5Local(_)) => None, + Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")), + }, + _ => Some(std::env::current_exe().unwrap()), + } + } + fn start_wireguard_tunnel( #[cfg(not(any(target_os = "linux", target_os = "windows")))] params: &wireguard_types::TunnelParameters, #[cfg(any(target_os = "linux", target_os = "windows"))] params: &wireguard_types::TunnelParameters, log: Option, - args: TunnelArgs<'_, L, F>, + args: TunnelArgs<'_>, ) -> Result { let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?; Ok(TunnelMonitor { @@ -174,12 +193,12 @@ where config: &openvpn_types::TunnelParameters, log: Option, resource_dir: &path::Path, - on_event: L, + event_hook: EventHook, tunnel_close_rx: oneshot::Receiver<()>, route_manager: RouteManagerHandle, ) -> Result { let monitor = talpid_openvpn::OpenVpnMonitor::start( - on_event, + event_hook, config, log, resource_dir, @@ -255,39 +274,13 @@ where } } -impl TunnelMonitor<()> { - /// Returns a path to an executable that communicates with relay servers. - /// Returns `None` if the executable is unknown. - #[cfg(windows)] - pub fn get_relay_client( - resource_dir: &path::Path, - params: &TunnelParameters, - ) -> Option { - use talpid_types::net::proxy::CustomProxy; - - let resource_dir = resource_dir.to_path_buf(); - match params { - TunnelParameters::OpenVpn(params) => match ¶ms.proxy { - Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()), - Some(CustomProxy::Socks5Local(_)) => None, - Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")), - }, - _ => Some(std::env::current_exe().unwrap()), - } - } -} - -enum InternalTunnelMonitor { +enum InternalTunnelMonitor { #[cfg(not(target_os = "android"))] OpenVpn(talpid_openvpn::OpenVpnMonitor), - Wireguard(talpid_wireguard::WireguardMonitor), + Wireguard(talpid_wireguard::WireguardMonitor), } -impl InternalTunnelMonitor -where - L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: std::future::Future + Send + 'static, -{ +impl InternalTunnelMonitor { fn wait(self) -> Result<()> { #[cfg(not(target_os = "android"))] let handle = tokio::runtime::Handle::current(); diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 1f1ea1be4fe2..53ef61475e82 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -19,7 +19,9 @@ use std::{ time::{Duration, Instant}, }; use talpid_routing::RouteManagerHandle; -use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; +use talpid_tunnel::{ + tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, +}; use talpid_types::{ net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters}, tunnel::{ErrorStateCause, FirewallPolicyError}, @@ -214,13 +216,7 @@ impl ConnectingState { retry_attempt: u32, ) -> Self { let (event_tx, event_rx) = mpsc::unbounded(); - let on_tunnel_event = move |event| { - let (tx, rx) = oneshot::channel(); - let _ = event_tx.unbounded_send((event, tx)); - async move { - let _ = rx.await; - } - }; + let event_hook = EventHook::new(event_tx); let route_manager = route_manager.clone(); let log_dir = log_dir.clone(); @@ -237,7 +233,7 @@ impl ConnectingState { let args = TunnelArgs { runtime, resource_dir: &resource_dir, - on_event: on_tunnel_event, + event_hook, tunnel_close_rx, tun_provider, retry_attempt, @@ -289,14 +285,10 @@ impl ConnectingState { } } - fn wait_for_tunnel_monitor( - tunnel_monitor: TunnelMonitor, + fn wait_for_tunnel_monitor( + tunnel_monitor: TunnelMonitor, retry_attempt: u32, - ) -> Option - where - L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: std::future::Future + Send + 'static, - { + ) -> Option { match tunnel_monitor.wait() { Ok(_) => None, Err(error) => match error { diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index 0939b2e3434a..16d0d0e4fc6c 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -19,7 +19,7 @@ use std::{ }; #[cfg(target_os = "linux")] use talpid_routing::RequiredRoute; -use talpid_tunnel::TunnelEvent; +use talpid_tunnel::EventHook; use talpid_types::{ net::{openvpn, proxy::CustomProxy}, ErrorExt, @@ -245,17 +245,13 @@ impl WintunContextImpl { impl OpenVpnMonitor { /// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given /// path. - pub async fn start( - on_event: L, + pub async fn start( + event_hook: EventHook, params: &openvpn::TunnelParameters, log_path: Option, resource_dir: &Path, route_manager: talpid_routing::RouteManagerHandle, - ) -> Result - where - L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: std::future::Future + Send + 'static, - { + ) -> Result { let user_pass_file = Self::create_credentials_file(¶ms.config.username, ¶ms.config.password) .map_err(Error::CredentialsWriteError)?; @@ -306,7 +302,7 @@ impl OpenVpnMonitor { cmd, openvpn_init_args, event_server::OpenvpnEventProxyImpl { - on_event, + event_hook, user_pass_file_path: user_pass_file_path.clone(), proxy_auth_file_path: proxy_auth_file_path.clone(), abort_server_tx: event_server_abort_tx, @@ -775,7 +771,7 @@ mod event_server { pin::Pin, task::{Context, Poll}, }; - use talpid_tunnel::TunnelMetadata; + use talpid_tunnel::{EventHook, TunnelMetadata}; #[cfg(any(target_os = "macos", target_os = "windows"))] use talpid_types::net::proxy::CustomProxy; use talpid_types::ErrorExt; @@ -806,8 +802,8 @@ mod event_server { } /// Implements a gRPC service used to process events sent to by OpenVPN. - pub struct OpenvpnEventProxyImpl { - pub on_event: L, + pub struct OpenvpnEventProxyImpl { + pub event_hook: EventHook, pub user_pass_file_path: super::PathBuf, pub proxy_auth_file_path: Option, pub abort_server_tx: triggered::Trigger, @@ -818,21 +814,19 @@ mod event_server { pub ipv6_enabled: bool, } - impl< - L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: std::future::Future, - > OpenvpnEventProxyImpl - { + impl OpenvpnEventProxyImpl { async fn up_inner( &self, request: Request, ) -> std::result::Result, tonic::Status> { let env = request.into_inner().env; - (self.on_event)(talpid_tunnel::TunnelEvent::InterfaceUp( - Self::get_tunnel_metadata(&env)?, - talpid_types::net::AllowedTunnelTraffic::All, - )) - .await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::InterfaceUp( + Self::get_tunnel_metadata(&env)?, + talpid_types::net::AllowedTunnelTraffic::All, + )) + .await; Ok(Response::new(())) } @@ -902,7 +896,10 @@ mod event_server { return Err(tonic::Status::failed_precondition("Failed to add routes")); } - (self.on_event)(talpid_tunnel::TunnelEvent::Up(metadata)).await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::Up(metadata)) + .await; Ok(Response::new(())) } @@ -956,20 +953,18 @@ mod event_server { } #[tonic::async_trait] - impl< - L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: std::future::Future + 'static + Send, - > OpenvpnEventProxy for OpenvpnEventProxyImpl - { + impl OpenvpnEventProxy for OpenvpnEventProxyImpl { async fn auth_failed( &self, request: Request, ) -> std::result::Result, tonic::Status> { let env = request.into_inner().env; - (self.on_event)(talpid_tunnel::TunnelEvent::AuthFailed( - env.get("auth_failed_reason").cloned(), - )) - .await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::AuthFailed( + env.get("auth_failed_reason").cloned(), + )) + .await; Ok(Response::new(())) } @@ -995,7 +990,10 @@ mod event_server { &self, _request: Request, ) -> std::result::Result, tonic::Status> { - (self.on_event)(talpid_tunnel::TunnelEvent::Down).await; + self.event_hook + .clone() + .on_event(talpid_tunnel::TunnelEvent::Down) + .await; Ok(Response::new(())) } } diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs index 9fe6e0074ea8..ddffb0d3026f 100644 --- a/talpid-tunnel/src/lib.rs +++ b/talpid-tunnel/src/lib.rs @@ -1,5 +1,4 @@ use std::{ - future::Future, net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::Path, sync::{Arc, Mutex}, @@ -10,7 +9,13 @@ use std::{ pub mod network_interface; pub mod tun_provider; -use futures::channel::oneshot; +use futures::{ + channel::{ + mpsc::UnboundedSender, + oneshot::{self, Sender}, + }, + SinkExt, +}; use talpid_routing::RouteManagerHandle; use talpid_types::net::AllowedTunnelTraffic; use tun_provider::TunProvider; @@ -29,17 +34,13 @@ pub const MIN_IPV4_MTU: u16 = 576; pub const MIN_IPV6_MTU: u16 = 1280; /// Arguments for creating a tunnel. -pub struct TunnelArgs<'a, L, F> -where - L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static, - F: Future, -{ +pub struct TunnelArgs<'a> { /// Tokio runtime handle. pub runtime: tokio::runtime::Handle, /// Resource directory path. pub resource_dir: &'a Path, /// Callback function called when an event happens. - pub on_event: L, + pub event_hook: EventHook, /// Receiver oneshot channel for closing the tunnel. pub tunnel_close_rx: oneshot::Receiver<()>, /// Mutex to tunnel provider. @@ -50,6 +51,24 @@ where pub route_manager: RouteManagerHandle, } +#[derive(Clone)] +pub struct EventHook { + event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>, +} + +impl EventHook { + pub fn new(event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>) -> Self { + Self { event_tx } + } + + pub async fn on_event(&mut self, event: TunnelEvent) { + let (tx, rx) = oneshot::channel::<()>(); + if let Ok(()) = self.event_tx.send((event, tx)).await { + let _ = rx.await; + } + } +} + /// Information about a VPN tunnel. #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct TunnelMetadata { diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index fa52717e5011..c3cf9a554fa8 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -24,7 +24,9 @@ use std::{env, sync::LazyLock}; use talpid_routing::{self, RequiredRoute}; #[cfg(not(windows))] use talpid_tunnel::tun_provider; -use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; +use talpid_tunnel::{ + tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata, +}; use talpid_types::{ net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol}, @@ -134,12 +136,12 @@ impl Error { } /// Spawns and monitors a wireguard tunnel -pub struct WireguardMonitor { +pub struct WireguardMonitor { runtime: tokio::runtime::Handle, /// Tunnel implementation tunnel: Arc>>, /// Callback to signal tunnel events - event_callback: F, + event_hook: EventHook, close_msg_receiver: sync_mpsc::Receiver, pinger_stop_sender: sync_mpsc::Sender<()>, obfuscator: Arc>>, @@ -153,18 +155,14 @@ static FORCE_USERSPACE_WIREGUARD: LazyLock = LazyLock::new(|| { .unwrap_or(false) }); -impl WireguardMonitor -where - F: (Fn(TunnelEvent) -> Fut) + Send + Sync + Clone + 'static, - Fut: Future + Send, -{ +impl WireguardMonitor { /// Starts a WireGuard tunnel with the given config #[cfg(not(target_os = "android"))] pub fn start( params: &TunnelParameters, log_path: Option<&Path>, - args: TunnelArgs<'_, F, Fut>, - ) -> Result> { + args: TunnelArgs<'_>, + ) -> Result { #[cfg(any(target_os = "windows", target_os = "linux"))] let desired_mtu = args .runtime @@ -222,12 +220,13 @@ where let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), - event_callback: args.on_event.clone(), + event_hook: args.event_hook.clone(), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, obfuscator, }; + let mut event_hook = args.event_hook.clone(); let moved_tunnel = monitor.tunnel.clone(); let moved_close_obfs_sender = close_obfs_sender.clone(); let moved_obfuscator = monitor.obfuscator.clone(); @@ -242,7 +241,9 @@ where let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - (args.on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; + event_hook + .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + .await; // Add non-default routes before establishing the tunnel. #[cfg(target_os = "linux")] @@ -274,11 +275,12 @@ where .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (args.on_event)(TunnelEvent::InterfaceUp( - metadata, - Self::allowed_traffic_after_tunnel_config(), - )) - .await; + event_hook + .on_event(TunnelEvent::InterfaceUp( + metadata, + Self::allowed_traffic_after_tunnel_config(), + )) + .await; } if detect_mtu { @@ -343,7 +345,7 @@ where .map_err(CloseMsg::SetupError)?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (args.on_event)(TunnelEvent::Up(metadata)).await; + event_hook.on_event(TunnelEvent::Up(metadata)).await; let monitored_tunnel = Arc::downgrade(&tunnel); tokio::task::spawn_blocking(move || { @@ -388,16 +390,10 @@ where /// being ready to serve traffic. /// - No routes are configured on android. #[cfg(target_os = "android")] - pub fn start< - F: (Fn(TunnelEvent) -> Pin + Send>>) - + Send - + Sync - + Clone - + 'static, - >( + pub fn start( params: &TunnelParameters, log_path: Option<&Path>, - args: TunnelArgs<'_, F>, + args: TunnelArgs<'_>, ) -> Result { let desired_mtu = get_desired_mtu(params); let mut config = @@ -441,10 +437,11 @@ where let iface_name = tunnel.get_interface_name(); let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); + let mut event_hook = args.event_hook; let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::clone(&tunnel), - event_callback: Box::new(args.on_event.clone()), + event_hook: event_hook.clone(), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, obfuscator: Arc::new(AsyncMutex::new(obfuscator)), @@ -458,7 +455,8 @@ where let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - args.on_event.clone()(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) + event_hook + .on_event(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)) .await; if should_negotiate_ephemeral_peer { @@ -475,15 +473,16 @@ where .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - args.on_event.clone()(TunnelEvent::InterfaceUp( - metadata, - Self::allowed_traffic_after_tunnel_config(), - )) - .await; + event_hook + .on_event(TunnelEvent::InterfaceUp( + metadata, + Self::allowed_traffic_after_tunnel_config(), + )) + .await; } let metadata = Self::tunnel_metadata(&iface_name, &config); - args.on_event.clone()(TunnelEvent::Up(metadata)).await; + event_hook.on_event(TunnelEvent::Up(metadata)).await; // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it let connectivity_check = { @@ -795,7 +794,7 @@ where let _ = self.pinger_stop_sender.send(()); self.runtime - .block_on((self.event_callback)(TunnelEvent::Down)); + .block_on(self.event_hook.on_event(TunnelEvent::Down)); self.stop_tunnel(); @@ -902,10 +901,7 @@ where fn get_post_tunnel_routes<'a>( iface_name: &str, config: &'a Config, - ) -> impl Iterator + 'a - where - Fut: 'a, - { + ) -> impl Iterator + 'a { let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); let iter = config .get_tunnel_destinations()