Skip to content

Commit

Permalink
Replace dyn fn with generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Serock3 committed Dec 2, 2024
1 parent db019b7 commit 96bceb6
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 122 deletions.
104 changes: 48 additions & 56 deletions talpid-core/src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ use talpid_tunnel::tun_provider;
pub use talpid_tunnel::{TunnelArgs, TunnelEvent, TunnelMetadata};
#[cfg(not(target_os = "android"))]
use talpid_types::net::openvpn as openvpn_types;
use talpid_types::net::{wireguard as wireguard_types, TunnelParameters};
use talpid_types::tunnel::ErrorStateCause;
use talpid_types::{
net::{wireguard as wireguard_types, TunnelParameters},
tunnel::ErrorStateCause,
};

const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";
Expand Down Expand Up @@ -113,27 +115,24 @@ impl Error {
}

/// Abstraction for monitoring a generic VPN tunnel.
pub struct TunnelMonitor {
monitor: InternalTunnelMonitor,
pub struct TunnelMonitor<F> {
monitor: InternalTunnelMonitor<F>,
}

// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
impl TunnelMonitor {
impl<L, F> TunnelMonitor<L>
where
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
/// on tunnel state changes.
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
pub fn start<L>(
pub fn start(
tunnel_parameters: &TunnelParameters,
log_dir: &Option<path::PathBuf>,
args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Clone
+ Sync
+ 'static,
{
args: TunnelArgs<'_, L, F>,
) -> Result<Self> {
Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?;

Expand All @@ -156,62 +155,29 @@ impl TunnelMonitor {
}
}

/// Returns a path to an executable that communicates with relay servers.
/// Returns `None` if the executable is unknown.
#[cfg(windows)]
pub fn get_relay_client(
resource_dir: &path::Path,
params: &TunnelParameters,
) -> Option<path::PathBuf> {
use talpid_types::net::proxy::CustomProxy;

let resource_dir = resource_dir.to_path_buf();
match params {
TunnelParameters::OpenVpn(params) => match &params.proxy {
Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
Some(CustomProxy::Socks5Local(_)) => None,
Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
},
_ => Some(std::env::current_exe().unwrap()),
}
}

fn start_wireguard_tunnel<L>(
fn start_wireguard_tunnel(
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
params: &wireguard_types::TunnelParameters,
#[cfg(any(target_os = "linux", target_os = "windows"))]
params: &wireguard_types::TunnelParameters,
log: Option<path::PathBuf>,
args: TunnelArgs<'_, L>,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ Clone
+ 'static,
{
args: TunnelArgs<'_, L, F>,
) -> Result<Self> {
let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?;
Ok(TunnelMonitor {
monitor: InternalTunnelMonitor::Wireguard(monitor),
})
}

#[cfg(not(target_os = "android"))]
async fn start_openvpn_tunnel<L>(
async fn start_openvpn_tunnel(
config: &openvpn_types::TunnelParameters,
log: Option<path::PathBuf>,
resource_dir: &path::Path,
on_event: L,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
{
) -> Result<Self> {
let monitor = talpid_openvpn::OpenVpnMonitor::start(
on_event,
config,
Expand Down Expand Up @@ -289,13 +255,39 @@ impl TunnelMonitor {
}
}

enum InternalTunnelMonitor {
impl TunnelMonitor<()> {
/// Returns a path to an executable that communicates with relay servers.
/// Returns `None` if the executable is unknown.
#[cfg(windows)]
pub fn get_relay_client(
resource_dir: &path::Path,
params: &TunnelParameters,
) -> Option<path::PathBuf> {
use talpid_types::net::proxy::CustomProxy;

let resource_dir = resource_dir.to_path_buf();
match params {
TunnelParameters::OpenVpn(params) => match &params.proxy {
Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
Some(CustomProxy::Socks5Local(_)) => None,
Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
},
_ => Some(std::env::current_exe().unwrap()),
}
}
}

enum InternalTunnelMonitor<F> {
#[cfg(not(target_os = "android"))]
OpenVpn(talpid_openvpn::OpenVpnMonitor),
Wireguard(talpid_wireguard::WireguardMonitor),
Wireguard(talpid_wireguard::WireguardMonitor<F>),
}

impl InternalTunnelMonitor {
impl<L, F> InternalTunnelMonitor<L>
where
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
fn wait(self) -> Result<()> {
#[cfg(not(target_os = "android"))]
let handle = tokio::runtime::Handle::current();
Expand Down
25 changes: 14 additions & 11 deletions talpid-core/src/tunnel_state_machine/connecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,13 @@ impl ConnectingState {
retry_attempt: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
let on_tunnel_event =
move |event| -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
let (tx, rx) = oneshot::channel();
let _ = event_tx.unbounded_send((event, tx));
Box::pin(async move {
let _ = rx.await;
})
};
let on_tunnel_event = move |event| {
let (tx, rx) = oneshot::channel();
let _ = event_tx.unbounded_send((event, tx));
async move {
let _ = rx.await;
}
};

let route_manager = route_manager.clone();
let log_dir = log_dir.clone();
Expand Down Expand Up @@ -290,10 +289,14 @@ impl ConnectingState {
}
}

fn wait_for_tunnel_monitor(
tunnel_monitor: TunnelMonitor,
fn wait_for_tunnel_monitor<L, F>(
tunnel_monitor: TunnelMonitor<L>,
retry_attempt: u32,
) -> Option<ErrorStateCause> {
) -> Option<ErrorStateCause>
where
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
match tunnel_monitor.wait() {
Ok(_) => None,
Err(error) => match error {
Expand Down
35 changes: 8 additions & 27 deletions talpid-openvpn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,16 @@ impl WintunContextImpl {
impl OpenVpnMonitor<OpenVpnCommand> {
/// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given
/// path.
pub async fn start<L>(
pub async fn start<L, F>(
on_event: L,
params: &openvpn::TunnelParameters,
log_path: Option<PathBuf>,
resource_dir: &Path,
route_manager: talpid_routing::RouteManagerHandle,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
let user_pass_file =
Self::create_credentials_file(&params.config.username, &params.config.password)
Expand Down Expand Up @@ -808,14 +806,7 @@ mod event_server {
}

/// Implements a gRPC service used to process events sent to by OpenVPN.
pub struct OpenvpnEventProxyImpl<
L: (Fn(
talpid_tunnel::TunnelEvent,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
> {
pub struct OpenvpnEventProxyImpl<L> {
pub on_event: L,
pub user_pass_file_path: super::PathBuf,
pub proxy_auth_file_path: Option<super::PathBuf>,
Expand All @@ -828,13 +819,8 @@ mod event_server {
}

impl<
L: (Fn(
talpid_tunnel::TunnelEvent,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()>,
> OpenvpnEventProxyImpl<L>
{
async fn up_inner(
Expand Down Expand Up @@ -971,13 +957,8 @@ mod event_server {

#[tonic::async_trait]
impl<
L: (Fn(
talpid_tunnel::TunnelEvent,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>)
+ Send
+ Sync
+ 'static,
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + 'static + Send,
> OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
{
async fn auth_failed(
Expand Down
8 changes: 5 additions & 3 deletions talpid-tunnel/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::Path,
sync::{Arc, Mutex},
Expand All @@ -9,7 +10,7 @@ use std::{
pub mod network_interface;

pub mod tun_provider;
use futures::{channel::oneshot, future::BoxFuture};
use futures::channel::oneshot;
use talpid_routing::RouteManagerHandle;
use talpid_types::net::AllowedTunnelTraffic;
use tun_provider::TunProvider;
Expand All @@ -28,9 +29,10 @@ pub const MIN_IPV4_MTU: u16 = 576;
pub const MIN_IPV6_MTU: u16 = 1280;

/// Arguments for creating a tunnel.
pub struct TunnelArgs<'a, L>
pub struct TunnelArgs<'a, L, F>
where
L: (Fn(TunnelEvent) -> BoxFuture<'static, ()>) + Send + Clone + Sync + 'static,
L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: Future<Output = ()>,
{
/// Tokio runtime handle.
pub runtime: tokio::runtime::Handle,
Expand Down
Loading

0 comments on commit 96bceb6

Please sign in to comment.