diff --git a/Cargo.lock b/Cargo.lock index c929ce575478..50f68e0cd984 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1863,6 +1863,7 @@ dependencies = [ "talpid-platform-metadata", "talpid-time", "talpid-types", + "talpid-windows", "tokio", "tokio-stream", "winapi", @@ -3456,7 +3457,7 @@ dependencies = [ "talpid-tunnel", "talpid-tunnel-config-client", "talpid-types", - "talpid-windows-net", + "talpid-windows", "talpid-wireguard", "tokio", "tonic-build", @@ -3498,7 +3499,7 @@ dependencies = [ "talpid-routing", "talpid-tunnel", "talpid-types", - "talpid-windows-net", + "talpid-windows", "tokio", "tonic", "tonic-build", @@ -3555,7 +3556,7 @@ dependencies = [ "rtnetlink", "system-configuration", "talpid-types", - "talpid-windows-net", + "talpid-windows", "tokio", "widestring", "windows-sys 0.48.0", @@ -3583,7 +3584,7 @@ dependencies = [ "nix 0.23.2", "talpid-routing", "talpid-types", - "talpid-windows-net", + "talpid-windows", "tokio", "tun", "windows-sys 0.48.0", @@ -3622,7 +3623,7 @@ dependencies = [ ] [[package]] -name = "talpid-windows-net" +name = "talpid-windows" version = "0.0.0" dependencies = [ "err-derive", @@ -3662,7 +3663,7 @@ dependencies = [ "talpid-tunnel", "talpid-tunnel-config-client", "talpid-types", - "talpid-windows-net", + "talpid-windows", "tokio", "tokio-stream", "tunnel-obfuscation", diff --git a/Cargo.toml b/Cargo.toml index f68e491d9fb5..a94e5514b7bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ members = [ "talpid-time", "talpid-tunnel", "talpid-tunnel-config-client", - "talpid-windows-net", + "talpid-windows", "talpid-wireguard", "mullvad-management-interface", "tunnel-obfuscation", diff --git a/mullvad-daemon/Cargo.toml b/mullvad-daemon/Cargo.toml index cc89500abdcf..352db7d9f99a 100644 --- a/mullvad-daemon/Cargo.toml +++ b/mullvad-daemon/Cargo.toml @@ -60,6 +60,7 @@ ctrlc = "3.0" windows-service = "0.6.0" winapi = { version = "0.3", features = ["winnt", "excpt"] } dirs = "5.0.1" +talpid-windows = { path = "../talpid-windows" } [target.'cfg(windows)'.dependencies.windows-sys] workspace = true diff --git a/mullvad-daemon/src/exception_logging/win.rs b/mullvad-daemon/src/exception_logging/win.rs index 906a8533fcdc..85f16019b05b 100644 --- a/mullvad-daemon/src/exception_logging/win.rs +++ b/mullvad-daemon/src/exception_logging/win.rs @@ -1,27 +1,25 @@ use mullvad_paths::log_dir; use std::{ borrow::Cow, - ffi::{c_char, c_void, CStr}, + ffi::c_void, fmt::Write, - fs, io, mem, + fs, io, os::windows::io::AsRawHandle, path::{Path, PathBuf}, ptr, }; use talpid_types::ErrorExt; +use talpid_windows::process::{ModuleEntry, ProcessSnapshot}; use winapi::{ um::winnt::{CONTEXT_CONTROL, CONTEXT_INTEGER, CONTEXT_SEGMENTS}, vc::excpt::EXCEPTION_EXECUTE_HANDLER, }; use windows_sys::Win32::{ - Foundation::{CloseHandle, BOOL, ERROR_NO_MORE_FILES, HANDLE, INVALID_HANDLE_VALUE}, + Foundation::{BOOL, HANDLE}, System::{ Diagnostics::{ Debug::{SetUnhandledExceptionFilter, CONTEXT, EXCEPTION_POINTERS, EXCEPTION_RECORD}, - ToolHelp::{ - CreateToolhelp32Snapshot, Module32First, Module32Next, MODULEENTRY32, - TH32CS_SNAPMODULE, - }, + ToolHelp::TH32CS_SNAPMODULE, }, Threading::{GetCurrentProcess, GetCurrentProcessId, GetCurrentThreadId}, }, @@ -291,7 +289,7 @@ fn get_context_info(context: &CONTEXT) -> String { } /// Return module info for the current process and given memory address. -fn find_address_module(address: *mut c_void) -> io::Result> { +fn find_address_module(address: *mut c_void) -> io::Result> { let snap = ProcessSnapshot::new(TH32CS_SNAPMODULE, 0)?; for module in snap.modules() { @@ -306,85 +304,3 @@ fn find_address_module(address: *mut c_void) -> io::Result> { Ok(None) } - -struct ModuleInfo { - name: String, - base_address: *const u8, - size: usize, -} - -struct ProcessSnapshot { - handle: HANDLE, -} - -impl ProcessSnapshot { - fn new(flags: u32, process_id: u32) -> io::Result { - let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; - - if snap == INVALID_HANDLE_VALUE { - Err(io::Error::last_os_error()) - } else { - Ok(ProcessSnapshot { handle: snap }) - } - } - - fn handle(&self) -> HANDLE { - self.handle - } - - fn modules(&self) -> ProcessSnapshotModules<'_> { - let mut entry: MODULEENTRY32 = unsafe { mem::zeroed() }; - entry.dwSize = mem::size_of::() as u32; - - ProcessSnapshotModules { - snapshot: self, - iter_started: false, - temp_entry: entry, - } - } -} - -impl Drop for ProcessSnapshot { - fn drop(&mut self) { - unsafe { - CloseHandle(self.handle); - } - } -} - -struct ProcessSnapshotModules<'a> { - snapshot: &'a ProcessSnapshot, - iter_started: bool, - temp_entry: MODULEENTRY32, -} - -impl Iterator for ProcessSnapshotModules<'_> { - type Item = io::Result; - - fn next(&mut self) -> Option> { - if self.iter_started { - if unsafe { Module32Next(self.snapshot.handle(), &mut self.temp_entry) } == 0 { - let last_error = io::Error::last_os_error(); - - return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { - None - } else { - Some(Err(last_error)) - }; - } - } else { - if unsafe { Module32First(self.snapshot.handle(), &mut self.temp_entry) } == 0 { - return Some(Err(io::Error::last_os_error())); - } - self.iter_started = true; - } - - let cstr_ref = &self.temp_entry.szModule[0]; - let cstr = unsafe { CStr::from_ptr(cstr_ref as *const u8 as *const c_char) }; - Some(Ok(ModuleInfo { - name: cstr.to_string_lossy().into_owned(), - base_address: self.temp_entry.modBaseAddr, - size: self.temp_entry.modBaseSize as usize, - })) - } -} diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index e39b07f24462..eecc07e388cb 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -60,7 +60,7 @@ widestring = "1.0" winreg = { version = "0.51", features = ["transactions"] } memoffset = "0.6" windows-service = "0.6.0" -talpid-windows-net = { path = "../talpid-windows-net" } +talpid-windows = { path = "../talpid-windows" } [target.'cfg(windows)'.dependencies.windows-sys] workspace = true diff --git a/talpid-core/src/dns/windows/iphlpapi.rs b/talpid-core/src/dns/windows/iphlpapi.rs index 2d62e1860e19..f9aba7006564 100644 --- a/talpid-core/src/dns/windows/iphlpapi.rs +++ b/talpid-core/src/dns/windows/iphlpapi.rs @@ -13,7 +13,7 @@ use std::{ ptr, }; use talpid_types::win32_err; -use talpid_windows_net::{guid_from_luid, luid_from_alias}; +use talpid_windows::net::{guid_from_luid, luid_from_alias}; use windows_sys::{ core::GUID, s, w, diff --git a/talpid-core/src/dns/windows/netsh.rs b/talpid-core/src/dns/windows/netsh.rs index 17cf551b1feb..7de3fe790013 100644 --- a/talpid-core/src/dns/windows/netsh.rs +++ b/talpid-core/src/dns/windows/netsh.rs @@ -9,7 +9,7 @@ use std::{ time::Duration, }; use talpid_types::{net::IpVersion, ErrorExt}; -use talpid_windows_net::{index_from_luid, luid_from_alias}; +use talpid_windows::net::{index_from_luid, luid_from_alias}; use windows_sys::Win32::{ Foundation::{MAX_PATH, WAIT_OBJECT_0, WAIT_TIMEOUT}, System::{ diff --git a/talpid-core/src/dns/windows/tcpip.rs b/talpid-core/src/dns/windows/tcpip.rs index c1d577f5fe30..244417e119ca 100644 --- a/talpid-core/src/dns/windows/tcpip.rs +++ b/talpid-core/src/dns/windows/tcpip.rs @@ -1,7 +1,7 @@ use crate::dns::DnsMonitorT; use std::{io, net::IpAddr}; use talpid_types::ErrorExt; -use talpid_windows_net::{guid_from_luid, luid_from_alias}; +use talpid_windows::net::{guid_from_luid, luid_from_alias}; use windows_sys::{core::GUID, Win32::System::Com::StringFromGUID2}; use winreg::{ enums::{HKEY_LOCAL_MACHINE, KEY_SET_VALUE}, diff --git a/talpid-core/src/offline/windows.rs b/talpid-core/src/offline/windows.rs index 2756f9c6900b..6539e6e256b5 100644 --- a/talpid-core/src/offline/windows.rs +++ b/talpid-core/src/offline/windows.rs @@ -9,7 +9,7 @@ use std::{ time::Duration, }; use talpid_types::ErrorExt; -use talpid_windows_net::AddressFamily; +use talpid_windows::net::AddressFamily; #[derive(err_derive::Error, Debug)] pub enum Error { diff --git a/talpid-core/src/split_tunnel/windows/driver.rs b/talpid-core/src/split_tunnel/windows/driver.rs index c8688345189a..3f5a4e6a6d1b 100644 --- a/talpid-core/src/split_tunnel/windows/driver.rs +++ b/talpid-core/src/split_tunnel/windows/driver.rs @@ -1,6 +1,6 @@ use super::windows::{ - get_device_path, get_process_creation_time, get_process_device_path, open_process, Event, - Overlapped, ProcessAccess, ProcessSnapshot, + get_device_path, get_process_creation_time, get_process_device_path, open_process, + ProcessAccess, }; use bitflags::bitflags; use memoffset::offset_of; @@ -22,6 +22,7 @@ use std::{ time::Duration, }; use talpid_types::ErrorExt; +use talpid_windows::{io::Overlapped, process::ProcessSnapshot, sync::Event}; use windows_sys::Win32::{ Foundation::{ ERROR_ACCESS_DENIED, ERROR_FILE_NOT_FOUND, ERROR_INVALID_PARAMETER, ERROR_IO_PENDING, @@ -485,7 +486,7 @@ fn build_process_tree() -> io::Result> { let mut process_info = HashMap::new(); let snap = ProcessSnapshot::new(TH32CS_SNAPPROCESS, 0)?; - for entry in snap.entries() { + for entry in snap.processes() { let entry = entry?; let process = match open_process(ProcessAccess::QueryLimitedInformation, false, entry.pid) { @@ -877,7 +878,7 @@ pub fn get_overlapped_result( let event = overlapped.get_event().unwrap(); // SAFETY: This is a valid event object. - unsafe { wait_for_single_object(event.as_handle(), None) }?; + unsafe { wait_for_single_object(event.as_raw(), None) }?; // SAFETY: The handle and overlapped object are valid. let mut returned_bytes = 0u32; diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 40d7c340ad3c..e168ad38aecf 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -21,7 +21,11 @@ use std::{ }; use talpid_routing::{get_best_default_route, CallbackHandle, EventType, RouteManagerHandle}; use talpid_types::{split_tunnel::ExcludedProcess, tunnel::ErrorStateCause, ErrorExt}; -use talpid_windows_net::{get_ip_address_for_interface, AddressFamily}; +use talpid_windows::{ + io::Overlapped, + net::{get_ip_address_for_interface, AddressFamily}, + sync::Event, +}; use windows_sys::Win32::Foundation::ERROR_OPERATION_ABORTED; const DRIVER_EVENT_BUFFER_SIZE: usize = 2048; @@ -69,7 +73,7 @@ pub enum Error { /// Failed to obtain an IP address given a network interface LUID #[error(display = "Failed to obtain IP address for interface LUID")] - LuidToIp(#[error(source)] talpid_windows_net::Error), + LuidToIp(#[error(source)] talpid_windows::net::Error), /// Failed to set up callback for monitoring default route changes #[error(display = "Failed to register default route change callback")] @@ -105,7 +109,7 @@ pub struct SplitTunnel { runtime: tokio::runtime::Handle, request_tx: RequestTx, event_thread: Option>, - quit_event: Arc, + quit_event: Arc, excluded_processes: Arc>>, _route_change_callback: Option, daemon_tx: Weak>, @@ -191,14 +195,13 @@ impl SplitTunnel { fn spawn_event_listener( handle: Arc, excluded_processes: Arc>>, - ) -> Result<(std::thread::JoinHandle<()>, Arc), Error> { - let mut event_overlapped = windows::Overlapped::new(Some( - windows::Event::new(true, false).map_err(Error::EventThreadError)?, + ) -> Result<(std::thread::JoinHandle<()>, Arc), Error> { + let mut event_overlapped = Overlapped::new(Some( + Event::new(true, false).map_err(Error::EventThreadError)?, )) .map_err(Error::EventThreadError)?; - let quit_event = - Arc::new(windows::Event::new(true, false).map_err(Error::EventThreadError)?); + let quit_event = Arc::new(Event::new(true, false).map_err(Error::EventThreadError)?); let quit_event_copy = quit_event.clone(); let event_thread = std::thread::spawn(move || { @@ -237,11 +240,11 @@ impl SplitTunnel { fn fetch_next_event( device: &Arc, - quit_event: &windows::Event, - overlapped: &mut windows::Overlapped, + quit_event: &Event, + overlapped: &mut Overlapped, data_buffer: &mut Vec, ) -> io::Result { - if unsafe { driver::wait_for_single_object(quit_event.as_handle(), Some(Duration::ZERO)) } + if unsafe { driver::wait_for_single_object(quit_event.as_raw(), Some(Duration::ZERO)) } .is_ok() { return Ok(EventResult::Quit); @@ -268,8 +271,8 @@ impl SplitTunnel { })?; let event_objects = [ - overlapped.get_event().unwrap().as_handle(), - quit_event.as_handle(), + overlapped.get_event().unwrap().as_raw(), + quit_event.as_raw(), ]; let signaled_object = @@ -283,7 +286,7 @@ impl SplitTunnel { }, )?; - if signaled_object == quit_event.as_handle() { + if signaled_object == quit_event.as_raw() { // Quit event was signaled return Ok(EventResult::Quit); } diff --git a/talpid-core/src/split_tunnel/windows/windows.rs b/talpid-core/src/split_tunnel/windows/windows.rs index 77fafdc199bd..c20a83708960 100644 --- a/talpid-core/src/split_tunnel/windows/windows.rs +++ b/talpid-core/src/split_tunnel/windows/windows.rs @@ -1,6 +1,3 @@ -// TODO: The snapshot code could be combined with the mostly-identical code in -// the windows_exception_logging module. - use std::{ ffi::{OsStr, OsString}, fs, io, iter, mem, @@ -12,102 +9,15 @@ use std::{ ptr, }; use windows_sys::Win32::{ - Foundation::{ - CloseHandle, BOOL, ERROR_INSUFFICIENT_BUFFER, ERROR_NO_MORE_FILES, FILETIME, HANDLE, - INVALID_HANDLE_VALUE, - }, + Foundation::{CloseHandle, ERROR_INSUFFICIENT_BUFFER, FILETIME, HANDLE}, Storage::FileSystem::{GetFinalPathNameByHandleW, QueryDosDeviceW}, System::{ - Diagnostics::ToolHelp::{ - CreateToolhelp32Snapshot, Process32FirstW, Process32NextW, PROCESSENTRY32W, - }, ProcessStatus::GetProcessImageFileNameW, - Threading::{ - CreateEventW, GetProcessTimes, OpenProcess, SetEvent, PROCESS_QUERY_LIMITED_INFORMATION, - }, + Threading::{GetProcessTimes, OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION}, WindowsProgramming::VOLUME_NAME_NT, - IO::OVERLAPPED, }, }; -pub struct ProcessSnapshot { - handle: HANDLE, -} - -impl ProcessSnapshot { - pub fn new(flags: u32, process_id: u32) -> io::Result { - let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; - - if snap == INVALID_HANDLE_VALUE { - Err(io::Error::last_os_error()) - } else { - Ok(ProcessSnapshot { handle: snap }) - } - } - - pub fn handle(&self) -> HANDLE { - self.handle - } - - pub fn entries(&self) -> ProcessSnapshotEntries<'_> { - let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() }; - entry.dwSize = mem::size_of::() as u32; - - ProcessSnapshotEntries { - snapshot: self, - iter_started: false, - temp_entry: entry, - } - } -} - -impl Drop for ProcessSnapshot { - fn drop(&mut self) { - unsafe { - CloseHandle(self.handle); - } - } -} - -pub struct ProcessEntry { - pub pid: u32, - pub parent_pid: u32, -} - -pub struct ProcessSnapshotEntries<'a> { - snapshot: &'a ProcessSnapshot, - iter_started: bool, - temp_entry: PROCESSENTRY32W, -} - -impl Iterator for ProcessSnapshotEntries<'_> { - type Item = io::Result; - - fn next(&mut self) -> Option> { - if self.iter_started { - if unsafe { Process32NextW(self.snapshot.handle(), &mut self.temp_entry) } == 0 { - let last_error = io::Error::last_os_error(); - - return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { - None - } else { - Some(Err(last_error)) - }; - } - } else { - if unsafe { Process32FirstW(self.snapshot.handle(), &mut self.temp_entry) } == 0 { - return Some(Err(io::Error::last_os_error())); - } - self.iter_started = true; - } - - Some(Ok(ProcessEntry { - pid: self.temp_entry.th32ProcessID, - parent_pid: self.temp_entry.th32ParentProcessID, - })) - } -} - /// Obtains a device path without resolving links or mount points. pub fn get_device_path>(path: T) -> Result { // Preferentially, use GetFinalPathNameByHandleW. If the file does not exist @@ -299,95 +209,3 @@ fn get_process_device_path_inner( Ok(OsStringExt::from_wide(&buffer)) } - -/// Abstraction over `OVERLAPPED`, which is used for async I/O. -pub struct Overlapped { - overlapped: OVERLAPPED, - event: Option, -} - -unsafe impl Send for Overlapped {} -unsafe impl Sync for Overlapped {} - -impl Overlapped { - /// Creates an `OVERLAPPED` object with `hEvent` set. - pub fn new(event: Option) -> io::Result { - let mut overlapped = Overlapped { - overlapped: unsafe { mem::zeroed() }, - event: None, - }; - overlapped.set_event(event); - Ok(overlapped) - } - - /// Borrows the underlying `OVERLAPPED` object. - pub fn as_mut_ptr(&mut self) -> *mut OVERLAPPED { - &mut self.overlapped - } - - /// Returns a reference to the associated event. - pub fn get_event(&self) -> Option<&Event> { - self.event.as_ref() - } - - /// Sets the event object for the underlying `OVERLAPPED` object (i.e., `hEvent`) - fn set_event(&mut self, event: Option) { - match event { - Some(event) => { - self.overlapped.hEvent = event.0; - self.event = Some(event); - } - None => { - self.overlapped.hEvent = 0; - self.event = None; - } - } - } -} - -/// Abstraction over a Windows event object. -pub struct Event(HANDLE); - -unsafe impl Send for Event {} -unsafe impl Sync for Event {} - -impl Event { - pub fn new(manual_reset: bool, initial_state: bool) -> io::Result { - let event = unsafe { - CreateEventW( - ptr::null_mut(), - bool_to_winbool(manual_reset), - bool_to_winbool(initial_state), - ptr::null(), - ) - }; - if event == 0 { - return Err(io::Error::last_os_error()); - } - Ok(Self(event)) - } - - pub fn set(&self) -> io::Result<()> { - if unsafe { SetEvent(self.0) } == 0 { - return Err(io::Error::last_os_error()); - } - Ok(()) - } - - pub fn as_handle(&self) -> HANDLE { - self.0 - } -} - -impl Drop for Event { - fn drop(&mut self) { - unsafe { CloseHandle(self.0) }; - } -} - -const fn bool_to_winbool(val: bool) -> BOOL { - match val { - true => 1, - false => 0, - } -} diff --git a/talpid-openvpn/Cargo.toml b/talpid-openvpn/Cargo.toml index 6bfdb0c57741..a348ea748634 100644 --- a/talpid-openvpn/Cargo.toml +++ b/talpid-openvpn/Cargo.toml @@ -32,7 +32,7 @@ prost = { workspace = true } [target.'cfg(windows)'.dependencies] widestring = "1.0" winreg = { version = "0.51", features = ["transactions"] } -talpid-windows-net = { path = "../talpid-windows-net" } +talpid-windows = { path = "../talpid-windows" } [target.'cfg(windows)'.dependencies.windows-sys] workspace = true diff --git a/talpid-openvpn/src/lib.rs b/talpid-openvpn/src/lib.rs index 48e3b20b20ea..9fd0c317efe7 100644 --- a/talpid-openvpn/src/lib.rs +++ b/talpid-openvpn/src/lib.rs @@ -203,7 +203,7 @@ impl WintunContext for WintunContextImpl { async fn wait_for_interfaces(&self) -> io::Result<()> { let luid = self.adapter.luid(); - talpid_windows_net::wait_for_interfaces(luid, true, self.wait_v6_interface).await + talpid_windows::net::wait_for_interfaces(luid, true, self.wait_v6_interface).await } fn prepare_interface(&self) { @@ -867,11 +867,12 @@ mod event_server { #[cfg(windows)] { let tunnel_device = metadata.interface.clone(); - let luid = talpid_windows_net::luid_from_alias(tunnel_device).map_err(|error| { - log::error!("{}", error.display_chain_with_msg("luid_from_alias failed")); - tonic::Status::unavailable("failed to obtain interface luid") - })?; - talpid_windows_net::wait_for_addresses(luid) + let luid = + talpid_windows::net::luid_from_alias(tunnel_device).map_err(|error| { + log::error!("{}", error.display_chain_with_msg("luid_from_alias failed")); + tonic::Status::unavailable("failed to obtain interface luid") + })?; + talpid_windows::net::wait_for_addresses(luid) .await .map_err(|error| { log::error!( diff --git a/talpid-openvpn/src/wintun.rs b/talpid-openvpn/src/wintun.rs index afeeeecbf922..32d9ddc283c0 100644 --- a/talpid-openvpn/src/wintun.rs +++ b/talpid-openvpn/src/wintun.rs @@ -1,12 +1,5 @@ -use once_cell::sync::Lazy; -use std::{ - ffi::CStr, - fmt, io, mem, - os::windows::io::RawHandle, - path::Path, - ptr, - sync::{Arc, Mutex}, -}; +use once_cell::sync::OnceCell; +use std::{ffi::CStr, fmt, io, mem, os::windows::io::RawHandle, path::Path, ptr}; use talpid_types::{win32_err, ErrorExt}; use widestring::{U16CStr, U16CString}; use windows_sys::{ @@ -29,7 +22,7 @@ use winreg::{ }; /// Shared `WintunDll` instance -static WINTUN_DLL: Lazy>>> = Lazy::new(|| Mutex::new(None)); +static WINTUN_DLL: OnceCell = OnceCell::new(); type WintunCreateAdapterFn = unsafe extern "stdcall" fn( name: *const u16, @@ -67,7 +60,7 @@ unsafe impl Sync for WintunDll {} /// Represents a Wintun adapter. pub struct WintunAdapter { - dll_handle: Arc, + dll_handle: &'static WintunDll, handle: RawHandle, name: U16CString, } @@ -85,7 +78,7 @@ unsafe impl Sync for WintunAdapter {} impl WintunAdapter { pub fn create( - dll_handle: Arc, + dll_handle: &'static WintunDll, name: &U16CStr, tunnel_type: &U16CStr, requested_guid: Option, @@ -177,16 +170,8 @@ impl Drop for WintunAdapter { } impl WintunDll { - pub fn instance(resource_dir: &Path) -> io::Result> { - let mut dll = (*WINTUN_DLL).lock().expect("Wintun mutex poisoned"); - match &*dll { - Some(dll) => Ok(dll.clone()), - None => { - let new_dll = Arc::new(Self::new(resource_dir)?); - *dll = Some(new_dll.clone()); - Ok(new_dll) - } - } + pub fn instance(resource_dir: &Path) -> io::Result<&'static Self> { + WINTUN_DLL.get_or_try_init(|| Self::new(resource_dir)) } fn new(resource_dir: &Path) -> io::Result { @@ -268,8 +253,8 @@ impl WintunDll { luid.assume_init() } - pub fn activate_logging(self: &Arc) -> WintunLoggerHandle { - WintunLoggerHandle::from_handle(self.clone()) + pub fn activate_logging(&'static self) -> WintunLoggerHandle { + WintunLoggerHandle::from_handle(self) } fn set_logger(&self, logger: Option) { @@ -284,11 +269,11 @@ impl Drop for WintunDll { } pub struct WintunLoggerHandle { - dll_handle: Arc, + dll_handle: &'static WintunDll, } impl WintunLoggerHandle { - fn from_handle(dll_handle: Arc) -> Self { + fn from_handle(dll_handle: &'static WintunDll) -> Self { dll_handle.set_logger(Some(Self::callback)); Self { dll_handle } } diff --git a/talpid-routing/Cargo.toml b/talpid-routing/Cargo.toml index 56fb12b13c31..e5cfb9d9f904 100644 --- a/talpid-routing/Cargo.toml +++ b/talpid-routing/Cargo.toml @@ -36,7 +36,7 @@ system-configuration = "0.5.1" [target.'cfg(windows)'.dependencies] libc = "0.2" -talpid-windows-net = { path = "../talpid-windows-net" } +talpid-windows = { path = "../talpid-windows" } widestring = "1.0" [target.'cfg(windows)'.dependencies.windows-sys] diff --git a/talpid-routing/src/windows/default_route_monitor.rs b/talpid-routing/src/windows/default_route_monitor.rs index 0f7d64e3a892..4e452ef31911 100644 --- a/talpid-routing/src/windows/default_route_monitor.rs +++ b/talpid-routing/src/windows/default_route_monitor.rs @@ -22,7 +22,7 @@ use windows_sys::Win32::{ }, }; -use talpid_windows_net::AddressFamily; +use talpid_windows::net::AddressFamily; const WIN_FALSE: BOOLEAN = 0; diff --git a/talpid-routing/src/windows/get_best_default_route.rs b/talpid-routing/src/windows/get_best_default_route.rs index 4a1f254a79ac..f51313034895 100644 --- a/talpid-routing/src/windows/get_best_default_route.rs +++ b/talpid-routing/src/windows/get_best_default_route.rs @@ -1,7 +1,7 @@ use super::{Error, Result}; use std::{net::SocketAddr, slice}; use talpid_types::win32_err; -use talpid_windows_net::{ +use talpid_windows::net::{ get_ip_interface_entry, try_socketaddr_from_inet_sockaddr, AddressFamily, }; use widestring::{widecstr, WideCStr}; diff --git a/talpid-routing/src/windows/mod.rs b/talpid-routing/src/windows/mod.rs index 51ac345f82c5..7924d4c7abf4 100644 --- a/talpid-routing/src/windows/mod.rs +++ b/talpid-routing/src/windows/mod.rs @@ -12,7 +12,7 @@ use net::AddressFamily; pub use route_manager::{Callback, CallbackHandle, Route, RouteManagerInternal}; use std::{collections::HashSet, io, net::IpAddr}; use talpid_types::ErrorExt; -use talpid_windows_net as net; +use talpid_windows::net; mod default_route_monitor; mod get_best_default_route; @@ -284,12 +284,10 @@ fn get_mtu_for_route(addr_family: AddressFamily) -> Result> { match get_best_default_route(addr_family) { Ok(Some(route)) => { let interface_row = - talpid_windows_net::get_ip_interface_entry(addr_family, &route.iface).map_err( - |e| { - log::error!("Could not get ip interface entry: {}", e); - Error::GetMtu - }, - )?; + net::get_ip_interface_entry(addr_family, &route.iface).map_err(|e| { + log::error!("Could not get ip interface entry: {}", e); + Error::GetMtu + })?; let mtu = interface_row.NlMtu; let mtu = u16::try_from(mtu).map_err(|_| Error::GetMtu)?; Ok(Some(mtu)) diff --git a/talpid-routing/src/windows/route_manager.rs b/talpid-routing/src/windows/route_manager.rs index 277832eaa94a..a122a29003ce 100644 --- a/talpid-routing/src/windows/route_manager.rs +++ b/talpid-routing/src/windows/route_manager.rs @@ -11,7 +11,7 @@ use std::{ sync::{Arc, Mutex}, }; use talpid_types::win32_err; -use talpid_windows_net::{ +use talpid_windows::net::{ inet_sockaddr_from_socketaddr, try_socketaddr_from_inet_sockaddr, AddressFamily, }; use widestring::{WideCStr, WideCString}; @@ -824,7 +824,7 @@ impl<'a> Iterator for AdaptersIterator<'a> { pub fn win_ip_address_prefix_from_ipnetwork_port_zero(from: IpNetwork) -> IP_ADDRESS_PREFIX { // Port should not matter so we set it to 0 let prefix = - talpid_windows_net::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from.ip(), 0)); + talpid_windows::net::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from.ip(), 0)); IP_ADDRESS_PREFIX { Prefix: prefix, PrefixLength: from.prefix(), @@ -834,7 +834,7 @@ pub fn win_ip_address_prefix_from_ipnetwork_port_zero(from: IpNetwork) -> IP_ADD /// Convert to a windows defined `SOCKADDR_INET` from a `IpAddr` but set the port to 0 pub fn inet_sockaddr_from_ipaddr(from: IpAddr) -> SOCKADDR_INET { // Port should not matter so we set it to 0 - talpid_windows_net::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from, 0)) + talpid_windows::net::inet_sockaddr_from_socketaddr(std::net::SocketAddr::new(from, 0)) } /// Convert to a `AddressFamily` from a `ipnetwork::IpNetwork` diff --git a/talpid-tunnel/Cargo.toml b/talpid-tunnel/Cargo.toml index 1e8bbd939537..1f5ef36a1a6f 100644 --- a/talpid-tunnel/Cargo.toml +++ b/talpid-tunnel/Cargo.toml @@ -32,7 +32,7 @@ tun = "0.5.1" tun = "0.5.1" [target.'cfg(windows)'.dependencies] -talpid-windows-net = { path = "../talpid-windows-net" } +talpid-windows = { path = "../talpid-windows" } [target.'cfg(windows)'.dependencies.windows-sys] workspace = true diff --git a/talpid-tunnel/src/windows.rs b/talpid-tunnel/src/windows.rs index d9c54f69408e..bc5fffb3f0ec 100644 --- a/talpid-tunnel/src/windows.rs +++ b/talpid-tunnel/src/windows.rs @@ -1,5 +1,5 @@ use std::io; -use talpid_windows_net::{get_ip_interface_entry, set_ip_interface_entry, AddressFamily}; +use talpid_windows::net::{get_ip_interface_entry, set_ip_interface_entry, AddressFamily}; use windows_sys::Win32::{ Foundation::ERROR_NOT_FOUND, NetworkManagement::Ndis::NET_LUID_LH, Networking::WinSock::RouterDiscoveryDisabled, diff --git a/talpid-windows-net/src/lib.rs b/talpid-windows-net/src/lib.rs deleted file mode 100644 index 48dc8062ee4a..000000000000 --- a/talpid-windows-net/src/lib.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! Interface with low-level windows specific bits. - -#![deny(missing_docs)] -#![deny(rust_2018_idioms)] - -/// Nicer interfaces with Windows networking code. -#[cfg(windows)] -pub mod net; -#[cfg(windows)] -pub use net::*; diff --git a/talpid-windows-net/Cargo.toml b/talpid-windows/Cargo.toml similarity index 85% rename from talpid-windows-net/Cargo.toml rename to talpid-windows/Cargo.toml index 1c6fc02d9ab8..0765a1376b70 100644 --- a/talpid-windows-net/Cargo.toml +++ b/talpid-windows/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "talpid-windows-net" -description = "Work with Windows network interfaces and their configuration" +name = "talpid-windows" +description = "Nice abstractions for Windows" version.workspace = true authors.workspace = true repository.workspace = true diff --git a/talpid-windows/src/io.rs b/talpid-windows/src/io.rs new file mode 100644 index 000000000000..1bcffb30d345 --- /dev/null +++ b/talpid-windows/src/io.rs @@ -0,0 +1,49 @@ +use std::{io, mem}; +use windows_sys::Win32::System::IO::OVERLAPPED; + +use crate::sync::Event; + +/// Abstraction over `OVERLAPPED`. +pub struct Overlapped { + overlapped: OVERLAPPED, + event: Option, +} + +unsafe impl Send for Overlapped {} +unsafe impl Sync for Overlapped {} + +impl Overlapped { + /// Creates an `OVERLAPPED` object with `hEvent` set. + pub fn new(event: Option) -> io::Result { + let mut overlapped = Overlapped { + overlapped: unsafe { mem::zeroed() }, + event: None, + }; + overlapped.set_event(event); + Ok(overlapped) + } + + /// Borrows the underlying `OVERLAPPED` object. + pub fn as_mut_ptr(&mut self) -> *mut OVERLAPPED { + &mut self.overlapped + } + + /// Returns a reference to the associated event. + pub fn get_event(&self) -> Option<&Event> { + self.event.as_ref() + } + + /// Sets the event object for the underlying `OVERLAPPED` object (i.e., `hEvent`) + fn set_event(&mut self, event: Option) { + match event { + Some(event) => { + self.overlapped.hEvent = event.as_raw(); + self.event = Some(event); + } + None => { + self.overlapped.hEvent = 0; + self.event = None; + } + } + } +} diff --git a/talpid-windows/src/lib.rs b/talpid-windows/src/lib.rs new file mode 100644 index 000000000000..755f6b2e30f2 --- /dev/null +++ b/talpid-windows/src/lib.rs @@ -0,0 +1,17 @@ +//! Interface with low-level Windows-specific bits. + +#![deny(missing_docs)] +#![deny(rust_2018_idioms)] +#![cfg(windows)] + +/// I/O +pub mod io; + +/// Networking +pub mod net; + +/// Synchronization +pub mod sync; + +/// Processes +pub mod process; diff --git a/talpid-windows-net/src/net.rs b/talpid-windows/src/net.rs similarity index 100% rename from talpid-windows-net/src/net.rs rename to talpid-windows/src/net.rs diff --git a/talpid-windows/src/process.rs b/talpid-windows/src/process.rs new file mode 100644 index 000000000000..ecddbe09f200 --- /dev/null +++ b/talpid-windows/src/process.rs @@ -0,0 +1,157 @@ +use std::{ + ffi::{c_char, CStr}, + io, mem, +}; +use windows_sys::Win32::{ + Foundation::{CloseHandle, ERROR_NO_MORE_FILES, HANDLE, INVALID_HANDLE_VALUE}, + System::Diagnostics::ToolHelp::{ + CreateToolhelp32Snapshot, Module32First, Module32Next, Process32FirstW, Process32NextW, + MODULEENTRY32, PROCESSENTRY32W, + }, +}; + +/// A snapshot of process modules, threads, and heaps +pub struct ProcessSnapshot { + handle: HANDLE, +} + +impl ProcessSnapshot { + /// Create a new process snapshot using `CreateToolhelp32Snapshot` + pub fn new(flags: u32, process_id: u32) -> io::Result { + let snap = unsafe { CreateToolhelp32Snapshot(flags, process_id) }; + + if snap == INVALID_HANDLE_VALUE { + Err(io::Error::last_os_error()) + } else { + Ok(ProcessSnapshot { handle: snap }) + } + } + + /// Return the raw handle + pub fn as_raw(&self) -> HANDLE { + self.handle + } + + /// Return an iterator over the modules in the snapshot + pub fn modules(&self) -> ProcessSnapshotModules<'_> { + let mut entry: MODULEENTRY32 = unsafe { mem::zeroed() }; + entry.dwSize = mem::size_of::() as u32; + + ProcessSnapshotModules { + snapshot: self, + iter_started: false, + temp_entry: entry, + } + } + + /// Return an iterator over the processes in the snapshot + pub fn processes(&self) -> ProcessSnapshotEntries<'_> { + let mut entry: PROCESSENTRY32W = unsafe { mem::zeroed() }; + entry.dwSize = mem::size_of::() as u32; + + ProcessSnapshotEntries { + snapshot: self, + iter_started: false, + temp_entry: entry, + } + } +} + +impl Drop for ProcessSnapshot { + fn drop(&mut self) { + unsafe { + CloseHandle(self.handle); + } + } +} + +/// Description of a snapshot module entry. See `MODULEENTRY32` +pub struct ModuleEntry { + /// Module name + pub name: String, + /// Module base address (in the owning process) + pub base_address: *const u8, + /// Size of the module (in bytes) + pub size: usize, +} + +/// Module iterator for [ProcessSnapshot] +pub struct ProcessSnapshotModules<'a> { + snapshot: &'a ProcessSnapshot, + iter_started: bool, + temp_entry: MODULEENTRY32, +} + +impl Iterator for ProcessSnapshotModules<'_> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + if self.iter_started { + if unsafe { Module32Next(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + let last_error = io::Error::last_os_error(); + + return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { + None + } else { + Some(Err(last_error)) + }; + } + } else { + if unsafe { Module32First(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + return Some(Err(io::Error::last_os_error())); + } + self.iter_started = true; + } + + let cstr_ref = &self.temp_entry.szModule[0]; + let cstr = unsafe { CStr::from_ptr(cstr_ref as *const u8 as *const c_char) }; + Some(Ok(ModuleEntry { + name: cstr.to_string_lossy().into_owned(), + base_address: self.temp_entry.modBaseAddr, + size: self.temp_entry.modBaseSize as usize, + })) + } +} + +/// Description of a snapshot process entry. See `PROCESSENTRY32W` +pub struct ProcessEntry { + /// Process identifier + pub pid: u32, + /// Process identifier of the parent process + pub parent_pid: u32, +} + +/// Process iterator for [ProcessSnapshot] +pub struct ProcessSnapshotEntries<'a> { + snapshot: &'a ProcessSnapshot, + iter_started: bool, + temp_entry: PROCESSENTRY32W, +} + +impl Iterator for ProcessSnapshotEntries<'_> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + if self.iter_started { + if unsafe { Process32NextW(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + let last_error = io::Error::last_os_error(); + + return if last_error.raw_os_error().unwrap() as u32 == ERROR_NO_MORE_FILES { + None + } else { + Some(Err(last_error)) + }; + } + } else { + if unsafe { Process32FirstW(self.snapshot.as_raw(), &mut self.temp_entry) } == 0 { + return Some(Err(io::Error::last_os_error())); + } + self.iter_started = true; + } + + Some(Ok(ProcessEntry { + pid: self.temp_entry.th32ProcessID, + parent_pid: self.temp_entry.th32ParentProcessID, + })) + } +} diff --git a/talpid-windows/src/sync.rs b/talpid-windows/src/sync.rs new file mode 100644 index 000000000000..202c96524d28 --- /dev/null +++ b/talpid-windows/src/sync.rs @@ -0,0 +1,55 @@ +use std::{io, ptr}; +use windows_sys::Win32::{ + Foundation::{CloseHandle, BOOL, HANDLE}, + System::Threading::{CreateEventW, SetEvent}, +}; + +/// Windows event object +pub struct Event(HANDLE); + +unsafe impl Send for Event {} +unsafe impl Sync for Event {} + +impl Event { + /// Create a new event object using `CreateEventW` + pub fn new(manual_reset: bool, initial_state: bool) -> io::Result { + let event = unsafe { + CreateEventW( + ptr::null_mut(), + bool_to_winbool(manual_reset), + bool_to_winbool(initial_state), + ptr::null(), + ) + }; + if event == 0 { + return Err(io::Error::last_os_error()); + } + Ok(Self(event)) + } + + /// Signal the event object + pub fn set(&self) -> io::Result<()> { + if unsafe { SetEvent(self.0) } == 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + } + + /// Return raw event object + pub fn as_raw(&self) -> HANDLE { + self.0 + } +} + +impl Drop for Event { + fn drop(&mut self) { + unsafe { CloseHandle(self.0) }; + } +} + +const fn bool_to_winbool(val: bool) -> BOOL { + match val { + true => 1, + false => 0, + } +} diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index f8e854a7c56f..f9a446a700d7 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -50,7 +50,7 @@ tokio-stream = { version = "0.1", features = ["io-util"] } [target.'cfg(windows)'.dependencies] bitflags = "1.2" -talpid-windows-net = { path = "../talpid-windows-net" } +talpid-windows = { path = "../talpid-windows" } widestring = "1.0" # TODO: Figure out which features are needed and which are not diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index cdc09224ef4b..c95a5d371b67 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -105,7 +105,7 @@ pub enum Error { /// Failed to set IP addresses on WireGuard interface #[cfg(target_os = "windows")] #[error(display = "Failed to set IP addresses on WireGuard interface")] - SetIpAddressesError(#[error(source)] talpid_windows_net::Error), + SetIpAddressesError(#[error(source)] talpid_windows::net::Error), } /// Spawns and monitors a wireguard tunnel @@ -643,12 +643,12 @@ impl WireguardMonitor { })?; // TODO: The LUID can be obtained directly. - let luid = talpid_windows_net::luid_from_alias(iface_name).map_err(|error| { + let luid = talpid_windows::net::luid_from_alias(iface_name).map_err(|error| { log::error!("Failed to obtain tunnel interface LUID: {}", error); CloseMsg::SetupError(Error::IpInterfacesError) })?; for address in addresses { - talpid_windows_net::add_ip_address_for_interface(luid, *address) + talpid_windows::net::add_ip_address_for_interface(luid, *address) .map_err(|error| CloseMsg::SetupError(Error::SetIpAddressesError(error)))?; } Ok(()) diff --git a/talpid-wireguard/src/wireguard_nt.rs b/talpid-wireguard/src/wireguard_nt.rs index 1b4405eba290..daa7da65c86c 100644 --- a/talpid-wireguard/src/wireguard_nt.rs +++ b/talpid-wireguard/src/wireguard_nt.rs @@ -7,7 +7,7 @@ use super::{ use bitflags::bitflags; use futures::SinkExt; use ipnetwork::IpNetwork; -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; use std::{ ffi::CStr, fmt, @@ -22,7 +22,7 @@ use std::{ sync::{Arc, Mutex}, }; use talpid_types::{BoxedError, ErrorExt}; -use talpid_windows_net as net; +use talpid_windows::net; use widestring::{U16CStr, U16CString}; use windows_sys::{ core::GUID, @@ -38,7 +38,7 @@ use windows_sys::{ }, }; -static WG_NT_DLL: Lazy>>> = Lazy::new(|| Mutex::new(None)); +static WG_NT_DLL: OnceCell = OnceCell::new(); static ADAPTER_TYPE: Lazy = Lazy::new(|| U16CString::from_str("Mullvad").unwrap()); static ADAPTER_ALIAS: Lazy = Lazy::new(|| U16CString::from_str("Mullvad").unwrap()); @@ -163,7 +163,7 @@ pub enum Error { } pub struct WgNtTunnel { - device: Arc>>, + device: Option>, interface_name: String, setup_handle: tokio::task::JoinHandle<()>, _logger_handle: LoggerHandle, @@ -430,7 +430,7 @@ impl WgNtTunnel { mut done_tx: futures::channel::mpsc::Sender>, ) -> Result { let dll = load_wg_nt_dll(resource_dir)?; - let logger_handle = LoggerHandle::new(dll.clone(), log_path)?; + let logger_handle = LoggerHandle::new(dll, log_path)?; let device = WgNtAdapter::create(dll, &ADAPTER_ALIAS, &ADAPTER_TYPE, Some(ADAPTER_GUID)) .map_err(Error::CreateTunnelDevice)?; @@ -443,10 +443,11 @@ impl WgNtTunnel { ); } device.set_config(config)?; - let device = Arc::new(Mutex::new(Some(device))); + let device2 = Arc::new(device); + let device = Some(device2.clone()); let setup_future = setup_ip_listener( - device.clone(), + device2, u32::from(config.mtu), config.tunnel.addresses.iter().any(|addr| addr.is_ipv6()), ); @@ -466,16 +467,12 @@ impl WgNtTunnel { fn stop_tunnel(&mut self) { self.setup_handle.abort(); - let _ = self.device.lock().unwrap().take(); + let _ = self.device.take(); } } -async fn setup_ip_listener( - device: Arc>>, - mtu: u32, - has_ipv6: bool, -) -> Result<()> { - let luid = { device.lock().unwrap().as_ref().unwrap().luid() }; +async fn setup_ip_listener(device: Arc, mtu: u32, has_ipv6: bool) -> Result<()> { + let luid = device.luid(); let luid = NET_LUID_LH { Value: unsafe { luid.Value }, }; @@ -489,13 +486,9 @@ async fn setup_ip_listener( talpid_tunnel::network_interface::initialize_interfaces(luid, Some(mtu)) .map_err(Error::SetTunnelMtu)?; - if let Some(device) = &*device.lock().unwrap() { - device - .set_state(WgAdapterState::Up) - .map_err(Error::EnableTunnel) - } else { - Ok(()) - } + device + .set_state(WgAdapterState::Up) + .map_err(Error::EnableTunnel) } impl Drop for WgNtTunnel { @@ -507,12 +500,12 @@ impl Drop for WgNtTunnel { static LOG_CONTEXT: Lazy>> = Lazy::new(|| Mutex::new(None)); struct LoggerHandle { - dll: Arc, + dll: &'static WgNtDll, context: u32, } impl LoggerHandle { - fn new(dll: Arc, log_path: Option<&Path>) -> Result { + fn new(dll: &'static WgNtDll, log_path: Option<&Path>) -> Result { let context = logging::initialize_logging(log_path).map_err(Error::InitLogging)?; { *(LOG_CONTEXT.lock().unwrap()) = Some(context); @@ -547,7 +540,7 @@ impl Drop for LoggerHandle { } struct WgNtAdapter { - dll_handle: Arc, + dll_handle: &'static WgNtDll, handle: RawHandle, } @@ -564,7 +557,7 @@ unsafe impl Sync for WgNtAdapter {} impl WgNtAdapter { fn create( - dll_handle: Arc, + dll_handle: &'static WgNtDll, name: &U16CStr, tunnel_type: &U16CStr, requested_guid: Option, @@ -806,16 +799,8 @@ impl Drop for WgNtDll { } } -fn load_wg_nt_dll(resource_dir: &Path) -> Result> { - let mut dll = (*WG_NT_DLL).lock().expect("WireGuardNT mutex poisoned"); - match &*dll { - Some(dll) => Ok(dll.clone()), - None => { - let new_dll = Arc::new(WgNtDll::new(resource_dir).map_err(Error::LoadDll)?); - *dll = Some(new_dll.clone()); - Ok(new_dll) - } - } +fn load_wg_nt_dll(resource_dir: &Path) -> Result<&'static WgNtDll> { + WG_NT_DLL.get_or_try_init(|| WgNtDll::new(resource_dir).map_err(Error::LoadDll)) } fn serialize_config(config: &Config) -> Result>> { @@ -941,7 +926,7 @@ impl Tunnel for WgNtTunnel { } fn get_tunnel_stats(&self) -> std::result::Result { - if let Some(ref device) = &*self.device.lock().unwrap() { + if let Some(ref device) = self.device { let mut map = StatsMap::new(); let (_interface, peers) = device.get_config().map_err(|error| { log::error!( @@ -976,9 +961,12 @@ impl Tunnel for WgNtTunnel { config: Config, ) -> Pin> + Send>> { let device = self.device.clone(); + Box::pin(async move { - let guard = device.lock().unwrap(); - let device = guard.as_ref().ok_or(super::TunnelError::SetConfigError)?; + let Some(device) = device else { + log::error!("Failed to set config: No tunnel device"); + return Err(super::TunnelError::SetConfigError); + }; device.set_config(&config).map_err(|error| { log::error!( "{}", @@ -1043,7 +1031,7 @@ mod tests { public_key: *WG_PUBLIC_KEY.as_bytes(), preshared_key: [0; WIREGUARD_KEY_LENGTH], persistent_keepalive: 0, - endpoint: talpid_windows_net::inet_sockaddr_from_socketaddr( + endpoint: talpid_windows::net::inet_sockaddr_from_socketaddr( "1.2.3.4:1234".parse().unwrap(), ) .into(),