Skip to content

Commit

Permalink
Replace generic with new type
Browse files Browse the repository at this point in the history
  • Loading branch information
Serock3 committed Dec 2, 2024
1 parent 96bceb6 commit ca17aee
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 137 deletions.
75 changes: 34 additions & 41 deletions talpid-core/src/tunnel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use talpid_types::{
tunnel::ErrorStateCause,
};

#[cfg(not(target_os = "android"))]
use talpid_tunnel::EventHook;

const OPENVPN_LOG_FILENAME: &str = "openvpn.log";
const WIREGUARD_LOG_FILENAME: &str = "wireguard.log";

Expand Down Expand Up @@ -115,23 +118,19 @@ impl Error {
}

/// Abstraction for monitoring a generic VPN tunnel.
pub struct TunnelMonitor<F> {
monitor: InternalTunnelMonitor<F>,
pub struct TunnelMonitor {
monitor: InternalTunnelMonitor,
}

// TODO(emilsp) move most of the openvpn tunnel details to OpenVpnTunnelMonitor
impl<L, F> TunnelMonitor<L>
where
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
impl TunnelMonitor {
/// Creates a new `TunnelMonitor` that connects to the given remote and notifies `on_event`
/// on tunnel state changes.
#[cfg_attr(any(target_os = "android", windows), allow(unused_variables))]
pub fn start(
tunnel_parameters: &TunnelParameters,
log_dir: &Option<path::PathBuf>,
args: TunnelArgs<'_, L, F>,
args: TunnelArgs<'_>,
) -> Result<Self> {
Self::ensure_ipv6_can_be_used_if_enabled(tunnel_parameters)?;
let log_file = Self::prepare_tunnel_log_file(tunnel_parameters, log_dir)?;
Expand All @@ -142,7 +141,7 @@ where
config,
log_file,
args.resource_dir,
args.on_event,
args.event_hook,
args.tunnel_close_rx,
args.route_manager,
)),
Expand All @@ -155,13 +154,33 @@ where
}
}

/// Returns a path to an executable that communicates with relay servers.
/// Returns `None` if the executable is unknown.
#[cfg(windows)]
pub fn get_relay_client(
resource_dir: &path::Path,
params: &TunnelParameters,
) -> Option<path::PathBuf> {
use talpid_types::net::proxy::CustomProxy;

let resource_dir = resource_dir.to_path_buf();
match params {
TunnelParameters::OpenVpn(params) => match &params.proxy {
Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
Some(CustomProxy::Socks5Local(_)) => None,
Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
},
_ => Some(std::env::current_exe().unwrap()),
}
}

fn start_wireguard_tunnel(
#[cfg(not(any(target_os = "linux", target_os = "windows")))]
params: &wireguard_types::TunnelParameters,
#[cfg(any(target_os = "linux", target_os = "windows"))]
params: &wireguard_types::TunnelParameters,
log: Option<path::PathBuf>,
args: TunnelArgs<'_, L, F>,
args: TunnelArgs<'_>,
) -> Result<Self> {
let monitor = talpid_wireguard::WireguardMonitor::start(params, log.as_deref(), args)?;
Ok(TunnelMonitor {
Expand All @@ -174,12 +193,12 @@ where
config: &openvpn_types::TunnelParameters,
log: Option<path::PathBuf>,
resource_dir: &path::Path,
on_event: L,
event_hook: EventHook,
tunnel_close_rx: oneshot::Receiver<()>,
route_manager: RouteManagerHandle,
) -> Result<Self> {
let monitor = talpid_openvpn::OpenVpnMonitor::start(
on_event,
event_hook,
config,
log,
resource_dir,
Expand Down Expand Up @@ -255,39 +274,13 @@ where
}
}

impl TunnelMonitor<()> {
/// Returns a path to an executable that communicates with relay servers.
/// Returns `None` if the executable is unknown.
#[cfg(windows)]
pub fn get_relay_client(
resource_dir: &path::Path,
params: &TunnelParameters,
) -> Option<path::PathBuf> {
use talpid_types::net::proxy::CustomProxy;

let resource_dir = resource_dir.to_path_buf();
match params {
TunnelParameters::OpenVpn(params) => match &params.proxy {
Some(CustomProxy::Shadowsocks(_)) => Some(std::env::current_exe().unwrap()),
Some(CustomProxy::Socks5Local(_)) => None,
Some(CustomProxy::Socks5Remote(_)) | None => Some(resource_dir.join("openvpn.exe")),
},
_ => Some(std::env::current_exe().unwrap()),
}
}
}

enum InternalTunnelMonitor<F> {
enum InternalTunnelMonitor {
#[cfg(not(target_os = "android"))]
OpenVpn(talpid_openvpn::OpenVpnMonitor),
Wireguard(talpid_wireguard::WireguardMonitor<F>),
Wireguard(talpid_wireguard::WireguardMonitor),
}

impl<L, F> InternalTunnelMonitor<L>
where
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
impl InternalTunnelMonitor {
fn wait(self) -> Result<()> {
#[cfg(not(target_os = "android"))]
let handle = tokio::runtime::Handle::current();
Expand Down
24 changes: 8 additions & 16 deletions talpid-core/src/tunnel_state_machine/connecting_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use std::{
time::{Duration, Instant},
};
use talpid_routing::RouteManagerHandle;
use talpid_tunnel::{tun_provider::TunProvider, TunnelArgs, TunnelEvent, TunnelMetadata};
use talpid_tunnel::{
tun_provider::TunProvider, EventHook, TunnelArgs, TunnelEvent, TunnelMetadata,
};
use talpid_types::{
net::{AllowedClients, AllowedEndpoint, AllowedTunnelTraffic, TunnelParameters},
tunnel::{ErrorStateCause, FirewallPolicyError},
Expand Down Expand Up @@ -214,13 +216,7 @@ impl ConnectingState {
retry_attempt: u32,
) -> Self {
let (event_tx, event_rx) = mpsc::unbounded();
let on_tunnel_event = move |event| {
let (tx, rx) = oneshot::channel();
let _ = event_tx.unbounded_send((event, tx));
async move {
let _ = rx.await;
}
};
let event_hook = EventHook::new(event_tx);

let route_manager = route_manager.clone();
let log_dir = log_dir.clone();
Expand All @@ -237,7 +233,7 @@ impl ConnectingState {
let args = TunnelArgs {
runtime,
resource_dir: &resource_dir,
on_event: on_tunnel_event,
event_hook,
tunnel_close_rx,
tun_provider,
retry_attempt,
Expand Down Expand Up @@ -289,14 +285,10 @@ impl ConnectingState {
}
}

fn wait_for_tunnel_monitor<L, F>(
tunnel_monitor: TunnelMonitor<L>,
fn wait_for_tunnel_monitor(
tunnel_monitor: TunnelMonitor,
retry_attempt: u32,
) -> Option<ErrorStateCause>
where
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
) -> Option<ErrorStateCause> {
match tunnel_monitor.wait() {
Ok(_) => None,
Err(error) => match error {
Expand Down
64 changes: 31 additions & 33 deletions talpid-openvpn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::{
};
#[cfg(target_os = "linux")]
use talpid_routing::RequiredRoute;
use talpid_tunnel::TunnelEvent;
use talpid_tunnel::EventHook;
use talpid_types::{
net::{openvpn, proxy::CustomProxy},
ErrorExt,
Expand Down Expand Up @@ -245,17 +245,13 @@ impl WintunContextImpl {
impl OpenVpnMonitor<OpenVpnCommand> {
/// Creates a new `OpenVpnMonitor` with the given listener and using the plugin at the given
/// path.
pub async fn start<L, F>(
on_event: L,
pub async fn start(
event_hook: EventHook,
params: &openvpn::TunnelParameters,
log_path: Option<PathBuf>,
resource_dir: &Path,
route_manager: talpid_routing::RouteManagerHandle,
) -> Result<Self>
where
L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + Send + 'static,
{
) -> Result<Self> {
let user_pass_file =
Self::create_credentials_file(&params.config.username, &params.config.password)
.map_err(Error::CredentialsWriteError)?;
Expand Down Expand Up @@ -306,7 +302,7 @@ impl OpenVpnMonitor<OpenVpnCommand> {
cmd,
openvpn_init_args,
event_server::OpenvpnEventProxyImpl {
on_event,
event_hook,
user_pass_file_path: user_pass_file_path.clone(),
proxy_auth_file_path: proxy_auth_file_path.clone(),
abort_server_tx: event_server_abort_tx,
Expand Down Expand Up @@ -775,7 +771,7 @@ mod event_server {
pin::Pin,
task::{Context, Poll},
};
use talpid_tunnel::TunnelMetadata;
use talpid_tunnel::{EventHook, TunnelMetadata};
#[cfg(any(target_os = "macos", target_os = "windows"))]
use talpid_types::net::proxy::CustomProxy;
use talpid_types::ErrorExt;
Expand Down Expand Up @@ -806,8 +802,8 @@ mod event_server {
}

/// Implements a gRPC service used to process events sent to by OpenVPN.
pub struct OpenvpnEventProxyImpl<L> {
pub on_event: L,
pub struct OpenvpnEventProxyImpl {
pub event_hook: EventHook,
pub user_pass_file_path: super::PathBuf,
pub proxy_auth_file_path: Option<super::PathBuf>,
pub abort_server_tx: triggered::Trigger,
Expand All @@ -818,21 +814,19 @@ mod event_server {
pub ipv6_enabled: bool,
}

impl<
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()>,
> OpenvpnEventProxyImpl<L>
{
impl OpenvpnEventProxyImpl {
async fn up_inner(
&self,
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
(self.on_event)(talpid_tunnel::TunnelEvent::InterfaceUp(
Self::get_tunnel_metadata(&env)?,
talpid_types::net::AllowedTunnelTraffic::All,
))
.await;
self.event_hook
.clone()
.on_event(talpid_tunnel::TunnelEvent::InterfaceUp(
Self::get_tunnel_metadata(&env)?,
talpid_types::net::AllowedTunnelTraffic::All,
))
.await;
Ok(Response::new(()))
}

Expand Down Expand Up @@ -902,7 +896,10 @@ mod event_server {
return Err(tonic::Status::failed_precondition("Failed to add routes"));
}

(self.on_event)(talpid_tunnel::TunnelEvent::Up(metadata)).await;
self.event_hook
.clone()
.on_event(talpid_tunnel::TunnelEvent::Up(metadata))
.await;

Ok(Response::new(()))
}
Expand Down Expand Up @@ -956,20 +953,18 @@ mod event_server {
}

#[tonic::async_trait]
impl<
L: (Fn(talpid_tunnel::TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: std::future::Future<Output = ()> + 'static + Send,
> OpenvpnEventProxy for OpenvpnEventProxyImpl<L>
{
impl OpenvpnEventProxy for OpenvpnEventProxyImpl {
async fn auth_failed(
&self,
request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
let env = request.into_inner().env;
(self.on_event)(talpid_tunnel::TunnelEvent::AuthFailed(
env.get("auth_failed_reason").cloned(),
))
.await;
self.event_hook
.clone()
.on_event(talpid_tunnel::TunnelEvent::AuthFailed(
env.get("auth_failed_reason").cloned(),
))
.await;
Ok(Response::new(()))
}

Expand All @@ -995,7 +990,10 @@ mod event_server {
&self,
_request: Request<EventDetails>,
) -> std::result::Result<Response<()>, tonic::Status> {
(self.on_event)(talpid_tunnel::TunnelEvent::Down).await;
self.event_hook
.clone()
.on_event(talpid_tunnel::TunnelEvent::Down)
.await;
Ok(Response::new(()))
}
}
Expand Down
35 changes: 27 additions & 8 deletions talpid-tunnel/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
future::Future,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::Path,
sync::{Arc, Mutex},
Expand All @@ -10,7 +9,13 @@ use std::{
pub mod network_interface;

pub mod tun_provider;
use futures::channel::oneshot;
use futures::{
channel::{
mpsc::UnboundedSender,
oneshot::{self, Sender},
},
SinkExt,
};
use talpid_routing::RouteManagerHandle;
use talpid_types::net::AllowedTunnelTraffic;
use tun_provider::TunProvider;
Expand All @@ -29,17 +34,13 @@ pub const MIN_IPV4_MTU: u16 = 576;
pub const MIN_IPV6_MTU: u16 = 1280;

/// Arguments for creating a tunnel.
pub struct TunnelArgs<'a, L, F>
where
L: (Fn(TunnelEvent) -> F) + Send + Clone + Sync + 'static,
F: Future<Output = ()>,
{
pub struct TunnelArgs<'a> {
/// Tokio runtime handle.
pub runtime: tokio::runtime::Handle,
/// Resource directory path.
pub resource_dir: &'a Path,
/// Callback function called when an event happens.
pub on_event: L,
pub event_hook: EventHook,
/// Receiver oneshot channel for closing the tunnel.
pub tunnel_close_rx: oneshot::Receiver<()>,
/// Mutex to tunnel provider.
Expand All @@ -50,6 +51,24 @@ where
pub route_manager: RouteManagerHandle,
}

#[derive(Clone)]
pub struct EventHook {
event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>,
}

impl EventHook {
pub fn new(event_tx: UnboundedSender<(TunnelEvent, Sender<()>)>) -> Self {
Self { event_tx }
}

pub async fn on_event(&mut self, event: TunnelEvent) {
let (tx, rx) = oneshot::channel::<()>();
if let Ok(()) = self.event_tx.send((event, tx)).await {
let _ = rx.await;
}
}
}

/// Information about a VPN tunnel.
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct TunnelMetadata {
Expand Down
Loading

0 comments on commit ca17aee

Please sign in to comment.