Skip to content

Commit

Permalink
Remove pointless locks from wireguard_nt
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Oct 20, 2023
1 parent b2a9781 commit f2c6041
Showing 1 changed file with 25 additions and 37 deletions.
62 changes: 25 additions & 37 deletions talpid-wireguard/src/wireguard_nt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::{
use bitflags::bitflags;
use futures::SinkExt;
use ipnetwork::IpNetwork;
use once_cell::sync::Lazy;
use once_cell::sync::{Lazy, OnceCell};
use std::{
ffi::CStr,
fmt,
Expand Down Expand Up @@ -38,7 +38,7 @@ use windows_sys::{
},
};

static WG_NT_DLL: Lazy<Mutex<Option<Arc<WgNtDll>>>> = Lazy::new(|| Mutex::new(None));
static WG_NT_DLL: OnceCell<WgNtDll> = OnceCell::new();
static ADAPTER_TYPE: Lazy<U16CString> = Lazy::new(|| U16CString::from_str("Mullvad").unwrap());
static ADAPTER_ALIAS: Lazy<U16CString> = Lazy::new(|| U16CString::from_str("Mullvad").unwrap());

Expand Down Expand Up @@ -163,7 +163,7 @@ pub enum Error {
}

pub struct WgNtTunnel {
device: Arc<Mutex<Option<WgNtAdapter>>>,
device: Option<Arc<WgNtAdapter>>,
interface_name: String,
setup_handle: tokio::task::JoinHandle<()>,
_logger_handle: LoggerHandle,
Expand Down Expand Up @@ -430,7 +430,7 @@ impl WgNtTunnel {
mut done_tx: futures::channel::mpsc::Sender<std::result::Result<(), BoxedError>>,
) -> Result<Self> {
let dll = load_wg_nt_dll(resource_dir)?;
let logger_handle = LoggerHandle::new(dll.clone(), log_path)?;
let logger_handle = LoggerHandle::new(dll, log_path)?;
let device = WgNtAdapter::create(dll, &ADAPTER_ALIAS, &ADAPTER_TYPE, Some(ADAPTER_GUID))
.map_err(Error::CreateTunnelDevice)?;

Expand All @@ -443,10 +443,11 @@ impl WgNtTunnel {
);
}
device.set_config(config)?;
let device = Arc::new(Mutex::new(Some(device)));
let device2 = Arc::new(device);
let device = Some(device2.clone());

let setup_future = setup_ip_listener(
device.clone(),
device2,
u32::from(config.mtu),
config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()),
);
Expand All @@ -466,16 +467,12 @@ impl WgNtTunnel {

fn stop_tunnel(&mut self) {
self.setup_handle.abort();
let _ = self.device.lock().unwrap().take();
let _ = self.device.take();
}
}

async fn setup_ip_listener(
device: Arc<Mutex<Option<WgNtAdapter>>>,
mtu: u32,
has_ipv6: bool,
) -> Result<()> {
let luid = { device.lock().unwrap().as_ref().unwrap().luid() };
async fn setup_ip_listener(device: Arc<WgNtAdapter>, mtu: u32, has_ipv6: bool) -> Result<()> {
let luid = device.luid();
let luid = NET_LUID_LH {
Value: unsafe { luid.Value },
};
Expand All @@ -489,13 +486,9 @@ async fn setup_ip_listener(
talpid_tunnel::network_interface::initialize_interfaces(luid, Some(mtu))
.map_err(Error::SetTunnelMtu)?;

if let Some(device) = &*device.lock().unwrap() {
device
.set_state(WgAdapterState::Up)
.map_err(Error::EnableTunnel)
} else {
Ok(())
}
device
.set_state(WgAdapterState::Up)
.map_err(Error::EnableTunnel)
}

impl Drop for WgNtTunnel {
Expand All @@ -507,12 +500,12 @@ impl Drop for WgNtTunnel {
static LOG_CONTEXT: Lazy<Mutex<Option<u32>>> = Lazy::new(|| Mutex::new(None));

struct LoggerHandle {
dll: Arc<WgNtDll>,
dll: &'static WgNtDll,
context: u32,
}

impl LoggerHandle {
fn new(dll: Arc<WgNtDll>, log_path: Option<&Path>) -> Result<Self> {
fn new(dll: &'static WgNtDll, log_path: Option<&Path>) -> Result<Self> {
let context = logging::initialize_logging(log_path).map_err(Error::InitLogging)?;
{
*(LOG_CONTEXT.lock().unwrap()) = Some(context);
Expand Down Expand Up @@ -547,7 +540,7 @@ impl Drop for LoggerHandle {
}

struct WgNtAdapter {
dll_handle: Arc<WgNtDll>,
dll_handle: &'static WgNtDll,
handle: RawHandle,
}

Expand All @@ -564,7 +557,7 @@ unsafe impl Sync for WgNtAdapter {}

impl WgNtAdapter {
fn create(
dll_handle: Arc<WgNtDll>,
dll_handle: &'static WgNtDll,
name: &U16CStr,
tunnel_type: &U16CStr,
requested_guid: Option<GUID>,
Expand Down Expand Up @@ -806,16 +799,8 @@ impl Drop for WgNtDll {
}
}

fn load_wg_nt_dll(resource_dir: &Path) -> Result<Arc<WgNtDll>> {
let mut dll = (*WG_NT_DLL).lock().expect("WireGuardNT mutex poisoned");
match &*dll {
Some(dll) => Ok(dll.clone()),
None => {
let new_dll = Arc::new(WgNtDll::new(resource_dir).map_err(Error::LoadDll)?);
*dll = Some(new_dll.clone());
Ok(new_dll)
}
}
fn load_wg_nt_dll(resource_dir: &Path) -> Result<&'static WgNtDll> {
WG_NT_DLL.get_or_try_init(|| WgNtDll::new(resource_dir).map_err(Error::LoadDll))
}

fn serialize_config(config: &Config) -> Result<Vec<MaybeUninit<u8>>> {
Expand Down Expand Up @@ -941,7 +926,7 @@ impl Tunnel for WgNtTunnel {
}

fn get_tunnel_stats(&self) -> std::result::Result<StatsMap, super::TunnelError> {
if let Some(ref device) = &*self.device.lock().unwrap() {
if let Some(ref device) = self.device {
let mut map = StatsMap::new();
let (_interface, peers) = device.get_config().map_err(|error| {
log::error!(
Expand Down Expand Up @@ -976,9 +961,12 @@ impl Tunnel for WgNtTunnel {
config: Config,
) -> Pin<Box<dyn Future<Output = std::result::Result<(), super::TunnelError>> + Send>> {
let device = self.device.clone();

Box::pin(async move {
let guard = device.lock().unwrap();
let device = guard.as_ref().ok_or(super::TunnelError::SetConfigError)?;
let Some(device) = device else {
log::error!("Failed to set config: No tunnel device");
return Err(super::TunnelError::SetConfigError);
};
device.set_config(&config).map_err(|error| {
log::error!(
"{}",
Expand Down

0 comments on commit f2c6041

Please sign in to comment.