Skip to content

Commit

Permalink
WIP: gateway RPC client
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Oct 21, 2024
1 parent a374e31 commit 7f9906a
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 82 deletions.
2 changes: 1 addition & 1 deletion migrations/20241015074303_network_gateways.down.sql
Original file line number Diff line number Diff line change
@@ -1 +1 @@
ALTER TABLE wireguard_network DROP COLUMN gateways;
DROP TABLE gateway;
6 changes: 5 additions & 1 deletion migrations/20241015074303_network_gateways.up.sql
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
ALTER TABLE wireguard_network ADD COLUMN gateways text[] NOT NULL DEFAULT array[]::text[];
CREATE TABLE gateway (
id bigserial PRIMARY KEY,
network_id bigint NOT NULL,
FOREIGN KEY(network_id) REFERENCES wireguard_network(id)
);
5 changes: 4 additions & 1 deletion src/bin/defguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use defguard::{
config::{Command, DefGuardConfig},
db::{init_db, AppEvent, GatewayEvent, Settings, User},
enterprise::license::{run_periodic_license_check, set_cached_license, License},
grpc::{run_grpc_bidi_stream, run_grpc_server, GatewayMap, WorkerState},
grpc::{
run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, GatewayMap, WorkerState,
},
headers::create_user_agent_parser,
init_dev_env, init_vpn_location,
mail::{run_mail_handler, Mail},
Expand Down Expand Up @@ -114,6 +116,7 @@ async fn main() -> Result<(), anyhow::Error> {

// run services
tokio::select! {
res = run_grpc_gateway_stream(pool.clone()) => error!("Gateway gRPC stream returned early: {res:#?}"),
res = run_grpc_bidi_stream(pool.clone(), wireguard_tx.clone(), mail_tx.clone(), user_agent_parser.clone()), if config.proxy_url.is_some() => error!("Proxy gRPC stream returned early: {res:#?}"),
res = run_grpc_server(Arc::clone(&worker_state), pool.clone(), Arc::clone(&gateway_map), wireguard_tx.clone(), mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone()) => error!("gRPC server returned early: {res:#?}"),
res = run_web_server(worker_state, gateway_map, webhook_tx, webhook_rx, wireguard_tx.clone(), mail_tx, pool.clone(), user_agent_parser, failed_logins) => error!("Web server returned early: {res:#?}"),
Expand Down
11 changes: 3 additions & 8 deletions src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ pub struct WireguardNetwork<I = NoId> {
pub mfa_enabled: bool,
pub keepalive_interval: i32,
pub peer_disconnect_threshold: i32,
// URLs pointing to all gateways serving gRPC
#[model(ref)]
pub gateways: Vec<String>,
}

pub struct WireguardKey {
Expand Down Expand Up @@ -159,7 +156,6 @@ impl WireguardNetwork {
mfa_enabled,
keepalive_interval,
peer_disconnect_threshold,
gateways: Vec::new(),
}
}

Expand All @@ -182,7 +178,7 @@ impl WireguardNetwork<Id> {
let networks = query_as!(
Self,
"SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold \
FROM wireguard_network WHERE name = $1",
name
)
Expand All @@ -204,7 +200,7 @@ impl WireguardNetwork<Id> {
query_as!(
Self,
"SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, allowed_ips, \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold, gateways \
connected_at, mfa_enabled, keepalive_interval, peer_disconnect_threshold \
FROM wireguard_network WHERE mfa_enabled = true",
)
.fetch_all(executor)
Expand Down Expand Up @@ -959,12 +955,11 @@ impl Default for WireguardNetwork {
mfa_enabled: false,
keepalive_interval: DEFAULT_KEEPALIVE_INTERVAL,
peer_disconnect_threshold: DEFAULT_DISCONNECT_THRESHOLD,
gateways: Vec::default(),
}
}
}

#[derive(Serialize, Clone, Debug)]
#[derive(Debug, Serialize)]
pub struct WireguardNetworkInfo {
#[serde(flatten)]
pub network: WireguardNetwork<Id>,
Expand Down
84 changes: 82 additions & 2 deletions src/grpc/gateway.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::{
fs::read_to_string,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};

use chrono::{DateTime, Utc};
Expand All @@ -12,9 +14,14 @@ use tokio::{
mpsc::{self, Receiver, UnboundedSender},
},
task::JoinHandle,
time::sleep,
};
use tokio_stream::{wrappers::UnboundedReceiverStream, Stream};
use tonic::{
metadata::MetadataMap,
transport::{Certificate, ClientTlsConfig, Endpoint, Identity, Server},
Code, Request, Response, Status,
};
use tokio_stream::Stream;
use tonic::{metadata::MetadataMap, Code, Request, Response, Status};

use super::GatewayMap;
use crate::{
Expand Down Expand Up @@ -166,6 +173,79 @@ impl WireguardPeerStats {
}
}

// TODO: merge with super.
const TEN_SECS: Duration = Duration::from_secs(10);

pub(super) struct GatewayHandler {
endpoint: Endpoint,
}

impl GatewayHandler {
pub(super) fn new(url: &str, ca_path: Option<&str>) -> Result<Self, tonic::transport::Error> {
let endpoint = Endpoint::from_shared(url.to_string())?;
let endpoint = endpoint
.http2_keep_alive_interval(TEN_SECS)
.tcp_keepalive(Some(TEN_SECS))
.keep_alive_while_idle(true);
let endpoint = if let Some(ca) = ca_path {
let ca = read_to_string(ca).unwrap(); // FIXME: use custom error
let tls = ClientTlsConfig::new().ca_certificate(Certificate::from_pem(ca));
endpoint.tls_config(tls)?
} else {
endpoint
};

Ok(Self { endpoint })
}

pub(super) async fn handle_connection(&self) -> ! {
let uri = self.endpoint.uri();
loop {
debug!("Connecting to gateway at {uri}");
let mut client = gateway_client::GatewayClient::new(self.endpoint.connect_lazy());
let (tx, rx) = mpsc::unbounded_channel();
let Ok(response) = client.bidi(UnboundedReceiverStream::new(rx)).await else {
error!("Failed to connect to gateway @ {uri}, retrying in 10s",);
sleep(TEN_SECS).await;
continue;
};
info!("Connected to proxy at {uri}");
let mut resp_stream = response.into_inner();
'message: loop {
match resp_stream.message().await {
Ok(None) => {
info!("stream was closed by the sender");
break 'message;
}
Ok(Some(received)) => {
info!("Received message from gateway.");
debug!("Received the following message from gateway: {received:?}");
let payload: Option<i64> = match received.payload {
Some(core_request::Payload::ConfigRequest(config_request)) => {
info!("*** ConfigurationRequest {config_request:?}");
None
}
Some(core_request::Payload::PeerStats(peer_stats)) => {
info!("*** PeerStats {peer_stats:?}");
None
}
// Reply without payload.
None => None,
};
}
Err(err) => {
error!("Disconnected from gateway at {uri}");
error!("stream error: {err}");
debug!("waiting 10s to re-establish the connection");
sleep(TEN_SECS).await;
break 'message;
}
}
}
}
}
}

/// Helper struct for handling gateway events
struct GatewayUpdatesHandler {
network_id: Id,
Expand Down
Loading

0 comments on commit 7f9906a

Please sign in to comment.