diff --git a/talpid-wireguard/src/wireguard_nt.rs b/talpid-wireguard/src/wireguard_nt.rs index 1b4405eba290..c3fb12de1eba 100644 --- a/talpid-wireguard/src/wireguard_nt.rs +++ b/talpid-wireguard/src/wireguard_nt.rs @@ -7,7 +7,7 @@ use super::{ use bitflags::bitflags; use futures::SinkExt; use ipnetwork::IpNetwork; -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; use std::{ ffi::CStr, fmt, @@ -38,7 +38,7 @@ use windows_sys::{ }, }; -static WG_NT_DLL: Lazy>>> = Lazy::new(|| Mutex::new(None)); +static WG_NT_DLL: OnceCell = OnceCell::new(); static ADAPTER_TYPE: Lazy = Lazy::new(|| U16CString::from_str("Mullvad").unwrap()); static ADAPTER_ALIAS: Lazy = Lazy::new(|| U16CString::from_str("Mullvad").unwrap()); @@ -163,7 +163,7 @@ pub enum Error { } pub struct WgNtTunnel { - device: Arc>>, + device: Option>, interface_name: String, setup_handle: tokio::task::JoinHandle<()>, _logger_handle: LoggerHandle, @@ -430,7 +430,7 @@ impl WgNtTunnel { mut done_tx: futures::channel::mpsc::Sender>, ) -> Result { let dll = load_wg_nt_dll(resource_dir)?; - let logger_handle = LoggerHandle::new(dll.clone(), log_path)?; + let logger_handle = LoggerHandle::new(dll, log_path)?; let device = WgNtAdapter::create(dll, &ADAPTER_ALIAS, &ADAPTER_TYPE, Some(ADAPTER_GUID)) .map_err(Error::CreateTunnelDevice)?; @@ -443,10 +443,11 @@ impl WgNtTunnel { ); } device.set_config(config)?; - let device = Arc::new(Mutex::new(Some(device))); + let device2 = Arc::new(device); + let device = Some(device2.clone()); let setup_future = setup_ip_listener( - device.clone(), + device2, u32::from(config.mtu), config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()), ); @@ -466,16 +467,12 @@ impl WgNtTunnel { fn stop_tunnel(&mut self) { self.setup_handle.abort(); - let _ = self.device.lock().unwrap().take(); + let _ = self.device.take(); } } -async fn setup_ip_listener( - device: Arc>>, - mtu: u32, - has_ipv6: bool, -) -> Result<()> { - let luid = { device.lock().unwrap().as_ref().unwrap().luid() }; +async fn setup_ip_listener(device: Arc, mtu: u32, has_ipv6: bool) -> Result<()> { + let luid = device.luid(); let luid = NET_LUID_LH { Value: unsafe { luid.Value }, }; @@ -489,13 +486,9 @@ async fn setup_ip_listener( talpid_tunnel::network_interface::initialize_interfaces(luid, Some(mtu)) .map_err(Error::SetTunnelMtu)?; - if let Some(device) = &*device.lock().unwrap() { - device - .set_state(WgAdapterState::Up) - .map_err(Error::EnableTunnel) - } else { - Ok(()) - } + device + .set_state(WgAdapterState::Up) + .map_err(Error::EnableTunnel) } impl Drop for WgNtTunnel { @@ -507,12 +500,12 @@ impl Drop for WgNtTunnel { static LOG_CONTEXT: Lazy>> = Lazy::new(|| Mutex::new(None)); struct LoggerHandle { - dll: Arc, + dll: &'static WgNtDll, context: u32, } impl LoggerHandle { - fn new(dll: Arc, log_path: Option<&Path>) -> Result { + fn new(dll: &'static WgNtDll, log_path: Option<&Path>) -> Result { let context = logging::initialize_logging(log_path).map_err(Error::InitLogging)?; { *(LOG_CONTEXT.lock().unwrap()) = Some(context); @@ -547,7 +540,7 @@ impl Drop for LoggerHandle { } struct WgNtAdapter { - dll_handle: Arc, + dll_handle: &'static WgNtDll, handle: RawHandle, } @@ -564,7 +557,7 @@ unsafe impl Sync for WgNtAdapter {} impl WgNtAdapter { fn create( - dll_handle: Arc, + dll_handle: &'static WgNtDll, name: &U16CStr, tunnel_type: &U16CStr, requested_guid: Option, @@ -806,16 +799,8 @@ impl Drop for WgNtDll { } } -fn load_wg_nt_dll(resource_dir: &Path) -> Result> { - let mut dll = (*WG_NT_DLL).lock().expect("WireGuardNT mutex poisoned"); - match &*dll { - Some(dll) => Ok(dll.clone()), - None => { - let new_dll = Arc::new(WgNtDll::new(resource_dir).map_err(Error::LoadDll)?); - *dll = Some(new_dll.clone()); - Ok(new_dll) - } - } +fn load_wg_nt_dll(resource_dir: &Path) -> Result<&'static WgNtDll> { + WG_NT_DLL.get_or_try_init(|| WgNtDll::new(resource_dir).map_err(Error::LoadDll)) } fn serialize_config(config: &Config) -> Result>> { @@ -941,7 +926,7 @@ impl Tunnel for WgNtTunnel { } fn get_tunnel_stats(&self) -> std::result::Result { - if let Some(ref device) = &*self.device.lock().unwrap() { + if let Some(ref device) = self.device { let mut map = StatsMap::new(); let (_interface, peers) = device.get_config().map_err(|error| { log::error!( @@ -976,9 +961,12 @@ impl Tunnel for WgNtTunnel { config: Config, ) -> Pin> + Send>> { let device = self.device.clone(); + Box::pin(async move { - let guard = device.lock().unwrap(); - let device = guard.as_ref().ok_or(super::TunnelError::SetConfigError)?; + let Some(device) = device else { + log::error!("Failed to set config: No tunnel device"); + return Err(super::TunnelError::SetConfigError); + }; device.set_config(&config).map_err(|error| { log::error!( "{}",