diff --git a/talpid-core/src/tunnel/mod.rs b/talpid-core/src/tunnel/mod.rs index db2d83346a42..5a12b67eb39e 100644 --- a/talpid-core/src/tunnel/mod.rs +++ b/talpid-core/src/tunnel/mod.rs @@ -9,8 +9,10 @@ use talpid_tunnel::tun_provider; pub use talpid_tunnel::{TunnelArgs, TunnelEvent, TunnelMetadata}; #[cfg(not(target_os = "android"))] use talpid_types::net::openvpn as openvpn_types; -use talpid_types::net::{wireguard as wireguard_types, TunnelParameters}; -use talpid_types::tunnel::ErrorStateCause; +use talpid_types::{ + net::{wireguard as wireguard_types, TunnelParameters}, + tunnel::ErrorStateCause, +}; const OPENVPN_LOG_FILENAME: &str = "openvpn.log"; const WIREGUARD_LOG_FILENAME: &str = "wireguard.log"; @@ -113,27 +115,24 @@ impl Error { } /// Abstraction for monitoring a generic VPN tunnel. -pub struct TunnelMonitor { - monitor: InternalTunnelMonitor, +pub struct TunnelMonitor { + monitor: InternalTunnelMonitor, } // TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor -impl TunnelMonitor { +impl TunnelMonitor +where + L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: std::future::Future + Send + 'static, +{ /// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event` /// on tunnel state changes. #[cfg_attr(any(target_os = "android", windows), allow(unused_variables))] - pub fn start( + pub fn start( tunnel_parameters: &TunnelParameters, log_dir: &Option, - args: TunnelArgs<'_, L>, - ) -> Result - where - L: (Fn(TunnelEvent) -> std::pin::Pin + Send>>) - + Send - + Clone - + Sync - + 'static, - { + args: TunnelArgs<'_, L, F>, + ) -> Result { Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?; let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?; @@ -156,41 +155,14 @@ impl TunnelMonitor { } } - /// Returns a path to an executable that communicates with relay servers. - /// Returns `None` if the executable is unknown. - #[cfg(windows)] - pub fn get_relay_client( - resource_dir: &path::Path, - params: &TunnelParameters, - ) -> Option { - use talpid_types::net::proxy::CustomProxy; - - let resource_dir = resource_dir.to_path_buf(); - match params { - TunnelParameters::OpenVpn(params) => match ¶ms.proxy { - Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()), - Some(CustomProxy::Socks5Local(_)) => None, - Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")), - }, - _ => Some(std::env::current_exe().unwrap()), - } - } - - fn start_wireguard_tunnel( + fn start_wireguard_tunnel( #[cfg(not(any(target_os = "linux", target_os = "windows")))] params: &wireguard_types::TunnelParameters, #[cfg(any(target_os = "linux", target_os = "windows"))] params: &wireguard_types::TunnelParameters, log: Option, - args: TunnelArgs<'_, L>, - ) -> Result - where - L: (Fn(TunnelEvent) -> std::pin::Pin + Send>>) - + Send - + Sync - + Clone - + 'static, - { + args: TunnelArgs<'_, L, F>, + ) -> Result { let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?; Ok(TunnelMonitor { monitor: InternalTunnelMonitor::Wireguard(monitor), @@ -198,20 +170,14 @@ impl TunnelMonitor { } #[cfg(not(target_os = "android"))] - async fn start_openvpn_tunnel( + async fn start_openvpn_tunnel( config: &openvpn_types::TunnelParameters, log: Option, resource_dir: &path::Path, on_event: L, tunnel_close_rx: oneshot::Receiver<()>, route_manager: RouteManagerHandle, - ) -> Result - where - L: (Fn(TunnelEvent) -> std::pin::Pin + Send>>) - + Send - + Sync - + 'static, - { + ) -> Result { let monitor = talpid_openvpn::OpenVpnMonitor::start( on_event, config, @@ -289,13 +255,39 @@ impl TunnelMonitor { } } -enum InternalTunnelMonitor { +impl TunnelMonitor<()> { + /// Returns a path to an executable that communicates with relay servers. + /// Returns `None` if the executable is unknown. + #[cfg(windows)] + pub fn get_relay_client( + resource_dir: &path::Path, + params: &TunnelParameters, + ) -> Option { + use talpid_types::net::proxy::CustomProxy; + + let resource_dir = resource_dir.to_path_buf(); + match params { + TunnelParameters::OpenVpn(params) => match ¶ms.proxy { + Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()), + Some(CustomProxy::Socks5Local(_)) => None, + Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")), + }, + _ => Some(std::env::current_exe().unwrap()), + } + } +} + +enum InternalTunnelMonitor { #[cfg(not(target_os = "android"))] OpenVpn(talpid_openvpn::OpenVpnMonitor), - Wireguard(talpid_wireguard::WireguardMonitor), + Wireguard(talpid_wireguard::WireguardMonitor), } -impl InternalTunnelMonitor { +impl InternalTunnelMonitor +where + L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: std::future::Future + Send + 'static, +{ fn wait(self) -> Result<()> { #[cfg(not(target_os = "android"))] let handle = tokio::runtime::Handle::current(); diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 7387d51c03f4..1f1ea1be4fe2 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -214,14 +214,13 @@ impl ConnectingState { retry_attempt: u32, ) -> Self { let (event_tx, event_rx) = mpsc::unbounded(); - let on_tunnel_event = - move |event| -> std::pin::Pin + Send>> { - let (tx, rx) = oneshot::channel(); - let _ = event_tx.unbounded_send((event, tx)); - Box::pin(async move { - let _ = rx.await; - }) - }; + let on_tunnel_event = move |event| { + let (tx, rx) = oneshot::channel(); + let _ = event_tx.unbounded_send((event, tx)); + async move { + let _ = rx.await; + } + }; let route_manager = route_manager.clone(); let log_dir = log_dir.clone(); @@ -290,10 +289,14 @@ impl ConnectingState { } } - fn wait_for_tunnel_monitor( - tunnel_monitor: TunnelMonitor, + fn wait_for_tunnel_monitor( + tunnel_monitor: TunnelMonitor, retry_attempt: u32, - ) -> Option { + ) -> Option + where + L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: std::future::Future + Send + 'static, + { match tunnel_monitor.wait() { Ok(_) => None, Err(error) => match error { diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index 421c2076efaa..0939b2e3434a 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -245,7 +245,7 @@ impl WintunContextImpl { impl OpenVpnMonitor { /// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given /// path. - pub async fn start( + pub async fn start( on_event: L, params: &openvpn::TunnelParameters, log_path: Option, @@ -253,10 +253,8 @@ impl OpenVpnMonitor { route_manager: talpid_routing::RouteManagerHandle, ) -> Result where - L: (Fn(TunnelEvent) -> std::pin::Pin + Send>>) - + Send - + Sync - + 'static, + L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: std::future::Future + Send + 'static, { let user_pass_file = Self::create_credentials_file(¶ms.config.username, ¶ms.config.password) @@ -808,14 +806,7 @@ mod event_server { } /// Implements a gRPC service used to process events sent to by OpenVPN. - pub struct OpenvpnEventProxyImpl< - L: (Fn( - talpid_tunnel::TunnelEvent, - ) -> std::pin::Pin + Send>>) - + Send - + Sync - + 'static, - > { + pub struct OpenvpnEventProxyImpl { pub on_event: L, pub user_pass_file_path: super::PathBuf, pub proxy_auth_file_path: Option, @@ -828,13 +819,8 @@ mod event_server { } impl< - L: (Fn( - talpid_tunnel::TunnelEvent, - ) - -> std::pin::Pin + Send>>) - + Send - + Sync - + 'static, + L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: std::future::Future, > OpenvpnEventProxyImpl { async fn up_inner( @@ -971,13 +957,8 @@ mod event_server { #[tonic::async_trait] impl< - L: (Fn( - talpid_tunnel::TunnelEvent, - ) - -> std::pin::Pin + Send>>) - + Send - + Sync - + 'static, + L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: std::future::Future + 'static + Send, > OpenvpnEventProxy for OpenvpnEventProxyImpl { async fn auth_failed( diff --git a/talpid-tunnel/src/lib.rs b/talpid-tunnel/src/lib.rs index 53a746d10244..9fe6e0074ea8 100644 --- a/talpid-tunnel/src/lib.rs +++ b/talpid-tunnel/src/lib.rs @@ -1,4 +1,5 @@ use std::{ + future::Future, net::{IpAddr, Ipv4Addr, Ipv6Addr}, path::Path, sync::{Arc, Mutex}, @@ -9,7 +10,7 @@ use std::{ pub mod network_interface; pub mod tun_provider; -use futures::{channel::oneshot, future::BoxFuture}; +use futures::channel::oneshot; use talpid_routing::RouteManagerHandle; use talpid_types::net::AllowedTunnelTraffic; use tun_provider::TunProvider; @@ -28,9 +29,10 @@ pub const MIN_IPV4_MTU: u16 = 576; pub const MIN_IPV6_MTU: u16 = 1280; /// Arguments for creating a tunnel. -pub struct TunnelArgs<'a, L> +pub struct TunnelArgs<'a, L, F> where - L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static, + L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static, + F: Future, { /// Tokio runtime handle. pub runtime: tokio::runtime::Handle, diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 06726f9c69c0..fa52717e5011 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -5,7 +5,7 @@ use self::config::Config; #[cfg(windows)] use futures::channel::mpsc; -use futures::future::{BoxFuture, Future}; +use futures::future::Future; use obfuscation::ObfuscatorHandle; #[cfg(target_os = "android")] use std::borrow::Cow; @@ -26,9 +26,8 @@ use talpid_routing::{self, RequiredRoute}; use talpid_tunnel::tun_provider; use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata}; -use talpid_types::net::wireguard::TunnelParameters; use talpid_types::{ - net::{AllowedTunnelTraffic, Endpoint, TransportProtocol}, + net::{wireguard::TunnelParameters, AllowedTunnelTraffic, Endpoint, TransportProtocol}, BoxedError, ErrorExt, }; use tokio::sync::Mutex as AsyncMutex; @@ -60,7 +59,6 @@ type TunnelType = Box; type TunnelType = WgGoTunnel; type Result = std::result::Result; -type EventCallback = Box BoxFuture<'static, ()>) + Send + Sync + 'static>; /// Errors that can happen in the Wireguard tunnel monitor. #[derive(thiserror::Error, Debug)] @@ -136,12 +134,12 @@ impl Error { } /// Spawns and monitors a wireguard tunnel -pub struct WireguardMonitor { +pub struct WireguardMonitor { runtime: tokio::runtime::Handle, /// Tunnel implementation tunnel: Arc>>, /// Callback to signal tunnel events - event_callback: EventCallback, + event_callback: F, close_msg_receiver: sync_mpsc::Receiver, pinger_stop_sender: sync_mpsc::Sender<()>, obfuscator: Arc>>, @@ -155,22 +153,18 @@ static FORCE_USERSPACE_WIREGUARD: LazyLock = LazyLock::new(|| { .unwrap_or(false) }); -impl WireguardMonitor { +impl WireguardMonitor +where + F: (Fn(TunnelEvent) -> Fut) + Send + Sync + Clone + 'static, + Fut: Future + Send, +{ /// Starts a WireGuard tunnel with the given config #[cfg(not(target_os = "android"))] - pub fn start< - F: (Fn(TunnelEvent) -> Pin + Send>>) - + Send - + Sync - + Clone - + 'static, - >( + pub fn start( params: &TunnelParameters, log_path: Option<&Path>, - args: TunnelArgs<'_, F>, - ) -> Result { - let on_event = args.on_event.clone(); - + args: TunnelArgs<'_, F, Fut>, + ) -> Result> { #[cfg(any(target_os = "windows", target_os = "linux"))] let desired_mtu = args .runtime @@ -225,11 +219,10 @@ impl WireguardMonitor { .map_err(Error::ConnectivityMonitorError)? .with_cancellation(); - let event_callback = Box::new(on_event.clone()); let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), - event_callback, + event_callback: args.on_event.clone(), close_msg_receiver: close_obfs_listener, pinger_stop_sender: pinger_tx, obfuscator, @@ -249,7 +242,7 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); let allowed_traffic = Self::allowed_traffic_during_tunnel_config(&config); - (on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; + (args.on_event)(TunnelEvent::InterfaceUp(metadata.clone(), allowed_traffic)).await; // Add non-default routes before establishing the tunnel. #[cfg(target_os = "linux")] @@ -281,7 +274,7 @@ impl WireguardMonitor { .await?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (on_event)(TunnelEvent::InterfaceUp( + (args.on_event)(TunnelEvent::InterfaceUp( metadata, Self::allowed_traffic_after_tunnel_config(), )) @@ -350,7 +343,7 @@ impl WireguardMonitor { .map_err(CloseMsg::SetupError)?; let metadata = Self::tunnel_metadata(&iface_name, &config); - (on_event)(TunnelEvent::Up(metadata)).await; + (args.on_event)(TunnelEvent::Up(metadata)).await; let monitored_tunnel = Arc::downgrade(&tunnel); tokio::task::spawn_blocking(move || { @@ -568,7 +561,6 @@ impl WireguardMonitor { /// Replace `0.0.0.0/0`/`::/0` with the gateway IPs when `gateway_only` is true. /// Used to block traffic to other destinations while connecting on Android. - /// #[cfg(target_os = "android")] fn patch_allowed_ips(config: &Config, gateway_only: bool) -> Cow<'_, Config> { if gateway_only { @@ -910,7 +902,10 @@ impl WireguardMonitor { fn get_post_tunnel_routes<'a>( iface_name: &str, config: &'a Config, - ) -> impl Iterator + 'a { + ) -> impl Iterator + 'a + where + Fut: 'a, + { let (node_v4, node_v6) = Self::get_tunnel_nodes(iface_name, config); let iter = config .get_tunnel_destinations()