Skip to content

Commit

Permalink
Removed GatewayMap
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Oct 29, 2024
1 parent b71d901 commit ea1138a
Show file tree
Hide file tree
Showing 16 changed files with 109 additions and 379 deletions.
7 changes: 2 additions & 5 deletions src/bin/defguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ use defguard::{
models::{settings::Settings, user::User, webhook::AppEvent, wireguard::ChangeEvent},
},
enterprise::license::{run_periodic_license_check, set_cached_license, License},
grpc::{
run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, GatewayMap, WorkerState,
},
grpc::{run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, WorkerState},
headers::create_user_agent_parser,
init_dev_env, init_vpn_location,
mail::{run_mail_handler, Mail},
Expand Down Expand Up @@ -83,7 +81,6 @@ async fn main() -> Result<(), anyhow::Error> {
let (events_tx, _events_rx) = broadcast::channel::<ChangeEvent>(256);
let (mail_tx, mail_rx) = unbounded_channel::<Mail>();
let worker_state = Arc::new(Mutex::new(WorkerState::new(webhook_tx.clone())));
let gateway_map = Arc::new(Mutex::new(GatewayMap::new()));
let user_agent_parser = create_user_agent_parser();

// initialize admin user
Expand Down Expand Up @@ -122,7 +119,7 @@ async fn main() -> Result<(), anyhow::Error> {
res = run_grpc_gateway_stream(pool.clone(), events_tx.clone(), mail_tx.clone()) => error!("Gateway gRPC stream returned early: {res:#?}"),
res = run_grpc_bidi_stream(pool.clone(), events_tx.clone(), mail_tx.clone(), user_agent_parser.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"),
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"),
res = run_web_server(worker_state, gateway_map, webhook_tx, webhook_rx, events_tx.clone(), mail_tx, pool.clone(), user_agent_parser, failed_logins) => error!("Web server returned early: {res:#?}"),
res = run_web_server(worker_state, webhook_tx, webhook_rx, events_tx.clone(), mail_tx, pool.clone(), user_agent_parser, failed_logins) => error!("Web server returned early: {res:#?}"),
res = run_mail_handler(mail_rx, pool.clone()) => error!("Mail handler returned early: {res:#?}"),
res = run_periodic_peer_disconnect(pool.clone(), events_tx) => error!("Periodic peer disconnect task returned early: {res:#?}"),
res = run_periodic_stats_purge(pool.clone(), config.stats_purge_frequency.into(), config.stats_purge_threshold.into()), if !config.disable_stats_purge => error!("Periodic stats purge task returned early: {res:#?}"),
Expand Down
13 changes: 0 additions & 13 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use crate::{
wireguard::WireguardNetworkError,
},
enterprise::license::LicenseError,
grpc::GatewayMapError,
ldap::error::LdapError,
templates::TemplateError,
};
Expand Down Expand Up @@ -105,18 +104,6 @@ impl From<DeviceError> for WebError {
}
}

impl From<GatewayMapError> for WebError {
fn from(error: GatewayMapError) -> Self {
match error {
GatewayMapError::NotFound(_, _)
| GatewayMapError::NetworkNotFound(_)
| GatewayMapError::UidNotFound(_) => Self::ObjectNotFound(error.to_string()),
GatewayMapError::RemoveActive(_) => Self::BadRequest(error.to_string()),
GatewayMapError::ConfigError => Self::ServerConfigMissing,
}
}
}

impl From<WireguardNetworkError> for WebError {
fn from(error: WireguardNetworkError) -> Self {
match error {
Expand Down
184 changes: 0 additions & 184 deletions src/grpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use chrono::NaiveDateTime;
use reqwest::Url;
use serde::Serialize;
use sqlx::{postgres::PgListener, PgPool};
use thiserror::Error;
use tokio::{
sync::{
broadcast::Sender,
Expand Down Expand Up @@ -75,189 +74,6 @@ pub(crate) mod proto {

use proto::{core_request, proxy_client::ProxyClient, CoreError, CoreResponse};

// Helper struct used to handle gateway state
// gateways are grouped by network
type GatewayHostname = String;
// TODO: save state to database
pub struct GatewayMap(HashMap<Id, HashMap<GatewayHostname, GatewayState>>);

#[derive(Debug, Error)]
pub enum GatewayMapError {
#[error("Gateway {1} for network {0} not found")]
NotFound(i64, GatewayHostname),
#[error("Network {0} not found")]
NetworkNotFound(i64),
#[error("Gateway with UID {0} not found")]
UidNotFound(Id),
#[error("Cannot remove. Gateway with UID {0} is still active")]
RemoveActive(Id),
#[error("Config missing")]
ConfigError,
}

impl GatewayMap {
#[must_use]
pub fn new() -> Self {
Self(HashMap::new())
}

// add a new gateway to map
// this method is meant to be called when a gateway requests a config
// as a sort of "registration"
// pub fn add_gateway(
// &mut self,
// network_id: Id,
// network_name: &str,
// hostname: String,
// name: Option<String>,
// mail_tx: UnboundedSender<Mail>,
// ) {
// info!("Adding gateway {hostname} with to gateway map for network {network_id}",);
// let gateway_state = GatewayState::new(network_id, network_name, &hostname, name, mail_tx);

// if let Some(network_gateway_map) = self.0.get_mut(&network_id) {
// network_gateway_map.entry(hostname).or_insert(gateway_state);
// } else {
// // no map for a given network exists yet
// let mut network_gateway_map = HashMap::new();
// network_gateway_map.insert(hostname, gateway_state);
// self.0.insert(network_id, network_gateway_map);
// }
// }

// remove gateway from map
pub(crate) fn remove_gateway(&mut self, network_id: Id, id: Id) -> Result<(), GatewayMapError> {
debug!("Removing gateway from network {network_id}");
if let Some(network_gateway_map) = self.0.get_mut(&network_id) {
// find gateway by uuid
let hostname = match network_gateway_map
.iter()
.find(|(_address, state)| state.id == id)
{
None => {
error!("Failed to find gateway with ID {id}");
return Err(GatewayMapError::UidNotFound(id));
}
Some((hostname, state)) => {
if state.connected {
error!("Cannot remove. Gateway with UID {id} is still active");
return Err(GatewayMapError::RemoveActive(id));
}
hostname.clone()
}
};
// remove matching gateway
network_gateway_map.remove(&hostname)
} else {
// no map for a given network exists yet
error!("Network ID {network_id} not found in gateway map");
return Err(GatewayMapError::NetworkNotFound(network_id));
};

info!("Gateway with UID {id} removed from network ID {network_id}");
Ok(())
}

// change gateway status to connected
// we assume that the gateway is already present in hashmap
// pub(crate) fn connect_gateway(
// &mut self,
// network_id: Id,
// hostname: &str,
// ) -> Result<(), GatewayMapError> {
// debug!("Connecting gateway {hostname} in network {network_id}");
// if let Some(network_gateway_map) = self.0.get_mut(&network_id) {
// if let Some(state) = network_gateway_map.get_mut(hostname) {
// state.connected = true;
// state.disconnected_at = None;
// state.connected_at = Some(Utc::now().naive_utc());
// // debug!(
// // "Gateway {hostname} found in gateway map, current state: {:#?}",
// // state
// // );
// } else {
// error!("Gateway {hostname} not found in gateway map for network {network_id}");
// return Err(GatewayMapError::NotFound(network_id, hostname.into()));
// }
// } else {
// // no map for a given network exists yet
// error!("Network ID {network_id} not found in gateway map");
// return Err(GatewayMapError::NetworkNotFound(network_id));
// };

// info!("Gateway {hostname} connected in network ID {network_id}");
// Ok(())
// }

// change gateway status to disconnected
// pub(crate) fn disconnect_gateway(
// &mut self,
// network_id: Id,
// hostname: String,
// pool: &PgPool,
// ) -> Result<(), GatewayMapError> {
// debug!("Disconnecting gateway {hostname} in network {network_id}");
// if let Some(network_gateway_map) = self.0.get_mut(&network_id) {
// if let Some(state) = network_gateway_map.get_mut(&hostname) {
// state.connected = false;
// state.disconnected_at = Some(Utc::now().naive_utc());
// // state.send_disconnect_notification(pool);
// // debug!("Gateway {hostname} found in gateway map, current state: {state:#?}");
// info!("Gateway {hostname} disconnected in network {network_id}");
// return Ok(());
// };
// };
// let err = GatewayMapError::NotFound(network_id, hostname);
// error!("Gateway disconnect failed: {err}");
// Err(err)
// }

// return `true` if at least one gateway in a given network is connected
// #[must_use]
// pub(crate) fn connected(&self, network_id: Id) -> bool {
// match self.0.get(&network_id) {
// Some(network_gateway_map) => network_gateway_map
// .values()
// .any(|gateway| gateway.connected),
// None => false,
// }
// }

// return a list af aff statuses af all gateways in a given network
// #[must_use]
// pub(crate) fn get_network_gateway_status(&self, network_id: Id) -> Vec<GatewayState> {
// match self.0.get(&network_id) {
// Some(network_gateway_map) => network_gateway_map.clone().into_values().collect(),
// None => Vec::new(),
// }
// }

// return gateway name
// #[must_use]
// pub(crate) fn get_network_gateway_name(
// &self,
// network_id: Id,
// hostname: &str,
// ) -> Option<String> {
// match self.0.get(&network_id) {
// Some(network_gateway_map) => {
// if let Some(state) = network_gateway_map.get(hostname) {
// state.name.clone()
// } else {
// None
// }
// }
// None => None,
// }
// }
}

impl Default for GatewayMap {
fn default() -> Self {
Self::new()
}
}

#[derive(Clone, Serialize)]
pub(crate) struct GatewayState {
pub id: Id,
Expand Down
27 changes: 2 additions & 25 deletions src/handlers/wireguard.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
use std::{
net::IpAddr,
str::FromStr,
sync::{Arc, Mutex},
};
use std::{net::IpAddr, str::FromStr};

use axum::{
extract::{Json, Path, Query, State},
http::StatusCode,
Extension,
};
use chrono::{DateTime, NaiveDateTime, TimeDelta, Utc};
use ipnetwork::IpNetwork;
Expand All @@ -34,7 +29,7 @@ use crate::{
Id,
},
enterprise::handlers::CanManageDevices,
grpc::{GatewayMap, GatewayState},
grpc::GatewayState,
handlers::mail::send_new_device_added_email,
server_config,
templates::TemplateLocation,
Expand Down Expand Up @@ -299,24 +294,6 @@ pub(crate) async fn gateway_status(
Ok(ApiResponse::new(json!(gateways), StatusCode::OK))
}

// TODO: gateway_id should be enough; remove network_id.
pub(crate) async fn remove_gateway(
Path((network_id, gateway_id)): Path<(Id, Id)>,
_role: VpnRole,
Extension(gateway_state): Extension<Arc<Mutex<GatewayMap>>>,
) -> Result<ApiResponse, WebError> {
debug!("Removing gateway {gateway_id} in network {network_id}");
let mut gateway_state = gateway_state
.lock()
.expect("Failed to acquire gateway state lock");

gateway_state.remove_gateway(network_id, gateway_id)?;

info!("Removed gateway {gateway_id} in network {network_id}");

Ok(ApiResponse::new(Value::Null, StatusCode::OK))
}

pub(crate) async fn import_network(
_role: VpnRole,
State(appstate): State<AppState>,
Expand Down
16 changes: 3 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use self::handlers::wireguard::{
add_device, add_user_devices, create_network, create_network_token, delete_device,
delete_network, download_config, gateway_status, get_device, import_network, list_devices,
list_networks, list_user_devices, modify_device, modify_network, network_details,
network_stats, remove_gateway, user_stats,
network_stats, user_stats,
};
#[cfg(feature = "worker")]
use self::handlers::worker::{
Expand Down Expand Up @@ -119,9 +119,7 @@ use self::{
};
#[cfg(any(feature = "openid", feature = "worker"))]
use self::{
auth::failed_login::FailedLoginMap,
db::models::oauth2client::OAuth2Client,
grpc::{GatewayMap, WorkerState},
auth::failed_login::FailedLoginMap, db::models::oauth2client::OAuth2Client, grpc::WorkerState,
handlers::app_info::get_app_info,
};

Expand Down Expand Up @@ -292,7 +290,6 @@ pub fn build_webapp(
events_tx: Sender<ChangeEvent>,
mail_tx: UnboundedSender<Mail>,
worker_state: Arc<Mutex<WorkerState>>,
gateway_state: Arc<Mutex<GatewayMap>>,
pool: PgPool,
user_agent_parser: Arc<UserAgentParser>,
failed_logins: Arc<Mutex<FailedLoginMap>>,
Expand Down Expand Up @@ -469,10 +466,6 @@ pub fn build_webapp(
.route("/network", get(list_networks))
.route("/network/:network_id", get(network_details))
.route("/network/:network_id/gateways", get(gateway_status))
.route(
"/network/:network_id/gateways/:gateway_id",
delete(remove_gateway),
)
.route("/network/import", post(import_network))
.route("/network/:network_id/devices", post(add_user_devices))
.route(
Expand All @@ -486,8 +479,7 @@ pub fn build_webapp(
.route("/gateway/:gateway_id", get(get_gateway))
.route("/gateway/:gateway_id", put(update_gateway))
.route("/gateway/:gateway_id", delete(delete_gateway))
.route("/network/:network_id/all_gateways", get(get_gateways))
.layer(Extension(gateway_state)),
.route("/network/:network_id/all_gateways", get(get_gateways)),
);

#[cfg(feature = "worker")]
Expand Down Expand Up @@ -532,7 +524,6 @@ pub fn build_webapp(
/// Runs core web server exposing REST API.
pub async fn run_web_server(
worker_state: Arc<Mutex<WorkerState>>,
gateway_state: Arc<Mutex<GatewayMap>>,
webhook_tx: UnboundedSender<AppEvent>,
webhook_rx: UnboundedReceiver<AppEvent>,
events_tx: Sender<ChangeEvent>,
Expand All @@ -547,7 +538,6 @@ pub async fn run_web_server(
events_tx,
mail_tx,
worker_state,
gateway_state,
pool,
user_agent_parser,
failed_logins,
Expand Down
4 changes: 1 addition & 3 deletions tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use defguard::{
Id,
},
enterprise::license::{set_cached_license, License},
grpc::{GatewayMap, WorkerState},
grpc::WorkerState,
headers::create_user_agent_parser,
mail::Mail,
SERVER_CONFIG,
Expand Down Expand Up @@ -120,7 +120,6 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie
let worker_state = Arc::new(Mutex::new(WorkerState::new(tx.clone())));
let (wg_tx, wg_rx) = broadcast::channel::<ChangeEvent>(16);
let (mail_tx, mail_rx) = unbounded_channel::<Mail>();
let gateway_state = Arc::new(Mutex::new(GatewayMap::new()));

let failed_logins = FailedLoginMap::new();
let failed_logins = Arc::new(Mutex::new(failed_logins));
Expand Down Expand Up @@ -167,7 +166,6 @@ pub async fn make_base_client(pool: PgPool, config: DefGuardConfig) -> (TestClie
wg_tx,
mail_tx,
worker_state,
gateway_state,
pool,
user_agent_parser,
failed_logins,
Expand Down
Loading

0 comments on commit ea1138a

Please sign in to comment.