Skip to content

Commit

Permalink
fix: fix a bug on windows related to reconnecting to a location after…
Browse files Browse the repository at this point in the history
… system restart/ungraceful shutdown (#350)
  • Loading branch information
t-aleksander authored Nov 7, 2024
1 parent dcb343d commit e59eec9
Showing 1 changed file with 40 additions and 2 deletions.
42 changes: 40 additions & 2 deletions src-tauri/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ use winapi::{
um::{
errhandlingapi::GetLastError,
winsvc::{
OpenSCManagerW, OpenServiceW, QueryServiceStatus, SC_HANDLE__, SC_MANAGER_CONNECT,
SERVICE_QUERY_STATUS, SERVICE_RUNNING,
CloseServiceHandle, OpenSCManagerW, OpenServiceW, QueryServiceStatus, SC_HANDLE__,
SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS, SERVICE_RUNNING,
},
},
};
Expand Down Expand Up @@ -233,6 +233,7 @@ pub fn spawn_stats_thread(
handle: tauri::AppHandle,
interface_name: String,
connection_type: ConnectionType,
location_id: Id,
) {
tokio::spawn(async move {
let state = handle.state::<AppState>();
Expand All @@ -247,6 +248,16 @@ pub fn spawn_stats_thread(
.into_inner();

loop {
debug!("Checking if connection for interface {interface_name} still exists...");
if state
.find_connection(location_id, connection_type)
.await
.is_none()
{
debug!("Location connection for interface {interface_name} has been removed, stopping stats thread for that interface.");
break;
}
debug!("Connection for interface {interface_name} still exists, continuing to read network stats...");
match stream.message().await {
Ok(Some(interface_data)) => {
debug!("Received new network usage statistics for interface {interface_name}.");
Expand Down Expand Up @@ -647,6 +658,7 @@ pub async fn handle_connection_for_location(
handle.clone(),
interface_name.clone(),
ConnectionType::Location,
location.id,
);
debug!(
"Network usage stats thread for location {} spawned.",
Expand Down Expand Up @@ -707,6 +719,7 @@ pub async fn handle_connection_for_tunnel(
handle.clone(),
interface_name.clone(),
ConnectionType::Tunnel,
tunnel.id,
);
debug!("Stats thread for tunnel {} spawned", tunnel.name);

Expand Down Expand Up @@ -933,6 +946,23 @@ fn get_service_status(service_handle: *mut SC_HANDLE__) -> Result<DWORD, DWORD>
}
}

#[cfg(target_os = "windows")]
fn close_service_handle(
service_handle: *mut SC_HANDLE__,
service_name: &str,
) -> Result<i32, Error> {
let result = unsafe { CloseServiceHandle(service_handle) };
if result == 0 {
let error = unsafe { GetLastError() };
Err(Error::InternalError(format!(
"Failed to close service handle for service {service_name}, error code: {error}",
)))
} else {
info!("Service handle closed successfully");
Ok(result)
}
}

// TODO: Move the connection handling to a seperate, common function,
// so `handle_connection_for_location` and `handle_connection_for_tunnel` are not
// partially duplicated here.
Expand Down Expand Up @@ -979,6 +1009,7 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
match get_service_status(service) {
Ok(status) => {
// Only point where we don't jump to the next iteration of the loop and continue with the rest of the code below the match
close_service_handle(service, &service_name)?;
if status == SERVICE_RUNNING {
debug!("WireGuard tunnel {} is running, ", interface_name);
} else {
Expand All @@ -990,6 +1021,7 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
}
}
Err(err) => {
close_service_handle(service, &service_name)?;
warn!(
"Failed to query service status for interface {} while synchronizing active connections. \
This may cause the location {} state to display incorrectly in the client. Reconnect to it manually to fix it. Error: {err}",
Expand Down Expand Up @@ -1040,6 +1072,7 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
apphandle.clone(),
interface_name.clone(),
ConnectionType::Location,
location.id,
);
debug!(
"Network usage stats thread for location {} spawned.",
Expand Down Expand Up @@ -1090,6 +1123,7 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
match get_service_status(service) {
Ok(status) => {
// Only point where we don't jump to the next iteration of the loop and continue with the rest of the code below the match
close_service_handle(service, &service_name)?;
if status == SERVICE_RUNNING {
debug!("WireGuard tunnel {} is running", interface_name);
} else {
Expand All @@ -1101,6 +1135,7 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
}
}
Err(err) => {
close_service_handle(service, &service_name)?;
warn!(
"Failed to query service status for interface {}. \
This may cause the tunnel {} state to display incorrectly in the client. Reconnect to it manually to fix it. Error: {err}",
Expand Down Expand Up @@ -1149,6 +1184,7 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
apphandle.clone(),
interface_name.clone(),
ConnectionType::Tunnel,
tunnel.id,
);
debug!("Stats thread for tunnel {} spawned", tunnel.name);

Expand All @@ -1166,6 +1202,8 @@ pub async fn sync_connections(apphandle: &AppHandle) -> Result<(), Error> {
debug!("Log watcher for tunnel {} spawned", tunnel.name);
}

close_service_handle(service_control_manager, "SERVICE_CONTROL_MANAGER")?;

debug!("Active connections synchronized with the system state");

Ok(())
Expand Down

0 comments on commit e59eec9

Please sign in to comment.