Skip to content

Commit

Permalink
Handle peer stats
Browse files Browse the repository at this point in the history
  • Loading branch information
moubctez committed Oct 21, 2024
1 parent eea5f14 commit 769790a
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 134 deletions.
29 changes: 14 additions & 15 deletions src/auth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,22 @@ where

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let appstate = AppState::from_ref(state);
if let Ok(cookies) = CookieJar::from_request_parts(parts, state).await {
if let Some(session_cookie) = cookies.get(SESSION_COOKIE_NAME) {
return {
match Session::find_by_id(&appstate.pool, session_cookie.value()).await {
Ok(Some(session)) => {
if session.expired() {
let _result = session.delete(&appstate.pool).await;
Err(WebError::Authorization("Session expired".into()))
} else {
Ok(session)
}
let Ok(cookies) = CookieJar::from_request_parts(parts, state).await;
if let Some(session_cookie) = cookies.get(SESSION_COOKIE_NAME) {
return {
match Session::find_by_id(&appstate.pool, session_cookie.value()).await {
Ok(Some(session)) => {
if session.expired() {
let _result = session.delete(&appstate.pool).await;
Err(WebError::Authorization("Session expired".into()))
} else {
Ok(session)
}
Ok(None) => Err(WebError::Authorization("Session not found".into())),
Err(err) => Err(err.into()),
}
};
}
Ok(None) => Err(WebError::Authorization("Session not found".into())),
Err(err) => Err(err.into()),
}
};
}
Err(WebError::Authorization("Session is required".into()))
}
Expand Down
39 changes: 39 additions & 0 deletions src/db/models/wireguard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,45 @@ impl WireguardNetwork<Id> {
.and_then(|ep| Some(ep.split(':').next()?.to_owned()))
}

/// Get a list of all allowed peers
///
/// Each device is marked as allowed or not allowed in a given network,
/// which enables enforcing peer disconnect in MFA-protected networks.
pub async fn get_peers<'e, E>(&self, executor: E) -> Result<Vec<Peer>, SqlxError>
where
E: PgExecutor<'e>,
{
debug!("Fetching all peers for network {}", self.id);
let rows = query!(
"SELECT d.wireguard_pubkey pubkey, preshared_key, \
array[host(wnd.wireguard_ip)] \"allowed_ips!: Vec<String>\" \
FROM wireguard_network_device wnd \
JOIN device d ON wnd.device_id = d.id \
JOIN \"user\" u ON d.user_id = u.id \
WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \
AND u.is_active = true \
ORDER BY d.id ASC",
self.id,
self.mfa_enabled
)
.fetch_all(executor)
.await?;

// keepalive has to be added manually because Postgres
// doesn't support unsigned integers
let result = rows
.into_iter()
.map(|row| Peer {
pubkey: row.pubkey,
allowed_ips: row.allowed_ips,
preshared_key: row.preshared_key,
keepalive_interval: Some(self.keepalive_interval as u32),
})
.collect();

Ok(result)
}

/// Update `connected_at` to the current time and save it to the database.
pub(crate) async fn touch_connected_at<'e, E>(&mut self, executor: E) -> Result<(), SqlxError>
where
Expand Down
159 changes: 40 additions & 119 deletions src/grpc/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,47 +44,6 @@ pub struct GatewayServer {
mail_tx: UnboundedSender<Mail>,
}

impl WireguardNetwork<Id> {
/// Get a list of all allowed peers
///
/// Each device is marked as allowed or not allowed in a given network,
/// which enables enforcing peer disconnect in MFA-protected networks.
pub async fn get_peers<'e, E>(&self, executor: E) -> Result<Vec<Peer>, SqlxError>
where
E: PgExecutor<'e>,
{
debug!("Fetching all peers for network {}", self.id);
let rows = query!(
"SELECT d.wireguard_pubkey pubkey, preshared_key, \
array[host(wnd.wireguard_ip)] \"allowed_ips!: Vec<String>\" \
FROM wireguard_network_device wnd \
JOIN device d ON wnd.device_id = d.id \
JOIN \"user\" u ON d.user_id = u.id \
WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \
AND u.is_active = true \
ORDER BY d.id ASC",
self.id,
self.mfa_enabled
)
.fetch_all(executor)
.await?;

// keepalive has to be added manually because Postgres
// doesn't support unsigned integers
let result = rows
.into_iter()
.map(|row| Peer {
pubkey: row.pubkey,
allowed_ips: row.allowed_ips,
preshared_key: row.preshared_key,
keepalive_interval: Some(self.keepalive_interval as u32),
})
.collect();

Ok(result)
}
}

impl GatewayServer {
/// Create new gateway server instance
#[must_use]
Expand Down Expand Up @@ -179,6 +138,7 @@ impl WireguardPeerStats {
// TODO: merge with super.
const TEN_SECS: Duration = Duration::from_secs(10);

/// One instance per connected gateway.
pub(super) struct GatewayHandler {
endpoint: Endpoint,
message_id: AtomicU64,
Expand Down Expand Up @@ -214,8 +174,8 @@ impl GatewayHandler {
})
}

async fn config(&self) -> Result<Configuration, Status> {
debug!("Sending configuration to gateway client.");
async fn send_configuration(&self, tx: &UnboundedSender<CoreResponse>) -> Result<(), Status> {
debug!("Sending configuration to gateway.");
let network_id = self.network_id;
// let hostname = Self::get_gateway_hostname(request.metadata())?;

Expand Down Expand Up @@ -247,7 +207,7 @@ impl GatewayHandler {
// }

if let Err(err) = network.touch_connected_at(&self.pool).await {
error!("Failed to update network {network_id} in the database, status: {err}");
error!("Failed to update connected at for network {network_id} in the database, status: {err}");
}

let peers = network.get_peers(&self.pool).await.map_err(|error| {
Expand All @@ -258,37 +218,31 @@ impl GatewayHandler {
)
})?;

let payload = Some(core_response::Payload::Config(gen_config(&network, peers)));
let id = self.message_id.fetch_add(1, Ordering::Relaxed);
let req = CoreResponse { id, payload };
tx.send(req).unwrap();
info!("Configuration sent to gateway client, network {network}.");

Ok(gen_config(&network, peers))
Ok(())
}

pub(super) async fn handle_connection(&self) -> ! {
let uri = self.endpoint.uri();
loop {
debug!("Connecting to gateway at {uri}");
debug!("Connecting to gateway {uri}");
let mut client = gateway_client::GatewayClient::new(self.endpoint.connect_lazy());
let (tx, rx) = mpsc::unbounded_channel();
let Ok(response) = client.bidi(UnboundedReceiverStream::new(rx)).await else {
error!("Failed to connect to gateway at {uri}, retrying in 10s",);
error!("Failed to connect to gateway {uri}, retrying in 10s",);
sleep(TEN_SECS).await;
continue;
};
info!("Connected to gateway at {uri}");
info!("Connected to gateway {uri}");
let mut resp_stream = response.into_inner();

debug!("Sending configuration to gateway at {uri}");
match self.config().await {
Ok(config) => {
let payload = Some(core_response::Payload::Config(config));
let id = self.message_id.fetch_add(1, Ordering::Relaxed);
let req = CoreResponse { id, payload };
tx.send(req).unwrap();
}
Err(err) => {
error!("Failed to obtain configuration");
}
}
// TODO: probably fail on error
let _ = self.send_configuration(&tx).await;

'message: loop {
match resp_stream.message().await {
Expand All @@ -298,18 +252,38 @@ impl GatewayHandler {
}
Ok(Some(received)) => {
info!("Received message from gateway.");
debug!("Received the following message from gateway: {received:?}");
let payload: Option<i64> = match received.payload {
debug!("Message from gateway {uri}: {received:?}");
match received.payload {
Some(core_request::Payload::ConfigRequest(config_request)) => {
info!("*** ConfigurationRequest {config_request:?}");
None
}
Some(core_request::Payload::PeerStats(peer_stats)) => {
info!("*** PeerStats {peer_stats:?}");
None

let public_key = peer_stats.public_key.clone();
let mut stats = WireguardPeerStats::from_peer_stats(
peer_stats,
self.network_id,
);
// Get device by public key and fill in stats.device_id
// FIXME: keep an in-memory device map to avoid repeated DB requests
match Device::find_by_pubkey(&self.pool, &public_key).await {
Ok(Some(device)) => {
stats.device_id = device.id;
match stats.save(&self.pool).await {
Ok(_) => info!("Saved WireGuard peer stats to database."),
Err(err) => error!("Failed to save WireGuard peer stats to database: {err}"),
}
}
Ok(None) => {
error!("Device with public key {public_key} not found");
}
Err(err) => {
error!("Failed to retrieve device with public key {public_key}: {err}",);
}
};
}
// Reply without payload.
None => None,
None => (),
};
}
Err(err) => {
Expand Down Expand Up @@ -652,58 +626,6 @@ impl Drop for GatewayUpdatesStream {
// impl gateway_service_server::GatewayService for GatewayServer {
// type UpdatesStream = GatewayUpdatesStream;

// /// Retrieve stats from gateway and save it to database
// async fn stats(
// &self,
// request: Request<tonic::Streaming<StatsUpdate>>,
// ) -> Result<Response<()>, Status> {
// let network_id = Self::get_network_id(request.metadata())?;
// let mut stream = request.into_inner();
// while let Some(stats_update) = stream.message().await? {
// debug!("Received stats message: {stats_update:?}");
// let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else {
// debug!("Received stats message is empty, skipping.");
// continue;
// };
// let public_key = peer_stats.public_key.clone();
// let mut stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id);
// // Get device by public key and fill in stats.device_id
// // FIXME: keep an in-memory device map to avoid repeated DB requests
// stats.device_id = match Device::find_by_pubkey(&self.pool, &public_key).await {
// Ok(Some(device)) => device.id,
// Ok(None) => {
// error!("Device with public key {public_key} not found");
// return Err(Status::new(
// Code::Internal,
// format!("Device with public key {public_key} not found"),
// ));
// }
// Err(err) => {
// error!("Failed to retrieve device with public key {public_key}: {err}",);
// return Err(Status::new(
// Code::Internal,
// format!("Failed to retrieve device with public key {public_key}: {err}",),
// ));
// }
// };
// // Save stats to db
// let stats = match stats.save(&self.pool).await {
// Ok(stats) => stats,
// Err(err) => {
// error!("Saving WireGuard peer stats to db failed: {err}");
// return Err(Status::new(
// Code::Internal,
// format!("Saving WireGuard peer stats to db failed: {err}"),
// ));
// }
// };
// info!("Saved WireGuard peer stats to db.");
// debug!("WireGuard peer stats: {stats:?}");
// }

// Ok(Response::new(()))
// }

// async fn updates(&self, request: Request<()>) -> Result<Response<Self::UpdatesStream>, Status> {
// let gateway_network_id = Self::get_network_id(request.metadata())?;
// let hostname = Self::get_gateway_hostname(request.metadata())?;
Expand All @@ -725,7 +647,6 @@ impl Drop for GatewayUpdatesStream {
// };

// info!("New client connected to updates stream: {hostname}, network {network}",);

// let (tx, rx) = mpsc::channel(4);
// let events_rx = self.wireguard_tx.subscribe();
// let mut state = self.state.lock().unwrap();
Expand Down

0 comments on commit 769790a

Please sign in to comment.