From 7f9906adbc2b7f7879520068dbfe21ee340040f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 16 Oct 2024 15:57:51 +0200 Subject: [PATCH] WIP: gateway RPC client --- .../20241015074303_network_gateways.down.sql | 2 +- .../20241015074303_network_gateways.up.sql | 6 +- src/bin/defguard.rs | 5 +- src/db/models/wireguard.rs | 11 +- src/grpc/gateway.rs | 84 +++++++++- src/grpc/mod.rs | 145 ++++++++++-------- src/handlers/mail.rs | 4 +- src/handlers/wireguard.rs | 9 +- 8 files changed, 184 insertions(+), 82 deletions(-) diff --git a/migrations/20241015074303_network_gateways.down.sql b/migrations/20241015074303_network_gateways.down.sql index 0197bef93..6bbc938ef 100644 --- a/migrations/20241015074303_network_gateways.down.sql +++ b/migrations/20241015074303_network_gateways.down.sql @@ -1 +1 @@ -ALTER TABLE wireguard_network DROP COLUMN gateways; +DROP TABLE gateway; diff --git a/migrations/20241015074303_network_gateways.up.sql b/migrations/20241015074303_network_gateways.up.sql index 9c866c4ef..000a79a0c 100644 --- a/migrations/20241015074303_network_gateways.up.sql +++ b/migrations/20241015074303_network_gateways.up.sql @@ -1 +1,5 @@ -ALTER TABLE wireguard_network ADD COLUMN gateways text[] NOT NULL DEFAULT array[]::text[]; +CREATE TABLE gateway ( + id bigserial PRIMARY KEY, + network_id bigint NOT NULL, + FOREIGN KEY(network_id) REFERENCES wireguard_network(id) +); diff --git a/src/bin/defguard.rs b/src/bin/defguard.rs index 5626e1ef0..9521ae363 100644 --- a/src/bin/defguard.rs +++ b/src/bin/defguard.rs @@ -8,7 +8,9 @@ use defguard::{ config::{Command, DefGuardConfig}, db::{init_db, AppEvent, GatewayEvent, Settings, User}, enterprise::license::{run_periodic_license_check, set_cached_license, License}, - grpc::{run_grpc_bidi_stream, run_grpc_server, GatewayMap, WorkerState}, + grpc::{ + run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, GatewayMap, WorkerState, + }, headers::create_user_agent_parser, init_dev_env, init_vpn_location, mail::{run_mail_handler, Mail}, @@ -114,6 +116,7 @@ async fn main() -> Result<(), anyhow::Error> { // run services tokio::select! { + res = run_grpc_gateway_stream(pool.clone()) => error!("Gateway gRPC stream returned early: {res:#?}"), res = run_grpc_bidi_stream(pool.clone(), wireguard_tx.clone(), mail_tx.clone(), user_agent_parser.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"), res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_map), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"), res = run_web_server(worker_state, gateway_map, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), user_agent_parser, failed_logins) => error!("Web server returned early: {res:#?}"), diff --git a/src/db/models/wireguard.rs b/src/db/models/wireguard.rs index 0149e1eca..b384a18bb 100644 --- a/src/db/models/wireguard.rs +++ b/src/db/models/wireguard.rs @@ -88,9 +88,6 @@ pub struct WireguardNetwork { pub mfa_enabled: bool, pub keepalive_interval: i32, pub peer_disconnect_threshold: i32, - // URLs pointing to all gateways serving gRPC - #[model(ref)] - pub gateways: Vec, } pub struct WireguardKey { @@ -159,7 +156,6 @@ impl WireguardNetwork { mfa_enabled, keepalive_interval, peer_disconnect_threshold, - gateways: Vec::new(), } } @@ -182,7 +178,7 @@ impl WireguardNetwork { let networks = query_as!( Self, "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \ - connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \ + connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold \ FROM wireguard_network WHERE name = $1", name ) @@ -204,7 +200,7 @@ impl WireguardNetwork { query_as!( Self, "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \ - connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \ + connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold \ FROM wireguard_network WHERE mfa_enabled = true", ) .fetch_all(executor) @@ -959,12 +955,11 @@ impl Default for WireguardNetwork { mfa_enabled: false, keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL, peer_disconnect_threshold: DEFAULT_DISCONNECT_THRESHOLD, - gateways: Vec::default(), } } } -#[derive(Serialize, Clone, Debug)] +#[derive(Debug, Serialize)] pub struct WireguardNetworkInfo { #[serde(flatten)] pub network: WireguardNetwork, diff --git a/src/grpc/gateway.rs b/src/grpc/gateway.rs index 5fedabf81..41ef42192 100644 --- a/src/grpc/gateway.rs +++ b/src/grpc/gateway.rs @@ -1,7 +1,9 @@ use std::{ + fs::read_to_string, pin::Pin, sync::{Arc, Mutex}, task::{Context, Poll}, + time::Duration, }; use chrono::{DateTime, Utc}; @@ -12,9 +14,14 @@ use tokio::{ mpsc::{self, Receiver, UnboundedSender}, }, task::JoinHandle, + time::sleep, +}; +use tokio_stream::{wrappers::UnboundedReceiverStream, Stream}; +use tonic::{ + metadata::MetadataMap, + transport::{Certificate, ClientTlsConfig, Endpoint, Identity, Server}, + Code, Request, Response, Status, }; -use tokio_stream::Stream; -use tonic::{metadata::MetadataMap, Code, Request, Response, Status}; use super::GatewayMap; use crate::{ @@ -166,6 +173,79 @@ impl WireguardPeerStats { } } +// TODO: merge with super. +const TEN_SECS: Duration = Duration::from_secs(10); + +pub(super) struct GatewayHandler { + endpoint: Endpoint, +} + +impl GatewayHandler { + pub(super) fn new(url: &str, ca_path: Option<&str>) -> Result { + let endpoint = Endpoint::from_shared(url.to_string())?; + let endpoint = endpoint + .http2_keep_alive_interval(TEN_SECS) + .tcp_keepalive(Some(TEN_SECS)) + .keep_alive_while_idle(true); + let endpoint = if let Some(ca) = ca_path { + let ca = read_to_string(ca).unwrap(); // FIXME: use custom error + let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca)); + endpoint.tls_config(tls)? + } else { + endpoint + }; + + Ok(Self { endpoint }) + } + + pub(super) async fn handle_connection(&self) -> ! { + let uri = self.endpoint.uri(); + loop { + debug!("Connecting to gateway at {uri}"); + let mut client = gateway_client::GatewayClient::new(self.endpoint.connect_lazy()); + let (tx, rx) = mpsc::unbounded_channel(); + let Ok(response) = client.bidi(UnboundedReceiverStream::new(rx)).await else { + error!("Failed to connect to gateway @ {uri}, retrying in 10s",); + sleep(TEN_SECS).await; + continue; + }; + info!("Connected to proxy at {uri}"); + let mut resp_stream = response.into_inner(); + 'message: loop { + match resp_stream.message().await { + Ok(None) => { + info!("stream was closed by the sender"); + break 'message; + } + Ok(Some(received)) => { + info!("Received message from gateway."); + debug!("Received the following message from gateway: {received:?}"); + let payload: Option = match received.payload { + Some(core_request::Payload::ConfigRequest(config_request)) => { + info!("*** ConfigurationRequest {config_request:?}"); + None + } + Some(core_request::Payload::PeerStats(peer_stats)) => { + info!("*** PeerStats {peer_stats:?}"); + None + } + // Reply without payload. + None => None, + }; + } + Err(err) => { + error!("Disconnected from gateway at {uri}"); + error!("stream error: {err}"); + debug!("waiting 10s to re-establish the connection"); + sleep(TEN_SECS).await; + break 'message; + } + } + } + } + } +} + /// Helper struct for handling gateway events struct GatewayUpdatesHandler { network_id: Id, diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 732e8b1a3..da749dd80 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -20,6 +20,7 @@ use tokio::{ broadcast::Sender, mpsc::{self, UnboundedSender}, }, + task::JoinSet, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -28,10 +29,9 @@ use tonic::{ Code, Status, }; use uaparser::UserAgentParser; -use uuid::Uuid; #[cfg(feature = "wireguard")] -use self::gateway::{gateway_service_server::GatewayServiceServer, GatewayServer}; +use self::gateway::GatewayHandler; use self::{ auth::{auth_service_server::AuthServiceServer, AuthServer}, desktop_client_mfa::ClientMfaServer, @@ -78,7 +78,6 @@ use proto::{core_request, proxy_client::ProxyClient, CoreError, CoreResponse}; // Helper struct used to handle gateway state // gateways are grouped by network type GatewayHostname = String; -#[derive(Debug)] pub struct GatewayMap(HashMap>); #[derive(Error, Debug)] @@ -88,9 +87,9 @@ pub enum GatewayMapError { #[error("Network {0} not found")] NetworkNotFound(i64), #[error("Gateway with UID {0} not found")] - UidNotFound(Uuid), + UidNotFound(Id), #[error("Cannot remove. Gateway with UID {0} is still active")] - RemoveActive(Uuid), + RemoveActive(Id), #[error("Config missing")] ConfigError, } @@ -104,44 +103,44 @@ impl GatewayMap { // add a new gateway to map // this method is meant to be called when a gateway requests a config // as a sort of "registration" - pub fn add_gateway( - &mut self, - network_id: Id, - network_name: &str, - hostname: String, - name: Option, - mail_tx: UnboundedSender, - ) { - info!("Adding gateway {hostname} with to gateway map for network {network_id}",); - let gateway_state = GatewayState::new(network_id, network_name, &hostname, name, mail_tx); - - if let Some(network_gateway_map) = self.0.get_mut(&network_id) { - network_gateway_map.entry(hostname).or_insert(gateway_state); - } else { - // no map for a given network exists yet - let mut network_gateway_map = HashMap::new(); - network_gateway_map.insert(hostname, gateway_state); - self.0.insert(network_id, network_gateway_map); - } - } + // pub fn add_gateway( + // &mut self, + // network_id: Id, + // network_name: &str, + // hostname: String, + // name: Option, + // mail_tx: UnboundedSender, + // ) { + // info!("Adding gateway {hostname} with to gateway map for network {network_id}",); + // let gateway_state = GatewayState::new(network_id, network_name, &hostname, name, mail_tx); + + // if let Some(network_gateway_map) = self.0.get_mut(&network_id) { + // network_gateway_map.entry(hostname).or_insert(gateway_state); + // } else { + // // no map for a given network exists yet + // let mut network_gateway_map = HashMap::new(); + // network_gateway_map.insert(hostname, gateway_state); + // self.0.insert(network_id, network_gateway_map); + // } + // } // remove gateway from map - pub fn remove_gateway(&mut self, network_id: Id, uid: Uuid) -> Result<(), GatewayMapError> { + pub fn remove_gateway(&mut self, network_id: Id, id: Id) -> Result<(), GatewayMapError> { debug!("Removing gateway from network {network_id}"); if let Some(network_gateway_map) = self.0.get_mut(&network_id) { // find gateway by uuid let hostname = match network_gateway_map .iter() - .find(|(_address, state)| state.uid == uid) + .find(|(_address, state)| state.id == id) { None => { - error!("Failed to find gateway with UID {uid}"); - return Err(GatewayMapError::UidNotFound(uid)); + error!("Failed to find gateway with ID {id}"); + return Err(GatewayMapError::UidNotFound(id)); } Some((hostname, state)) => { if state.connected { - error!("Cannot remove. Gateway with UID {uid} is still active"); - return Err(GatewayMapError::RemoveActive(uid)); + error!("Cannot remove. Gateway with UID {id} is still active"); + return Err(GatewayMapError::RemoveActive(id)); } hostname.clone() } @@ -153,7 +152,8 @@ impl GatewayMap { error!("Network {network_id} not found in gateway map"); return Err(GatewayMapError::NetworkNotFound(network_id)); }; - info!("Gateway with UID {uid} removed from network {network_id}"); + + info!("Gateway with UID {id} removed from network {network_id}"); Ok(()) } @@ -183,6 +183,7 @@ impl GatewayMap { error!("Network {network_id} not found in gateway map"); return Err(GatewayMapError::NetworkNotFound(network_id)); }; + info!("Gateway {hostname} connected in network {network_id}"); Ok(()) } @@ -252,9 +253,9 @@ impl Default for GatewayMap { } } -#[derive(Serialize, Clone, Debug)] -pub struct GatewayState { - pub uid: Uuid, +#[derive(Clone, Debug, Serialize)] +pub(crate) struct GatewayState { + pub id: Id, pub connected: bool, pub network_id: Id, pub network_name: String, @@ -269,27 +270,27 @@ pub struct GatewayState { } impl GatewayState { - #[must_use] - pub fn new>( - network_id: Id, - network_name: S, - hostname: S, - name: Option, - mail_tx: UnboundedSender, - ) -> Self { - Self { - uid: Uuid::new_v4(), - connected: false, - network_id, - network_name: network_name.into(), - name, - hostname: hostname.into(), - connected_at: None, - disconnected_at: None, - mail_tx, - last_email_notification: None, - } - } + // #[must_use] + // pub fn new>( + // network_id: Id, + // network_name: S, + // hostname: S, + // name: Option, + // mail_tx: UnboundedSender, + // ) -> Self { + // Self { + // uid: Uuid::new_v4(), + // connected: false, + // network_id, + // network_name: network_name.into(), + // name, + // hostname: hostname.into(), + // connected_at: None, + // disconnected_at: None, + // mail_tx, + // last_email_notification: None, + // } + // } /// Send gateway disconnected notification /// Sends notification only if last notification time is bigger than specified in config @@ -315,11 +316,11 @@ impl GatewayState { // FIXME: Try to get rid of spawn and use something like block_on // To return result instead of logging tokio::spawn(async move { - if let Err(e) = - send_gateway_disconnected_email(name, network_name, &hostname, &mail_tx, &pool) + if let Err(err) = + send_gateway_disconnected_email(name, &network_name, &hostname, &mail_tx, &pool) .await { - error!("Failed to send gateway disconnect notification: {e}"); + error!("Failed to send gateway disconnect notification: {err}"); } else { info!("Gateway {hostname} disconnected. Email notification sent",); } @@ -344,6 +345,28 @@ impl From for CoreError { } } +/// Bi-directional gRPC stream for comminication with Defguard proxy. +pub async fn run_grpc_gateway_stream(pool: PgPool) -> Result<(), anyhow::Error> { + // TODO: for each gateway... + let gateway_url = "http://localhost:50066"; + + let config = server_config(); + + let mut tasks = JoinSet::new(); + + tasks.spawn(async { + let gateway_client = + GatewayHandler::new(gateway_url, config.proxy_grpc_ca.as_deref()).unwrap(); + gateway_client.handle_connection().await; + }); + + while let Some(Ok(_result)) = tasks.join_next().await { + debug!("Gateway gRPC task has ended"); + } + + Ok(()) +} + /// Bi-directional gRPC stream for comminication with Defguard proxy. pub async fn run_grpc_bidi_stream( pool: PgPool, @@ -377,11 +400,11 @@ pub async fn run_grpc_bidi_stream( let uri = endpoint.uri(); loop { - debug!("Connecting to proxy at {uri}",); + debug!("Connecting to proxy at {uri}"); let mut client = ProxyClient::new(endpoint.connect_lazy()); let (tx, rx) = mpsc::unbounded_channel(); let Ok(response) = client.bidi(UnboundedReceiverStream::new(rx)).await else { - error!("Failed to connect to proxy @ {uri}, retrying in 10s",); + error!("Failed to connect to proxy @ {uri}, retrying in 10s"); sleep(TEN_SECS).await; continue; }; diff --git a/src/handlers/mail.rs b/src/handlers/mail.rs index f2cc35f90..6efa29e73 100644 --- a/src/handlers/mail.rs +++ b/src/handlers/mail.rs @@ -210,7 +210,7 @@ pub fn send_new_device_added_email( pub async fn send_gateway_disconnected_email( gateway_name: Option, - network_name: String, + network_name: &str, gateway_adress: &str, mail_tx: &UnboundedSender, pool: &PgPool, @@ -225,7 +225,7 @@ pub async fn send_gateway_disconnected_email( content: templates::gateway_disconnected_mail( &gateway_name, gateway_adress, - &network_name, + network_name, )?, attachments: Vec::new(), result_tx: None, diff --git a/src/handlers/wireguard.rs b/src/handlers/wireguard.rs index 07969d3e6..1e9fd0322 100644 --- a/src/handlers/wireguard.rs +++ b/src/handlers/wireguard.rs @@ -305,8 +305,9 @@ pub async fn gateway_status( }) } +// TODO: gateway_id should be enough; remove network_id. pub async fn remove_gateway( - Path((network_id, gateway_id)): Path<(i64, String)>, + Path((network_id, gateway_id)): Path<(i64, i64)>, _role: VpnRole, Extension(gateway_state): Extension>>, ) -> ApiResult { @@ -315,11 +316,7 @@ pub async fn remove_gateway( .lock() .expect("Failed to acquire gateway state lock"); - gateway_state.remove_gateway( - network_id, - Uuid::from_str(&gateway_id) - .map_err(|_| WebError::Http(StatusCode::INTERNAL_SERVER_ERROR))?, - )?; + gateway_state.remove_gateway(network_id, gateway_id)?; info!("Removed gateway {gateway_id} in network {network_id}");